如何将张量保存到 TFRecord 中?

sam*_*uar 5 python tensorflow

我想将张量保存为 TFRecord 格式。我试过这个:

import tensorflow as tf

x = tf.constant([[2.0, 3.0, 3.0],
                 [1.0, 5.0, 9.0]], dtype='float32')

x2 = tf.io.serialize_tensor(x)

# I understand that I can parse it using this:
# x3 = tf.io.parse_tensor(x2, 'float32')

record_file = 'temp.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
    writer.write(x2)
Run Code Online (Sandbox Code Playgroud)

它给了我错误:

TypeError: write(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorflow.python._pywrap_record_io.RecordWriter, record: str) -> None
Run Code Online (Sandbox Code Playgroud)

我知道这可能是一个基本问题,但我阅读了 TensorFlow 网站上的指南并在 StackOverflow 中搜索,但没有找到答案

jde*_*esa 5

问题是您需要使用张量的实际值x2,而不是张量对象本身:

import tensorflow as tf

x = tf.constant([[2.0, 3.0, 3.0],
                 [1.0, 5.0, 9.0]], dtype='float32')
x2 = tf.io.serialize_tensor(x)
record_file = 'temp.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
    # Get value with .numpy()
    writer.write(x2.numpy())
# Read from file
parse_tensor_f32 = lambda x: tf.io.parse_tensor(x, tf.float32)
ds = (tf.data.TFRecordDataset('temp.tfrecord')
      .map(parse_tensor_f32))
for x3 in ds:
    tf.print(x3)
# [[2 3 3]
#  [1 5 9]]
Run Code Online (Sandbox Code Playgroud)

在 TensorFlow 的最新版本中,目前也有一个实验tf.data.experimental.TFRecordWriter可以做到这一点:

import tensorflow as tf

x = tf.constant([[2.0, 3.0, 3.0],
                 [1.0, 5.0, 9.0]], dtype='float32')
# Write to file
ds = (tf.data.Dataset.from_tensors(x)
      .map(tf.io.serialize_tensor))
writer = tf.data.experimental.TFRecordWriter('temp.tfrecord')
writer.write(ds)
# Read from file
parse_tensor_f32 = lambda x: tf.io.parse_tensor(x, tf.float32)
ds2 = (tf.data.TFRecordDataset('temp.tfrecord')
       .map(parse_tensor_f32))
for x2 in ds2:
    tf.print(x2)
# [[2 3 3]
#  [1 5 9]]
Run Code Online (Sandbox Code Playgroud)