小编the*_*eck的帖子

TensorFlow:tf.contrib.data API中的"无法按值捕获有状态节点"

对于转移学习,通常使用网络作为特征提取器来创建特征的数据集,在该数据集上训练另一个分类器(例如,SVM).

我想使用Dataset API(tf.contrib.data)和实现它dataset.map():

# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
    dataset = inputs(...)  # This creates a dataset of (image, label) pairs

    def map_example(image, label):
        features = feature_extractor(image, trainable=False)
        #  Leaving out initialization from a checkpoint here... 
        return features, label

    dataset = dataset.map(map_example)

    return dataset
Run Code Online (Sandbox Code Playgroud)

在为数据集创建迭代器时,执行此操作会失败.

ValueError: Cannot capture a stateful node by value.
Run Code Online (Sandbox Code Playgroud)

这是事实,网络的内核和偏见是变量,因此是有状态的.对于这个特殊的例子,他们不一定是这样.

有没有办法让Ops和特定tf.Variable对象无状态?

因为我正在使用tf.layers我不能简单地将它们创建为常量,并且设置trainable=False既不会创建常量,也不会将变量添加到GraphKeys.TRAINABLE_VARIABLES集合中.

tensorflow tensorflow-datasets

7
推荐指数
1
解决办法
2787
查看次数

标签 统计

tensorflow ×1

tensorflow-datasets ×1