Tensorflow:如何通过 tf.gather 传播梯度?

use*_*261 5 tensorflow

我在尝试针对表示聚集索引的变量传播我的损失函数的梯度时遇到了一些问题,类似于在空间转换器网络中所做的工作(https://github.com/tensorflow/models/blob/ master/transformer/spatial_transformer.py)。我觉得我可能遗漏了一些非常简单的东西。这是我想要做的一个简化的玩具示例:

import tensorflow as tf
import numpy as np
lf = np.array([1.0,2.0,3.0])
lf_b = 2.0
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=(3))
pt = tf.Variable(0, name='point')
y_ = tf.placeholder(tf.float32, shape=())
sess.run(tf.initialize_all_variables())
y = tf.gather(x, pt)
data_loss = tf.reduce_mean(tf.squared_difference(y,y_))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(data_loss)
Run Code Online (Sandbox Code Playgroud)

目前,这会返回错误:ValueError: No gradients provided for any variable

cha*_*255 0

问题是你无法区分,因为 pt 是整数。它在 x 占位符中选择一个索引,因此它没有导数。通常,当您执行此操作时,您将输入一个整数并使用它来选择浮点值。你正在以相反的方式做这件事。