Rob*_*cok 5 word2vec tensorflow google-cloud-ml-engine
我正在解决文本分类问题.我使用Estimator自己的类定义了我的分类器model_fn.我想使用Google的预训练word2vec嵌入作为初始值,然后针对手头的任务进一步优化它.
我看到这篇文章:在TensorFlow中使用预先训练的单词嵌入(word2vec或Glove),
它解释了如何在'原始'TensorFlow代码中进行处理.但是,我真的很喜欢Estimator上课.
作为扩展,我想在Cloud ML Engine上训练此代码,是否有一种传递具有初始值的相当大的文件的好方法?
假设我们有类似的东西:
def build_model_fn():
def _model_fn(features, labels, mode, params):
input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
#... what goes here to initialize W
embedded = tf.nn.embedding_lookup(W, input_layer)
...
return predictions
estimator = tf.contrib.learn.Estimator(
model_fn=build_model_fn(),
model_dir=MODEL_DIR,
params=params)
estimator.fit(input_fn=read_data, max_steps=2500)
Run Code Online (Sandbox Code Playgroud)
嵌入通常足够大,唯一可行的方法是使用它们来初始化tf.Variable图形中的a.这将允许您利用分布式等中的param服务器.
对于这个(以及其他任何东西),我建议你使用新的"核心"估算器,tf.estimator.Estimator因为这将使事情变得更容易.
从你提供的链接中的答案,并知道我们想要一个不是常数的变量,我们可以采取方法:
(2)使用feed dict初始化变量,或(3)从检查点加载变量
我将首先介绍选项(3),因为它更容易,更好:
在您的model_fn简单中,只需使用调用Tensor返回的变量初始化变量即可tf.contrib.framework.load_variable.这需要:
代码非常简单:
def model_fn(mode, features, labels, hparams):
embeddings = tf.Variable(tf.contrib.framework.load_variable(
'gs://my-bucket/word2vec_checkpoints/',
'a/fully/qualified/scope/embeddings'
))
....
return tf.estimator.EstimatorSpec(...)
Run Code Online (Sandbox Code Playgroud)
但是,如果嵌入不是由另一个TF模型生成的,那么这种方法对你不起作用,因此选项(2).
对于(2),我们需要使用tf.train.Scaffold它本质上是一个配置对象,它包含启动a的所有选项tf.Session(由于许多原因故意隐藏该估计器).
你可以Scaffold在tf.train.EstimatorSpec你的回归中指定一个model_fn.
我们创造我们的model_fn一个占位符,使之成为我们的嵌入变量初始化操作,然后传递一个init_feed_dict通过Scaffold.例如
def model_fn(mode, features, labels, hparams):
embed_ph = tf.placeholder(
shape=[hparams.vocab_size, hparams.embedding_size],
dtype=tf.float32)
embeddings = tf.Variable(embed_ph)
# Define your model
return tf.estimator.EstimatorSpec(
..., # normal EstimatorSpec args
scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
)
Run Code Online (Sandbox Code Playgroud)
这里发生的是init_feed_dict将embed_ph在运行时填充占位符的值,然后允许embeddings.initialization_op(占位符的赋值)运行.
| 归档时间: |
|
| 查看次数: |
1657 次 |
| 最近记录: |