了解如何使用 tf.dataset.map()

ash*_*213 6 python tensorflow tensorflow-datasets

我正在将一些最初使用 JPEG 作为输入的代码转换为使用 Matlab MAT 文件。该代码包含以下行:

train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.mat')
train_dataset = train_dataset.shuffle(BUFFER_SIZE) 
train_dataset = train_dataset.map(load_image_train)
Run Code Online (Sandbox Code Playgroud)

如果我循环遍历数据集并在 map() 之前 print() 每个元素,我会得到一组文件路径可见的张量。

然而,在 load_image_train 函数中,情况并非如此, print() 的输出是:

Tensor("add:0", shape=(), dtype=string)

我想使用 scipy.io.loadmat() 函数从我的 mat 文件中获取数据,但它失败了,因为路径是张量而不是字符串。dataset.map() 做了什么似乎使文字字符串值不再可见?如何提取字符串以便将其用作 scipy.io.loadmat() 的输入?

如果这是一个愚蠢的问题,对于 Tensorflow 来说相对较新并且仍在尝试理解,我深表歉意。我能找到的很多相关问题的讨论仅适用于 TF v1。感谢您的任何帮助!

小智 7

在下面的代码中,我用来tf.data.Dataset.list_files读取图像的文件路径。在函数中map,我加载图像并执行crop_central(基本上按照给定的百分比裁剪图像的中心部分,这里我已通过 指定了百分比np.random.uniform(0.50, 1.00))。

正如您正确地提到的,读取文件很困难,因为文件路径是类型的tf.string,并且load_img读取图像文件的或任何其他函数都需要简单的string类型。

所以你可以这样做 -

  1. 您需要用以下内容来装饰您的地图功能您可以在此处tf.py_function(load_file_and_process, [x], [tf.float32]).找到有关它的更多信息。
  2. 您可以使用string来检索。tf.stringbytes.decode(path.numpy()

下面是完整的代码供大家参考。运行此代码时,您可以将其替换为图像路径。

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np

def load_file_and_process(path):
    image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
    image = img_to_array(image)
    image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))

for f in train_dataset:
  for l in f:
    image = np.array(array_to_img(l))
    plt.imshow(image)
Run Code Online (Sandbox Code Playgroud)

输出 -

在此输入图像描述

希望这能回答您的问题。快乐学习。