我在使用Tensorflow时还很陌生,这是我在谷歌搜索时发现的代码示例。我试图冻结该图,但是它说我需要输入正确的输出节点。由于我是新手,所以我很难理解它。如何在此代码中找到我的输出节点,还是需要冻结整个图?
import tensorflow as tf
import numpy as np
import sys
class Seq2Seq(object):
def __init__(self, xseq_len, yseq_len,
xvocab_size, yvocab_size,
emb_dim, num_layers, ckpt_path,
lr=0.0001,
epochs=200, model_name='seq2seq_model'):
# attach these arguments to self
self.xseq_len = xseq_len
self.yseq_len = yseq_len
self.ckpt_path = ckpt_path
self.epochs = epochs
self.model_name = model_name
# build thy graph
# attach any part of the graph that needs to be exposed, to the self
def __graph__():
# placeholders
tf.reset_default_graph()
# encoder inputs : list of indices of length xseq_len …Run Code Online (Sandbox Code Playgroud)