尽管对 keras 中的零填充小批量 LSTM 训练进行掩码支持,但零预测

Des*_*eux 5 python deep-learning lstm keras mini-batch

问题陈述

我在带有 tensorflow 后端(tf 版本 1.13.1)的 keras 中训练多对多 LSTM,用于标记文本序列,以使用预训练的 GloVe 嵌入来预测序列中每个元素的标记。我的训练方案涉及小批量随机梯度下降,每个小批量矩阵按列填充零以确保输入到网络的长度相等。

至关重要的是,由于任务和数据的性质对我的小批量进行自定义限制,我没有使用 keras 嵌入层。我的目标是为我的零填充单元实现屏蔽机制,以确保损失计算不会将这些单元虚假地视为真正的数据点。

方法

正如 keras文档中所解释的,keras 提供了三种设置掩蔽层的方法:

  1. keras.layers.Embedding图层配置mask_zeroTrue
  2. 添加keras.layers.Masking图层;
  3. 调用循环层时手动传递掩码参数。

因为我没有使用嵌入层来编码我的训练数据,所以我无法使用带有掩码嵌入层的选项 (1)。因此,我选择了 (2) 并在初始化我的模型后立即添加了一个遮罩层。然而,这种变化似乎没有产生影响。事实上,不仅我的模型的准确性没有提高,在预测阶段模型仍然产生零预测。为什么我的屏蔽层不屏蔽零填充的单元格?这是否与在我的密集层中指定 3 个类而不是 2 个(因此将 0 作为单独的类包括在内)的事实有关?

现有资源的限制

已经提出并回答了类似的问题,但我无法使用它们来解决我的问题。虽然这篇文章没有得到直接回应,但评论中提到的一个链接帖子侧重于如何预处理数据以分配掩码值,这在这里没有争议。然而,掩蔽层初始化与此处使用的相同。这篇文章提到了同样的问题 - 遮罩层对性能没有影响 - 答案以与我相同的方式定义遮罩层,但再次侧重于将特定值转换为遮罩值。最后,这篇文章中的答案提供了相同的层初始化,而无需进一步详细说明。

玩具数据生成

为了重现我的问题,我生成了一个包含两个类 (1,2) 的玩具 10 批次数据集。批是一个可变长度的序列,后填充零,最大长度为 20 个嵌入,每个嵌入向量由 5 个单元组成,因此input_shape=(20,5). 这两个类的嵌入值是从不同但部分重叠的截断正态分布生成的,从而为网络创建了一个可学习但并非微不足道的问题。我在下面包含了玩具数据,以便您可以重现该问题。

import pandas as pd
from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed, Bidirectional, Dropout, Masking
from keras import optimizers

# *** model initialization ***

model = Sequential()
model.add(Masking(mask_value=0., input_shape=(20, 5))) # <- masking layer here
model.add(Bidirectional(LSTM(20, return_sequences=True), input_shape=(20, 5)))
model.add(Dropout(0.2))
model.add(TimeDistributed(Dense(3, activation='sigmoid')))

sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd, metrics=['mse'])

# *** model training ***

for epoch in range(10):

    for X,y in data_train:

        X = X.reshape(1, 20, 5)
        y = y.reshape(1, 20, 1)

        history = model.fit(X, y, epochs=1, batch_size=20, verbose=0)

# *** model prediction ***

preds = pd.DataFrame(columns=['true', 'pred'])

for index, (X,y) in enumerate(data_test):
    X = X.reshape(1, 20, 5)
    y = y.reshape(1, 20, 1)

    y_pred = model.predict_classes(X, verbose=0)

    df = pd.DataFrame(columns=['true', 'pred'])

    df['true'] = [y[0, i][0] for i in range(20)]
    df['pred'] = [y_pred[0, i] for i in range(20)]

    preds = preds.append(df, ignore_index=True)

# convert true labels to int & drop padded rows (where y_true=0)
preds['true'] = [int(label) for label in preds['true']]
preds = preds[preds['true']!=0]
Run Code Online (Sandbox Code Playgroud)

