将 tf1 中的代码转换为 tf2 时出错

Arg*_*rgs 6 python keras tensorflow attention-model tensorflow2.0

值在哪里

rnn_size: 512
batch_size: 128


rnn_inputs: Tensor("embedding_lookup/Identity_1:0", shape=(?, ?, 128), dtype=float32)
sequence_length: Tensor("inputs_length:0", shape=(?,), dtype=int32)
cell_fw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb6d0>
cell_bw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb910>
Run Code Online (Sandbox Code Playgroud)

获取 enc_state 值

enc_output, enc_state = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw, 
                                                                        cell_bw, 
                                                                        rnn_inputs,
                                                                        sequence_length,
                                                                        dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)

enc_state 值在哪里

enc_state: LSTMStateTuple(c=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(?, 512) dtype=float32>)
Run Code Online (Sandbox Code Playgroud)

TF1代码:

initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(enc_state,
                                                                _zero_state_tensors(rnn_size, 
                                                                                    batch_size, 
                                                                                    tf.float32))
Run Code Online (Sandbox Code Playgroud)

转换为 TF2

initial_state = tfa.seq2seq.AttentionWrapper(enc_state,_zero_state_tensors(rnn_size, batch_size, tf.float32))
Run Code Online (Sandbox Code Playgroud)

获取错误:


TypeError                                 Traceback (most recent call last)
<ipython-input-54-d87646b9df5d> in <module>()
      8                                                     threshold) 
      9             model = build_graph(keep_probability, rnn_size, num_layers, batch_size, 
---> 10                                 learning_rate, embedding_size, direction)
     11             train(model, epochs, log_string)

6 frames
/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py in check_type(argname, value, expected_type, memo)
    596                 raise TypeError(
    597                     'type of {} must be {}; got {} instead'.
--> 598                     format(argname, qualified_name(expected_type), qualified_name(value)))
    599     elif isinstance(expected_type, TypeVar):
    600         # Only happens on < 3.6

TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead
Run Code Online (Sandbox Code Playgroud)

您还可以解释错误的最后一行,即

    TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead
Run Code Online (Sandbox Code Playgroud)