Joe*_*e_P 4 regression machine-learning neural-network keras
我正在 keras (python, backend: tensorflow) 中训练一个神经网络作为回归。因此,我的输出层不包含激活函数,我使用均方误差作为我的损失函数。
我的问题是:我想确保所有输出估计的总和(几乎)等于所有实际标签的总和。
我的意思是:我想确保不仅 (y_real)^i ~ (y_predict)^i 对于每个训练示例 i,而且还保证 sum(y_real) = sum(y_predict),对所有 i 求和。常规线性回归使添加此限制变得足够简单,但我没有看到神经网络有任何类似的东西。我可以将最终结果乘以 sum(y_real) / sum(y_predict),但如果我不想损害个人预测,恐怕这不是理想的方法。
我还有什么其他选择?
(我无法共享我的数据,也无法使用不同的数据轻松重现该问题,但这是按要求使用的代码:)
from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
model.add(Dense(128, activation = 'relu', input_dim = 459))
model.add(Dense(32, activation = 'relu'))
model.add(Dense(1))
model.compile(loss = 'mean_squared_error',
optimizer = 'adam')
model.fit(X_train, Y_train, epochs = 5, validation_data = (X_val,
Y_val), batch_size = 128)
Run Code Online (Sandbox Code Playgroud)
从优化的角度来看,您希望为问题引入等式约束。您正在寻找网络权重,以便预测y1_hat, y2_hat and y3_hat最小化标签的均方误差y1, y2, y3。此外,您希望保留以下内容:
sum(y1, y2, y3) = sum(y1_hat, y2_hat, y3_hat)
Run Code Online (Sandbox Code Playgroud)
因为您使用的是神经网络,所以您希望以这样一种方式强加此约束,即您仍然可以使用反向传播来训练网络。
一种方法是在损失函数中添加一个项来惩罚sum(y1, y2, y3)和之间的差异sum(y1_hat, y2_hat, y3_hat)。
最小工作示例:
import numpy as np
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
# Some random training data and labels
features = np.random.rand(100, 20)
labels = np.random.rand(100, 3)
# Simple neural net with three outputs
input_layer = Input((20,))
hidden_layer = Dense(16)(input_layer)
output_layer = Dense(3)(hidden_layer)
# Model
model = Model(inputs=input_layer, outputs=output_layer)
# Write a custom loss function
def custom_loss(y_true, y_pred):
# Normal MSE loss
mse = K.mean(K.square(y_true-y_pred), axis=-1)
# Loss that penalizes differences between sum(predictions) and sum(labels)
sum_constraint = K.square(K.sum(y_pred, axis=-1) - K.sum(y_true, axis=-1))
return(mse+sum_constraint)
# Compile with custom loss
model.compile(loss=custom_loss, optimizer='sgd')
model.fit(features, labels, epochs=1, verbose=1)
Run Code Online (Sandbox Code Playgroud)
请注意,这以“软”方式强加约束,而不是作为硬约束。你仍然会得到偏差,但网络应该以很小的方式学习权重。