小编iva*_*van的帖子

Tensorflow CNN 图像增强管道

我正在尝试学习新的 Tensorflow API,但我对在哪里获得输入批处理张量的句柄有点迷茫,这样我就可以使用例如 tf.image 来操作和增强它们。

这是我当前的网络和管道:

trainX, testX, trainY, testY = read_data()
# trainX [num_image, height, width, channels], these are numpy arrays

#...
train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))

#...
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                 train_dataset.output_shapes)
features, labels = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)

#...defining cnn architecture...

# In the train loop
TrainLoop {
   sess.run(train_init_op)  # switching to train data
   sess.run(train_step, ...) # running a train step

   #... 
   sess.run(test_init_op)  # switching to test data
   test_loss = sess.run(loss, ...) …
Run Code Online (Sandbox Code Playgroud)

python machine-learning deep-learning tensorflow

0
推荐指数
1
解决办法
3876
查看次数