我在 Keras 中创建了一个自定义层(称为 GraphGather),但输出张量打印为:
张量(“graph_gather/tanh:0”,形状=(?,?),dtype=float32)
由于某种原因,形状被返回为 (?,?),这导致下一个密集层引发以下错误:
ValueError:
Dense
应定义输入的最后一个维度。找到了None
。
GraphGather层代码如下:
class GraphGather(tf.keras.layers.Layer):
def __init__(self, batch_size, num_mols_in_batch, activation_fn=None, **kwargs):
self.batch_size = batch_size
self.num_mols_in_batch = num_mols_in_batch
self.activation_fn = activation_fn
super(GraphGather, self).__init__(**kwargs)
def build(self, input_shape):
super(GraphGather, self).build(input_shape)
def call(self, x, **kwargs):
# some operations (most of def call omitted)
out_tensor = result_of_operations() # this line is pseudo code
if self.activation_fn is not None:
out_tensor = self.activation_fn(out_tensor)
out_tensor = out_tensor
return out_tensor
def compute_output_shape(self, input_shape):
return (self.num_mols_in_batch, 2 * input_shape[0][-1])}
Run Code Online (Sandbox Code Playgroud)
I have also tried hardcoding compute_output_shape to be:
python def compute_output_shape(self, input_shape): return (64, 150) ``` 然而打印时的输出张量仍然是
张量(“graph_gather/tanh:0”,形状=(?,?),dtype=float32)
这会导致上面写的 ValueError 。
小智 5
我有同样的问题。我的解决方法是在 call 方法中添加以下几行:
input_shape = tf.shape(x)
Run Code Online (Sandbox Code Playgroud)
进而:
return tf.reshape(out_tensor, self.compute_output_shape(input_shape))
Run Code Online (Sandbox Code Playgroud)
我还没有遇到任何问题。
归档时间: |
|
查看次数: |
2303 次 |
最近记录: |