scikit learn:train_test_split,我可以确保在不同的数据集上进行相同的拆分

Ziq*_*iqi 8 scikit-learn

据我所知,train_test_split方法将数据集拆分为随机序列和测试子集.并且使用random_state = int可以确保每次调用方法时我们在此数据集上都有相同的拆分.

我的问题略有不同.

我有两个数据集,A和B,它们包含相同的示例集,每个数据集中出现的这些示例的顺序也相同.但它们的关键区别在于每个数据集中的exmaples使用不同的功能集.

我想测试一下A中使用的功能是否比B中使用的功能更好.所以我想确保当我在A和B上调用train_test_split时,我可以在两个数据集上获得相同的分割,以便比较是有意义的.

这可能吗?我是否只需要确保两个数据集的两个方法调用中的random_state都相同?

谢谢

eqz*_*qzx 11

是的,随机状态就足够了.

>>> X, y = np.arange(10).reshape((5, 2)), range(5)
>>> X2 = np.hstack((X,X))
>>> X_train, X_test, _, _ = train_test_split(X,y, test_size=0.33, random_state=42)
>>> X_train2, X_test2, _, _ = train_test_split(X2,y, test_size=0.33, random_state=42)
>>> X_train
array([[4, 5],
       [0, 1],
       [6, 7]])
>>> X_train2
array([[4, 5, 4, 5],
       [0, 1, 0, 1],
       [6, 7, 6, 7]])
>>> X_test
array([[2, 3],
       [8, 9]])
>>> X_test2
array([[2, 3, 2, 3],
       [8, 9, 8, 9]])
Run Code Online (Sandbox Code Playgroud)


pim*_*314 6

查看train_test_split函数的代码,它在每次调用时在函数内设置随机种子。所以每次都会导致相同的分裂。我们可以检查这是否非常简单

X1 = np.random.random((200, 5))
X2 = np.random.random((200, 5))
y = np.arange(200)

X1_train, X1_test, y1_train, y1_test = model_selection.train_test_split(X1, y,
                                                                        test_size=0.1,
                                                                        random_state=42)
X2_train, X2_test, y2_train, y2_test = model_selection.train_test_split(X1, y,
                                                                        test_size=0.1,
                                                                        random_state=42)

print np.all(y1_train == y2_train)
print np.all(y1_test == y2_test)
Run Code Online (Sandbox Code Playgroud)

哪些输出:

True
True
Run Code Online (Sandbox Code Playgroud)

哪个好!解决此问题的另一种方法是在所有特征上创建一个训练和测试拆分,然后在训练之前拆分您的特征。但是,如果您处于需要同时执行这两项操作的奇怪情况(有时使用相似矩阵,您不希望在训练集中测试特征),那么您可以使用该StratifiedShuffleSplit函数返回属于数据的索引到每一组。例如:

n_splits = 1 
sss = model_selection.StratifiedShuffleSplit(n_splits=n_splits, 
                                             test_size=0.1,
                                             random_state=42)
train_idx, test_idx = list(sss.split(X, y))[0]
Run Code Online (Sandbox Code Playgroud)