Adi*_*hil 2 python machine-learning deep-learning keras tensorflow
我最近遇到了焦点损失函数,听说它主要用于不平衡数据集。所以我只是使用我在网上找到的这个简单的焦点损失函数(对于 Keras)在 Cifar10 数据集上进行了尝试。
我一直面临着我在最后提到的错误。我尝试了多种方法来解决它,但没有成功。请注意,我非常感谢您的帮助。谢谢你!
焦点损失
import keras.backend as K
ALPHA = 0.8
GAMMA = 2
def FocalLoss(targets, inputs, alpha=ALPHA, gamma=GAMMA):
inputs = K.flatten(inputs)
targets = K.flatten(targets)
BCE = K.binary_crossentropy(targets, inputs)
BCE_EXP = K.exp(-BCE)
focal_loss = K.mean(alpha * K.pow((1-BCE_EXP), gamma) * BCE)
return focal_loss
Run Code Online (Sandbox Code Playgroud)
输入数据
from keras.datasets import cifar10
(xtrain,ytrain),(xtest,ytest) = cifar10.load_data()
Run Code Online (Sandbox Code Playgroud)
神经网络
from keras.layers import Dense, Conv2D, Flatten, MaxPool2D
from keras.models import Sequential
from keras.optimizers import Adam
model = Sequential([
Conv2D(filters=64, kernel_size=(27,27), strides=(1,1), input_shape=(32,32,3),padding='same', activation='sigmoid'),
MaxPool2D(pool_size=(13,13), strides=(1,1), padding='valid'),
Conv2D(filters=32, kernel_size=(11,11), strides=(1,1), padding='valid', activation='sigmoid'),
Flatten(),
Dense(units=600, activation='sigmoid'),
Dense(units=128, activation='sigmoid'),
Dense(units=10, activation='softmax')
])
Run Code Online (Sandbox Code Playgroud)
编译和拟合
model.compile(loss=FocalLoss, optimizer=Adam(learning_rate=0.0001), metrics=['accuracy'])
model.fit(xtrain, ytrain, epochs=10, batch_size=120, validation_data=(xtest,ytest), verbose=2)
Run Code Online (Sandbox Code Playgroud)
拟合时出错
Epoch 1/10
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-52-52246069690d> in <module>()
----> 1 model.fit(xtrain, ytrain, epochs=10, batch_size=120, validation_data=(xtest,ytest), verbose=2)
10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
971 except Exception as e: # pylint:disable=broad-except
972 if hasattr(e, "ag_error_metadata"):
--> 973 raise e.ag_error_metadata.to_exception(e)
974 else:
975 raise
TypeError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
<ipython-input-50-e8cbeb45fe58>:12 FocalLoss *
BCE = K.binary_crossentropy(targets, inputs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper **
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:4829 binary_crossentropy
bce = target * math_ops.log(output + epsilon())
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1141 binary_op_wrapper
raise e
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1125 binary_op_wrapper
return func(x, y, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1457 _mul_dispatch
return multiply(x, y, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:509 multiply
return gen_math_ops.mul(x, y, name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:6176 mul
"Mul", x=x, y=y, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:506 _apply_op_helper
inferred_from[input_arg.type_attr]))
TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type uint8 of argument 'x'.
Run Code Online (Sandbox Code Playgroud)
注意
xtrain 和 ytrain 具有相同的数据类型。(即)'uint8'
问题与您的目标类型有关,它们是,int8但您需要将其强制转换为float32。我在损失里面做的,我也删除了扁平部分,这是一个错误
def FocalLoss(targets, inputs, alpha=ALPHA, gamma=GAMMA):
targets = K.cast(targets, 'float32')
BCE = K.binary_crossentropy(targets, inputs)
BCE_EXP = K.exp(-BCE)
focal_loss = K.mean(alpha * K.pow((1-BCE_EXP), gamma) * BCE)
return focal_loss
Run Code Online (Sandbox Code Playgroud)
这里是正在运行的笔记本:https://colab.research.google.com/drive/1E89tggfCvifuoJRdGuXTHuBQPvXFCYN4 ?usp=sharing
| 归档时间: |
|
| 查看次数: |
1086 次 |
| 最近记录: |