我在大型数据集上使用Keras(使用MagnaTagATune数据集进行音乐自动标记)。因此,我尝试将fit_generator()功能与自定义数据生成器一起使用。但是损失函数和指标的价值在培训过程中不会改变。看来我的网络根本没有训练。
当我使用fit()函数而不是fit_generator()时,一切都很好,但是我无法将整个数据集保留在内存中。
我已经尝试了Theano和TensorFlow后端
主要代码:
if __name__ == '__main__':
model = models.FCN4()
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'categorical_accuracy', 'precision', 'recall'])
gen = mttutils.generator_v2(csv_path, melgrams_dir)
history = model.fit_generator(gen.generate(0,750),
samples_per_epoch=750,
nb_epoch=80,
validation_data=gen.generate(750,1000,False),
nb_val_samples=250)
# RESULTS SAVING
np.save(output_history, history.history)
model.save(output_model)
Run Code Online (Sandbox Code Playgroud)
generator_v2类:
genres = ['guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast',
'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing',
'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal',
'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', …Run Code Online (Sandbox Code Playgroud)