Dar*_*ioB 6 python pickle pytest scikit-learn joblib
我构建了一个自定义的 sklearn 管道,如下所示:
pipeline = make_pipeline(
SelectColumnsTransfomer(features_to_use),
ToDummiesTransformer('feature_0', prefix='feat_0', drop_first=True, dtype=bool), # Dummify customer_type
ToDummiesTransformer('feature_1', prefix='feat_1'), # Dummify the feature
ToDummiesTransformer('feature_2', prefix='feat_2'), # Dummify
ToDummiesTransformer('feature_3', prefix='feat_3'), # Dummify
)
pipeline.fit(df)
Run Code Online (Sandbox Code Playgroud)
类SelectColumnsTransfomer和是实现和ToDummiesTransformer的自定义 sklearn 步骤。为了序列化这个对象,我使用BaseEstimatorTransformerMixin
from sklearn.externals import joblib
joblib.dump(pipeline, 'data_pipeline.joblib')
Run Code Online (Sandbox Code Playgroud)
但是当我反序列化时
pipeline = joblib.load('data_pipeline.joblib')
Run Code Online (Sandbox Code Playgroud)
我明白了AttributeError: module '__main__' has no attribute 'SelectColumnsTransfomer'。
我已经阅读了其他类似的问题,并按照此处这篇博文中的说明进行操作,但无法解决问题。我正在复制粘贴这些类,并将它们导入到代码中。如果我创建此练习的简化版本,整个事情都会起作用,出现问题是因为我正在使用 pytest 运行一些测试,当我运行 pytest 时,它似乎看不到我的自定义类,实际上还有其他部分错误
self = <sklearn.externals.joblib.numpy_pickle.NumpyUnpickler object at 0x7f821508a588>, module = '__main__', name = 'SelectColumnsTransfomer'提示我即使在测试中导入也NumpyUnpickler看不到它。SelectColumnsTransfomer
我的测试代码
import pytest
from app.pipeline import * # the pipeline objects
# SelectColumnsTransfomer and ToDummiesTransformer
# are here!
@pytest.fixture(scope="module")
def clf():
pipeline = joblib.load("persistence/data_pipeline.joblib")
return clf
def test_fake(clf):
assert True
Run Code Online (Sandbox Code Playgroud)
当我尝试保存像这样的 Pytorch 类时,我遇到了相同的错误消息:
import torch.nn as nn
class custom(nn.Module):
def __init__(self):
super(custom, self).__init__()
print("Class loaded")
model = custom()
Run Code Online (Sandbox Code Playgroud)
然后使用 Joblib 转储该模型,如下所示:
from joblib import dump
dump(model, 'some_filepath.jobjib')
Run Code Online (Sandbox Code Playgroud)
问题是我在Kaggle内核中运行上面的代码。然后下载转储的文件并尝试在本地使用此脚本加载它:
from joblib import load
model = load(model, 'some_filepath.jobjib')
Run Code Online (Sandbox Code Playgroud)
我解决这个问题的方法是在我的计算机上本地运行所有这些代码片段,而不是创建类并将其转储到 Kaggle 上,而是将其加载到我的本地计算机上。想要在此处添加此内容,因为 @DarioB 对答案的评论让我困惑,他们对“函数”的引用不适用于我的更简单的情况。