如何使用 TensorFlow 2 数据集从 URL 加载图像

Mar*_*cel 2 python image-processing conv-neural-network tensorflow2.0

我想使用 TensorFlow 2 数据集对象将图像提供给 CNN。我的图像位于 AWS S3 上,但我将在示例中使用来自 Wikipedia 的图像(问题是相同的)。

image_urls = [
    'https://upload.wikimedia.org/wikipedia/commons/6/60/Matterhorn_from_Domh%C3%BCtte_-_2.jpg',
    'https://upload.wikimedia.org/wikipedia/commons/6/6e/Matterhorn_from_Klein_Matterhorn.jpg',
]
dataset = tf.data.Dataset.from_tensor_slices(image_urls)

def read_image_from_url(url):
    img_array = None
    with urlopen(url) as request:
        img_array = np.asarray(bytearray(request.read()), dtype=np.uint8)
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  #as RGB image (cv2 is BGR by default)
Run Code Online (Sandbox Code Playgroud)

当我使用数据集的一个元素测试我的函数时,它可以工作:

url = next(iter(dataset)).numpy().decode('utf-8')
img = read_image_from_url(url)
plt.imshow(img)
Run Code Online (Sandbox Code Playgroud)

但是当我将我的函数映射到数据集以创建一个为图像提供服务的新数据集时,它失败了:

dataset_images = dataset.map(lambda x: read_image_from_url(x.numpy().decode('utf-8')))

AttributeError: in converted code:

    <ipython-input-6-e8eb89833196>:2 None  *
        map_func=lambda x: read_image_from_url(x.numpy().decode('utf-8')),

    AttributeError: 'Tensor' object has no attribute 'numpy'
Run Code Online (Sandbox Code Playgroud)

显然,当使用next或迭代时,数据集提供了不同的数据类型map。知道我该如何解决这个问题吗?

Fre*_*ode 9

嗯,这比它需要的要困难得多:

import tensorflow as tf
import numpy as np 
import cv2
from urllib.request import urlopen
import matplotlib.pyplot as plt
image_urls = [
    'https://upload.wikimedia.org/wikipedia/commons/6/60/Matterhorn_from_Domh%C3%BCtte_-_2.jpg',
    'https://upload.wikimedia.org/wikipedia/commons/6/6e/Matterhorn_from_Klein_Matterhorn.jpg',
]
dataset = tf.data.Dataset.from_tensor_slices(image_urls)

def get(url):
    with urlopen(str(url.numpy().decode("utf-8"))) as request:
        img_array = np.asarray(bytearray(request.read()), dtype=np.uint8)
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def read_image_from_url(url):
    return tf.py_function(get, [url], tf.uint8)


dataset_images = dataset.map(lambda x: read_image_from_url(x))

for d in dataset_images:
  print(d)

Run Code Online (Sandbox Code Playgroud)

为什么第一个工作,然后在tf.Dataset?那么tf.Dataset在确定graph mode没有eager mode像第一个。图形模式更快,并且tf.Dataset针对速度进行了优化,因此很有意义。你不能.numpy()在图形模式下做,因为一切都应该在tensorflowops 中定义。py_func将一个 python 函数封装在一个tf.Operation在 中执行的函数中eager mode,这正是我们所需要的。

注意:我尝试过tf.keras.utils.get_file(),但遇到了与您在此处描述的类似问题。希望这可以帮助!