Sta*_*tal 2 tensorflow tensorflow-estimator
我试图在tensorflow模型中使用现有的嵌入,嵌入的大小大于2Gb,这使得我最初的尝试不成功:
embedding_var = tf.get_variable(
"embeddings",
shape=GLOVE_MATRIX.shape,
initializer=tf.constant_initializer(np.array(GLOVE_MATRIX))
)
Run Code Online (Sandbox Code Playgroud)
这给了我这个错误:
Cannot create a tensor proto whose content is larger than 2GB.
Run Code Online (Sandbox Code Playgroud)
我正在使用基于Estimator API的AWS SageMaker,并且会话中实际运行的图形发生在场景后面,因此我不确定如何初始化一些占位符以进行嵌入.如果有人能够在EstimatorAPI方面分享如何进行这种初始化的方式会很有帮助.
Oli*_*rot 12
如果指定initializer参数to tf.get_variable(),则初始值GLOVE_MATRIX将存储在图形中并超过2Gb.一个好的答案解释了如何一般地加载嵌入.
这是我们使用初始化程序的第一个例子,图形def大约是4Mb,因为它将(1000, 1000)矩阵存储在其中.
size = 1000
initial_value = np.random.randn(size, size)
x = tf.get_variable("x", [size, size], initializer=tf.constant_initializer(initial_value))
sess = tf.Session()
sess.run(x.initializer)
assert np.allclose(sess.run(x), initial_value)
graph = tf.get_default_graph()
print(graph.as_graph_def().ByteSize()) # should be 4000394
Run Code Online (Sandbox Code Playgroud)
这是一个更好的版本,我们不存储它:
size = 1000
initial_value = np.random.randn(size, size)
x = tf.get_variable("x", [size, size])
sess = tf.Session()
sess.run(x.initializer, {x.initial_value: initial_value})
assert np.allclose(sess.run(x), initial_value)
graph = tf.get_default_graph()
print(graph.as_graph_def().ByteSize()) # should be 1203
Run Code Online (Sandbox Code Playgroud)
对于Estimators,我们无法直接访问Session.初始化嵌入的方法可以是使用tf.train.Scaffold.您可以向其传递一个参数,init_fn在该参数中初始化嵌入变量,而不保存图形def中的实际值.
def model_fn(features, labels, mode):
size = 10
initial_value = np.random.randn(size, size).astype(np.float32)
x = tf.get_variable("x", [size, size])
def init_fn(scaffold, sess):
sess.run(x.initializer, {x.initial_value: initial_value})
scaffold = tf.train.Scaffold(init_fn=init_fn)
loss = ...
train_op = ...
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, scaffold=scaffold)
Run Code Online (Sandbox Code Playgroud)
使用内置脚手架的一个好处是,它只会在您第一次调用时初始化嵌入train_input_fn.对于将来的呼叫,它将不会再次运行init_fn.
| 归档时间: |
|
| 查看次数: |
1559 次 |
| 最近记录: |