小编Ale*_*kov的帖子

在gpu上预加载整个数据集以用于训练Keras模型

我有一个特定的情况,网络相对较小,对于收敛和泛化问题,我应该保持小批量(例如256),这导致每个时期处理数百个批次.

不幸的是,在这种情况下,批量加载和损失计算成为瓶颈(正如timeline工具告诉我的那样).

在tensorflow中,您可以编写类似这样的内容来加载gpu上的数据:

with tf.device('/gpu:0'):
    train_data = tf.constant(train_data_numpy)
Run Code Online (Sandbox Code Playgroud)

但如果我传递train_data给Keras Model.predictModel.fit函数,我会收到以下错误:

keras/engine/training.pyc in predict(self, x, batch_size, verbose)
   1515         f = self.predict_function
   1516         return self._predict_loop(f, ins,
-> 1517                                   batch_size=batch_size, verbose=verbose)
   1518 
   1519     def train_on_batch(self, x, y,

keras/engine/training.pyc in _predict_loop(self, f, ins, batch_size, verbose)
   1129         if verbose == 1:
   1130             progbar = Progbar(target=samples)
-> 1131         batches = _make_batches(samples, batch_size)
   1132         index_array = np.arange(samples)
   1133         for batch_index, (batch_start, batch_end) in enumerate(batches):

keras/engine/training.pyc in _make_batches(size, batch_size)
    368         A …
Run Code Online (Sandbox Code Playgroud)

python keras tensorflow

8
推荐指数
1
解决办法
1065
查看次数

标签 统计

keras ×1

python ×1

tensorflow ×1