在body中创建变量时正确使用tf.while_loop

Ale*_*lia 6 python tensorflow

我在Tensorflow中使用了一个while_loop,以迭代张量并提取给定维度上的特定切片.对于每个步骤,我需要使用解码器RNN来生成输出符号序列.我正在使用tf.contrib.seq2seq中提供的代码,特别是tf.contrib.seq2seq.dynamic_decode.代码类似于以下内容:

def decoder_condition(i, data, source_seq_len, ta_outputs):
    return tf.less(i, max_loop_len)

def decode_body(i, data, source_seq_len, ta_outputs):
    curr_data = data[:, i, :]
    curr_source_seq_len = source_seq_len[:, i, :]
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        2 * self.opt["encoder_rnn_h_size"],
        curr_data,
        memory_sequence_length=curr_source_seq_len
    )
    cell = GRUCell(num_units)
    cell = AttentionWrapper(cell, attention_mechanism)
    # ... other code that initialises all the variables required
    # for the RNN decoder
    outputs = tf.contrib.seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=self.opt["max_sys_seq_len"],
        swap_memory=True
    )
    with tf.control_dependencies([outputs)]:
        ta_outputs = ta_outputs.write(i, outputs)

    return i+1, data, ta_outputs

 loop_index = tf.constant(0)
 gen_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
 outputs = tf.while_loop(
      decoder_condition,
      decoder_body,
      loop_vars=[
          loop_index,
          data,
          data_source_len,
          ta_outputs
      ],
      swap_memory=True,
      back_prop=True, 
      parallel_iterations=1
)
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,我创建了不同的对象,这些对象特别依赖于当前步骤i的输入.我正在tf.AUTO_REUSE我当前的变量范围中使用,即使我正在创建不同的对象,也会重用变量.不幸的是,我的解码器似乎没有正确训练,因为它不断产生不正确的值.我已经检查了解码器RNN的输入数据,一切都正确.我怀疑在TensorFlow如何管理TensorArray和while_loop方面,我做得不好.

所以我的主要问题是:

  1. TensorFlow是否正确传播了在while循环中创建的每个变量的渐变?
  2. 是否有可能在while循环中创建依赖于使用循环索引获得的Tensor的特定切片的对象?
  3. backprop参数是否保证在训练期间传播梯度?推理期间应该设置为False吗?
  4. 一般来说,是否有任何健全性检查,我可以用来发现我的实施中可能出现的错误?

谢谢!

更新:不确定为什么,但似乎存在一个与此有关的问题,这与在while循环中调用自定义操作的可能性有关,如下所述:https://github.com/tensorflow/tensorflow/issues/13616.不幸的是,我不知道TensorFlow的内部是否足以判断它是否与此完全相关.

更新2:我解决了使用PyTorch :)

Ale*_*sos 0

(1) 是

(2) 是的,只需使用循环索引对张量进行切片

(3) 在普通用例中不需要设置 backprop=False

(4) 使用 ML 模型的常见操作(玩具数据集、单独测试部件等)

重新更新2,尝试使用eagerexecution或tf.contrib.autograph;两者都应该让你用普通的 python 编写 while 循环。