我是 TensorFlow 的新手,想读取一个逗号分隔值 (csv) 文件,其中包含 2 列,第 1 列是索引,第 2 列是标签字符串。我有以下代码逐行读取 csv 文件中的行,并且我能够使用打印语句正确获取 csv 文件中的数据。但是,我想从字符串标签进行单热编码转换,而不是如何在 TensorFlow 中进行。最终目标是使用 tf.train.batch() 函数,这样我就可以获得一批单热标签向量来训练神经网络。
正如您在下面的代码中看到的,我可以在 TensorFlow 会话中为每个标签条目手动创建一个单热向量。但是如何使用 tf.train.batch() 函数?如果我移动线
label_batch = tf.train.batch([col2], batch_size=5)
Run Code Online (Sandbox Code Playgroud)
进入 TensorFlow 会话块(用 label_one_hot 替换 col2),程序块什么都不做。我试图将 one-hot 向量转换移到 TensorFlow 会话之外,但未能使其正常工作。正确的做法是什么?请帮忙。
label_files = []
label_files.append(LABEL_FILE)
print "label_files: ", label_files
filename_queue = tf.train.string_input_producer(label_files)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
print "key:", key, ", value:", value
record_defaults = [['default_id'], ['default_label']]
col1, col2 = tf.decode_csv(value, record_defaults=record_defaults)
num_lines = sum(1 for line in open(LABEL_FILE))
label_batch = …Run Code Online (Sandbox Code Playgroud)