scikit-learn StratifiedShuffleSplit KeyError 与索引

zin*_*rim 5 python-3.x pandas scikit-learn

这是我的熊猫数据框lots_not_preprocessed_usd

<class 'pandas.core.frame.DataFrame'>
Index: 78718 entries, 2017-09-12T18-38-38-076065 to 2017-10-02T07-29-40-245031
Data columns (total 20 columns):
created_year              78718 non-null float64
price                     78718 non-null float64
........
decade                    78718 non-null int64
dtypes: float64(8), int64(1), object(11)
memory usage: 12.6+ MB
Run Code Online (Sandbox Code Playgroud)

头(1):

artist_name_normalized  house   created_year    description exhibited_in    exhibited_in_museums    height  images  max_estimated_price min_estimated_price price   provenance  provenance_estate_of    sale_date   sale_id sale_title  style   title   width   decade
    key                                                                             
    2017-09-12T18-38-38-076065  NaN c11 1862.0  An Album and a small Quantity of unframed Draw...   NaN NaN NaN NaN 535.031166  267.515583  845.349242  NaN NaN 1998-06-21  8033    OILS, WATERCOLOURS & DRAWINGS FROM 18TH - 20TH...   watercolor painting An Album and a small Quantity of unframed Draw...   NaN 186
Run Code Online (Sandbox Code Playgroud)

我的脚本:

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size =0.2, random_state=42)
for train_index, test_index  in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):
    strat_train_set = lots_not_preprocessed_usd.loc[train_index]
    strat_test_set  = lots_not_preprocessed_usd.loc[test_index]
Run Code Online (Sandbox Code Playgroud)

我收到错误消息

KeyError                                  Traceback (most recent call last)
<ipython-input-224-cee2389254f2> in <module>()
      3 split = StratifiedShuffleSplit(n_splits=1, test_size =0.2, random_state=42)
      4 for train_index, test_index  in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):
----> 5     strat_train_set = lots_not_preprocessed_usd.loc[train_index]
      6     strat_test_set  = lots_not_preprocessed_usd.loc[test_index]

......

KeyError: 'None of [[32199 67509 69003 ..., 44204  2809 56726]] are in the [index]'
Run Code Online (Sandbox Code Playgroud)

我的索引似乎有问题(例如 2017-09-12T18-38-38-076065),我不明白。问题出在哪里?

如果我使用另一个拆分,它会按预期工作:

from sklearn.model_selection import train_test_split

train_set, test_set = train_test_split(lots_not_preprocessed_usd, test_size=0.2, random_state=42)
Run Code Online (Sandbox Code Playgroud)

Flo*_*oor 6

使用时,.loc您需要为 row_indexer 传递相同的索引,因此.iloc当您想使用普通数字索引器而不是.loc. 在 for 循环中,train_index 和 text_index 不是日期时间,因为split.split(X,y)返回随机索引数组。

...
for train_index, test_index  in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):
    strat_train_set = lots_not_preprocessed_usd.iloc[train_index]
    strat_test_set  = lots_not_preprocessed_usd.iloc[test_index]
Run Code Online (Sandbox Code Playgroud)

示例示例

lots_not_preprocessed_usd = pd.DataFrame({'some':np.random.randint(5,10,100),'decade':np.random.randint(5,10,100)},index= pd.date_range('5-10-15',periods=100))

for train_index, test_index  in split.split(lots_not_preprocessed_usd, lots_not_preprocessed_usd['decade']):

    strat_train_set = lots_not_preprocessed_usd.iloc[train_index]
    strat_test_set  = lots_not_preprocessed_usd.iloc[test_index]
Run Code Online (Sandbox Code Playgroud)

示例输出:

strat_train_set.head()
Run Code Online (Sandbox Code Playgroud)
          十年一些
2015-08-02 6 7
2015-06-14 7 6
2015-08-14 7 9
2015-06-25 9 5
2015-05-15 7 9