don*_*loo 4 python machine-learning tensorflow
我有一个预先训练的Tensorflow检查点,其中所有参数都是float32数据类型。
如何将检查点参数加载为float16?还是有办法修改检查点的数据类型?
以下是我的代码片段,试图将float32检查点加载到float16图中,但出现类型不匹配错误。
import tensorflow as tf
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(dense))
save_path = saver.save(sess, "tmp.ckpt")
tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)
with tf.Session() as sess:
saver.restore(sess, "tmp.ckpt")
print(sess.run(dense))
pass
# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float
Run Code Online (Sandbox Code Playgroud)
看起来有点成储户是如何工作的,看来你可以通过重新定义其建设builder
目标。例如,您可能有一个生成器,将值加载为tf.float32
,然后将其强制转换为变量的实际类型:
import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder
class CastFromFloat32SaverBuilder(BaseSaverBuilder):
# Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
from tensorflow.python.ops import io_ops
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
restore_dtypes = [tf.float32 for _ in dtypes]
with tf.device("cpu:0"):
restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
Run Code Online (Sandbox Code Playgroud)
请注意,这假设所有还原的变量均为tf.float32
。您可以根据需要使构建器适合您的用例,例如,在构造函数中传递源类型,等等。使用此方法,您只需在第二个保护程序中使用上述构建器即可使示例工作:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
assign = {vari.name: vari for vari in varis}
saver = tf.train.Saver(assign)
sess.run(tf.global_variables_initializer())
print('Value to save:')
print(sess.run(dense))
save_path = saver.save(sess, "ckpt/tmp.ckpt")
with tf.Graph().as_default(), tf.Session() as sess:
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
assign = {vari.name: vari for vari in varis}
saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())
saver.restore(sess, "ckpt/tmp.ckpt")
print('Restored value:')
print(sess.run(dense))
Run Code Online (Sandbox Code Playgroud)
输出:
import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilder
class CastFromFloat32SaverBuilder(BaseSaverBuilder):
# Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
from tensorflow.python.ops import io_ops
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
restore_dtypes = [tf.float32 for _ in dtypes]
with tf.device("cpu:0"):
restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)
return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
236 次 |
最近记录: |