如何从保存的 XGBoost 模型中获取参数

TX *_*Shi 4 python xgboost

我正在尝试使用以下参数训练 XGBoost 模型:

xgb_params = {
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    'lambda': 0.8,
    'alpha': 0.4,
    'max_depth': 10,
    'max_delta_step': 1,
    'verbose': True
}
Run Code Online (Sandbox Code Playgroud)

由于我的输入数据太大而无法完全加载到内存中,因此我调整了增量训练:

xgb_clf = xgb.train(xgb_params, input_data, num_boost_round=rounds_per_batch,
                    xgb_model=model_path)
Run Code Online (Sandbox Code Playgroud)

预测代码是

xgb_clf = xgb.XGBClassifier()
booster = xgb.Booster()
booster.load_model(model_path)
xgb_clf._Booster = booster
raw_probas = xgb_clf.predict_proba(x)
Run Code Online (Sandbox Code Playgroud)

结果似乎不错。但是当我尝试调用时xgb_clf.get_xgb_params(),我得到了一个 param dict,其中所有参数都设置为默认值。

我可以猜到根本原因是当我初始化模型时,我没有传入任何参数。所以模型是使用默认值初始化的,但是当它预测时,它使用了一个内部助推器,该助推器已经使用一些预拟合定义的参数。

但是,我想知道有什么方法可以在我将预训练的 booster 模型分配给 XGBClassifier 之后,看到用于训练 booster 的真实参数,而不是用于初始化分类器的参数。

yts*_*aig 6

您似乎在代码中将 sklearn API 与函数式 API 混合在一起,如果您坚持使用其中任何一个,您应该将参数保留在泡菜中。这是使用 sklearn API 的示例。

import pickle
import numpy as np
import xgboost as xgb
from sklearn.datasets import load_digits


digits = load_digits(2)
y = digits['target']
X = digits['data']

xgb_params = {
    'objective': 'binary:logistic',
    'reg_lambda': 0.8,
    'reg_alpha': 0.4,
    'max_depth': 10,
    'max_delta_step': 1,
}
clf = xgb.XGBClassifier(**xgb_params)
clf.fit(X, y, eval_metric='auc', verbose=True)

pickle.dump(clf, open("xgb_temp.pkl", "wb"))
clf2 = pickle.load(open("xgb_temp.pkl", "rb"))

assert np.allclose(clf.predict(X), clf2.predict(X))
print(clf2.get_xgb_params())
Run Code Online (Sandbox Code Playgroud)

产生

{'base_score': 0.5,
 'colsample_bylevel': 1,
 'colsample_bytree': 1,
 'gamma': 0,
 'learning_rate': 0.1,
 'max_delta_step': 1,
 'max_depth': 10,
 'min_child_weight': 1,
 'missing': nan,
 'n_estimators': 100,
 'objective': 'binary:logistic',
 'reg_alpha': 0.4,
 'reg_lambda': 0.8,
 'scale_pos_weight': 1,
 'seed': 0,
 'silent': 1,
 'subsample': 1}
Run Code Online (Sandbox Code Playgroud)