How to use different feature matrices for sklearn.ensemble.StackingClassifier (with class inheritance)?

O.r*_*rka 3 python statistics classification machine-learning scikit-learn

I have a dataset where there a bunch of different types of data for each each sample and I'd like to use separate models for different data types and use them together in sklearn.ensemble.StackingClassifier. However, StackingClassifier takes the same feature matrix and applies different algorithms to it, then sends the probabilities to the meta classifier.

Is there a way to specify particular feature matrices (representing the same samples) that correspond with specific algorithms in the StackingClassifier?

If not, how can you use class inheritance of a StackingClassifier to adapt to this type of functionality?

Below is a very quick and non-elegant example (e.g., for demonstration only not for practicality) of using 2 feature sets (i.e., sepal features and pedal features from iris) from the same samples (i.e., iris samples). Each feature set uses a different algorithm and then the probabilities are used as input into the meta classifier.

Doing it this way is very tedious...

from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.model_selection import train_test_split

# Data
X_sepal = pd.DataFrame({'sepal_length': {'iris_0': 5.1,'iris_1': 4.9,'iris_2': 4.7,'iris_3': 4.6,'iris_4': 5.0,'iris_5': 5.4,'iris_6': 4.6,'iris_7': 5.0,'iris_8': 4.4,'iris_9': 4.9,'iris_10': 5.4,'iris_11': 4.8,'iris_12': 4.8,'iris_13': 4.3,'iris_14': 5.8,'iris_15': 5.7,'iris_16': 5.4,'iris_17': 5.1,'iris_18': 5.7,'iris_19': 5.1,'iris_20': 5.4,'iris_21': 5.1,'iris_22': 4.6,'iris_23': 5.1,'iris_24': 4.8,'iris_25': 5.0,'iris_26': 5.0,'iris_27': 5.2,'iris_28': 5.2,'iris_29': 4.7,'iris_30': 4.8,'iris_31': 5.4,'iris_32': 5.2,'iris_33': 5.5,'iris_34': 4.9,'iris_35': 5.0,'iris_36': 5.5,'iris_37': 4.9,'iris_38': 4.4,'iris_39': 5.1,'iris_40': 5.0,'iris_41': 4.5,'iris_42': 4.4,'iris_43': 5.0,'iris_44': 5.1,'iris_45': 4.8,'iris_46': 5.1,'iris_47': 4.6,'iris_48': 5.3,'iris_49': 5.0,'iris_50': 7.0,'iris_51': 6.4,'iris_52': 6.9,'iris_53': 5.5,'iris_54': 6.5,'iris_55': 5.7,'iris_56': 6.3,'iris_57': 4.9,'iris_58': 6.6,'iris_59': 5.2,'iris_60': 5.0,'iris_61': 5.9,'iris_62': 6.0,'iris_63': 6.1,'iris_64': 5.6,'iris_65': 6.7,'iris_66': 5.6,'iris_67': 5.8,'iris_68': 6.2,'iris_69': 5.6,'iris_70': 5.9,'iris_71': 6.1,'iris_72': 6.3,'iris_73': 6.1,'iris_74': 6.4,'iris_75': 6.6,'iris_76': 6.8,'iris_77': 6.7,'iris_78': 6.0,'iris_79': 5.7,'iris_80': 5.5,'iris_81': 5.5,'iris_82': 5.8,'iris_83': 6.0,'iris_84': 5.4,'iris_85': 6.0,'iris_86': 6.7,'iris_87': 6.3,'iris_88': 5.6,'iris_89': 5.5,'iris_90': 5.5,'iris_91': 6.1,'iris_92': 5.8,'iris_93': 5.0,'iris_94': 5.6,'iris_95': 5.7,'iris_96': 5.7,'iris_97': 6.2,'iris_98': 5.1,'iris_99': 5.7,'iris_100': 6.3,'iris_101': 5.8,'iris_102': 7.1,'iris_103': 6.3,'iris_104': 6.5,'iris_105': 7.6,'iris_106': 4.9,'iris_107': 7.3,'iris_108': 6.7,'iris_109': 7.2,'iris_110': 6.5,'iris_111': 6.4,'iris_112': 6.8,'iris_113': 5.7,'iris_114': 5.8,'iris_115': 6.4,'iris_116': 6.5,'iris_117': 7.7,'iris_118': 7.7,'iris_119': 6.0,'iris_120': 6.9,'iris_121': 5.6,'iris_122': 7.7,'iris_123': 6.3,'iris_124': 6.7,'iris_125': 7.2,'iris_126': 6.2,'iris_127': 6.1,'iris_128': 6.4,'iris_129': 7.2,'iris_130': 7.4,'iris_131': 7.9,'iris_132': 6.4,'iris_133': 6.3,'iris_134': 6.1,'iris_135': 7.7,'iris_136': 6.3,'iris_137': 6.4,'iris_138': 6.0,'iris_139': 6.9,'iris_140': 6.7,'iris_141': 6.9,'iris_142': 5.8,'iris_143': 6.8,'iris_144': 6.7,'iris_145': 6.7,'iris_146': 6.3,'iris_147': 6.5,'iris_148': 6.2,'iris_149': 5.9},'sepal_width': {'iris_0': 3.5,'iris_1': 3.0,'iris_2': 3.2,'iris_3': 3.1,'iris_4': 3.6,'iris_5': 3.9,'iris_6': 3.4,'iris_7': 3.4,'iris_8': 2.9,'iris_9': 3.1,'iris_10': 3.7,'iris_11': 3.4,'iris_12': 3.0,'iris_13': 3.0,'iris_14': 4.0,'iris_15': 4.4,'iris_16': 3.9,'iris_17': 3.5,'iris_18': 3.8,'iris_19': 3.8,'iris_20': 3.4,'iris_21': 3.7,'iris_22': 3.6,'iris_23': 3.3,'iris_24': 3.4,'iris_25': 3.0,'iris_26': 3.4,'iris_27': 3.5,'iris_28': 3.4,'iris_29': 3.2,'iris_30': 3.1,'iris_31': 3.4,'iris_32': 4.1,'iris_33': 4.2,'iris_34': 3.1,'iris_35': 3.2,'iris_36': 3.5,'iris_37': 3.6,'iris_38': 3.0,'iris_39': 3.4,'iris_40': 3.5,'iris_41': 2.3,'iris_42': 3.2,'iris_43': 3.5,'iris_44': 3.8,'iris_45': 3.0,'iris_46': 3.8,'iris_47': 3.2,'iris_48': 3.7,'iris_49': 3.3,'iris_50': 3.2,'iris_51': 3.2,'iris_52': 3.1,'iris_53': 2.3,'iris_54': 2.8,'iris_55': 2.8,'iris_56': 3.3,'iris_57': 2.4,'iris_58': 2.9,'iris_59': 2.7,'iris_60': 2.0,'iris_61': 3.0,'iris_62': 2.2,'iris_63': 2.9,'iris_64': 2.9,'iris_65': 3.1,'iris_66': 3.0,'iris_67': 2.7,'iris_68': 2.2,'iris_69': 2.5,'iris_70': 3.2,'iris_71': 2.8,'iris_72': 2.5,'iris_73': 2.8,'iris_74': 2.9,'iris_75': 3.0,'iris_76': 2.8,'iris_77': 3.0,'iris_78': 2.9,'iris_79': 2.6,'iris_80': 2.4,'iris_81': 2.4,'iris_82': 2.7,'iris_83': 2.7,'iris_84': 3.0,'iris_85': 3.4,'iris_86': 3.1,'iris_87': 2.3,'iris_88': 3.0,'iris_89': 2.5,'iris_90': 2.6,'iris_91': 3.0,'iris_92': 2.6,'iris_93': 2.3,'iris_94': 2.7,'iris_95': 3.0,'iris_96': 2.9,'iris_97': 2.9,'iris_98': 2.5,'iris_99': 2.8,'iris_100': 3.3,'iris_101': 2.7,'iris_102': 3.0,'iris_103': 2.9,'iris_104': 3.0,'iris_105': 3.0,'iris_106': 2.5,'iris_107': 2.9,'iris_108': 2.5,'iris_109': 3.6,'iris_110': 3.2,'iris_111': 2.7,'iris_112': 3.0,'iris_113': 2.5,'iris_114': 2.8,'iris_115': 3.2,'iris_116': 3.0,'iris_117': 3.8,'iris_118': 2.6,'iris_119': 2.2,'iris_120': 3.2,'iris_121': 2.8,'iris_122': 2.8,'iris_123': 2.7,'iris_124': 3.3,'iris_125': 3.2,'iris_126': 2.8,'iris_127': 3.0,'iris_128': 2.8,'iris_129': 3.0,'iris_130': 2.8,'iris_131': 3.8,'iris_132': 2.8,'iris_133': 2.8,'iris_134': 2.6,'iris_135': 3.0,'iris_136': 3.4,'iris_137': 3.1,'iris_138': 3.0,'iris_139': 3.1,'iris_140': 3.1,'iris_141': 3.1,'iris_142': 2.7,'iris_143': 3.2,'iris_144': 3.3,'iris_145': 3.0,'iris_146': 2.5,'iris_147': 3.0,'iris_148': 3.4,'iris_149': 3.0}})
X_petal = pd.DataFrame({'petal_length': {'iris_0': 1.4,'iris_1': 1.4,'iris_2': 1.3,'iris_3': 1.5,'iris_4': 1.4,'iris_5': 1.7,'iris_6': 1.4,'iris_7': 1.5,'iris_8': 1.4,'iris_9': 1.5,'iris_10': 1.5,'iris_11': 1.6,'iris_12': 1.4,'iris_13': 1.1,'iris_14': 1.2,'iris_15': 1.5,'iris_16': 1.3,'iris_17': 1.4,'iris_18': 1.7,'iris_19': 1.5,'iris_20': 1.7,'iris_21': 1.5,'iris_22': 1.0,'iris_23': 1.7,'iris_24': 1.9,'iris_25': 1.6,'iris_26': 1.6,'iris_27': 1.5,'iris_28': 1.4,'iris_29': 1.6,'iris_30': 1.6,'iris_31': 1.5,'iris_32': 1.5,'iris_33': 1.4,'iris_34': 1.5,'iris_35': 1.2,'iris_36': 1.3,'iris_37': 1.4,'iris_38': 1.3,'iris_39': 1.5,'iris_40': 1.3,'iris_41': 1.3,'iris_42': 1.3,'iris_43': 1.6,'iris_44': 1.9,'iris_45': 1.4,'iris_46': 1.6,'iris_47': 1.4,'iris_48': 1.5,'iris_49': 1.4,'iris_50': 4.7,'iris_51': 4.5,'iris_52': 4.9,'iris_53': 4.0,'iris_54': 4.6,'iris_55': 4.5,'iris_56': 4.7,'iris_57': 3.3,'iris_58': 4.6,'iris_59': 3.9,'iris_60': 3.5,'iris_61': 4.2,'iris_62': 4.0,'iris_63': 4.7,'iris_64': 3.6,'iris_65': 4.4,'iris_66': 4.5,'iris_67': 4.1,'iris_68': 4.5,'iris_69': 3.9,'iris_70': 4.8,'iris_71': 4.0,'iris_72': 4.9,'iris_73': 4.7,'iris_74': 4.3,'iris_75': 4.4,'iris_76': 4.8,'iris_77': 5.0,'iris_78': 4.5,'iris_79': 3.5,'iris_80': 3.8,'iris_81': 3.7,'iris_82': 3.9,'iris_83': 5.1,'iris_84': 4.5,'iris_85': 4.5,'iris_86': 4.7,'iris_87': 4.4,'iris_88': 4.1,'iris_89': 4.0,'iris_90': 4.4,'iris_91': 4.6,'iris_92': 4.0,'iris_93': 3.3,'iris_94': 4.2,'iris_95': 4.2,'iris_96': 4.2,'iris_97': 4.3,'iris_98': 3.0,'iris_99': 4.1,'iris_100': 6.0,'iris_101': 5.1,'iris_102': 5.9,'iris_103': 5.6,'iris_104': 5.8,'iris_105': 6.6,'iris_106': 4.5,'iris_107': 6.3,'iris_108': 5.8,'iris_109': 6.1,'iris_110': 5.1,'iris_111': 5.3,'iris_112': 5.5,'iris_113': 5.0,'iris_114': 5.1,'iris_115': 5.3,'iris_116': 5.5,'iris_117': 6.7,'iris_118': 6.9,'iris_119': 5.0,'iris_120': 5.7,'iris_121': 4.9,'iris_122': 6.7,'iris_123': 4.9,'iris_124': 5.7,'iris_125': 6.0,'iris_126': 4.8,'iris_127': 4.9,'iris_128': 5.6,'iris_129': 5.8,'iris_130': 6.1,'iris_131': 6.4,'iris_132': 5.6,'iris_133': 5.1,'iris_134': 5.6,'iris_135': 6.1,'iris_136': 5.6,'iris_137': 5.5,'iris_138': 4.8,'iris_139': 5.4,'iris_140': 5.6,'iris_141': 5.1,'iris_142': 5.1,'iris_143': 5.9,'iris_144': 5.7,'iris_145': 5.2,'iris_146': 5.0,'iris_147': 5.2,'iris_148': 5.4,'iris_149': 5.1},'petal_width': {'iris_0': 0.2,'iris_1': 0.2,'iris_2': 0.2,'iris_3': 0.2,'iris_4': 0.2,'iris_5': 0.4,'iris_6': 0.3,'iris_7': 0.2,'iris_8': 0.2,'iris_9': 0.1,'iris_10': 0.2,'iris_11': 0.2,'iris_12': 0.1,'iris_13': 0.1,'iris_14': 0.2,'iris_15': 0.4,'iris_16': 0.4,'iris_17': 0.3,'iris_18': 0.3,'iris_19': 0.3,'iris_20': 0.2,'iris_21': 0.4,'iris_22': 0.2,'iris_23': 0.5,'iris_24': 0.2,'iris_25': 0.2,'iris_26': 0.4,'iris_27': 0.2,'iris_28': 0.2,'iris_29': 0.2,'iris_30': 0.2,'iris_31': 0.4,'iris_32': 0.1,'iris_33': 0.2,'iris_34': 0.2,'iris_35': 0.2,'iris_36': 0.2,'iris_37': 0.1,'iris_38': 0.2,'iris_39': 0.2,'iris_40': 0.3,'iris_41': 0.3,'iris_42': 0.2,'iris_43': 0.6,'iris_44': 0.4,'iris_45': 0.3,'iris_46': 0.2,'iris_47': 0.2,'iris_48': 0.2,'iris_49': 0.2,'iris_50': 1.4,'iris_51': 1.5,'iris_52': 1.5,'iris_53': 1.3,'iris_54': 1.5,'iris_55': 1.3,'iris_56': 1.6,'iris_57': 1.0,'iris_58': 1.3,'iris_59': 1.4,'iris_60': 1.0,'iris_61': 1.5,'iris_62': 1.0,'iris_63': 1.4,'iris_64': 1.3,'iris_65': 1.4,'iris_66': 1.5,'iris_67': 1.0,'iris_68': 1.5,'iris_69': 1.1,'iris_70': 1.8,'iris_71': 1.3,'iris_72': 1.5,'iris_73': 1.2,'iris_74': 1.3,'iris_75': 1.4,'iris_76': 1.4,'iris_77': 1.7,'iris_78': 1.5,'iris_79': 1.0,'iris_80': 1.1,'iris_81': 1.0,'iris_82': 1.2,'iris_83': 1.6,'iris_84': 1.5,'iris_85': 1.6,'iris_86': 1.5,'iris_87': 1.3,'iris_88': 1.3,'iris_89': 1.3,'iris_90': 1.2,'iris_91': 1.4,'iris_92': 1.2,'iris_93': 1.0,'iris_94': 1.3,'iris_95': 1.2,'iris_96': 1.3,'iris_97': 1.3,'iris_98': 1.1,'iris_99': 1.3,'iris_100': 2.5,'iris_101': 1.9,'iris_102': 2.1,'iris_103': 1.8,'iris_104': 2.2,'iris_105': 2.1,'iris_106': 1.7,'iris_107': 1.8,'iris_108': 1.8,'iris_109': 2.5,'iris_110': 2.0,'iris_111': 1.9,'iris_112': 2.1,'iris_113': 2.0,'iris_114': 2.4,'iris_115': 2.3,'iris_116': 1.8,'iris_117': 2.2,'iris_118': 2.3,'iris_119': 1.5,'iris_120': 2.3,'iris_121': 2.0,'iris_122': 2.0,'iris_123': 1.8,'iris_124': 2.1,'iris_125': 1.8,'iris_126': 1.8,'iris_127': 1.8,'iris_128': 2.1,'iris_129': 1.6,'iris_130': 1.9,'iris_131': 2.0,'iris_132': 2.2,'iris_133': 1.5,'iris_134': 1.4,'iris_135': 2.3,'iris_136': 2.4,'iris_137': 1.8,'iris_138': 1.8,'iris_139': 2.1,'iris_140': 2.4,'iris_141': 2.3,'iris_142': 1.9,'iris_143': 2.3,'iris_144': 2.5,'iris_145': 2.3,'iris_146': 1.9,'iris_147': 2.0,'iris_148': 2.3,'iris_149': 1.8}})
y_iris = pd.Series({'iris_0': 'setosa','iris_1': 'setosa','iris_2': 'setosa','iris_3': 'setosa','iris_4': 'setosa','iris_5': 'setosa','iris_6': 'setosa','iris_7': 'setosa','iris_8': 'setosa','iris_9': 'setosa','iris_10': 'setosa','iris_11': 'setosa','iris_12': 'setosa','iris_13': 'setosa','iris_14': 'setosa','iris_15': 'setosa','iris_16': 'setosa','iris_17': 'setosa','iris_18': 'setosa','iris_19': 'setosa','iris_20': 'setosa','iris_21': 'setosa','iris_22': 'setosa','iris_23': 'setosa','iris_24': 'setosa','iris_25': 'setosa','iris_26': 'setosa','iris_27': 'setosa','iris_28': 'setosa','iris_29': 'setosa','iris_30': 'setosa','iris_31': 'setosa','iris_32': 'setosa','iris_33': 'setosa','iris_34': 'setosa','iris_35': 'setosa','iris_36': 'setosa','iris_37': 'setosa','iris_38': 'setosa','iris_39': 'setosa','iris_40': 'setosa','iris_41': 'setosa','iris_42': 'setosa','iris_43': 'setosa','iris_44': 'setosa','iris_45': 'setosa','iris_46': 'setosa','iris_47': 'setosa','iris_48': 'setosa','iris_49': 'setosa','iris_50': 'versicolor','iris_51': 'versicolor','iris_52': 'versicolor','iris_53': 'versicolor','iris_54': 'versicolor','iris_55': 'versicolor','iris_56': 'versicolor','iris_57': 'versicolor','iris_58': 'versicolor','iris_59': 'versicolor','iris_60': 'versicolor','iris_61': 'versicolor','iris_62': 'versicolor','iris_63': 'versicolor','iris_64': 'versicolor','iris_65': 'versicolor','iris_66': 'versicolor','iris_67': 'versicolor','iris_68': 'versicolor','iris_69': 'versicolor','iris_70': 'versicolor','iris_71': 'versicolor','iris_72': 'versicolor','iris_73': 'versicolor','iris_74': 'versicolor','iris_75': 'versicolor','iris_76': 'versicolor','iris_77': 'versicolor','iris_78': 'versicolor','iris_79': 'versicolor','iris_80': 'versicolor','iris_81': 'versicolor','iris_82': 'versicolor','iris_83': 'versicolor','iris_84': 'versicolor','iris_85': 'versicolor','iris_86': 'versicolor','iris_87': 'versicolor','iris_88': 'versicolor','iris_89': 'versicolor','iris_90': 'versicolor','iris_91': 'versicolor','iris_92': 'versicolor','iris_93': 'versicolor','iris_94': 'versicolor','iris_95': 'versicolor','iris_96': 'versicolor','iris_97': 'versicolor','iris_98': 'versicolor','iris_99': 'versicolor','iris_100': 'virginica','iris_101': 'virginica','iris_102': 'virginica','iris_103': 'virginica','iris_104': 'virginica','iris_105': 'virginica','iris_106': 'virginica','iris_107': 'virginica','iris_108': 'virginica','iris_109': 'virginica','iris_110': 'virginica','iris_111': 'virginica','iris_112': 'virginica','iris_113': 'virginica','iris_114': 'virginica','iris_115': 'virginica','iris_116': 'virginica','iris_117': 'virginica','iris_118': 'virginica','iris_119': 'virginica','iris_120': 'virginica','iris_121': 'virginica','iris_122': 'virginica','iris_123': 'virginica','iris_124': 'virginica','iris_125': 'virginica','iris_126': 'virginica','iris_127': 'virginica','iris_128': 'virginica','iris_129': 'virginica','iris_130': 'virginica','iris_131': 'virginica','iris_132': 'virginica','iris_133': 'virginica','iris_134': 'virginica','iris_135': 'virginica','iris_136': 'virginica','iris_137': 'virginica','iris_138': 'virginica','iris_139': 'virginica','iris_140': 'virginica','iris_141': 'virginica','iris_142': 'virginica','iris_143': 'virginica','iris_144': 'virginica','iris_145': 'virginica','iris_146': 'virginica','iris_147': 'virginica','iris_148': 'virginica','iris_149': 'virginica'})

