gcs*_*osh 12 debugging keras tensorflow
我正在编写一个自定义目标来训练Keras(带有tensorflow后端)模型,但我需要调试一些中间计算.为简单起见,假设我有:
def custom_loss(y_pred, y_true):
diff = y_pred - y_true
return K.square(diff)
Run Code Online (Sandbox Code Playgroud)
我无法找到一种简单的方法来访问,例如,在训练期间中间变量diff或其形状.在这个简单的例子中,我知道我可以返回diff来打印它的值,但是我的实际损失更复杂,我不能在没有编译错误的情况下返回中间值.
有没有一种简单的方法来调试Keras中的中间变量?
据我所知,这不是Keras解决的问题,因此您必须采用特定于后端的功能.既Theano和TensorFlow有Print
那些身份的节点(即,它们将返回输入节点),并具有打印输入(或输入的一些张量)的副作用的节点.
Theano的例子:
diff = y_pred - y_true
diff = theano.printing.Print('shape of diff', attrs=['shape'])(diff)
return K.square(diff)
Run Code Online (Sandbox Code Playgroud)
TensorFlow示例:
diff = y_pred - y_true
diff = tf.Print(diff, [tf.shape(diff)])
return K.square(diff)
Run Code Online (Sandbox Code Playgroud)
请注意,这仅适用于中间值.Keras希望传递给其他层的张量具有特定属性,例如_keras_shape
.后端处理的值,即通过Print
,通常没有该属性.要解决此问题,您可以将调试语句包装在一个Lambda
层中.
归档时间: |
|
查看次数: |
6021 次 |
最近记录: |