这是带有掩码的模型的摘要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
masking_2 (Masking)          (None, 20, 5)             0         
_________________________________________________________________
bidirectional_4 (Bidirection (None, 20, 40)            4160      
_________________________________________________________________
dropout_4 (Dropout)          (None, 20, 40)            0         
_________________________________________________________________
time_distributed_4 (TimeDist (None, 20, 3)             123       
=================================================================
Total params: 4,283
Trainable params: 4,283
Non-trainable params: 0
Run Code Online (Sandbox Code Playgroud)

我训练了一个模型,一个模型没有掩蔽层,并使用以下方法计算精度:

np.round(sum(preds['true']==preds['pred'])/len(preds)*100,1)
Run Code Online (Sandbox Code Playgroud)

没有遮罩的模型准确率为 53.3%,有遮罩的模型准确率为 33.3%。更令人惊讶的是,在两个模型中,我一直将零作为预测标签。为什么屏蔽层无法忽略零填充单元格?


重现问题的数据

data_train = list(zip(X_batches_train, y_batches_train))
data_test = list(zip(X_batches_test, y_batches_test))
Run Code Online (Sandbox Code Playgroud)

X_batches_train

[array([[-1.00612917,  1.47313952,  2.68021318,  1.54875809,  0.98385996,
          1.49465265,  0.60429106,  1.12396908, -0.24041602,  1.77266187,
          0.1961381 ,  1.28019637,  1.78803092,  2.05151245,  0.93606708,
          0.51554755,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.97596563,  2.04536053,  0.88367922,  1.013342  , -0.16605355,
          3.02994344,  2.04080806, -0.25153046, -0.5964068 ,  2.9607247 ,
         -0.49722121,  0.02734492,  2.16949987,  2.77367066,  0.15628842,
          2.19823207,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.31546283,  3.27420503,  3.23550769, -0.63724013,  0.89150128,
          0.69774266,  2.76627308, -0.58408384, -0.45681779,  1.98843041,
         -0.31850477,  0.83729882,  0.45471165,  3.61974147, -1.45610756,
          1.35217453,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.03329532,  1.97471646,  1.33949611,  1.22857243, -1.46890642,
          1.74105506,  1.40969261,  0.52465603, -0.18895266,  2.81025597,
          2.64901037, -0.83415186,  0.76956826,  1.48730868, -0.16190164,
          2.24389007,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.0676654 ,  3.08429323,  1.7601179 ,  0.85448051,  1.15537064,
          2.82487842,  0.27891413,  0.57842569, -0.62392063,  1.00343057,
          1.15348843, -0.37650332,  3.37355345,  2.22285473,  0.43444434,
          0.15743873,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.05258873, -0.17897376, -0.99932932, -1.02854121,  0.85159208,
          2.32349131,  1.96526709, -0.08398597, -0.69474809,  1.32820222,
          1.19514151,  1.56814867,  0.86013263,  1.48342922,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.1920635 , -0.48702788,  1.24353985, -1.3864121 ,  0.16713229,
          3.10134683,  0.61658271, -0.63360643,  0.86000807,  2.74876157,
          2.87604877,  0.16339724,  2.87595396,  3.2846962 ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.1380241 , -0.76783029,  0.18814436, -1.18165209, -0.02981728,
          1.49908113,  0.61521007, -0.98191097,  0.31250199,  1.39015803,
          3.16213211, -0.70891214,  3.83881766,  1.92683533,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.39080778, -0.59179216,  0.80348201,  0.64638205, -1.40144268,
          1.49751413,  3.0092166 ,  1.33099666,  1.43714841,  2.90734268,
          3.09688943,  0.32934884,  1.14592787,  1.58152023,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.77164353,  0.50293096,  0.0717377 ,  0.14487556, -0.90246591,
          2.32612179,  1.98628857,  1.29683166, -0.12399569,  2.60184685,
          3.20136653,  0.44056647,  0.98283455,  1.79026663,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-0.93359914,  2.31840281,  0.55691601,  1.90930758, -1.58260431,
         -1.05801881,  3.28012523,  3.84105406, -1.2127093 ,  0.00490079,
          1.28149304,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.03105486,  2.7703693 ,  0.16751813,  1.12127987, -0.44070271,
         -0.0789227 ,  2.79008301,  1.11456745,  1.13982551, -1.10128658,
          0.87430834,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.69710668,  1.72702833, -2.62599502,  2.34730002,  0.77756661,
          0.16415884,  3.30712178,  1.67331828, -0.44022431,  0.56837829,
          1.1566811 ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.71845983,  1.79908544,  0.37385522,  1.3870915 , -1.48823234,
         -1.487419  ,  3.0879945 ,  1.74617784, -0.91538815, -0.24244522,
          0.81393954,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.38501563,  3.73330047, -0.52494265,  2.37133716, -0.24546709,
         -0.28360782,  2.89384717,  2.42891743,  0.40144022, -1.21850571,
          2.00370751,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.27989188,  1.16254538, -0.06889142,  1.84133355,  1.3234908 ,
          1.29611702,  2.0019294 , -0.03220116,  1.1085194 ,  1.96495985,
          1.68544302,  1.94503544,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.3004439 ,  2.48768923,  0.59809607,  2.38155155,  2.78705889,
          1.67018683,  0.21731778, -0.59277191,  2.87427207,  2.63950475,
          2.39211459,  0.93083423,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.39239371,  0.30900383, -0.97307155,  1.98100711,  0.30613735,
          1.12827171,  0.16987791,  0.31959096,  1.30366416,  1.45881023,
          2.45668401,  0.5218711 ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.0826574 ,  2.05100254,  0.013161  ,  2.95120798,  1.15730011,
          0.75537024,  0.13708569, -0.44922143,  0.64834001,  2.50640862,
          2.00349347,  3.35573624,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.47135124,  2.10258532,  0.70212032,  2.56063126,  1.62466971,
          2.64026892,  0.21309489, -0.57752813,  2.21335957,  0.20453233,
          0.03106993,  3.01167822,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-0.42125521,  0.54016939,  1.63016057,  2.01555253, -0.10961255,
         -0.42549555,  1.55793753, -0.0998756 ,  0.36417335,  3.37126414,
          1.62151191,  2.84084192,  0.10831384,  0.89293054, -0.08671363,
          0.49340353,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.37615411,  2.00581062,  2.30426605,  2.02205839,  0.65871664,
          1.34478836, -0.55379752, -1.42787727,  0.59732227,  0.84969282,
          0.54345723,  0.95849568, -0.17131602, -0.70425277, -0.5337757 ,
          1.78207229,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.13863276,  1.71490034,  2.02677925,  2.60608619,  0.26916522,
          0.35928298, -1.26521844, -0.59859219,  1.19162219,  1.64565259,
          1.16787165,  2.95245196,  0.48681084,  1.66621053,  0.918077  ,
         -1.10583747,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.87763797,  2.38740754,  2.9111822 ,  2.21184069,  0.78091173,
         -0.53270909,  0.40100338, -0.83375593,  0.9860009 ,  2.43898437,
         -0.64499989,  2.95092003, -1.52360727,  0.44640918,  0.78131922,
         -0.24401283,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.92615066,  3.45437746,  3.28808981,  2.87207404, -1.60027223,
         -1.14164941, -1.63807699,  0.33084805,  2.92963629,  3.51170824,
         -0.3286093 ,  2.19108385,  0.97812366, -1.82565766, -0.34034678,
         -2.0485913 ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.96438618e+00,  1.88104784e-01,  1.61114494e+00,
          6.99567690e-04,  2.55271963e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 2.41578815e+00, -5.70625661e-01,  2.15545894e+00,
         -1.80948908e+00,  1.62049331e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 1.97017040e+00, -1.62556528e+00,  2.49469152e+00,
          4.18785985e-02,  2.61875866e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 3.14277819e+00,  3.01098398e-02,  7.40376369e-01,
          1.76517344e+00,  2.68922918e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 2.06250296e+00,  4.67605528e-01,  1.55927230e+00,
          1.85788889e-01,  1.30359922e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00]]),
 array([[ 1.22152427,  3.74926839,  0.64415552,  2.35268329,  1.98754653,
          2.89384829,  0.44589817,  3.94228743,  2.72405657,  0.86222004,
          0.68681903,  3.89952458,  1.43454512,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.02203262,  0.95065123,  0.71669023,  0.02919391,  2.30714524,
          1.91843002,  0.73611294,  1.20560482,  0.85206836, -0.74221506,
         -0.72886308,  2.39872927, -0.95841402,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.55775319,  0.33773314,  0.79932151,  1.94966883,  3.2113281 ,
          2.70768249, -0.69745554,  1.23208345,  1.66199957,  1.69894081,
          0.13124461,  1.93256147, -0.17787952,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.45089205,  2.62430534, -1.9517961 ,  2.24040577,  1.75642049,
          1.94962325,  0.26796497,  2.28418304,  1.44944487,  0.28723885,
         -0.81081633,  1.54840214,  0.82652939,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.27678173,  1.17204606, -0.24738322,  1.02761617,  1.81060444,
          2.37830861,  0.55260134,  2.50046334,  1.04652821,  0.03467176,
         -2.07336654,  1.2628897 ,  0.61604732,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 3.86138405,  2.35068317, -1.90187438,  0.600788  ,  0.18011722,
          1.3469559 , -0.54708828,  1.83798823, -0.01957845,  2.88713217,
          3.1724991 ,  2.90802072,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.26785642,  0.51076756,  0.32070756,  2.33758816,  2.08146669,
         -0.60796736,  0.93777509,  2.70474711,  0.44785738,  1.61720609,
          1.52890594,  3.03072971,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 3.30219394,  3.1515445 ,  1.16550716,  2.07489374,  0.66441859,
          0.97529244,  0.35176367,  1.22593639, -1.80698271,  1.19936482,
          3.34017172,  2.15960657,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.34839018,  2.24827352, -1.61070856,  2.81044265, -1.21423372,
          0.24633846, -0.82196609,  2.28616568,  0.033922  ,  2.7557593 ,
          1.16178372,  3.66959512,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.32913219,  1.63231852,  0.58642744,  1.55873546,  0.86354741,
          2.06654246, -0.44036504,  3.22723595,  1.33279468,  0.05975892,
          2.48518999,  3.44690602,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 0.61424344, -1.03068819, -1.47929328,  2.91514641,  2.06867196,
          1.90384921, -0.45835234,  1.22054782,  0.67931536,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.76480464,  1.12442631, -2.36004758,  2.91912726,  1.67891181,
          3.76873596, -0.93874096, -0.32397781, -0.55732374,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.39953353, -1.26828104,  0.44482517,  2.85604975,  3.08891062,
          2.60268725, -0.15785176,  1.58549879, -0.32948578,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.65156484, -1.56545168, -1.42771206,  2.74216475,  1.8758154 ,
          3.51169147,  0.18353058, -0.14704149,  0.00442783,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.27736372,  0.37407608, -1.25713475,  0.53171176,  1.53714914,
          0.21015523, -1.06850669, -0.09755327, -0.92373834,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-1.39160433,  0.21014669, -0.89792475,  2.6702794 ,  1.54610601,
          0.84699037,  2.96726482,  1.84236946,  0.02211578,  0.32842575,
          1.02718924,  1.78447936, -1.20056829,  2.26699318, -0.23156537,
          2.50124959,  1.93372501,  0.10264369, -1.70813962,  0.        ],
        [ 0.38823591, -1.30348049, -0.31599117,  2.60044143,  2.32929389,
          1.40348483,  3.25758736,  1.92210728, -0.34150988, -1.22336921,
          2.3567069 ,  1.75456835,  0.28295694,  0.68114898, -0.457843  ,
          1.83372069,  2.10177851, -0.26664178, -0.26549595,  0.        ],
        [ 0.08540346,  0.71507504,  1.78164285,  3.04418137,  1.52975256,
          3.55159169,  3.21396003,  3.22720346,  0.68147142,  0.12466013,
         -0.4122895 ,  1.97986653,  1.51671949,  2.06096825, -0.6765908 ,
          2.00145086,  1.73723014,  0.50186043, -2.27525744,  0.        ],
        [ 0.00632717,  0.3050794 , -0.33167875,  1.48109172,  0.19653696,
          1.97504239,  2.51595821,  1.74499313, -1.65198805, -1.04424953,
         -0.23786945,  1.18639347, -0.03568057,  3.82541131,  2.84039446,
          2.88325909,  1.79827675, -0.80230291,  0.08165052,  0.        ],
        [ 0.89980086,  0.34690991, -0.60806566,  1.69472308,  1.38043417,
          0.97139487,  0.21977176,  1.01340944, -1.69946943, -0.01775586,
         -0.35851919,  1.81115864,  1.15105661,  1.21410373,  1.50667558,
          1.70155313,  3.1410754 , -0.54806167, -0.51879299,  0.        ]])]
Run Code Online (Sandbox Code Playgroud)

y_batches_train

[array([1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 1., 1., 2., 2., 1.

Mac*_*ner 6

第一个问题:重塑后的 X 数据不是您所期望的。如果你看一下重塑后的第一个样本,它是:

array([[[-1.00612917,  1.47313952,  2.68021318,  1.54875809,
          0.98385996],
        [ 1.49465265,  0.60429106,  1.12396908, -0.24041602,
          1.77266187],
        [ 0.1961381 ,  1.28019637,  1.78803092,  2.05151245,
          0.93606708],
        [ 0.51554755,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-0.97596563,  2.04536053,  0.88367922,  1.013342  ,
         -0.16605355],
        [ 3.02994344,  2.04080806, -0.25153046, -0.5964068 ,
          2.9607247 ],
        [-0.49722121,  0.02734492,  2.16949987,  2.77367066,
          0.15628842],
        [ 2.19823207,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.31546283,  3.27420503,  3.23550769, -0.63724013,
          0.89150128],
        [ 0.69774266,  2.76627308, -0.58408384, -0.45681779,
          1.98843041],
        [-0.31850477,  0.83729882,  0.45471165,  3.61974147,
         -1.45610756],
        [ 1.35217453,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 1.03329532,  1.97471646,  1.33949611,  1.22857243,
         -1.46890642],
        [ 1.74105506,  1.40969261,  0.52465603, -0.18895266,
          2.81025597],
        [ 2.64901037, -0.83415186,  0.76956826,  1.48730868,
         -0.16190164],
        [ 2.24389007,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-1.0676654 ,  3.08429323,  1.7601179 ,  0.85448051,
          1.15537064],
        [ 2.82487842,  0.27891413,  0.57842569, -0.62392063,
          1.00343057],
        [ 1.15348843, -0.37650332,  3.37355345,  2.22285473,
          0.43444434],
        [ 0.15743873,  0.        ,  0.        ,  0.        ,
          0.        ]]])
Run Code Online (Sandbox Code Playgroud)

所以实际上没有时间步被屏蔽,因为Masking层只屏蔽所有特征都为0的时间步,所以上面的20个时间步没有被屏蔽,因为它们都不是完全0的。

对于 Masking 层,为了确保将 mask 成功传播到输出层,您可以执行以下操作:

for i, l in enumerate(model.layers):
    print(f'layer {i}: {l}')
    print(f'has input mask: {l.input_mask}')
    print(f'has output mask: {l.output_mask}')

layer 0: <tensorflow.python.keras.layers.core.Masking object at 0x6417b7f60>
has input mask: None
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 1: <tensorflow.python.keras.layers.wrappers.Bidirectional object at 0x641e25cf8>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 2: <tensorflow.python.keras.layers.core.Dropout object at 0x641814128>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 3: <tensorflow.python.keras.layers.wrappers.TimeDistributed object at 0x6433b6ba8>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("time_distributed/Reshape_3:0", shape=(None, 20), dtype=bool)
Run Code Online (Sandbox Code Playgroud)

所以你可以看到最后一层也有output_mask,这意味着掩码已成功传播。您似乎对Keras中Masking的工作原理有误解,它实际上会生成一个掩码,这是一个布尔数组,掩码的形状是(无,时间步长),因为在您的模型定义中,时间步长维度始终保持相同,因此掩码将传播到末尾而不发生任何变化。然后,当 Keras 计算损失时(当然还有计算梯度时),掩码值为 False 的时间步将被忽略。Masking 层不会改变输出值,当然你的模型仍然会预测类 0,它唯一做的就是生成一个布尔数组,指示应跳过哪个时间步并将其传递到末尾(如果所有层都接受面具)。

因此,您可以做的就是按如下所示更改模型定义的一行,并将 y_labels 移动 1,这意味着您当前的类:

0 -> 0(因为这些时间步的损失将被忽略,对模型的训练没有贡献,所以是0还是1并不重要)

1 -> 0

2 -> 1

# I would prefer softmax if doing classification
# here we only need to specify 2 classes
# and actually TimeDistributed can be thrown away (at least in recent Keras versions)
model.add(TimeDistributed(Dense(2, activation='softmax')))
Run Code Online (Sandbox Code Playgroud)

您还可以在这里查看我的答案/sf/answers/4151970371/,以了解如何使用/不使用掩蔽来计算损失。