我目前正在研究一个文本分类问题,需要我们将文本分类为四个标签之一。编码后 y 值应该是[0,1,2,3]
预测标签之一。
然而,这个模型做出的预测似乎在 (0,1) 范围内,我有点困惑?此外,谁能澄清这是 ANN 还是 RNN?TensorFlow 零经验,仍在苦苦挣扎......
model = Sequential()
model.add(Dense(16, activation='relu'))
model.add(Dense(4, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
from sklearn.preprocessing import LabelEncoder
#encode the label
label_encoder = LabelEncoder()
y_train=np.array(label_encoder.fit_transform(train_labels))
x_train=np.array(train_features)
y_true=np.array(label_encoder.fit_transform(dev_label))
#fit the model
model.fit(x_train,y_train,epochs=1)
y_pred=model.predict(dev_features)
Run Code Online (Sandbox Code Playgroud)
和错误消息:Classification metrics can't handle a mix of multiclass and continuous-multioutput targets
假设目标列有 4 个唯一值:red, blue, green, yellow
并且语料库被转换为 TF-IDF 值。前 3 行如下所示:
字_1 | 字_2 | 目标 |
---|---|---|
0.567 | 0.897 | 红色的 |
0.098 | 0.238 | 蓝色的 |
0.66 | 0.786 | 绿色的 |
对目标进行one-hot 编码后,您的目标看起来像以下形式的数组:
array[[1. 0. 0. 0.], <- category 'red'
[0. 1. 0. 0.], <- category 'blue'
[0. 0. 1. 0.]...] <- category 'green'
Run Code Online (Sandbox Code Playgroud)
这里,目标列的大小为 (n_samples, n_targets),即 (n,4)。在这种情况下,最终激活必须是sigmoid
or softmax
,并且您将以损失来训练您的模型categorical_crossentropy
。这里回答你的问题的代码是:
model.add(Dense(4, activation='sigmoid'))
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
Run Code Online (Sandbox Code Playgroud)
对目标进行标签编码后,您的目标看起来像以下形式的数组:
array([1, 2, 3 ...])
Run Code Online (Sandbox Code Playgroud)
具有大小为 (n_targets) 的一维数组。这里的代码是:
model.add(Dense(4, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
Run Code Online (Sandbox Code Playgroud)
您看到的这些数字是给定输入样本的每个类别的概率。例如,[[0.4846592 0.5153408]] 表示给定样本属于 0 类的概率约为 0.48,属于 1 类的概率约为 0.51。因此,您想选择概率最高的类别,因此您可以使用 np.argmax 来查找哪个索引(即 0 或 1)是最大的索引:
import numpy as np
pred_class = np.argmax(y_pred, axis=-1)
Run Code Online (Sandbox Code Playgroud)
此外,这与模型的损失函数无关。这些概率由模型中的最后一层给出,它很可能使用 softmax 作为激活函数,将输出标准化为概率分布。 来源
sparse_categorical_crossentropy
作为损失函数。categorical_crossentropy
归档时间: |
|
查看次数: |
838 次 |
最近记录: |