bec*_*cko 3 python python-3.x keras tf.keras tensorflow2.0
我在 Tensorflow 2.0 中使用 Keras API。
例如,假设我想在我的模型中有两个密集层,称为layer1和layer2。但我想绑定它们的权重,这样权重矩阵 inlayer1总是等于 的权重矩阵的转置layer2。
我怎样才能做到这一点?
您可以为此定义一个自定义 Keras 层,您可以在其中传递参考Dense层。
自定义密集层:
class CustomDense(Layer):
def __init__(self, reference_layer):
super(CustomDense, self).__init__()
self.ref_layer = reference_layer
def call(self, inputs):
weights = self.ref_layer.get_weights()[0]
bias = self.ref_layer.get_weights()[1]
weights = tf.transpose(weights)
x = tf.linalg.matmul(inputs, weights) + bias
return x
Run Code Online (Sandbox Code Playgroud)
现在您使用Functional-API将此层添加到您的模型中。
inp = Input(shape=(5))
dense = Dense(5)
transposed_dense = CustomDense(dense)
#model
x = dense(inp)
x = transposed_dense(x)
model = Model(inputs=inp, outputs=x)
model.summary()
'''
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 5)] 0
_________________________________________________________________
dense_1 (Dense) (None, 5) 30
_________________________________________________________________
custom_dense_1 (CustomDense) (None, 5) 30
=================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
_________________________________________________________________
'''
Run Code Online (Sandbox Code Playgroud)
如您所见dense,custom_dense共享了 30 个参数。这里custom_dense只是使用dense层的转置权重进行密集操作,它没有自己的参数。
编辑 1:在评论中回答问题(子分类层如何获得 #params?):
层类跟踪传递给它的__init__方法的所有对象。
transposed_dense._layers
# [<tensorflow.python.keras.layers.core.Dense at 0x7fc3e0874f28>]
Run Code Online (Sandbox Code Playgroud)
以上参数将给出正在跟踪的依赖层。所有子属性权重都可以视为:
transposed_dense._gather_children_attribute("weights")
#[<tf.Variable 'dense_9/kernel:0' shape=(10, 5) dtype=float32>,
# <tf.Variable 'dense_9/bias:0' shape=(5,) dtype=float32>]
Run Code Online (Sandbox Code Playgroud)
因此,当我们model.summary()在内部调用 It 时调用count_params()each Layer,它计算所有trainable_variable包括 self 和 children 属性。
| 归档时间: |
|
| 查看次数: |
527 次 |
| 最近记录: |