在TensorFlow中重新训练冻结的* .pb模型

sch*_*hil 6 tensorflow

如何导入冻结的protobuf,以便对其进行重新训练?

我在网上找到的所有方法都需要检查点。有没有办法读取protobuf,以便将内核常数和偏差常数转换为变量?


编辑1:这类似于以下问题:如何在图(.pb)中重新训练模型?

我查看了DeepSpeech,该问题的答案中建议使用它。他们似乎有删除的支持initialize_from_frozen_model。我找不到原因。


编辑2:我尝试创建一个新的GraphDef对象,在其中我用变量替换了内核和偏差:

probable_variables = [...] # kernels and biases of Conv2D and MatMul

new_graph_def = tf.GraphDef()

with tf.Session(graph=graph) as sess:
    for n in sess.graph_def.node:

        if n.name in probable_variables:
            # create variable op
            nn = new_graph_def.node.add()
            nn.name = n.name
            nn.op = 'VariableV2'
            nn.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
            nn.attr['shape'].CopyFrom(attr_value_pb2.AttrValue(shape=shape))

        else:
            nn = new_model.node.add()
            nn.CopyFrom(n)
Run Code Online (Sandbox Code Playgroud)

不知道我走的路是否正确。不知道如何设置trainable=TrueNodeDef对象。

Fal*_*nUA 8

您提供的代码段实际上是在正确的方向:)


第 1 步:获取先前可训练变量的名称

最棘手的部分是获取以前可训练变量的名称。希望该模型是使用一些高级框架创建的,例如kerasor tf.slim- 他们将其变量很好地包装在诸如conv2d_1/kernel, dense_1/bias, 之类的东西中batch_normalization/gamma

如果您不确定,最有用的方法是将图形可视化...

# read graph definition
with tf.gfile.GFile('frozen.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# now build the graph in the memory and visualize it
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="prefix")
    writer = tf.summary.FileWriter('out', graph)
    writer.close()
Run Code Online (Sandbox Code Playgroud)

...与张量板:

$ tensorboard --logdir out/
Run Code Online (Sandbox Code Playgroud)

并亲眼看看图表是什么样子以及命名是什么。


第 2 步:用变量替换常量(有趣的部分:D)

您所需要的只是名为tf.contrib.graph_editor. 现在假设您已经将以前可训练的操作的名称(以前是变量,但现在是Const)存储在probable_variables(如在您的Edit 2 中)。

注:记得之间的差异opstensors以及variables。ops 是图的元素,tensor 是一个包含 ops 结果的缓冲区,variables 是围绕张量的包装器,有 3 个 ops:(assign在初始化变量时调用),read(由其他 ops 调用,例如conv2d),和ref tensor(保存值)。

注2:graph_editor可以运行一个会话之外-你不能让任何在线图形修改!

$ tensorboard --logdir out/
Run Code Online (Sandbox Code Playgroud)

PS:此代码未经测试;但是,我最近经常使用graph_editor和执行网络手术,所以我认为它应该大部分是正确的:)

  • @schil **1.** 你需要你的变量和常量有不同的名字,所以 `_a` 只是提供一个在图中还不存在的名字。您可以使用任何您想要的后缀。**2.** 好吧,如果你更方便的话,你可以。在这里,我们只有 [`tf.get_variable`](https://www.tensorflow.org/api_docs/python/tf/get_variable) 为我们完成这项工作。**3.** 哦,我的错,在这里你需要在 `graph` **内创建变量(这个想法是我们向我们的图表添加新的操作并使用 `graph_editor` 重新路由。将编辑稍后代码:D (2认同)