d_g*_*_gg 5 stateful layer keras tensorflow recurrent-neural-network
我想请您帮助创建我的自定义图层。我想做的实际上非常简单:生成一个带有“状态”变量的输出层,即其值在每个批次中更新的张量。
为了使一切更清楚,以下是我想做的事情的片段:
def call(self, inputs)
c = self.constant
m = self.extra_constant
update = inputs*m + c
X_new = self.X_old + update
outputs = X_new
self.X_old = X_new
return outputs
Run Code Online (Sandbox Code Playgroud)
这里的想法很简单:
X_old中初始化为0def__ init__(self, ...)update被计算为层输入的函数X_new)X_old设置为等于X_new,以便在下一批中X_old不再等于零,而是等于X_new前一批。我发现它K.update可以完成这项工作,如示例所示:
X_new = K.update(self.X_old, self.X_old + update)
Run Code Online (Sandbox Code Playgroud)
这里的问题是,如果我尝试将层的输出定义为:
outputs = X_new
return outputs
Run Code Online (Sandbox Code Playgroud)
当我尝试 model.fit() 时,我会收到以下错误:
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have
gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
Run Code Online (Sandbox Code Playgroud)
即使我施加了layer.trainable = False并且没有为该层定义任何偏差或权重,我仍然会遇到此错误。另一方面,如果我这样做self.X_old = X_new, 的值X_old不会更新。
你们有解决方案来实现这个吗?我相信这应该没那么难,因为有状态 RNN 也有“类似”的功能。
在此先感谢您的帮助!
有时定义自定义层可能会变得混乱。您重写的某些方法将被调用一次,但它给您的印象是,就像许多其他 OO 库/框架一样,它们将被调用多次。
我的意思是:当您定义一个层并在模型中使用它时,您为重写call方法编写的 python 代码不会在前向或后向传递中直接调用。相反,当您调用 时,它仅被调用一次model.compile。它将 python 代码编译为计算图,而张量在其中流动的图就是训练和预测期间进行计算的内容。
这就是为什么如果你想通过放置一条print语句来调试你的模型,这是行不通的;您需要使用tf.print向图表添加打印命令。
这与您想要的状态变量的情况相同。old + update您需要调用new一个 Keras 函数将该操作添加到图中,而不是简单地分配给。
tf.Variable请注意,张量是不可变的,因此您需要像方法中那样定义状态__init__。
所以我相信这段代码更像是您正在寻找的代码:
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(CustomLayer, self).__init__(**kwargs)
self.state = tf.Variable(tf.zeros((3,3), 'float32'))
self.constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
self.extra_constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
self.trainable = False
def call(self, X):
m = self.constant
c = self.extra_constant
outputs = self.state + tf.matmul(X, m) + c
tf.keras.backend.update(self.state, tf.reduce_sum(outputs, axis=0))
return outputs
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2204 次 |
| 最近记录: |