# Training/Testing
idx_training, idx_testing = train_test_split(y_iris.index, stratify=y_iris, random_state=0)

# Classifiers
clf_sepal = AdaBoostClassifier(base_estimator=LinearSVC(random_state=0), random_state=0, algorithm='SAMME')
clf_petal = RandomForestClassifier(random_state=0)
clf_meta = LogisticRegression(random_state=0)

# Fitting base classifiers
clf_sepal.fit(X_sepal.loc[idx_training], y_iris.loc[idx_training])
clf_petal.fit(X_sepal.loc[idx_training], y_iris.loc[idx_training])

# Fitting meta classifier
clf_meta.fit(
    X=pd.concat([
        pd.DataFrame(clf_sepal.predict_proba(X_sepal.loc[idx_training]), index=idx_training, columns=pd.Index(clf_sepal.classes_).map(lambda j: "sepal__{}".format(j))),
        pd.DataFrame(clf_petal.predict_proba(X_petal.loc[idx_training]), index=idx_training, columns=pd.Index(clf_petal.classes_).map(lambda j: "petal__{}".format(j))),
    ], axis=1),
    y=y_iris.loc[idx_training],
)

# Predicting with meta classifier
y_hat = pd.Series(
    clf_meta.predict(
        X=pd.concat([
            pd.DataFrame(clf_sepal.predict_proba(X_sepal.loc[idx_testing]), index=idx_testing, columns=pd.Index(clf_sepal.classes_).map(lambda j: "sepal__{}".format(j))),
            pd.DataFrame(clf_petal.predict_proba(X_petal.loc[idx_testing]), index=idx_testing, columns=pd.Index(clf_petal.classes_).map(lambda j: "petal__{}".format(j))),
        ], axis=1),
    ),
    index=idx_testing,
)

