TensorFlow:tf.contrib.data API中的"无法按值捕获有状态节点"

the*_*eck 7 tensorflow tensorflow-datasets

对于转移学习,通常使用网络作为特征提取器来创建特征的数据集,在该数据集上训练另一个分类器(例如,SVM).

我想使用Dataset API(tf.contrib.data)和实现它dataset.map():

# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
    dataset = inputs(...)  # This creates a dataset of (image, label) pairs

    def map_example(image, label):
        features = feature_extractor(image, trainable=False)
        #  Leaving out initialization from a checkpoint here... 
        return features, label

    dataset = dataset.map(map_example)

    return dataset
Run Code Online (Sandbox Code Playgroud)

在为数据集创建迭代器时,执行此操作会失败.

ValueError: Cannot capture a stateful node by value.
Run Code Online (Sandbox Code Playgroud)

这是事实,网络的内核和偏见是变量,因此是有状态的.对于这个特殊的例子,他们不一定是这样.

有没有办法让Ops和特定tf.Variable对象无状态?

因为我正在使用tf.layers我不能简单地将它们创建为常量,并且设置trainable=False既不会创建常量,也不会将变量添加到GraphKeys.TRAINABLE_VARIABLES集合中.

mrr*_*rry 14

不幸的是,tf.Variable本质上是有状态的.但是,只有在Dataset.make_one_shot_iterator()用于创建迭代器时才会出现此错误.*为了避免此问题,您可以使用Dataset.make_initializable_iterator()警告,在为输入管道中使用的对象运行初始化程序之后,还必须iterator.initializer在返回的迭代器上运行.tf.Variable


*此限制的原因是它用于封装数据集定义的实现细节Dataset.make_one_shot_iterator()和正在进行的TensorFlow函数(Defun)支持.由于使用查找表和变量等有状态资源比我们最初想象的更受欢迎,我们正在研究如何放松这种限制.