Tensorflow:tf.argmax()作为预测还是最大值?

bks*_*shi 1 tensorflow

我正在学习tensorflow,在各种示例中,我已经看到了从我们使用的logit获得预测tf.argmax(logits, 1)。根据我所理解的logits是概率值,tf.argmax()并将给出指定轴上最大值的索引。但是,如何使用索引代替概率值。我们不应该将最大值用作预测吗?

但是我已经看到上面的代码可以正常工作。我确定我在这里缺少一些基础知识。有人可以举例说明吗?

nes*_*uno 6

通常logits是分类网络的输出张量,其内容为未归一化(不在0和1之间缩放)的概率。

tf.argmax 给您沿指定轴的最大值索引。

您可以将其转换logits为伪概率(这只是张量,其值的总和为1),并将其作为输入送入argmax:

top = tf.argmax(tf.nn.softmax(logits), 1)
Run Code Online (Sandbox Code Playgroud)

但是最后,结果与直接馈入非标准化概率相同:

top = tf.argmax(logits, 1)
Run Code Online (Sandbox Code Playgroud)

但是,您必须使用argmax才能了解网络为该输入预测的类,这是唯一的方法,您不能仅使用概率(归一化或非归一化)。

只需考虑如下logits张量:

logits = [ [ 10, 500, -1, 0.5, 12 ] ]
Run Code Online (Sandbox Code Playgroud)

张量形状为[1,5]。仅查看张量值,就可以轻松地了解到,具有最高置信度的类是与位置1相关联的类,其值为500。

如何提取最高价值的头寸?您必须使用argmax:

top = tf.argmax(logits, 1)
Run Code Online (Sandbox Code Playgroud)

执行后将返回值1

摘要:logits的值是Scores,索引是Classes。通过使用argmax(),您可以获得predicted class