对于转移学习,通常使用网络作为特征提取器来创建特征的数据集,在该数据集上训练另一个分类器(例如,SVM).
我想使用Dataset API(tf.contrib.data)和实现它dataset.map():
# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
dataset = inputs(...) # This creates a dataset of (image, label) pairs
def map_example(image, label):
features = feature_extractor(image, trainable=False)
# Leaving out initialization from a checkpoint here...
return features, label
dataset = dataset.map(map_example)
return dataset
Run Code Online (Sandbox Code Playgroud)
在为数据集创建迭代器时,执行此操作会失败.
ValueError: Cannot capture a stateful node by value.
Run Code Online (Sandbox Code Playgroud)
这是事实,网络的内核和偏见是变量,因此是有状态的.对于这个特殊的例子,他们不一定是这样.
有没有办法让Ops和特定tf.Variable对象无状态?
因为我正在使用tf.layers我不能简单地将它们创建为常量,并且设置trainable=False既不会创建常量,也不会将变量添加到GraphKeys.TRAINABLE_VARIABLES集合中.