我最近遇到了将模型保存到更大尺寸的问题.我正在使用tensorflow 1.4
以前,我用过
tf.train.string_input_producer() 和 tf.train.batch()
从文本文件加载图像.在训练中,
tf.train.start_queue_runners() 和 tf.train.Coordinator()
用于向网络提供数据.在这种情况下,每次我使用保存模型
saver.save(sess, checkpoint_path, global_step=iters)
只给了我一个小尺寸的文件,即一个名为model.ckpt-1000.data-00000-of-00001且1.6MB的文件.
现在,我用
tf.data.Dataset.from_tensor_slices()
将图像提供给输入placeholder,保存的模型变为290MB.但我不知道为什么.我怀疑张量流也saver将数据集保存到模型中.如果是这样,如何删除它们以使其变小,并且仅保存网络的权重.
这不是网络依赖,因为我尝试了两个网络,他们都是这样的.
我用谷歌搜索,但遗憾的是没有看到任何与此问题相关的灵感.(或者这不是问题,只是我不知道怎么办?)
非常感谢您的任何想法和帮助!
我初始化数据集的方法是:
1.First生成的numpy.array数据集:
self.train_hr, self.train_lr = cifar10.load_dataset(sess)
Run Code Online (Sandbox Code Playgroud)
例如,初始数据集是numpy.array [8000,32,32,3].我传入sess这个函数是因为在函数中,我做了tf.image.resize_images()并sess.run()用来生成numpy.array.回报self.train_hr和self.train_lr是numpy.array在形状[8000,64,64,3].
然后我创建了数据集:
self.img_hr = tf.placeholder(tf.float32)
self.img_lr = tf.placeholder(tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((self.img_hr, self.img_lr))
dataset = dataset.repeat(conf.num_epoch).shuffle(buffer_size=conf.shuffle_size).batch(conf.batch_size)
self.iterator = dataset.make_initializable_iterator()
self.next_batch = self.iterator.get_next()
Run Code Online (Sandbox Code Playgroud)
然后我初始化了网络和数据集,做了培训并保存了模型:
self.labels = tf.placeholder(tf.float32,
shape=[conf.batch_size, conf.hr_size, conf.hr_size, conf.img_channel])
self.inputs = tf.placeholder(tf.float32,
shape=[conf.batch_size, conf.lr_size, conf.lr_size, conf.img_channel])
self.net = Net(self.labels, self.inputs, mask_type=conf.mask_type,
is_linear_only=conf.linear_mapping_only, scope='sr_spc')
sess.run(self.iterator.initializer,
feed_dict={self.img_hr: self.train_hr, self.img_lr: self.train_lr})
while True:
hr_img, lr_img = sess.run(self.next_batch)
_, loss, summary_str = sess.run([train_op, self.net.loss, summary_op],
feed_dict={self.labels: hr_img, self.inputs: lr_img})
...
...
checkpoint_path = os.path.join(conf.model_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=iters)
Run Code Online (Sandbox Code Playgroud)
所有sess都是同一个实例.
tf.constant我怀疑您从数据集中创建了一个张量流常量,这可以解释为什么数据集与图形一起存储。有一个可初始化的数据集,您可以feed_dict在运行时使用它来输入数据。这是需要配置的几行额外代码,但它可能是您想要使用的。
https://www.tensorflow.org/programmers_guide/datasets
请注意,Python 包装器中会自动为您创建常量。以下语句是等效的:
tf.Variable(42)
tf.Variable(tf.constant(42))
Run Code Online (Sandbox Code Playgroud)