在推理过程中从自定义 Tensorflow/Keras 层中提取中间变量 (TF 2.0)

Mar*_*cin 5 python deep-learning keras tensorflow

一点背景:

我主要使用 Tensorflow 2.0 的 Keras 功能模型位实现了 NLP 分类模型。模型架构是一个非常简单的 LSTM 网络,在 LSTM 和密集输出层之间添加了一个注意力层。注意层来自这个 Kaggle 内核(从第 51 行开始)。

我将训练好的模型包装在一个简单的 Flask 应用程序中,并获得了相当准确的预测。除了预测特定输入的类别外,我还输出来自上述注意力层的注意力权重向量“a”的值,以便我可以可视化应用于输入序列的权重。

我目前提取注意力权重变量的方法有效,但似乎效率低得令人难以置信,因为我正在预测输出类,然后使用中间 Keras 模型手动计算注意力向量。在 Flask 应用程序中,推理看起来像这样:

# Load the trained model
model = tf.keras.models.load_model('saved_model.h5')

# Extract the trained weights and biases of the trained attention layer
attention_weights = model.get_layer('attention').get_weights()

# Create an intermediate model that outputs the activations of the LSTM layer
intermediate_model = tf.keras.Model(inputs=model.input, outputs=model.get_layer('bi-lstm').output)

# Predict the output class using the trained model
model_score = model.predict(input)

# Obtain LSTM activations by predicting the output again using the intermediate model
lstm_activations = intermediate_model.predict(input)

# Use the intermediate LSTM activations and the trained model attention layer weights and biases to calculate the attention vector.  
# Maths from the custom Attention Layer (heavily modified for the sake of brevity)
eij = tf.keras.backend.dot(lstm_activations, attention_weights)
a = tf.keras.backend.exp(eij)
attention_vector = a
Run Code Online (Sandbox Code Playgroud)

我想我应该能够将注意力向量作为模型输出的一部分,但我正在努力弄清楚如何实现这一点。理想情况下,我会在一次前向传递中从自定义注意力层中提取注意力向量,而不是提取各种中间模型值并再次计算。

例如:

model_score = model.predict(input)

model_score[0] # The predicted class label or probability
model_score[1] # The attention vector, a
Run Code Online (Sandbox Code Playgroud)

我想我缺少一些关于 Tensorflow/Keras 如何抛出变量以及何时/如何访问这些值以包含为模型输出的基本知识。任何意见,将不胜感激。

Mar*_*cin 4

经过更多研究后,我设法拼凑出一个可行的解决方案。我将在这里为任何未来看到这篇文章的疲惫的互联网旅行者进行总结。

第一个线索来自这个 github 线程。 那里定义的注意力层似乎建立在前面提到的 Kaggle 内核中的注意力层的基础上。github 用户return_attention向层 init 添加了一个标志,启用后,除了层输出中的加权 RNN 输出向量之外,还包含注意力向量。

我还在同一 github 线程中添加了该用户get_config建议的功能,使我们能够保存和重新加载经过训练的模型。我必须将标志添加到,否则 TF 在尝试使用 加载保存的模型时会抛出列表迭代错误。 return_attentionget_configreturn_attention=True

进行这些更改后,需要更新模型定义以捕获附加层输出。

inputs = Input(shape=(max_sequence_length,))
lstm = Bidirectional(LSTM(lstm1_units, return_sequences=True))(inputs)
# Added 'attention_vector' to capture the second layer output
attention, attention_vector = Attention(max_sequence_length, return_attention=True)(lstm)
x = Dense(dense_units, activation="softmax")(attention)
Run Code Online (Sandbox Code Playgroud)

最后也是最重要的一块拼图来自Stackoverflow 的这个答案。 那里描述的方法允许我们输出多个结果,同时只优化其中之一。代码更改很微妙,但非常重要。我在下面为实现此功能所做的更改中添加了注释。

model = Model(
    inputs=inputs,
    outputs=[x, attention_vector] # Original value:  outputs=x
    )

model.compile(
    loss=['categorical_crossentropy', None], # Original value: loss='categorical_crossentropy'
    optimizer=optimizer,
    metrics=[BinaryAccuracy(name='accuracy')])
Run Code Online (Sandbox Code Playgroud)

完成这些更改后,我重新训练了模型,瞧!现在的输出model.predict()是一个包含分数及其相关注意力向量的列表。

改变的结果是相当戏剧性的。使用这种新方法对 10k 个示例进行推理大约需要 20 分钟。使用中间模型的旧方法需要约 33 分钟才能对同一数据集执行推理。

对于任何感兴趣的人,这是我修改后的注意力层:

from tensorflow.python.keras.layers import Layer
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras import backend as K


class Attention(Layer):
    def __init__(self, step_dim,
                W_regularizer=None, b_regularizer=None,
                W_constraint=None, b_constraint=None,
                bias=True, return_attention=True, **kwargs):
        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias

        self.step_dim = step_dim
        self.features_dim = 0
        self.return_attention = return_attention
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3

        self.W = self.add_weight(shape=(input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        self.features_dim = input_shape[-1]

        if self.bias:
            self.b = self.add_weight(shape=(input_shape[1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None

        self.built = True

    def compute_mask(self, input, input_mask=None):
        return None

    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = self.step_dim

        eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)),
                              K.reshape(self.W, (features_dim, 1))), (-1, step_dim))

        if self.bias:
            eij += self.b

        eij = K.tanh(eij)

        a = K.exp(eij)

        if mask is not None:
            a *= K.cast(mask, K.floatx())

        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())

        a = K.expand_dims(a)
        weighted_input = x * a
        result = K.sum(weighted_input, axis=1)

        if self.return_attention:
            return [result, a]
        return result

    def compute_output_shape(self, input_shape):
        if self.return_attention:
            return [(input_shape[0], self.features_dim),
                    (input_shape[0], input_shape[1])]
        else:
            return input_shape[0], self.features_dim

    def get_config(self):
        config = {
            'step_dim': self.step_dim,
            'W_regularizer': regularizers.serialize(self.W_regularizer),
            'b_regularizer': regularizers.serialize(self.b_regularizer),
            'W_constraint': constraints.serialize(self.W_constraint),
            'b_constraint': constraints.serialize(self.b_constraint),
            'bias': self.bias,
            'return_attention': self.return_attention
        }

        base_config = super(Attention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
Run Code Online (Sandbox Code Playgroud)