在sklearn中StratifiedKFold和StratifiedShuffleSplit之间的区别

gab*_*how 40 python scikit-learn cross-validation

从标题我想知道它们之间有什么区别

StratifiedKFold,参数shuffle = True

StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
Run Code Online (Sandbox Code Playgroud)

StratifiedShuffleSplit

StratifiedShuffleSplit(n_splits=10, test_size=’default’, train_size=None, random_state=0)
Run Code Online (Sandbox Code Playgroud)

使用StratifiedShuffleSplit有什么好处

Ken*_*yme 54

在KFolds中,即使是随机播放,每个测试集也不应重叠.使用KFolds和shuffle,数据在开始时被洗牌一次,然后被分成所需分割的数量.测试数据总是其中一个分裂,列车数据是其余的.

在ShuffleSplit中,每次都会对数据进行混洗,然后进行拆分.这意味着测试集可能在分裂之间重叠.

请参阅此块以获取差异示例.注意ShuffleSplit的测试集中元素的重叠.

splits = 5

tx = range(10)
ty = [0] * 5 + [1] * 5

from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from sklearn import datasets

kfold = StratifiedKFold(n_splits=splits, shuffle=True, random_state=42)
shufflesplit = StratifiedShuffleSplit(n_splits=splits, random_state=42, test_size=2)

print("KFold")
for train_index, test_index in kfold.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)

print("Shuffle Split")
for train_index, test_index in shufflesplit.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)
Run Code Online (Sandbox Code Playgroud)

输出:

KFold
TRAIN: [0 2 3 4 5 6 7 9] TEST: [1 8]
TRAIN: [0 1 2 3 5 7 8 9] TEST: [4 6]
TRAIN: [0 1 3 4 5 6 8 9] TEST: [2 7]
TRAIN: [1 2 3 4 6 7 8 9] TEST: [0 5]
TRAIN: [0 1 2 4 5 6 7 8] TEST: [3 9]
Shuffle Split
TRAIN: [8 4 1 0 6 5 7 2] TEST: [3 9]
TRAIN: [7 0 3 9 4 5 1 6] TEST: [8 2]
TRAIN: [1 2 5 6 4 8 9 0] TEST: [3 7]
TRAIN: [4 6 7 8 3 5 1 2] TEST: [9 0]
TRAIN: [7 2 6 5 4 3 0 9] TEST: [1 8]
Run Code Online (Sandbox Code Playgroud)

至于何时使用它们,我倾向于使用KFolds进行任何交叉验证,并且我使用ShuffleSplit以2分为我的火车/测试集拆分.但我确信两者都有其他用例.


Cat*_*lts 34

@Ken Syme已经有了很好的答案.我只想添加一些东西.

  • StratifiedKFold是一种变体KFold.首先,StratifiedKFold在将数据拆分为n_splits零件并完成后,将数据混洗.现在,它将使用每个部分作为测试集.请注意,它只会在分割之前一次洗牌数据.

随着 shuffle = True,数据被您的数据洗牌random_state.否则,数据将被洗牌np.random(默认情况下).例如,with n_splits = 4,并且您的数据有3个类(标签)用于y(因变量).4个测试集覆盖所有数据,没有任何重叠.

在此输入图像描述

  • 另一方面,StratifiedShuffleSplit是一种变化ShuffleSplit.首先,StratifiedShuffleSplit将数据混洗,然后将数据拆分为多个n_splits部分.但是,它还没有完成.完成此步骤后,StratifiedShuffleSplit选择一个部件作为测试集.然后它重复相同的过程n_splits - 1,以获得n_splits - 1其他测试集.使用相同的数据查看下面的图片,但这次,4个测试集未涵盖所有数据,即测试集之间存在重叠.

在此输入图像描述

因此,这里的区别在于StratifiedKFold 只需要一次洗牌和拆分,因此测试集不会重叠,而StratifiedShuffleSplit 每次在拆分之前n_splits进行随机播放,并且它会分裂次数,测试集可以重叠.

  • 注意:这两种方法使用"分层折叠"(这两个名称中出现"分层"的原因).这意味着每个部分保留与原始数据相同百分比的每个类(标签)样本.您可以在cross_validation文档中阅读更多内容

  • 完美的解释!! (4认同)

Bla*_*ven 11

KFold、StratifiedKFold、StratifiedShuffleSplit 的输出示例: KFold、StratifiedKFold、StratifiedShuffleSplit 的输出示例

上面的图形输出是@Ken Syme's 代码的扩展:

from sklearn.model_selection import KFold, StratifiedKFold, StratifiedShuffleSplit
SEED = 43
SPLIT = 3

X_train = [0,1,2,3,4,5,6,7,8]
y_train = [0,0,0,0,0,0,1,1,1]   # note 6,7,8 are labelled class '1'

print("KFold, shuffle=False (default)")
kf = KFold(n_splits=SPLIT, random_state=SEED)
for train_index, test_index in kf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("KFold, shuffle=True")
kf = KFold(n_splits=SPLIT, shuffle=True, random_state=SEED)
for train_index, test_index in kf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("\nStratifiedKFold, shuffle=False (default)")
skf = StratifiedKFold(n_splits=SPLIT, random_state=SEED)
for train_index, test_index in skf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)
    
print("StratifiedKFold, shuffle=True")
skf = StratifiedKFold(n_splits=SPLIT, shuffle=True, random_state=SEED)
for train_index, test_index in skf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)
    
print("\nStratifiedShuffleSplit")
sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=3)
for train_index, test_index in sss.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("\nStratifiedShuffleSplit (can customise test_size)")
sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=2)
for train_index, test_index in sss.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)
Run Code Online (Sandbox Code Playgroud)