从检查点还原时,如何更改参数的数据类型?

don*_*loo 4 python machine-learning tensorflow

我有一个预先训练的Tensorflow检查点,其中所有参数都是float32数据类型。

如何将检查点参数加载为float16?还是有办法修改检查点的数据类型?

以下是我的代码片段,试图将float32检查点加载到float16图中,但出现类型不匹配错误。

import tensorflow as tf

A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dense))
    save_path = saver.save(sess, "tmp.ckpt")

tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)

with tf.Session() as sess:
    saver.restore(sess, "tmp.ckpt")
    print(sess.run(dense))
    pass

# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float
Run Code Online (Sandbox Code Playgroud)

jde*_*esa 6

看起来有点成储户是如何工作的,看来你可以通过重新定义其建设builder目标。例如,您可能有一个生成器,将值加载为tf.float32,然后将其强制转换为变量的实际类型:

import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder

class CastFromFloat32SaverBuilder(BaseSaverBuilder):
  # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                   restore_sequentially):
    from tensorflow.python.ops import io_ops
    restore_specs = []
    for saveable in saveables:
      for spec in saveable.specs:
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    names, slices, dtypes = zip(*restore_specs)
    restore_dtypes = [tf.float32 for _ in dtypes]
    with tf.device("cpu:0"):
      restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
      return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
Run Code Online (Sandbox Code Playgroud)

请注意,这假设所有还原的变量均为tf.float32。您可以根据需要使构建器适合您的用例,例如,在构造函数中传递源类型,等等。使用此方法,您只需在第二个保护程序中使用上述构建器即可使示例工作:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign)
    sess.run(tf.global_variables_initializer())
    print('Value to save:')
    print(sess.run(dense))
    save_path = saver.save(sess, "ckpt/tmp.ckpt")

with tf.Graph().as_default(), tf.Session() as sess:
    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
    dense = tf.layers.dense(inputs=A, units=3)
    varis = tf.trainable_variables(scope=None)
    assign = {vari.name: vari for vari in varis}
    saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())
    saver.restore(sess, "ckpt/tmp.ckpt")
    print('Restored value:')
    print(sess.run(dense))
Run Code Online (Sandbox Code Playgroud)

输出:

import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder

class CastFromFloat32SaverBuilder(BaseSaverBuilder):
  # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
  def bulk_restore(self, filename_tensor, saveables, preferred_shard,
                   restore_sequentially):
    from tensorflow.python.ops import io_ops
    restore_specs = []
    for saveable in saveables:
      for spec in saveable.specs:
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    names, slices, dtypes = zip(*restore_specs)
    restore_dtypes = [tf.float32 for _ in dtypes]
    with tf.device("cpu:0"):
      restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
      return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
Run Code Online (Sandbox Code Playgroud)