相关疑难解决方法(0)

如何在tensorflow中使用自定义python函数预取数据

我正在尝试预取训练数据以隐藏I/O延迟.我想编写自定义Python代码,从磁盘加载数据并预处理数据(例如,通过添加上下文窗口).换句话说,一个线程进行数据预处理,另一个线程进行训练.这在TensorFlow中可行吗?

更新:我有一个基于@ mrry的例子的工作示例.

import numpy as np
import tensorflow as tf
import threading

BATCH_SIZE = 5
TRAINING_ITERS = 4100

feature_input = tf.placeholder(tf.float32, shape=[128])
label_input = tf.placeholder(tf.float32, shape=[128])

q = tf.FIFOQueue(200, [tf.float32, tf.float32], shapes=[[128], [128]])
enqueue_op = q.enqueue([label_input, feature_input])

label_batch, feature_batch = q.dequeue_many(BATCH_SIZE)
c = tf.reshape(feature_batch, [BATCH_SIZE, 128]) + tf.reshape(label_batch, [BATCH_SIZE, 128])

sess = tf.Session()

def load_and_enqueue(sess, enqueue_op, coord):
  with open('dummy_data/features.bin') as feature_file, open('dummy_data/labels.bin') as label_file:
    while not coord.should_stop():
      feature_array = np.fromfile(feature_file, np.float32, 128)
      if feature_array.shape[0] == 0:
        print('reach end of …
Run Code Online (Sandbox Code Playgroud)

python multithreading latency prefetch tensorflow

39
推荐指数
2
解决办法
1万
查看次数

标签 统计

latency ×1

multithreading ×1

prefetch ×1

python ×1

tensorflow ×1