Bagging分类器的'max_samples'关键字如何影响每个基本估算器使用的样本数?

hkh*_*are 5 machine-learning scikit-learn grid-search

我想了解Bagging分类器的max_samples值如何影响每个基本估算器使用的样本数.

这是GridSearch输出:

GridSearchCV(cv=5, error_score='raise',
       estimator=BaggingClassifier(base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=1, spl... n_estimators=100, n_jobs=-1, oob_score=False,
         random_state=1, verbose=2, warm_start=False),
       fit_params={}, iid=True, n_jobs=-1,
       param_grid={'max_features': [0.6, 0.8, 1.0], 'max_samples': [0.6, 0.8, 1.0]},
       pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=2)
Run Code Online (Sandbox Code Playgroud)

在这里,我发现最好的参数是什么:

print gs5.best_score_, gs5.best_params_
0.828282828283 {'max_features': 0.6, 'max_samples': 1.0}
Run Code Online (Sandbox Code Playgroud)

现在,我正在挑选出最佳的网格搜索估算器,并尝试查看特定Bagging分类器在其100个基本决策树估算器集中使用的样本数.

val=[]
for i in np.arange(100):
    x = np.bincount(gs5.best_estimator_.estimators_samples_[i])[1]
    val.append(x)
print np.max(val)
print np.mean(val), np.std(val)

587
563.92 10.3399032877
Run Code Online (Sandbox Code Playgroud)

现在,训练集的大小是891.由于CV是5,891*0.8 = 712.8应该进入每个Bagging分类器评估,并且因为max_samples是1.0,891*0.5*1.0 = 712.8应该是每个基数的样本数估算器,还是接近它的东西?

那么,为什么数字在564 +/- 10范围内,最大值587在计算时应该接近712?谢谢.

bpa*_*hev 5

经过更多的研究,我想我已经弄清楚发生了什么。GridSearchCV对训练数据使用交叉验证来确定最佳参数,但它返回的估计器适合整个训练集,而不是 CV 折叠之一。这是有道理的,因为更多的训练数据通常更好。

因此,您从 GridSearchCV 返回的BaggingClassifier适合包含 891 个数据样本的完整数据集。确实如此,当 max_sample=1. 时,每个基估计器将从训练集中随机抽取 891 个样本。但是,默认情况下,样本是通过替换绘制的,因此由于重复,唯一样本的数量将少于样本总数。如果要不替换地绘制,将BaggingClassifier 的bootstrap 关键字设置为false。

现在,在没有替换的情况下绘制时,我们应该期望不同样本的数量与数据集的大小有多接近?

基于这个问题,当从一组 n 个样本中抽取 n 个带有替换的样本时,不同样本的预期数量是 n * (1-(n-1)/n) ^ n。当我们将 891 插入其中时,我们得到

>>> 891 * (1.- (890./891)**891)
563.4034437025824
Run Code Online (Sandbox Code Playgroud)

预期的样本数 (563.4) 非常接近您观察到的平均值 (563.8),因此似乎没有发生任何异常。