`get_variable()`不识别tf.estimator的现有变量

Gab*_*Chu 6 tensorflow tensorflow-estimator

这里有人提出这个问题,区别在于我的问题是关注的Estimator.

一些上下文:我们使用估计器训练了一个模型,并在Estimator中定义了一些变量input_fn,该函数将数据预处理到批处理.现在,我们正在进行预测.在预测期间,我们使用相同的input_fn方法读入和处理数据.但得到错误说变量(word_embeddings)不存在(变量存在于chkp图中),这里是相关的代码位input_fn:

with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
    if mode == tf.estimator.ModeKeys.TRAIN:
        word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
        word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
                                          trainable=False,
                                          name="word_to_vec",
                                          dtype=tf.float32)
    else:
        word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)

基本上,当它处于预测模式时,else调用它来加载检查点中的变量.未能识别此变量表示a)范围的不当使用; b)图表未恢复.只要reuse设置得当,我认为范围不重要.

我怀疑这是因为图表尚未恢复input_fn阶段.通常,通过调用saver.restore(sess, "/tmp/model.ckpt") 引用来恢复图形.对估算器源代码的调查并没有得到任何与恢复有关的内容,最好的镜头是MonitoredSession,一个训练的包装器.它已经从最初的问题中伸展出来了,如果我走在正确的道路上,我就没有信心,如果有人有任何见解,我在这里寻求帮助.

我的问题的一行摘要:图表是如何在内部tf.estimator,通过input_fn或恢复的model_fn

abc*_*ire 1

您好,我认为您出现错误只是因为您没有在 tf.get_variable (在预测时)中指定形状,即使要恢复变量,您似乎也需要指定形状。

我使用简单的线性回归估计器进行了以下测试,只需预测 x + 5

def input_fn(mode):
    def _input_fn():
        with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
            if mode == tf.estimator.ModeKeys.TRAIN:
                var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
                x_data = np.random.randn(1000)
                labels = x_data + 5
                return {'x':x_data}, labels
            elif mode == tf.estimator.ModeKeys.PREDICT:
                var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
                return {'x':[0,10,100,var_to_follow]}
    return _input_fn

featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')
Run Code Online (Sandbox Code Playgroud)

这段代码工作得很好,const 的值为 20,并且为了好玩,在我的测试集中使用它来确认:p

但是,如果删除 shape=[] ,它就会中断,您还可以提供另一个初始值设定项,例如 tf.constant(500) ,一切都会正常工作,并且将使用 20 。

通过跑步

model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)
Run Code Online (Sandbox Code Playgroud)

preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))
Run Code Online (Sandbox Code Playgroud)

您可以可视化该图表,您将看到 a) 范围正常,b) 图表已恢复。

希望对你有帮助。