TensorFlow:有没有办法将冻结图转换为检查点模型?

kwo*_*sin 13 python tensorflow

可以将检查点模型转换为冻结图(.ckpt文件到.pb文件).但是,有没有一种将pb文件再次转换为检查点文件的反向方法?

我想它需要将常量转换回变量 - 有没有办法将正确的常量识别为变量并将它们恢复为检查点模型?

目前支持将变量转换为常量:https://www.tensorflow.org/api_docs/python/tf/graph_util/convert_variables_to_constants

但不是相反.

这里提出了类似的问题:Tensorflow:将常数张量从预训练的Vgg模型转换为变量

但该解决方案依赖于使用ckpt模型来恢复权重变量.有没有办法从PB文件而不是检查点文件中恢复权重变量?这对于重量修剪可能很有用.

Max*_* Wu 5

有一种方法可以通过图形编辑器将常量转换回 TensorFlow 中的可训练变量。但是,您需要指定要转换的节点,因为我不确定是否有办法以可靠的方式自动检测这一点。

步骤如下:

第 1 步:加载冻结图

我们将.pb文件加载到图形对象中。

import tensorflow as tf

# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

tf_graph = load_pb('frozen_graph.pb')
Run Code Online (Sandbox Code Playgroud)

步骤2:找到需要转换的常量

以下是列出图中节点名称的两种方法:

  • 使用此脚本来打印它们
  • print([n.name for n in tf_graph.as_graph_def().node])

您想要转换的节点可能被命名为“Const”。可以肯定的是,最好将图表加载到Netron中,以查看哪些张量存储了可训练权重。通常,可以安全地假设所有 const 节点都曾经是变量。

识别出这些节点后,让我们将它们的名称存储到列表中:

to_convert = [...] # names of tensors to convert
Run Code Online (Sandbox Code Playgroud)

步骤 3:将常量转换为变量

运行此代码以转换您指定的常量。它本质上是为每个常量创建相应的变量,并使用 GraphEditor 从图表中取消常量,并挂上变量。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge

const_var_name_pairs = []
with tf_graph.as_default() as g:

    for name in to_convert:
        tensor = g.get_tensor_by_name('{}:0'.format(name))
        with tf.Session() as sess:
            tensor_as_numpy_array = sess.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        # Create TensorFlow variable initialized by values of original const.
        var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \  
                      initializer=tf.constant_initializer(tensor_as_numpy_array))
        # We want to keep track of our variables names for later.
        const_var_name_pairs.append((name, var_name))

    # At this point, we added a bunch of tf.Variables to the graph, but they're
    # not connected to anything.

    # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
    # the outputs of our newly created Variables.

    for const_name, var_name in const_var_name_pairs:
        const_op = g.get_operation_by_name(const_name)
        var_reader_op = g.get_operation_by_name(var_name + '/read')
        ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
Run Code Online (Sandbox Code Playgroud)

第 4 步:将结果另存为.ckpt

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        save_path = tf.train.Saver().save(sess, 'model.ckpt')
        print("Model saved in path: %s" % save_path)
Run Code Online (Sandbox Code Playgroud)

还有中提琴!您应该在这一点上完成:)我自己能够完成这项工作,并验证模型权重已保留 - 唯一的区别是该图现在是可训练的。如果有任何问题,请告诉我。


Alm*_*vid 1

如果您有构建网络的源代码,则可以相对容易地完成,因为冻结图方法不会更改卷积/完全连接的名称,因此您基本上可以研究该图并将常量操作与其变量相匹配匹配并仅加载具有常量值的变量。

如果您没有构建网络的代码,仍然可以完成它,但它并不简单。

例如,您可以搜索图中的所有节点并查找常量类型的操作,然后在找到常量类型的所有操作后,您可以查看该操作是否连接到卷积/完全连接,例如..(或者您可以只是转换所有常量取决于你)。

找到要转换为变量的常量后,您可以将变量添加到保存常量值的图形中,然后使用 Tensorflow图形编辑器重新连接const 操作与变量之间的连接(使用reroute_ts方法)。

完成后,您可以保存图表,当您再次加载它时,您将拥有变量(但请注意,常量仍将保留在图表中,但可以通过例如图表转换工具进行优化)