Gab*_*Chu 6 tensorflow tensorflow-estimator
这里有人提出这个问题,区别在于我的问题是关注的Estimator
.
一些上下文:我们使用估计器训练了一个模型,并在Estimator中定义了一些变量input_fn
,该函数将数据预处理到批处理.现在,我们正在进行预测.在预测期间,我们使用相同的input_fn
方法读入和处理数据.但得到错误说变量(word_embeddings)不存在(变量存在于chkp图中),这里是相关的代码位input_fn
:
with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
if mode == tf.estimator.ModeKeys.TRAIN:
word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
trainable=False,
name="word_to_vec",
dtype=tf.float32)
else:
word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)
基本上,当它处于预测模式时,else
调用它来加载检查点中的变量.未能识别此变量表示a)范围的不当使用; b)图表未恢复.只要reuse
设置得当,我认为范围不重要.
我怀疑这是因为图表尚未恢复input_fn
阶段.通常,通过调用saver.restore(sess, "/tmp/model.ckpt")
引用来恢复图形.对估算器源代码的调查并没有得到任何与恢复有关的内容,最好的镜头是MonitoredSession,一个训练的包装器.它已经从最初的问题中伸展出来了,如果我走在正确的道路上,我就没有信心,如果有人有任何见解,我在这里寻求帮助.
我的问题的一行摘要:图表是如何在内部tf.estimator
,通过input_fn
或恢复的model_fn
?
您好,我认为您出现错误只是因为您没有在 tf.get_variable (在预测时)中指定形状,即使要恢复变量,您似乎也需要指定形状。
我使用简单的线性回归估计器进行了以下测试,只需预测 x + 5
def input_fn(mode):
def _input_fn():
with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
if mode == tf.estimator.ModeKeys.TRAIN:
var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
x_data = np.random.randn(1000)
labels = x_data + 5
return {'x':x_data}, labels
elif mode == tf.estimator.ModeKeys.PREDICT:
var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
return {'x':[0,10,100,var_to_follow]}
return _input_fn
featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')
Run Code Online (Sandbox Code Playgroud)
这段代码工作得很好,const 的值为 20,并且为了好玩,在我的测试集中使用它来确认:p
但是,如果删除 shape=[] ,它就会中断,您还可以提供另一个初始值设定项,例如 tf.constant(500) ,一切都会正常工作,并且将使用 20 。
通过跑步
model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)
Run Code Online (Sandbox Code Playgroud)
和
preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))
Run Code Online (Sandbox Code Playgroud)
您可以可视化该图表,您将看到 a) 范围正常,b) 图表已恢复。
希望对你有帮助。
归档时间: |
|
查看次数: |
224 次 |
最近记录: |