小编rom*_*zky的帖子

如何在`tensorflow.keras`中替换`keras.layers.merge._Merge`

我想使用tf.kerasAPI创建自定义合并层。但是,新 API 隐藏了keras.layers.merge._Merge我想要继承的类。

这样做的目的是创建一个可以对两个不同层的输出执行加权求和/合并的层。之前,在keraspython API(不是包含在 中的那个tensorflow.keras)我可以从keras.layers.merge._Merge类继承,现在不能从tensorflow.keras.

在我可以做到这一点之前

class RandomWeightedAverage(keras.layers.merge._Merge):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    def _merge_function(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
Run Code Online (Sandbox Code Playgroud)

现在我不能使用相同的逻辑,如果使用 tensorflow.keras

class RandomWeightedAverage(tf.keras.layers.merge._Merge):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    def _merge_function(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
Run Code Online (Sandbox Code Playgroud)

生产

AttributeError: module …
Run Code Online (Sandbox Code Playgroud)

python keras tensorflow tf.keras

4
推荐指数
1
解决办法
3942
查看次数

标签 统计

keras ×1

python ×1

tensorflow ×1

tf.keras ×1