Hag*_*Tz. 2 python python-3.x keras tensorflow tensorflow2.x
我正在尝试使用 TensorFlow2 构建一些模型,因此我创建了一个模型类,如下所示:
import tensorflow as tf
class Dummy(tf.keras.Model):
def __init__(self, name="dummy"):
super(Dummy, self).__init__()
self._name = name
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs, training=False):
x = self.dense1(inputs)
return self.dense2(x)
model = Dummy()
model.build(input_shape=(None,5))
Run Code Online (Sandbox Code Playgroud)
现在我想绘制模型,同时使用summary()返回我期望的内容,plot_model(model, show_shapes=True, expand_nested=True)仅返回带有模型名称的块。
如何返回模型的图表?
弗朗索瓦·乔莱说道:
您可以在功能或顺序模型中执行所有这些操作(打印输入/输出形状),因为这些模型是层的静态图。
相反,子类模型是一段 Python 代码(一个调用方法)。这里没有图层图。我们无法知道各层如何相互连接(因为这是在调用主体中定义的,而不是作为显式数据结构),因此我们无法推断输入/输出形状。
对此有两种解决方案:
call函数包装到函数模型中,如下所示:class Subclass(Model):
def __init__(self):
...
def call(self, x):
...
def model(self):
x = Input(shape=(24, 24, 3))
return Model(inputs=[x], outputs=self.call(x))
if __name__ == '__main__':
sub = subclass()
sub.model().summary()
Run Code Online (Sandbox Code Playgroud)
答案取自这里:model.summary() can't print output shape while using subclass model
另外,这是一篇很好的文章:https://medium.com/tensorflow/what-are-symbolic-and-imperative-apis-in-tensorflow-2-0-dfccecb01021
| 归档时间: |
|
| 查看次数: |
1938 次 |
| 最近记录: |