使用 ColumnTransformer 在 Sklearn 管道中的自定义转换器中访问 pandas 数据帧的列名称?

AJ *_* AJ 5 python pipeline dataframe pandas scikit-learn

我需要在使用列名称的管道中使用自定义转换器。但是,之前的管道转换将数据帧转换为 numpy 数组。我知道在管道拟合我可以从列转换器对象中检索列名称,但我需要拟合步骤中访问列名称。下面示例中的自定义转换器是一个简单的最小示例,仅用于说明,而不是真正的转换。

import pandas as pd

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn.base import BaseEstimator, TransformerMixin


class MyCustomTransformer(BaseEstimator, TransformerMixin):
    def my_custom_transformation(self, X):
        """
        Parameters
        ----------
        X: pandas dataframe
        """
        columns_to_keep = [col for col in X.columns if col.endswith(('_a', '_b'))]
        return columns_to_keep
    
    def fit(self, X, y=None):
        self.columns_to_keep = self.my_custom_transformation(X)
        return self

    def transform(self, X, y=None):
        return X[self.columns]

numeric_transformer = Pipeline(steps=[('minmax_scaler', MinMaxScaler())])
categorical_transformer = Pipeline(steps=[('onehot_encoder', OneHotEncoder(sparse=False))])

column_transformer = ColumnTransformer(transformers=[
    ('numeric_transformer', numeric_transformer, ['num']),
    ('categorical_transformer', categorical_transformer, ['cat']),
])

pipeline = Pipeline(steps=[
    ('column_transformer', column_transformer),
    ('my_custom_transformer', MyCustomTransformer())
])

df = pd.DataFrame(data={'num': [1,2,3], 'cat':['a', 'b', 'c']})
pipeline.fit(data_df)
Run Code Online (Sandbox Code Playgroud)

理想的结果是:

transformed_df = pipeline.transform(df)
print(transformed_df)
>>>    num    cat_a    cat_b
    0    0        1        0
    1  0.5        0        1
    2    1        0        0
Run Code Online (Sandbox Code Playgroud)

column_transformer 中的转换将数据帧转换为 numpy 数组,然后将其传递给自定义转换器。显然,这会导致错误,因为您无法从 numpy 数组中获取列名。

我无法使用索引来访问列,因为 one-hot 编码可能会导致以前未知的列数。

如果我可以在自定义转换器的 fit 方法中访问 ColumnTransformer 对象,我可以检索列名称,然后创建一个 pandas 数据框以在上面的 fit 方法中使用(?),但我还没有成功找到一种方法这。

任何帮助将非常感激。

Del*_*ine 0

请参阅我建议的实现,ColumnTransformerWithNames以响应如何使用列转换器获取特征名称ColumnTransformer您可以替换对的调用ColumnTransformerWithNames,管道的输出将是带有列名称的 DataFrame =)