print("Accuracy on test set:", np.mean(y_hat == y_iris.loc[idx_testing]))
# Accuracy on test set: 0.9736842105263158
Run Code Online (Sandbox Code Playgroud)

Ben*_*ger 5

您可以将列选择作为基本估计器管道的一部分进行。一种方法是 a ColumnTransformer,这对于目的来说有点冗长,但我所知道的替代方案(FunctionTransformer例如)有点不太健壮。

sepal_cols = ['sepal_length', 'sepal_width']
petal_cols = ['petal_length', 'petal_width']

X = X_iris  # as loaded from sklearn, or the hstack of your examples

pipe_sepal = Pipeline([
    ('select', ColumnTransformer([('sel', 'passthrough', sepal_cols)], remainder='drop')),  # remainder='drop' is the default, but I've included it for clarity
    ('clf', clf_sepal)
])
pipe_petal = Pipeline([
    ('select', ColumnTransformer([('sel', 'passthrough', petal_cols)], remainder='drop')),
    ('clf', clf_petal)
])

stack = StackingClassifier(
    estimators=[
        ('sepal', pipe_sepal),
        ('petal', pipe_petal),
    ],
    final_estimator=clf_meta,
    ...
)

stack.fit(X_train, y_train)
y_hat = stack.predict(X_test)
Run Code Online (Sandbox Code Playgroud)

您的手动方法除了乏味之外,在统计上也是不合理的:您的基本模型正在对自己的训练集进行预测,以用作元估计器的输入。这通常会导致元估计器优先考虑最过度拟合的基本估计器;我认为您的高测试分数(这似乎确实有效)只是因为虹膜相对容易?