加载预先训练的word2vec以在Estimator model_fn中初始化embedding_lookup

Rob*_*cok 5 word2vec tensorflow google-cloud-ml-engine

我正在解决文本分类问题.我使用Estimator自己的类定义了我的分类器model_fn.我想使用Google的预训练word2vec嵌入作为初始值,然后针对手头的任务进一步优化它.

我看到这篇文章:在TensorFlow中使用预先训练的单词嵌入(word2vec或Glove),
它解释了如何在'原始'TensorFlow代码中进行处理.但是,我真的很喜欢Estimator上课.

作为扩展,我想在Cloud ML Engine上训练此代码,是否有一种传递具有初始值的相当大的文件的好方法?

假设我们有类似的东西:

def build_model_fn():
    def _model_fn(features, labels, mode, params):
        input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
        #... what goes here to initialize W

        embedded = tf.nn.embedding_lookup(W, input_layer)
        ...
        return predictions

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(),
    model_dir=MODEL_DIR,
    params=params)
estimator.fit(input_fn=read_data, max_steps=2500)
Run Code Online (Sandbox Code Playgroud)

Eli*_*xby 9

嵌入通常足够大,唯一可行的方法是使用它们来初始化tf.Variable图形中的a.这将允许您利用分布式等中的param服务器.

对于这个(以及其他任何东西),我建议你使用新的"核心"估算器,tf.estimator.Estimator因为这将使事情变得更容易.

从你提供的链接中的答案,并知道我们想要一个不是常数的变量,我们可以采取方法:

(2)使用feed dict初始化变量,或(3)从检查点加载变量


我将首先介绍选项(3),因为它更容易,更好:

在您的model_fn简单中,只需使用调用Tensor返回的变量初始化变量即可tf.contrib.framework.load_variable.这需要:

  1. 您的嵌入有一个有效的TF检查点
  2. 您知道检查点中嵌入变量的完全限定名称.

代码非常简单:

def model_fn(mode, features, labels, hparams):
  embeddings = tf.Variable(tf.contrib.framework.load_variable(
      'gs://my-bucket/word2vec_checkpoints/',
      'a/fully/qualified/scope/embeddings'
  ))
  ....
  return tf.estimator.EstimatorSpec(...)
Run Code Online (Sandbox Code Playgroud)

但是,如果嵌入不是由另一个TF模型生成的,那么这种方法对你不起作用,因此选项(2).


对于(2),我们需要使用tf.train.Scaffold它本质上是一个配置对象,它包含启动a的所有选项tf.Session(由于许多原因故意隐藏该估计器).

你可以Scaffoldtf.train.EstimatorSpec你的回归中指定一个model_fn.

我们创造我们的model_fn一个占位符,使之成为我们的嵌入变量初始化操作,然后传递一个init_feed_dict通过Scaffold.例如

def model_fn(mode, features, labels, hparams):
  embed_ph = tf.placeholder(
      shape=[hparams.vocab_size, hparams.embedding_size], 
      dtype=tf.float32)
  embeddings = tf.Variable(embed_ph)
  # Define your model
  return tf.estimator.EstimatorSpec(
      ..., # normal EstimatorSpec args
      scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
  )
Run Code Online (Sandbox Code Playgroud)

这里发生的是init_feed_dictembed_ph在运行时填充占位符的值,然后允许embeddings.initialization_op(占位符的赋值)运行.