keras:返回model.summary()与scikit学习包装器

Chr*_*her 0 python summary wrapper scikit-learn keras

在使用keras时,我了解到使用包装器会对keras和scikit学习api请求产生不利影响。我对同时拥有这两种解决方案感兴趣。

变体1:scikit包装

from keras.wrappers.scikit_learn import KerasClassifier

    def model():
        model = Sequential()
        model.add(Dense(10, input_dim=4, activation='relu'))
        model.add(Dense(3, activation='softmax'))
        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
        return model

estimator = KerasClassifier(build_fn=model, epochs=100, batch_size=5)
model.fit(X, y)
Run Code Online (Sandbox Code Playgroud)

->这使我可以打印scikit命令,例如precision_score()或category_report()。但是,model.summary()不起作用:

AttributeError:“ KerasClassifier”对象没有属性“ summary”

形式2:无包装

model = Sequential()
model.add(Dense(10, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X, y, epochs=100, batch_size=5)
Run Code Online (Sandbox Code Playgroud)

->这使我可以打印model.summary()而不是scikit命令。

ValueError:不允许使用y的混合类型,类型为{'multiclass','multilabel-indicator'}

有办法同时使用两者吗?

Viv*_*mar 6

KerasClassifier只是实际包装的一个包装Modelkeras因此可以将keras api的实际方法路由到scikit中使用的方法,因此可以与scikit实用程序结合使用。但是在内部,它仅使用可通过使用访问的模型estimator.model

以上示例说明:

from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.datasets import make_classification
def model():
    model = Sequential()
    model.add(Dense(10, input_dim=20, activation='relu'))
    model.add(Dense(2, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

estimator = KerasClassifier(build_fn=model, epochs=100, batch_size=5)
X, y = make_classification()
estimator.fit(X, y)

# This is what you need
estimator.model.summary()
Run Code Online (Sandbox Code Playgroud)

输出为:

Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (None, 10)                210       
_________________________________________________________________
dense_10 (Dense)             (None, 2)                 22        
=================================================================
Total params: 232
Trainable params: 232
Non-trainable params: 0
_________________________________________________________________
Run Code Online (Sandbox Code Playgroud)