我是一名新手程序员,正在尝试遵循本指南。但是,我遇到了一个问题。该指南说将损失函数定义为:
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
Run Code Online (Sandbox Code Playgroud)
这给了我以下错误:
sparse_categorical_crossentropy()获得了意外的关键字参数'from_logits'
我的意思是这from_logits是函数中未指定的参数,该参数受文档支持,该文档tf.keras.losses.sparse_categorical_crossentropy()只有两个可能的输入。
有没有一种方法可以指定正在使用的日志,或者甚至是必要的?
在学习本教程时,我遇到了同样的问题。我更改了代码
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
Run Code Online (Sandbox Code Playgroud)
至
def loss(labels, logits):
return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
Run Code Online (Sandbox Code Playgroud)
这样就解决了问题,而不必每晚安装tf。
该from_logits参数是在Tensorflow 1.13中引入的。
您可以将 1.12 和 1.13 与以下网址进行比较:
https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/losses.py
https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/losses.py
Run Code Online (Sandbox Code Playgroud)
在撰写本文时 1.13 尚未发布。这就是为什么本教程以该行开头
!pip install -q tf-nightly
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2111 次 |
| 最近记录: |