mau*_*una 3 python machine-learning keras tensorflow
请考虑以下代码:
import tensorflow as tf
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.contrib.keras.api.keras.layers import Dense
def model_fn_1(features, labels, mode):
x = [[1]]
labels = [[10]]
m = tf.constant([[1, 2], [3, 4]], tf.float32)
lookup = tf.nn.embedding_lookup(m, x, name='embedding_lookup')
preds = Dense(1)(lookup)
loss = tf.reduce_mean(labels - preds)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step())
eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)}
return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
model_1 = tf.estimator.Estimator(model_fn_1)
model_1.train(input_fn=lambda: None, steps=1)
Run Code Online (Sandbox Code Playgroud)
正如预期的那样,我可以执行model_1.train(input_fn=lambda: None, steps=1)多次,并且训练将从之前的执行继续进行.
现在,请考虑以下代码:
import tensorflow as tf
import numpy as np
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.contrib.keras.api.keras.layers import Embedding, Dense
def model_fn_2(features, labels, mode):
x = tf.constant([[1]])
labels = [[10]]
m = np.array([[1, 2], [3, 4]])
m = Embedding(2, 2, weights=[m], input_length=1, name='embedding_lookup')
lookup = m(x)
preds = Dense(1)(lookup)
loss = tf.reduce_mean(labels - preds)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step())
eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)}
return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
model_2 = tf.estimator.Estimator(model_fn_2)
model_2.train(input_fn=lambda: None, steps=1)
Run Code Online (Sandbox Code Playgroud)
在这种情况下,我只能执行model_2.train(input_fn=lambda: None, steps=1)一次,当我再次尝试执行它时,我收到以下错误:
ValueError:Fetch参数不能解释为Tensor.(Tensor Tensor("embedding_lookup/embeddings:0",shape =(2,2),dtype = float32_ref)不是此图的元素.)
为什么会发生这种情况,我该如何解决?
这可能是tensorflow keras后端中的错误或不受支持的情况:会话全局缓存并且不会被清除.您可以通过以下方式手动清除它:
from tensorflow.contrib.keras.python.keras.backend import clear_session
clear_session()
Run Code Online (Sandbox Code Playgroud)
......在train调用之间.
原因很简单:第二次train调用使用新节点构建一个新图形,但是引擎盖下会话保存了上一个图形,这使得它们不兼容.
更新.在最新的tensorflow中,keras已被移动到另一个包中,现在它看起来更简单:
from keras.backend import clear_session
clear_session()
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2777 次 |
| 最近记录: |