如何加快此Keras Attention计算?

mod*_*itt 13 python vectorization keras tensorflow

我写了一个自定义的keras层为AttentiveLSTMCellAttentiveLSTM(RNN)与keras'线方法RNNs.该注意机制由Bahdanau描述,其中,在编码器/解码器模型中,从编码器的所有输出和解码器的当前隐藏状态创建"上下文"矢量.然后,我将每个时间步的上下文向量附加到输入.

该模型用于制作Dialog Agent,但与架构中的NMT模型(类似任务)非常相似.

然而,在添加这种注意机制时,我已经放慢了我的网络5倍的训练速度,我真的想知道如何编写代码的一部分,这样可以更有效地减慢它的速度.

计算的主要内容在这里完成:

h_tm1 = states[0]  # previous memory state
c_tm1 = states[1]  # previous carry state

# attention mechanism

# repeat the hidden state to the length of the sequence
_stm = K.repeat(h_tm1, self.annotation_timesteps)

# multiplty the weight matrix with the repeated (current) hidden state
_Wxstm = K.dot(_stm, self.kernel_w)

# calculate the attention probabilities
# self._uh is of shape (batch, timestep, self.units)
et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v))

at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps)
at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1)

# append the context vector to the inputs
inputs = K.concatenate([inputs, context])
Run Code Online (Sandbox Code Playgroud)

在(一次)的call方法中AttentiveLSTMCell.

完整的代码可以在这里找到.如果有必要提供一些数据和方法来与模型进行交互,那么我就可以做到.

有任何想法吗?当然,如果这里有一些聪明的话,我会在GPU上进行训练.

小智 1

我建议使用 relu 而不是 tanh 来训练模型,因为此操作的计算速度要快得多。这将节省您的计算时间,顺序为训练示例*每个示例的平均序列长度*时期数。

另外,我会评估附加上下文向量的性能改进,请记住这会减慢其他参数的迭代周期。如果它没有给你带来太大的改善,那么可能值得尝试其他方法。