小编Cha*_*mer的帖子

Tensorflow compute_output_shape() 不适用于自定义层

我在 Keras 中创建了一个自定义层(称为 GraphGather),但输出张量打印为:

张量(“graph_gather/tanh:0”,形状=(?,?),dtype=float32)

由于某种原因,形状被返回为 (?,?),这导致下一个密集层引发以下错误:

ValueError:Dense应定义输入的最后一个维度。找到了None

GraphGather层代码如下:

class GraphGather(tf.keras.layers.Layer):

  def __init__(self, batch_size, num_mols_in_batch, activation_fn=None, **kwargs):
    self.batch_size = batch_size
    self.num_mols_in_batch = num_mols_in_batch
    self.activation_fn = activation_fn
    super(GraphGather, self).__init__(**kwargs)

  def build(self, input_shape):
    super(GraphGather, self).build(input_shape)

 def call(self, x, **kwargs):
    # some operations (most of def call omitted)
    out_tensor = result_of_operations() # this line is pseudo code
    if self.activation_fn is not None:
      out_tensor = self.activation_fn(out_tensor)
    out_tensor = out_tensor
    return out_tensor

  def compute_output_shape(self, input_shape):
    return (self.num_mols_in_batch, 2 * input_shape[0][-1])}
Run Code Online (Sandbox Code Playgroud)

I …

tensorflow

5
推荐指数
1
解决办法
2303
查看次数

标签 统计

tensorflow ×1