相关疑难解决方法(0)

Tensorflow 保存子类模型,该模型具有 call() 方法的多个参数

我正在关注 tensorflow 神经机器翻译教程:https ://www.tensorflow.org/tutorials/text/nmt_with_attention

我正在尝试保存作为 tf.keras.Model 子类的 Encoder 和 Decoder 模型,并在训练和推理期间正常工作,但是我想保存模型。当我尝试这样做时,我收到以下错误:

TypeError: call() missing 1 required positional argument: 'initial_state'
Run Code Online (Sandbox Code Playgroud)

这是代码:

class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_matrix, n_units, batch_size):
        super(Encoder, self).__init__()
        self.n_units = n_units
        self.batch_size = batch_size

        self.embedding = Embedding(vocab_size, embedding_matrix.shape[1], weights=[embedding_matrix], trainable=True, mask_zero=True)
        self.lstm = LSTM(n_units, return_sequences=True, return_state=True, recurrent_initializer="glorot_uniform")

    def call(self, input_utterence, initial_state):
        input_embed = self.embedding(input_utterence)
        encoder_states, h1, c1 = self.lstm(input_embed, initial_state=initial_state)
        return encoder_states, h1, c1

    def create_initial_state(self):
        return tf.zeros((self.batch_size, self.n_units))

encoder = Encoder(vocab_size, embedding_matrix, LSTM_DIM, BATCH_SIZE)
# do …
Run Code Online (Sandbox Code Playgroud)

python machine-learning keras tensorflow tf.keras

3
推荐指数
1
解决办法
896
查看次数

标签 统计

keras ×1

machine-learning ×1

python ×1

tensorflow ×1

tf.keras ×1