Sklearn StratifiedKFold:ValueError:支持的目标类型是:('binary','multiclass').得到'多标签指标'

jKr*_*aut 18 python machine-learning scikit-learn cross-validation keras

使用Sklearn分层kfold分割,当我尝试使用多类分割时,我收到错误(见下文).当我尝试使用二进制分割时,它没有问题.

num_classes = len(np.unique(y_train))
y_train_categorical = keras.utils.to_categorical(y_train, num_classes)
kf=StratifiedKFold(n_splits=5, shuffle=True, random_state=999)

# splitting data into different folds
for i, (train_index, val_index) in enumerate(kf.split(x_train, y_train_categorical)):
    x_train_kf, x_val_kf = x_train[train_index], x_train[val_index]
    y_train_kf, y_val_kf = y_train[train_index], y_train[val_index]

ValueError: Supported target types are: ('binary', 'multiclass'). Got 'multilabel-indicator' instead.
Run Code Online (Sandbox Code Playgroud)

des*_*aut 15

keras.utils.to_categorical产生一个热门编码的类向量,即multilabel-indicator错误消息中提到的.StratifiedKFold不适合用于此类输入; 从split方法文档:

split(X,y,groups = None)

[...]

y:类似数组,形状(n_samples,)

监督学习问题的目标变量.分层基于y标签完成.

即你y必须是你的班级标签的一维数组.

基本上,你要做的只是颠倒操作的顺序:首先拆分(使用你的初始y_train),然后转换to_categorical.

  • @Minion这是不正确的;StratifiedKFold`会注意“ *通过保留每个类别的样本百分比来进行折叠*”([docs](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold。 html))。在非常特殊的情况下,其中某些类别的代表性不足,显然建议格外小心(和手动检查),但此处的答案仅针对一般情况,而不针对其他假设情况... (2认同)

小智 11

如果您的目标变量是连续的,则使用简单的 KFold 交叉验证而不是 StratifiedKFold。

from sklearn.model_selection import KFold
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
Run Code Online (Sandbox Code Playgroud)


小智 10

打电话给split()这样的:

for i, (train_index, val_index) in enumerate(kf.split(x_train, y_train_categorical.argmax(1))):
    x_train_kf, x_val_kf = x_train[train_index], x_train[val_index]
    y_train_kf, y_val_kf = y_train[train_index], y_train[val_index]
Run Code Online (Sandbox Code Playgroud)


noc*_*mbi 5

我遇到了同样的问题,发现可以使用此util功能检查目标的类型:

from sklearn.utils.multiclass import type_of_target
type_of_target(y)

'multilabel-indicator'
Run Code Online (Sandbox Code Playgroud)

从其文档字符串:

  • 'binary':y包含<= 2个离散值,为1d或列向量。
  • 'multiclass':y包含两个以上的离散值,不是序列序列,是1d或列向量。
  • 'multiclass-multioutput':y是一个二维数组,包含两个以上的离散值,不是序列序列,且两个维度的大小均大于1。
  • “ multilabel-indicator”:y是一个标签指示符矩阵,是一个二维数组,至少两列,最多2个唯一值。

随着LabelEncoder您可以将您的类成数字的一维数组(给你的目标标签是在categoricals的一维数组/对象):

from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
y = label_encoder.fit_transform(target_labels)
Run Code Online (Sandbox Code Playgroud)