Arg*_*rgs 6 python keras tensorflow attention-model tensorflow2.0
值在哪里
rnn_size: 512
batch_size: 128
rnn_inputs: Tensor("embedding_lookup/Identity_1:0", shape=(?, ?, 128), dtype=float32)
sequence_length: Tensor("inputs_length:0", shape=(?,), dtype=int32)
cell_fw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb6d0>
cell_bw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb910>
Run Code Online (Sandbox Code Playgroud)
获取 enc_state 值
enc_output, enc_state = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length,
dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)
enc_state 值在哪里
enc_state: LSTMStateTuple(c=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(?, 512) dtype=float32>)
Run Code Online (Sandbox Code Playgroud)
TF1代码:
initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(enc_state,
_zero_state_tensors(rnn_size,
batch_size,
tf.float32))
Run Code Online (Sandbox Code Playgroud)
转换为 TF2
initial_state = tfa.seq2seq.AttentionWrapper(enc_state,_zero_state_tensors(rnn_size, batch_size, tf.float32))
Run Code Online (Sandbox Code Playgroud)
获取错误:
TypeError Traceback (most recent call last)
<ipython-input-54-d87646b9df5d> in <module>()
8 threshold)
9 model = build_graph(keep_probability, rnn_size, num_layers, batch_size,
---> 10 learning_rate, embedding_size, direction)
11 train(model, epochs, log_string)
6 frames
/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py in check_type(argname, value, expected_type, memo)
596 raise TypeError(
597 'type of {} must be {}; got {} instead'.
--> 598 format(argname, qualified_name(expected_type), qualified_name(value)))
599 elif isinstance(expected_type, TypeVar):
600 # Only happens on < 3.6
TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead
Run Code Online (Sandbox Code Playgroud)
您还可以解释错误的最后一行,即
TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead
Run Code Online (Sandbox Code Playgroud)