自定义回调中的访问丢失和模型

Ell*_*lla 4 python callback keras tensorflow

我读了这个https://www.tensorflow.org/guide/keras/custom_callback,但我不知道如何获取所有其他参数。

这是我的代码

 (hits, ndcgs) = evaluate_model(model, testRatings, testNegatives, topK, evaluation_threads)
  hr, ndcg, loss = np.array(hits).mean(), np.array(ndcgs).mean(), hist.history['loss'][0]
  print('Iteration %d [%.1f s]: HR = %.4f, NDCG = %.4f, loss = %.4f [%.1f s]' 
                  % (epoch,  t2-t1, hr, ndcg, loss, time()-t2))
 if hr > best_hr:
     best_hr, best_ndcg, best_iter = hr, ndcg, epoch
 if args.out > 0:
     model.save(model_out_file, overwrite=True)
Run Code Online (Sandbox Code Playgroud)

正如你所看到的model,我需要histmodel.save。有没有办法在自定义回调中使用这三个参数?这样我就可以将所有这些写入自定义回调中。

class CustomCallback(keras.callbacks.Callback):

   def on_epoch_end(self, logs=None):
       keys = list(logs.keys())
       print("Stop training; got log keys: {}".format(keys))
Run Code Online (Sandbox Code Playgroud)

Les*_*rel 6

该模型是的属性tf.keras.callbacks.Callback,因此您可以直接使用 访问它self.model。要访问损失的值,您可以使用传递给 的方法的tf.keras.callbacks.Callback“logs”对象,该对象将包含名为“loss”的键。

如果您需要访问其他变量(在训练期间不会改变),那么您可以将它们设置为回调的实例变量,并通过定义函数在构造回调期间添加它们__init__

class CustomCallback(keras.callbacks.Callback):
   def __init__(self, testRatings, testNegatives, topK, evaluation_threads):
       super().__init__()
       self.testRatings = testRatings
       self.testNegatives = testNegatives
       self.topK = topK
       self.evaluation_threads = evaluation_threads

   def on_epoch_end(self, epoch, logs=None):
       logs = logs or {}
       current_loss = logs.get("loss")
       if current_loss:
           print("my_loss: ", current_loss)
       print("my_model", self.model)
       # the attributes are accessble with self
       print("my topK atributes", self.topK)

# you can then create the callback by passing the correct attributes
my_callback = CustomCallback(testRatings, testNegatives, topK, evaluation_threads)
Run Code Online (Sandbox Code Playgroud)

注意:如果您想要做的是评估每个时期之间的模型,并在模型获得最佳指标时保存模型,我建议您查看:

  • fit功能,您实际上可以在其中提供测试集
  • 指标模块,提供将在训练集和测试集上计算的指标
  • 回调,将在每个时期保存模型,如果提供选项,则保持最佳权ModelCheckpointsave_best_only