如何在 Tensorflow 2.x Keras 自定义层中使用多个输入?

Sal*_*lat 9 python layer keras tensorflow

我正在尝试在 Tensorflow-Keras 的自定义层中使用多个输入。用法可以是任何东西,现在它被定义为将蒙版与图像相乘。我已经搜索过,我能找到的唯一答案是 TF 1.x,所以它没有任何好处。

class mul(layers.Layer):
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    # I've added pass because this is the simplest form I can come up with.
    pass

def call(self, inputs):
    # magic happens here and multiplications occur
    return(Z)
Run Code Online (Sandbox Code Playgroud)

Mat*_*gro 11

编辑:从 TensorFlow v2.3/2.4 开始,合同是使用该call方法的输入列表。对于keras(不是tf.keras)我认为下面的答案仍然适用。

实现多个输入是在call你的类的方法中完成的,有两种选择:

  • 列表输入,这里的inputs参数应该是一个包含所有输入的列表,这里的优点是它可以是可变大小的。您可以索引列表,或使用=运算符解压缩参数:

      def call(self, inputs):
          Z = inputs[0] * inputs[1]
    
          #Alternate
          input1, input2 = inputs
          Z = input1 * input2
    
          return Z
    
    Run Code Online (Sandbox Code Playgroud)
  • call方法中有多个输入参数,但在定义层时参数的数量是固定的:

      def call(self, input1, input2):
          Z = input1 * input2
    
          return Z
    
    Run Code Online (Sandbox Code Playgroud)

无论您选择哪种方法来实现这取决于您是否需要固定大小或可变大小的参数数量。当然,每个方法都改变了层必须被调用的方式,或者通过传递参数列表,或者通过在函数调用中一个接一个传递参数。

您也可以*args在第一种方法中使用,以允许call具有可变数量参数的方法,但总体上 keras 自己的具有多个输入(如ConcatenateAdd)的层是使用​​列表实现的。

  • 您必须使用列表,而不是多个参数。请参阅此“文档”:https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/keras/engine/base_layer.py#L930-L941 (2认同)
  • 多个输入参数破坏了 `tf.keras.Layer.call()` (https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#call) 的契约,该契约明确指出了 `inputs`应该是多个输入张量的列表/元组。 (2认同)