sim*_*eon 5 reinforcement-learning neural-network q-learning theano keras
我在Keras中有一个网络,其中有很多输出,但是,我的训练数据一次只能提供单个输出的信息。
目前,我的训练方法是对有问题的输入进行预测,更改我正在训练的特定输出的值,然后进行单批更新。如果我是对的,这与将所有输出的损失设置为零(除了我尝试训练的损失)相同。
有没有更好的办法?我尝试过权重设置,但我正在训练的所有输出都设置为零权重,但是没有给我期望的结果?
我正在使用Theano后端。
Ser*_*hiy 18
假设您想从多个层返回输出,可能来自一些中间层,但您只需要优化一个目标输出。您可以这样做:
inputs = Input(shape=(784,))
x = Dense(64, activation='relu')(inputs)
# you want to extract these values
useful_info = Dense(32, activation='relu', name='useful_info')(x)
# final output. used for loss calculation and optimization
result = Dense(1, activation='softmax', name='result')(useful_info)
Run Code Online (Sandbox Code Playgroud)
None为额外输出:给出None的输出,你不想使用损耗计算和优化
model = Model(inputs=inputs, outputs=[result, useful_info])
model.compile(optimizer='rmsprop',
loss=['categorical_crossentropy', None],
metrics=['accuracy'])
Run Code Online (Sandbox Code Playgroud)
model.fit(my_inputs, {'result': train_labels}, epochs=.., batch_size=...)
# this also works:
#model.fit(my_inputs, [train_labels], epochs=.., batch_size=...)
Run Code Online (Sandbox Code Playgroud)
拥有一个模型,您只能运行predict一次以获得所需的所有输出:
predicted_labels, useful_info = model.predict(new_x)
Run Code Online (Sandbox Code Playgroud)
为了实现这一目标,我最终使用了“函数式 API”。您基本上可以使用相同的输入层和隐藏层但使用不同的输出层来创建多个模型。
例如:
https://keras.io/getting-started/function-api-guide/
from keras.layers import Input, Dense
from keras.models import Model
# This returns a tensor
inputs = Input(shape=(784,))
# a layer instance is callable on a tensor, and returns a tensor
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions_A = Dense(1, activation='softmax')(x)
predictions_B = Dense(1, activation='softmax')(x)
# This creates a model that includes
# the Input layer and three Dense layers
modelA = Model(inputs=inputs, outputs=predictions_A)
modelA.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
modelB = Model(inputs=inputs, outputs=predictions_B)
modelB.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1524 次 |
| 最近记录: |