将 Keras 模型应用于符号张量导致 TF2.0 内存泄漏

wei*_*ler 4 python memory-leaks out-of-memory keras tensorflow

tldr:我的实现的内存使用量显然随着通过它的样本数量的增加而增加,但网络/样本馈送中不应该有任何内容关心到目前为止传递了多少样本。


当通过功能 API 创建的自定义 Keras 模型传递大量高维数据时,我观察到GPU 内存使用量随着观察到的实例数量的不断增加而不断增长。以下是通过网络传递样本过程的最小示例:

sequence_length = 100
batch_size = 128

env = gym.make("ShadowHand-v1")
_, _, joint = build_shadow_brain(env, bs=batch_size)
optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.SGD()

start_time = time.time()
for t in tqdm(range(sequence_length), disable=False):
    sample_batch = (
        tf.random.normal([batch_size, 1, 200, 200, 3]),
        tf.random.normal([batch_size, 1, 48]),
        tf.random.normal([batch_size, 1, 92]),
        tf.random.normal([batch_size, 1, 7])
    )

    with tf.GradientTape() as tape:
        out, v = joint(sample_batch)
        loss = tf.reduce_mean(out - v)

    grads = tape.gradient(loss, joint.trainable_variables)
    optimizer.apply_gradients(zip(grads, joint.trainable_variables))
    joint.reset_states()

print(f"Execution Time: {time.time() - start_time}")
Run Code Online (Sandbox Code Playgroud)

我知道,考虑到批量大小,这是一个很大的样本,但是如果它对于我的 GPU 来说太大,我预计会立即出现 OOM 错误,并且我还假设 6GB 的 VRAM 实际上就足够了。这是因为只有在 33 个实例之后才会出现 OOM 错误,这让我怀疑内存使用量不断增加。

请参阅以下我的模型的Keras 摘要:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
visual_input (InputLayer)       [(32, None, 200, 200 0                                            
__________________________________________________________________________________________________
proprioceptive_input (InputLaye [(32, None, 48)]     0                                            
__________________________________________________________________________________________________
somatosensory_input (InputLayer [(32, None, 92)]     0                                            
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, None, 64)     272032      visual_input[0][0]               
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, None, 8)      848         proprioceptive_input[0][0]       
__________________________________________________________________________________________________
time_distributed_2 (TimeDistrib (None, None, 8)      3032        somatosensory_input[0][0]        
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, None, 80)     0           time_distributed[0][0]           
                                                                 time_distributed_1[0][0]         
                                                                 time_distributed_2[0][0]         
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, None, 48)     3888        concatenate[0][0]                
__________________________________________________________________________________________________
time_distributed_4 (TimeDistrib (None, None, 48)     0           time_distributed_3[0][0]         
__________________________________________________________________________________________________
time_distributed_5 (TimeDistrib (None, None, 32)     1568        time_distributed_4[0][0]         
__________________________________________________________________________________________________
time_distributed_6 (TimeDistrib (None, None, 32)     0           time_distributed_5[0][0]         
__________________________________________________________________________________________________
goal_input (InputLayer)         [(32, None, 7)]      0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (32, None, 39)       0           time_distributed_6[0][0]         
                                                                 goal_input[0][0]                 
__________________________________________________________________________________________________
lstm (LSTM)                     (32, 32)             9216        concatenate_1[0][0]              
__________________________________________________________________________________________________
dense_10 (Dense)                (32, 20)             660         lstm[0][0]                       
__________________________________________________________________________________________________
dense_11 (Dense)                (32, 20)             660         lstm[0][0]                       
__________________________________________________________________________________________________
activation (Activation)         (32, 20)             0           dense_10[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (32, 20)             0           dense_11[0][0]                   
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (32, 40)             0           activation[0][0]                 
                                                                 activation_1[0][0]               
__________________________________________________________________________________________________
dense_12 (Dense)                (32, 1)              33          lstm[0][0]                       
==================================================================================================
Total params: 291,937
Trainable params: 291,937
Non-trainable params: 0
__________________________________________________________________________________________________
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,该网络中有一个 LSTM 层。它通常应该是有状态的,但是我已经将其关闭,因为我认为问题就在那里。事实上我已经尝试了以下方法,但没有消除问题

  • 状态转变
  • 完全删除 LSTM
  • 不计算任何梯度
  • 在每个实例之后重建模型

现在我对这个问题的潜在原因的想法已经结束了。

我还强制将进程转移到 CPU 上并检查标准内存(此处不会发生 OOM,因为我的 RAM 比 VRAM 大得多)。有趣的是,内存使用量上下跳跃,但呈上升趋势。对于每个实例,都会占用大约 2GB 内存,但是在获取下一个样本之前释放内存时,只会释放比所占用内存少大约 200MB 的内存。

编辑 1:正如评论中提到的,问题可能是在输入上调用模型会添加到计算图中。但是我不能使用,joint.predict()因为我需要计算梯度。

编辑 2:我更仔细地监控了内存的增长,实际上每次迭代都会保留一些内存,正如您在前 9 个步骤中看到的那样:

0: 8744054784
1: 8885506048
2: 9015111680
3: 9143611392
4: 9272619008
5: 9405591552
6: 9516531712
7: 9647988736
8: 9785032704
Run Code Online (Sandbox Code Playgroud)

这是在批量大小为 32 的情况下完成的。 1 的大小sample_batch256 * (200 * 200 * 3 + 48 + 92 + 7) * 32 = 984244224位(精度是float32),这或多或少表明问题确实是当样本通过网络传递时,样本被添加到图中,因为它是正如@MatiasValdenegro 所建议的,具有象征意义。所以我想现在的问题可以归结为“如何使张量非符号化”(如果这确实是一件事的话)。

免责声明:我知道您无法使用给定的代码重现该问题,因为缺少组件,但我无法在此处提供完整的项目代码。

wei*_*ler 6

我花了一段时间,但现在我已经解决了这个问题。正如我之前编辑过的问题:问题是 Keras 的功能 API 似乎将每个样本添加到计算图中,而没有删除迭代后我们不再需要的输入。似乎没有简单的方法可以显式删除它,但是装饰tf.function器可以解决这个问题

以上面的代码为例,它可以按如下方式应用:

sequence_length = 100
batch_size = 256

env = gym.make("ShadowHand-v1")
_, _, joint = build_shadow_brain(env, bs=batch_size)
plot_model(joint, to_file="model.png")
optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.SGD()

@tf.function
def _train():
    start_time = time.time()

    for _ in tqdm(range(sequence_length), disable=False):
        sample_batch = (tf.convert_to_tensor(tf.random.normal([batch_size, 4, 224, 224, 3])),
                        tf.convert_to_tensor(tf.random.normal([batch_size, 4, 48])),
                        tf.convert_to_tensor(tf.random.normal([batch_size, 4, 92])),
                        tf.convert_to_tensor(tf.random.normal([batch_size, 4, 7])))

        with tf.GradientTape() as tape:
            out, v = joint(sample_batch, training=True)
            loss = tf.reduce_mean(out - v)

        grads = tape.gradient(loss, joint.trainable_variables)
        optimizer.apply_gradients(zip(grads, joint.trainable_variables))

    print(f"Execution Time: {time.time() - start_time}")

_train()
Run Code Online (Sandbox Code Playgroud)

也就是说,训练循环可以在带有tf.function装饰器的函数中传送。这意味着训练将以图形模式执行,并且出于某种原因,这消除了问题,很可能是因为图形将在函数结束后被转储。有关更多信息,tf.function请参阅有关该主题的TF2.0 指南。