我正在努力在iOS上运行MNIST的CNN推理.Apple开始提供一个很好的代码示例. https://developer.apple.com/library/content/samplecode/MPSCNNHelloWorld/Introduction/Intro.html#//apple_ref/doc/uid/TP40017482-Intro-DontLinkElementID_2
但是,当我尝试使用MPS 实现更复杂的CNN模型(例如https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py)时,我发现没有"Flatten"类过滤.
我查看了MPS框架,找到了重塑或更改维度的功能,但我找不到合适的维度.(例如,MPSImageConversion似乎仅用于转换颜色,而不是用于维度.
如果有人知道Flatten的过滤器或如何将多维图像转换为1D图像,请告诉我.
我正在通过 TensorFlow实现这个(https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py)。我的代码如下。
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
if __name__ == '__main__':
mnist = input_data.read_data_sets('data', one_hot=True)
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
sess = tf.InteractiveSession()
x_image = tf.reshape(x, [-1,28,28,1])
W_conv1 = weight_variable([3, …Run Code Online (Sandbox Code Playgroud)