如何在 pytest 单元测试中比较 XGBoost 模型对象,一个初始化/安装,另一个从文件读取?

Jam*_*ams 6 python pytest xgboost

我对用于封装 XGBoost 模型的类进行了简单测试。为了测试这个类,我训练了一个 XGBoost 模型并将其保存到文件中,我想使用这个经过训练的模型,我将从文件中读取该模型来测试我的模型训练代码。我不确定如何才能最好地将使用已知参数/数据训练的 XGBoost 模型与我保存到文件的模型进行比较。例如,我训练并保存了一个 XGBoost 模型,如下所示:

# specify parameters to use for training the XGBoost model
params = {
    'max_depth': 6,  # the maximum depth of each tree
    'eta': 0.25,  # the training step for each iteration
    'silent': 1,  # logging mode - quiet
    'objective': 'reg:tweedie',
    'booster': 'gbtree',
    'subsample': 0.7,
    'gamma': 0.3,  # regularization parameter
    'colsample_bytree': 0.2,
    'rate_drop': 0.3,
    'skip_drop': 0.2,
    'early_stopping_rounds': 10,
    'eval_metric': ['rmse', 'mae'],  # error evaluation for multiclass training
}

# split X and y into train and test sets
features_train, features_test, target_train, target_test = \
    train_test_split(features, target, test_size=test_percentage, random_state=31)

# package the dataset splits as input for XGBoost
dtrain = xgb.DMatrix(features_train, label=target_train)
dtest = xgb.DMatrix(features_test, label=target_test)
evallist = [(dtest, 'eval'), (dtrain, 'train')]

# train the XGBoost model
xgbooster = xgb.train(params, dtrain, training_iterations, evallist, verbose_eval=0)
pickle.dump(xgbooster, open("/path/to/fitted_model.dat", "wb"))
Run Code Online (Sandbox Code Playgroud)

在我的模型类的(pytest)单元测试中,我想测试我是否按预期训练模型,因此我从文件中读取这个保存的模型,以便与应该匹配的模型进行比较:

def test_xgboost_fit():

    features_train_df = pd.read_csv("/path/to/features_train.csv"))
    labels_train_df = pd.read_csv("/path/to/labels_train.csv"))
    fixture_xgbooster = pickle.load(open("/path/to/fitted_model.dat", "rb"))

    # train/fit the model
    xgbooster = mymodelclass.XGBoostModel()
    xgbooster.fit(features_train_df, labels_train_df)

    # compare the trained model against the expected model read from file
    assert xgbooster.model == fixture_xgbooster
Run Code Online (Sandbox Code Playgroud)

在这里使用双等于似乎不足以进行比较(否则我还有其他问题,因为它表明具有相同参数并配备相同训练数据的两个模型不相等)。

我应该如何在测试中进行这种比较?或者有更好的方法来测试这段代码吗?