小编Lad*_*lao的帖子

Keras:网络无法使用fit_generator()进行训练

我在大型数据集上使用Keras(使用MagnaTagATune数据集进行音乐自动标记)。因此,我尝试将fit_generator()功能与自定义数据生成器一起使用。但是损失函数和指标的价值在培训过程中不会改变。看来我的网络根本没有训练。

当我使用fit()函数而不是fit_generator()时,一切都很好,但是我无法将整个数据集保留在内存中。

我已经尝试了Theano和TensorFlow后端

主要代码:

if __name__ == '__main__':
    model = models.FCN4()
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy', 'categorical_accuracy', 'precision', 'recall'])
    gen = mttutils.generator_v2(csv_path, melgrams_dir)
    history = model.fit_generator(gen.generate(0,750),
                                  samples_per_epoch=750,
                                  nb_epoch=80,
                                  validation_data=gen.generate(750,1000,False),
                                  nb_val_samples=250)
    # RESULTS SAVING
    np.save(output_history, history.history)
    model.save(output_model)
Run Code Online (Sandbox Code Playgroud)

generator_v2类:

genres = ['guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast',
        'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing',
        'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal',
        'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', …
Run Code Online (Sandbox Code Playgroud)

python deep-learning keras

5
推荐指数
1
解决办法
4201
查看次数

标签 统计

deep-learning ×1

keras ×1

python ×1