Tensorflow compute_output_shape() 不适用于自定义层

Cha*_*mer 5 tensorflow

我在 Keras 中创建了一个自定义层(称为 GraphGather),但输出张量打印为:

张量(“graph_gather/tanh:0”,形状=(?,?),dtype=float32)

由于某种原因,形状被返回为 (?,?),这导致下一个密集层引发以下错误:

ValueError:Dense应定义输入的最后一个维度。找到了None

GraphGather层代码如下:

class GraphGather(tf.keras.layers.Layer):

  def __init__(self, batch_size, num_mols_in_batch, activation_fn=None, **kwargs):
    self.batch_size = batch_size
    self.num_mols_in_batch = num_mols_in_batch
    self.activation_fn = activation_fn
    super(GraphGather, self).__init__(**kwargs)

  def build(self, input_shape):
    super(GraphGather, self).build(input_shape)

 def call(self, x, **kwargs):
    # some operations (most of def call omitted)
    out_tensor = result_of_operations() # this line is pseudo code
    if self.activation_fn is not None:
      out_tensor = self.activation_fn(out_tensor)
    out_tensor = out_tensor
    return out_tensor

  def compute_output_shape(self, input_shape):
    return (self.num_mols_in_batch, 2 * input_shape[0][-1])}
Run Code Online (Sandbox Code Playgroud)

I have also tried hardcoding compute_output_shape to be: python def compute_output_shape(self, input_shape): return (64, 150) ``` 然而打印时的输出张量仍然是

张量(“graph_gather/tanh:0”,形状=(?,?),dtype=float32)

这会导致上面写的 ValueError 。


系统信息

  • 已编写自定义代码
  • **操作系统平台和发行版*:Linux Ubuntu 16.04
  • TensorFlow 版本(使用下面的命令):1.5.0
  • Python 版本:3.5.5

小智 5

我有同样的问题。我的解决方法是在 call 方法中添加以下几行:

input_shape = tf.shape(x)
Run Code Online (Sandbox Code Playgroud)

进而:

return tf.reshape(out_tensor, self.compute_output_shape(input_shape))
Run Code Online (Sandbox Code Playgroud)

我还没有遇到任何问题。