分类决策树中的学习曲线是什么意思?

zyw*_*l11 2 decision-tree grid-search

我在分析中使用了分类决策树。首先,我将整个数据分为训练和测试 - 60%:40%。然后我在训练集上使用 GridSearch 来获得最佳评分模型 (max_深度 = 7)。然后我在交叉验证集和训练集上绘制了学习曲线。这是我得到的图表。看起来两条线是重叠的。那么它告诉我什么?我的模型没有过度拟合吗?一般来说,为什么我们在分析中需要学习曲线?

链接到我的学习曲线图片

多谢!

ana*_*s95 6

学习曲线显示了估计器针对不同数量的训练样本的验证和训练分数。它是一个工具,可以找出我们从添加更多训练数据中获益多少,以及估计器是否会因方差误差或偏差误差而遭受更多损失。

机器学习曲线可用于多种目的,包括比较不同的算法、在设计过程中选择模型参数、调整优化以提高收敛性以及确定用于训练的数据量。

您没有充分利用学习曲线工具,因为您从非常高的训练规模开始,它不允许您很好地看到模型的行为。

这是一个示例,显示了一个图,其中您开始使用较小的训练大小进行分析,而另一个图则从非常大的训练大小(您的案例)开始进行分析。为此,您只需改变 sklearn.model_selection.learning_curve 的 train_sizes 参数即可。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from get_csv_data import HandleData
from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit

def plot_learning_curve(estimator, X, y, ax=None, ylim=(0.5, 1.01), cv=None, n_jobs=4, train_sizes=np.linspace(.1, 1.0, 5)):

    train_sizes, train_scores, test_scores = \
        learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
              
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)

    # Plot learning curve
    if ylim is not None:
        ax.set_ylim(*ylim)
    ax.set_xlabel("Training examples")
    ax.set_ylabel("Score")
    ax.plot(train_sizes, train_scores_mean, 'o-', color="r", label="Training score")
    ax.plot(train_sizes, test_scores_mean, 'o-', color="g", label="Cross-validation score")
    ax.legend(loc="best")

    return plt

fig, (ax1, ax2) = plt.subplots(1, 2)

data = HandleData(oneHotFlag=False)
#get the data
X, y = data.get_synthatic_data()

cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
estimator = SVC()
plot_learning_curve(estimator, X, y, ax = ax1, cv=cv, train_sizes=np.linspace(.1, 1.0, 5))
plot_learning_curve(estimator, X, y, ax = ax2, cv=cv, train_sizes=np.linspace(.5, 1.0, 5))

plt.show()
Run Code Online (Sandbox Code Playgroud)

输出: