类型错误:“NoneType”和“float”的实例之间不支持“>”

Oui*_*med 15 python machine-learning typeerror keras tensorflow

我有这段代码,它在 python 3 中引发了一个错误,这样的比较可以在 python 2 上工作,我该如何更改它?

import tensorflow as tf 
def train_set():
    class MyCallBacks(tf.keras.callbacks.Callback):
        def on_epoch_end(self,epoch,logs={}):
            if(logs.get('acc')>0.95):
                print('the training will stop !')
                self.model.stop_training=True
    callbacks=MyCallBacks()
    mnist_dataset=tf.keras.datasets.mnist 
    (x_train,y_train),(x_test,y_test)=mnist_dataset.load_data()
    x_train=x_train/255.0
    x_test=x_test/255.0
    classifier=tf.keras.Sequential([
                                    tf.keras.layers.Flatten(input_shape=(28,28)),
                                    tf.keras.layers.Dense(512,activation=tf.nn.relu),
                                    tf.keras.layers.Dense(10,activation=tf.nn.softmax)
                                    ])
    classifier.compile(
                        optimizer='sgd',
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy']
                       )    
    history=classifier.fit(x_train,y_train,epochs=20,callbacks=[callbacks])
    return history.epoch,history.history['acc'][-1]
train_set()
Run Code Online (Sandbox Code Playgroud)

Has*_*eeb 17

TensorFlow 2.0

DESIRED_ACCURACY = 0.979

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epochs, logs={}) :
        if(logs.get('acc') is not None and logs.get('acc') >= DESIRED_ACCURACY) :
            print('\nReached 99.9% accuracy so cancelling training!')
            self.model.stop_training = True

callbacks = myCallback()
Run Code Online (Sandbox Code Playgroud)


Log*_*son 12

似乎您的错误类似于Keras 中的带有回调的异常 - Tensorflow 2.0 - Python 尝试替换 logs.get('acc')logs.get('accuracy')


Mat*_*ava 10

它在 Python2 中有效,因为在 Python2 中您可以与之进行比较Nonefloat但这在 Python3 中是不可能的。

这条线

logs.get('acc')
Run Code Online (Sandbox Code Playgroud)

返回None,这是您的问题。

快速解决方案是将条件替换为

if logs.get('acc') is not None and logs.get('acc') > 0.95:
Run Code Online (Sandbox Code Playgroud)

如果logs.get('acc')是,None则上述条件将被短路,第二部分logs.get('acc') > 0.95将不会被评估,因此不会导致上述错误。


小智 6

使用“acc”而不是“accuracy”,您无需更改。


Mai*_*sad 5

在你的回调中尝试这个:

class myCallback(tf.keras.callbacks.Callback):
      def on_epoch_end(self, epoch, logs={}):
        print("---",logs,"---")
        '''
        if(logs.get('acc')>=0.99):
          print("Reached 99% accuracy so cancelling training!")
        '''
Run Code Online (Sandbox Code Playgroud)

它给了我这个 --- {'loss': 0.18487292938232422, 'acc': 0.94411665} ---

我有acc所以我用过,如果有accuracy我会用accuracy。所以记录一下你有什么然后使用它。

TF一直在经历重大的变化,所以稳打稳打是可以的,非常安全。