我想将张量保存为 TFRecord 格式。我试过这个:
import tensorflow as tf
x = tf.constant([[2.0, 3.0, 3.0],
[1.0, 5.0, 9.0]], dtype='float32')
x2 = tf.io.serialize_tensor(x)
# I understand that I can parse it using this:
# x3 = tf.io.parse_tensor(x2, 'float32')
record_file = 'temp.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
writer.write(x2)
Run Code Online (Sandbox Code Playgroud)
它给了我错误:
TypeError: write(): incompatible function arguments. The following argument types are supported:
1. (self: tensorflow.python._pywrap_record_io.RecordWriter, record: str) -> None
Run Code Online (Sandbox Code Playgroud)
我知道这可能是一个基本问题,但我阅读了 TensorFlow 网站上的指南并在 StackOverflow 中搜索,但没有找到答案
问题是您需要使用张量的实际值x2
,而不是张量对象本身:
import tensorflow as tf
x = tf.constant([[2.0, 3.0, 3.0],
[1.0, 5.0, 9.0]], dtype='float32')
x2 = tf.io.serialize_tensor(x)
record_file = 'temp.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
# Get value with .numpy()
writer.write(x2.numpy())
# Read from file
parse_tensor_f32 = lambda x: tf.io.parse_tensor(x, tf.float32)
ds = (tf.data.TFRecordDataset('temp.tfrecord')
.map(parse_tensor_f32))
for x3 in ds:
tf.print(x3)
# [[2 3 3]
# [1 5 9]]
Run Code Online (Sandbox Code Playgroud)
在 TensorFlow 的最新版本中,目前也有一个实验tf.data.experimental.TFRecordWriter
可以做到这一点:
import tensorflow as tf
x = tf.constant([[2.0, 3.0, 3.0],
[1.0, 5.0, 9.0]], dtype='float32')
# Write to file
ds = (tf.data.Dataset.from_tensors(x)
.map(tf.io.serialize_tensor))
writer = tf.data.experimental.TFRecordWriter('temp.tfrecord')
writer.write(ds)
# Read from file
parse_tensor_f32 = lambda x: tf.io.parse_tensor(x, tf.float32)
ds2 = (tf.data.TFRecordDataset('temp.tfrecord')
.map(parse_tensor_f32))
for x2 in ds2:
tf.print(x2)
# [[2 3 3]
# [1 5 9]]
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
4248 次 |
最近记录: |