使用tf.data.Dataset可以使保存的模型更大

F B*_*Bai 6 python tensorflow

我最近遇到了将模型保存到更大尺寸的问题.我正在使用tensorflow 1.4

以前,我用过

tf.train.string_input_producer()tf.train.batch()

从文本文件加载图像.在训练中,

tf.train.start_queue_runners()tf.train.Coordinator()

用于向网络提供数据.在这种情况下,每次我使用保存模型

saver.save(sess, checkpoint_path, global_step=iters)

只给了我一个小尺寸的文件,即一个名为model.ckpt-1000.data-00000-of-00001且1.6MB的文件.

现在,我用

tf.data.Dataset.from_tensor_slices()

将图像提供给输入placeholder,保存的模型变为290MB.但我不知道为什么.我怀疑张量流也saver将数据集保存到模型中.如果是这样,如何删除它们以使其变小,并且仅保存网络的权重.

这不是网络依赖,因为我尝试了两个网络,他们都是这样的.

我用谷歌搜索,但遗憾的是没有看到任何与此问题相关的灵感.(或者这不是问题,只是我不知道怎么办?)

非常感谢您的任何想法和帮助!

编辑

我初始化数据集的方法是:

1.First生成的numpy.array数据集:

self.train_hr, self.train_lr = cifar10.load_dataset(sess)
Run Code Online (Sandbox Code Playgroud)

例如,初始数据集是numpy.array [8000,32,32,3].我传入sess这个函数是因为在函数中,我做了tf.image.resize_images()sess.run()用来生成numpy.array.回报self.train_hrself.train_lrnumpy.array在形状[8000,64,64,3].

然后我创建了数据集:

self.img_hr = tf.placeholder(tf.float32)
self.img_lr = tf.placeholder(tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((self.img_hr, self.img_lr))
dataset = dataset.repeat(conf.num_epoch).shuffle(buffer_size=conf.shuffle_size).batch(conf.batch_size)
self.iterator = dataset.make_initializable_iterator()
self.next_batch = self.iterator.get_next()
Run Code Online (Sandbox Code Playgroud)

然后我初始化了网络和数据集,做了培训并保存了模型:

self.labels = tf.placeholder(tf.float32,
                                     shape=[conf.batch_size, conf.hr_size, conf.hr_size, conf.img_channel])
self.inputs = tf.placeholder(tf.float32,
                                     shape=[conf.batch_size, conf.lr_size, conf.lr_size, conf.img_channel])
self.net = Net(self.labels, self.inputs, mask_type=conf.mask_type,
                       is_linear_only=conf.linear_mapping_only, scope='sr_spc')

sess.run(self.iterator.initializer,
                         feed_dict={self.img_hr: self.train_hr, self.img_lr: self.train_lr})
while True:
    hr_img, lr_img = sess.run(self.next_batch)
    _, loss, summary_str = sess.run([train_op, self.net.loss, summary_op],
                                    feed_dict={self.labels: hr_img, self.inputs: lr_img})
    ...
    ...
    checkpoint_path = os.path.join(conf.model_dir, 'model.ckpt')
    saver.save(sess, checkpoint_path, global_step=iters)
Run Code Online (Sandbox Code Playgroud)

所有sess都是同一个实例.

Dav*_*rks 2

tf.constant我怀疑您从数据集中创建了一个张量流常量,这可以解释为什么数据集与图形一起存储。有一个可初始化的数据集,您可以feed_dict在运行时使用它来输入数据。这是需要配置的几行额外代码,但它可能是您想要使用的。

https://www.tensorflow.org/programmers_guide/datasets

请注意,Python 包装器中会自动为您创建常量。以下语句是等效的:

tf.Variable(42)
tf.Variable(tf.constant(42))
Run Code Online (Sandbox Code Playgroud)