如何在Tensorflow中进行切片分配

use*_*700 21 python-2.7 tensorflow

我发现Tensorflow提供了scatter_update()为0维度中的张量切片赋值.例如,如果张量T是三维的,我可以赋值v[1, :, :]T[i, :, :].

a = tf.Variable(tf.zeros([10,36,36]))   
value = np.ones([1,36,36])   
d = tf.scatter_update(a,[0],value)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print a.eval()
    sess.run(d)
    print a.eval()
Run Code Online (Sandbox Code Playgroud)

但是如何赋值v[1,1,:]T[i,j,:]

a = tf.Variable(tf.zeros([10,36,36]))   
value1 = np.random.randn(1,1,36)    
e = tf.scatter_update(a,[0],value1) #Error

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print a.eval()
    sess.rum(e)
    print a.eval()
Run Code Online (Sandbox Code Playgroud)

是否有TF提供的其他功能或简单的方法?

jde*_*esa 36

目前,您可以在TensorFlow中为变量执行切片分配.它没有特定的命名功能,但您可以选择一个切片并调用assign它:

my_var = my_var[4:8].assign(tf.zeros(4))
Run Code Online (Sandbox Code Playgroud)

首先,请注意(在查看文档之后)assign,即使应用于切片,返回值似乎始终是应用更新后对整个变量的引用.

编辑:以下信息是弃用,不精确或总是错误的.事实是返回的值assign是一个可以很容易使用的张量,并且已经将依赖项合并到赋值中,因此简单地评估或在进一步的操作中使用它将确保它在不需要显式tf.control_dependencies块的情况下执行.


另请注意,这只会将赋值操作添加到图形中,但除非显式执行或设置为某些其他操作的依赖项,否则不会运行它.一个好的做法是在tf.control_dependencies上下文中使用它:

with tf.control_dependencies([my_var[4:8].assign(tf.zeros(4))]):
    my_var = tf.identity(my_var)
Run Code Online (Sandbox Code Playgroud)

您可以在TensorFlow问题#4638中阅读更多相关信息.


Sor*_*vux 9

我相信你所需要的是票#206中assign_slice_update讨论的.但它尚不可用.

更新:现在已经实现了.请参阅jdehesa的回答:https://stackoverflow.com/a/43139565/6531137


assign_slice_update(或scatter_nd())可用之前,您可以构建所需行的块,其中包含您不想修改的值以及要更新的所需值,如下所示:

import tensorflow as tf

a = tf.Variable(tf.ones([10,36,36]))

i = 3
j = 5

# Gather values inside the a[i,...] block that are not on column j
idx_before = tf.concat(1, [tf.reshape(tf.tile(tf.Variable([i]), [j]), [-1, 1]), tf.reshape(tf.range(j), [-1, 1])])
values_before = tf.gather_nd(a, idx_before)
idx_after = tf.concat(1, [tf.reshape(tf.tile(tf.Variable([i]), [36-j-1]), [-1, 1]), tf.reshape(tf.range(j+1, 36), [-1, 1])])
values_after = tf.gather_nd(a, idx_after)

# Build a subset of tensor `a` with the values that should not be touched and the values to update
block = tf.concat(0, [values_before, 5*tf.ones([1, 36]), values_after])

d = tf.scatter_update(a, i, block)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    sess.run(d)
    print(a.eval()[3,4:7,:]) # Print a subset of the tensor to verify
Run Code Online (Sandbox Code Playgroud)

该示例生成一个张量并执行a[i,j,:] = 5.大多数复杂性在于获取我们不想修改的值a[i,~j,:](否则scatter_update()将替换这些值).

如果要T[i,k,:] = a[1,1,:]按照要求执行,则需要5*tf.ones([1, 36])在前面的示例中替换tf.gather_nd(a, [[1, 1]]).

另一种方法是从中创建tf.select()所需元素的掩码并将其分配回变量,如下所示:

import tensorflow as tf

a = tf.Variable(tf.zeros([10,36,36]))

i = tf.Variable([3])
j = tf.Variable([5])

# Build a mask using indices to perform [i,j,:]
atleast_2d = lambda x: tf.reshape(x, [-1, 1])
indices = tf.concat(1, [atleast_2d(tf.tile(i, [36])), atleast_2d(tf.tile(j, [36])), atleast_2d(tf.range(36))])
mask = tf.cast(tf.sparse_to_dense(indices, [10, 36, 36], 1), tf.bool)

to_update = 5*tf.ones_like(a)
out = a.assign( tf.select(mask, to_update, a) ) 

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    sess.run(out)
    print(a.eval()[2:5,5,:])
Run Code Online (Sandbox Code Playgroud)

它在内存方面可能效率较低,因为它需要两倍的内存来处理a类似的to_update变量,但您可以轻松地修改最后一个示例以从tf.select(...)节点获得梯度保留操作.您可能还有兴趣查看其他StackOverflow问题:TensorFlow中的张量值的条件分配.

那些不雅的扭曲应该被替换为对TensorFlow功能的调用,因为它变得可用.