如何在tensorflow中创建集合?

ces*_*ado 15 tensorflow

我正在尝试创建一个由许多训练模型组成的集合.所有型号都有相同的图形,只是权重不同.我正在使用创建模型图tf.get_variable.对于相同的图形体系结构,我有几个不同的检查点(具有不同的权重),我想为每个检查点创建一个实例模型.

如何在不覆盖以前加载的权重的情况下加载多个检查点?

当我创建我的图形时tf.get_variable,我可以创建多个图形的唯一方法是传递参数reuse = True.现在,如果我在加载之前尝试更改我的图形变量的名称,将构建方法包含在新作用域中(因此它们与其他创建的图形不可共享),那么这不会起作用,因为新名称将与保存的不同重量,我将无法加载它.

Pat*_*wie 3

这需要一些技巧。让我们保存一些简单的模型

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import tensorflow as tf


def build_graph(init_val=0.0):
    x = tf.placeholder(tf.float32)
    w = tf.get_variable('w', initializer=init_val)
    y = x + w
    return x, y


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--init', help='dummy string', type=float)
    parser.add_argument('--path', help='dummy string', type=str)
    args = parser.parse_args()

    x1, y1 = build_graph(args.init)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(y1, {x1: 10}))  # outputs: 10 + i

        save_path = saver.save(sess, args.path)
        print("Model saved in path: %s" % save_path)

# python ensemble.py --init 1 --path ./models/model1.chpt
# python ensemble.py --init 2 --path ./models/model2.chpt
# python ensemble.py --init 3 --path ./models/model3.chpt
Run Code Online (Sandbox Code Playgroud)

这些模型产生“10 + i”的输出,其中 i=1、2、3。请注意,此脚本会多次创建、运行和保存相同图形结构。加载这些值并单独恢复每个图表是民间传说,可以通过以下方式完成

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import tensorflow as tf


def build_graph(init_val=0.0):
    x = tf.placeholder(tf.float32)
    w = tf.get_variable('w', initializer=init_val)
    y = x + w
    return x, y


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', help='dummy string', type=str)
    args = parser.parse_args()

    x1, y1 = build_graph(-5.)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver.restore(sess, args.path)
        print("Model loaded from path: %s" % args.path)

        print(sess.run(y1, {x1: 10}))

# python ensemble_load.py --path ./models/model1.chpt  # gives 11
# python ensemble_load.py --path ./models/model2.chpt  # gives 12
# python ensemble_load.py --path ./models/model3.chpt  # gives 13
Run Code Online (Sandbox Code Playgroud)

这些再次产生如预期的输出 11,12,13。现在的技巧是为集合中的每个模型创建自己的范围,例如

def build_graph(x, init_val=0.0):
    w = tf.get_variable('w', initializer=init_val)
    y = x + w
    return x, y


if __name__ == '__main__':
    models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
    x = tf.placeholder(tf.float32)
    outputs = []
    for k, path in enumerate(models):
        # THE VARIABLE SCOPE IS IMPORTANT
        with tf.variable_scope('model_%03i' % (k + 1)):
            outputs.append(build_graph(x, -100 * np.random.rand())[1])
Run Code Online (Sandbox Code Playgroud)

因此,每个模型都存在于不同的变量范围下,即。我们有变量“model_001/w:0、model_002/w:0、model_003/w:0”,尽管它们有相似(不相同)的子图,但这些变量确实是不同的对象。现在,技巧是管理两组变量(当前范围内的图形变量和来自检查点的变量):

def restore_collection(path, scopename, sess):
    # retrieve all variables under scope
    variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
    # retrieves all variables in checkpoint
    for var_name, _ in tf.contrib.framework.list_variables(path):
        # get the value of the variable
        var_value = tf.contrib.framework.load_variable(path, var_name)
        # construct expected variablename under new scope
        target_var_name = '%s/%s:0' % (scopename, var_name)
        # reference to variable-tensor
        target_variable = variables[target_var_name]
        # assign old value from checkpoint to new variable
        sess.run(target_variable.assign(var_value))
Run Code Online (Sandbox Code Playgroud)

完整的解决方案是

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf


def restore_collection(path, scopename, sess):
    # retrieve all variables under scope
    variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
    # retrieves all variables in checkpoint
    for var_name, _ in tf.contrib.framework.list_variables(path):
        # get the value of the variable
        var_value = tf.contrib.framework.load_variable(path, var_name)
        # construct expected variablename under new scope
        target_var_name = '%s/%s:0' % (scopename, var_name)
        # reference to variable-tensor
        target_variable = variables[target_var_name]
        # assign old value from checkpoint to new variable
        sess.run(target_variable.assign(var_value))


def build_graph(x, init_val=0.0):
    w = tf.get_variable('w', initializer=init_val)
    y = x + w
    return x, y


if __name__ == '__main__':
    models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
    x = tf.placeholder(tf.float32)
    outputs = []
    for k, path in enumerate(models):
        with tf.variable_scope('model_%03i' % (k + 1)):
            outputs.append(build_graph(x, -100 * np.random.rand())[1])

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        print(sess.run(outputs[0], {x: 10}))  # random output -82.4929
        print(sess.run(outputs[1], {x: 10}))  # random output -63.65792
        print(sess.run(outputs[2], {x: 10}))  # random output -19.888203

        print(sess.run(W[0]))  # randomly initialize value -92.4929
        print(sess.run(W[1]))  # randomly initialize value -73.65792
        print(sess.run(W[2]))  # randomly initialize value -29.888203

        restore_collection(models[0], 'model_001', sess)  # restore all variables from different checkpoints
        restore_collection(models[1], 'model_002', sess)  # restore all variables from different checkpoints
        restore_collection(models[2], 'model_003', sess)  # restore all variables from different checkpoints

        print(sess.run(W[0]))  # old values from different checkpoints: 1.0
        print(sess.run(W[1]))  # old values from different checkpoints: 2.0
        print(sess.run(W[2]))  # old values from different checkpoints: 3.0

        print(sess.run(outputs[0], {x: 10}))  # what we expect: 11.0
        print(sess.run(outputs[1], {x: 10}))  # what we expect: 12.0
        print(sess.run(outputs[2], {x: 10}))  # what we expect: 13.0

# python ensemble_load_all.py
Run Code Online (Sandbox Code Playgroud)

现在有了输出列表,您可以在 TensorFlow对这些值进行平均或进行其他一些集成预测。

编辑

  • 使用 NumPy (npz) 将模型存储为 numpy 字典并加载这些值要容易得多,就像我在这里的回答一样: https ://stackoverflow.com/a/50181741/7443104
  • 上面的代码只是说明了一个解决方案。它不具有健全性检查功能(就像变量是否确实存在一样)。try-catch 可能会有所帮助。