Kar*_*yar 5 tensorflow tensorflow-datasets
我正在尝试使用 TF.dataset.map 来移植这个旧代码,因为我收到了弃用警告。
从 TFRecord 文件读取一组自定义原型的旧代码:
record_iterator = tf.python_io.tf_record_iterator(path=filename)
for record in record_iterator:
example = MyProto()
example.ParseFromString(record)
Run Code Online (Sandbox Code Playgroud)
我正在尝试使用急切模式和地图,但出现此错误。
def parse_proto(string):
proto_object = MyProto()
proto_object.ParseFromString(string)
Run Code Online (Sandbox Code Playgroud)
dataset = tf.data.TFRecordDataset(dataset_paths)
parsed_protos = raw_tf_dataset.map(parse_proto)
Run Code Online (Sandbox Code Playgroud)
这段代码的工作原理:
for raw_record in raw_tf_dataset:
proto_object = MyProto()
proto_object.ParseFromString(raw_record.numpy())
Run Code Online (Sandbox Code Playgroud)
但地图给了我一个错误:
TypeError: a bytes-like object is required, not 'Tensor'
Run Code Online (Sandbox Code Playgroud)
使用参数映射的函数结果并将它们视为字符串的正确方法是什么?
小智 3
您需要从张量中提取字符串并在map函数中使用。以下是在代码中实现此目的的步骤。
tf.py_function(get_path, [x], [tf.float32])。您可以在此处找到有关 tf.py_function 的更多信息。在 中tf.py_function,第一个参数是函数的名称map,第二个参数是要传递给函数的元素map,最后一个参数是返回类型。bytes.decode(file_path.numpy())。所以修改你的程序如下,
parsed_protos = raw_tf_dataset.map(parse_proto)
Run Code Online (Sandbox Code Playgroud)
到
parsed_protos = raw_tf_dataset.map(lambda x: tf.py_function(parse_proto, [x], [function return type]))
Run Code Online (Sandbox Code Playgroud)
还修改parse_proto如下,
def parse_proto(string):
proto_object = MyProto()
proto_object.ParseFromString(string)
Run Code Online (Sandbox Code Playgroud)
到
def parse_proto(string):
proto_object = MyProto()
proto_object.ParseFromString(bytes.decode(string.numpy()))
Run Code Online (Sandbox Code Playgroud)
在下面的简单程序中,我们使用tf.data.Dataset.list_files读取图像的路径。接下来,在该map函数中,我们将使用该函数读取图像load_img,然后执行该tf.image.central_crop函数来裁剪图像的中心部分。
代码 -
%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np
def load_file_and_process(path):
image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
image = img_to_array(image)
image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
return image
train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))
for f in train_dataset:
for l in f:
image = np.array(array_to_img(l))
plt.imshow(image)
Run Code Online (Sandbox Code Playgroud)
输出 -
希望这能回答您的问题。快乐学习。
| 归档时间: |
|
| 查看次数: |
4769 次 |
| 最近记录: |