为什么LayerNormBasicLSTMCell比LSTMCell慢得多且准确度低?

Mao*_*hen 9 normalization lstm tensorflow

我最近发现LayerNormBasicLSTMCell是LSTM的一个版本,它实现了Layer Normalization和dropout.因此,我使用LSTMCell替换了我的原始代码和LayerNormBasicLSTMCell.这种变化不仅将测试精度从~96%降低到~92%,而且需要更长的时间(~33小时)进行训练(原始训练时间约为6小时).所有参数都相同:历元数(10),堆叠层数(3),隐藏矢量大小数(250),丢失保持概率(0.5),......硬件也相同.

我的问题是:我在这里做错了什么?

我原来的模型(使用LSTMCell):

# Batch normalization of the raw input
tf_b_VCCs_AMs_BN1 = tf.layers.batch_normalization(
    tf_b_VCCs_AMs, # the input vector, size [#batches, #time_steps, 2]
    axis=-1, # axis that should be normalized 
    training=Flg_training, # Flg_training = True during training, and False during test
    trainable=True,
    name="Inputs_BN"
    )

# Bidirectional dynamic stacked LSTM

##### The part I changed in the new model (start) #####
dropcells = []
for iiLyr in range(3):
    cell_iiLyr = tf.nn.rnn_cell.LSTMCell(num_units=250, state_is_tuple=True)
    dropcells.append(tf.nn.rnn_cell.DropoutWrapper(cell=cell_iiLyr, output_keep_prob=0.5))
##### The part I changed in the new model (end) #####

MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=dropcells, state_is_tuple=True)

outputs, states  = tf.nn.bidirectional_dynamic_rnn(
    cell_fw=MultiLyr_cell, 
    cell_bw=MultiLyr_cell,
    dtype=tf.float32,
    sequence_length=tf_b_lens, # the actual lengths of the input sequences (tf_b_VCCs_AMs_BN1)
    inputs=tf_b_VCCs_AMs_BN1,
    scope = "BiLSTM"
    )
Run Code Online (Sandbox Code Playgroud)

我的新模型(使用LayerNormBasicLSTMCell):

...
dropcells = []
for iiLyr in range(3):
    cell_iiLyr = tf.contrib.rnn.LayerNormBasicLSTMCell(
        num_units=250,
        forget_bias=1.0,
        activation=tf.tanh,
        layer_norm=True,
        norm_gain=1.0,
        norm_shift=0.0,
        dropout_keep_prob=0.5
        )
    dropcells.append(cell_iiLyr)
...
Run Code Online (Sandbox Code Playgroud)

Far*_*ian 2

关于训练时间:我看到这篇博文:http://olavnymoen.com/2016/07/07/rnn-batch-normalization。请参阅最后一个图。批量归一化的 lstm 比普通 lstm 慢 3 倍以上。作者认为原因是批量统计计算。

关于准确性:我不知道。

  • 我不明白为什么简单的统计计算(均值或方差)会花费这么多时间(尤其是使用 GPU)。 (2认同)