Tensorflow:预测 4 个标签中的 1 个以进行文本分类

hel*_*len 2 python tensorflow

我目前正在研究一个文本分类问题,需要我们将文本分类为四个标签之一。编码后 y 值应该是[0,1,2,3]预测标签之一。

然而,这个模型做出的预测似乎在 (0,1) 范围内,我有点困惑?此外,谁能澄清这是 ANN 还是 RNN?TensorFlow 零经验,仍在苦苦挣扎......

model = Sequential()
model.add(Dense(16, activation='relu'))
model.add(Dense(4, activation='softmax'))
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
              
from sklearn.preprocessing import LabelEncoder
#encode the label
label_encoder = LabelEncoder()
y_train=np.array(label_encoder.fit_transform(train_labels))
x_train=np.array(train_features)
y_true=np.array(label_encoder.fit_transform(dev_label))
#fit the model
model.fit(x_train,y_train,epochs=1)
y_pred=model.predict(dev_features)
Run Code Online (Sandbox Code Playgroud)

和错误消息:Classification metrics can't handle a mix of multiclass and continuous-multioutput targets

在此输入图像描述

iam*_*sha 6

假设目标列有 4 个唯一值:red, blue, green, yellow并且语料库被转换为 TF-IDF 值。前 3 行如下所示:

字_1 字_2 目标
0.567 0.897 红色的
0.098 0.238 蓝色的
0.66 0.786 绿色的

一次性编码

对目标进行one-hot 编码后,您的目标看起来像以下形式的数组:

array[[1. 0. 0. 0.], <- category 'red'
[0. 1. 0. 0.], <- category 'blue'
[0. 0. 1. 0.]...] <- category 'green'
Run Code Online (Sandbox Code Playgroud)

这里,目标列的大小为 (n_samples, n_targets),即 (n,4)。在这种情况下,最终激活必须是sigmoidor softmax,并且您将以损失来训练您的模型categorical_crossentropy。这里回答你的问题的代码是:

model.add(Dense(4, activation='sigmoid'))
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
Run Code Online (Sandbox Code Playgroud)

标签编码

对目标进行标签编码后,您的目标看起来像以下形式的数组:

array([1, 2, 3 ...])
Run Code Online (Sandbox Code Playgroud)

具有大小为 (n_targets) 的一维数组。这里的代码是:

model.add(Dense(4, activation='softmax'))
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
Run Code Online (Sandbox Code Playgroud)

预言

您看到的这些数字是给定输入样本的每个类别的概率。例如,[[0.4846592 0.5153408]] 表示给定样本属于 0 类的概率约为 0.48,属于 1 类的概率约为 0.51。因此,您想选择概率最高的类别,因此您可以使用 np.argmax 来查找哪个索引(即 0 或 1)是最大的索引:

import numpy as np

pred_class = np.argmax(y_pred, axis=-1) 
Run Code Online (Sandbox Code Playgroud)

此外,这与模型的损失函数无关。这些概率由模型中的最后一层给出,它很可能使用 softmax 作为激活函数,将输出标准化为概率分布。 来源

结论

  • 您得到的错误是因为损失函数使用不正确。
  • 如果您有1D 整数编码或 LabelEncoded 目标,则应使用sparse_categorical_crossentropy作为损失函数。
  • 如果您对目标进行了 one-hot 编码以获得 2D 形状(n_samples,n_class),则应该使用categorical_crossentropy