我正在努力使用Tensorflow的修剪库,但没有找到许多有用的示例,因此我正在寻找帮助以修剪在MNIST数据集上训练的简单模型。如果有人可以帮助解决我的问题或提供如何在MNIST上使用该库的示例,我将不胜感激。
我的代码的前半部分是非常标准的,除了我的模型有2个隐藏的图层,其宽度layers.masked_fully_connected为300单位,用于修剪。
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data
# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])
# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = tf.contrib.layers.fully_connected(layer2, 10, tf.nn.relu)
# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))
# Training op
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy …Run Code Online (Sandbox Code Playgroud)