理解Keras模型体系结构(嵌套模型的节点索引)

Tob*_*ann 1 deep-learning keras

此脚本使用小型嵌套模型定义虚拟模型

from keras.layers import Input, Dense
from keras.models import Model
import keras

input_inner = Input(shape=(4,), name='input_inner')
output_inner = Dense(3, name='inner_dense')(input_inner)
inner_model = Model(inputs=input_inner, outputs=output_inner)

input = Input(shape=(5,), name='input')
x = Dense(4, name='dense_1')(input)
x = inner_model(x)
x = Dense(2, name='dense_2')(x)

output = keras.layers.concatenate([x, x], name='concat_1')
model = Model(inputs=input, outputs=output)

print(model.summary())
Run Code Online (Sandbox Code Playgroud)

产生以下输出

Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input (InputLayer)               (None, 5)             0                                            
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 4)             24          input[0][0]                      
____________________________________________________________________________________________________
model_1 (Model)                  (None, 3)             15          dense_1[0][0]                    
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 2)             8           model_1[1][0]                    
____________________________________________________________________________________________________
concat_1 (Concatenate)           (None, 4)             0           dense_2[0][0]                    
                                                                   dense_2[0][0]                    
Run Code Online (Sandbox Code Playgroud)

我的问题涉及Connected to专栏的内容.我知道一个图层可以有多个节点.

本专栏的表示法是layer_name[node_index][tensor_index].

如果我们将其inner_model视为一个层,我希望它只有一个节点,所以我希望dense_2能够连接到model_1[0][0].但实际上它与之相关model_1[1][0].为什么会这样?

Mir*_*ber 5

1.背景

当你说:

如果我们将inner_model视为一个层,我希望它只有一个节点

这是正确的,因为它只有一个节点是网络的一部分.

考虑GitHub的仓库中的model.summary功能.打印连接的函数是print_layer_summary_with_connections(第76行),它只考虑relevant_nodes数组中的节点.不在此阵列中的所有节点都被视为不属于网络,因此该功能会跳过它们.相关的行是88-90行:

if relevant_nodes and node not in relevant_nodes:
    # node is not part of the current network
    continue
Run Code Online (Sandbox Code Playgroud)

你的模特

现在让我们看看您的特定型号会发生什么.首先让我们来定义relevant_nodes:

relevant_nodes = []
for v in model.nodes_by_depth.values():
    relevant_nodes += v
Run Code Online (Sandbox Code Playgroud)

该数组relevant_nodes看起来像:

[<keras.engine.topology.Node at 0x9dfa518>,
 <keras.engine.topology.Node at 0x9dfa278>,
 <keras.engine.topology.Node at 0x9d8bac8>,
 <keras.engine.topology.Node at 0x9d8ba58>,
 <keras.engine.topology.Node at 0x9d74518>]
Run Code Online (Sandbox Code Playgroud)

但是,当我们在每一层打印入站节点时,我们将获得:

for i in model.layers:
    print(i.inbound_nodes)

[<keras.engine.topology.Node object at 0x0000000009D74518>]
[<keras.engine.topology.Node object at 0x0000000009D8BA58>]
[<keras.engine.topology.Node object at 0x0000000009D743C8>, <keras.engine.topology.Node object at 0x0000000009D8BAC8>]
[<keras.engine.topology.Node object at 0x0000000009DFA278>]
[<keras.engine.topology.Node object at 0x0000000009DFA518>]
Run Code Online (Sandbox Code Playgroud)

您可以看到上面列表中只有一个节点没有出现relevant_nodes.这是第三个数组中位置0的节点:

<keras.engine.topology.Node object at 0x0000000009D743C8>
Run Code Online (Sandbox Code Playgroud)

它不被认为是模型的一部分,因此没有出现在relevant_nodes.此数组中位置1的节点确实出现relevant_nodes,这就是您将其视为的原因model_1[1][0].

原因

原因基本上就是这条线x=inner_model(input).即使你运行的小型号,如下所示:

input_inner = Input(shape=(4,), name='input_inner')
output_inner = Dense(3, name='inner_dense')(input_inner)
inner_model = Model(inputs=input_inner, outputs=output_inner)

input = Input(shape=(5,), name='input')
output = inner_model(input)

model = Model(inputs=input, outputs=output)
Run Code Online (Sandbox Code Playgroud)

您将看到relevant_nodes包含两个元素,而via

for i in model.layers:
        print(i.inbound_nodes)
Run Code Online (Sandbox Code Playgroud)

你会得到三个节点.

这是因为第1层(上面较小的模型)有两个节点,但只有第二个节点被认为是模型的一部分.特别是,如果您在第1层的每个节点上打印输入layer.get_input_at(node_index),您将获得:

print(model.layers[1].get_input_at(0))
print(model.layers[1].get_input_at(1))

#prints
/input_inner
/input
Run Code Online (Sandbox Code Playgroud)

4.评论中的问题解答

1)您是否也知道这个不相关的节点有什么用/它来自哪里?

此节点似乎是在应用期间创建的"内部节点" inner_model.特别是,如果您在三个节点中的每一个上打印输入和输出形状(在上面的小模型中),您将得到:

nodes=[model.layers[0].inbound_nodes[0],model.layers[1].inbound_nodes[0],model.layers[1].inbound_nodes[1]]
for i in nodes:
    print(i.input_shapes)
    print(i.output_shapes)
    print(" ")

#prints
[(None, 5)]
[(None, 5)]

[(None, 4)]
[(None, 3)]

[(None, 5)]
[(None, 3)]
Run Code Online (Sandbox Code Playgroud)

所以你可以看到中间节点的形状(没有出现在相关节点列表中的形状)对应于中的形状inner_model.

2)具有n个输出节点的内部模型是否总是将它们与节点索引1到n而不是0到n-1表示?

我不确定是否总是如此,因为我猜有几种输出节点节点存在各种可能性,但如果我考虑以下上述小模型的非常自然的概括,情况确实如此:

input_inner = Input(shape=(4,), name='input_inner')
output_inner = Dense(3, name='inner_dense')(input_inner)
inner_model = Model(inputs=input_inner, outputs=output_inner)

input = Input(shape=(5,), name='input')
output = inner_model(input)
output = inner_model(output)

model = Model(inputs=input, outputs=output)

print(model.summary())
Run Code Online (Sandbox Code Playgroud)

在这里我只是添加output = inner_model(output)到小模型中.相关节点列表是

[<keras.engine.topology.Node at 0xd10c390>,
 <keras.engine.topology.Node at 0xd10c9b0>,
 <keras.engine.topology.Node at 0xd10ca20>]
Run Code Online (Sandbox Code Playgroud)

以及所有入站节点的列表

[<keras.engine.topology.Node object at 0x000000000D10CA20>]
[<keras.engine.topology.Node object at 0x000000000D10C588>, <keras.engine.topology.Node object at 0x000000000D10C9B0>, <keras.engine.topology.Node object at 0x000000000D10C390>]
Run Code Online (Sandbox Code Playgroud)

实际上,节点索引是1和2,正如您在评论中提到的那样.如果我添加另一个output = inner_model(output),节点索引为1,2,3等等,它将继续类似.