在 tf 函数内迭代 tf.Tensor 以生成基于 NamedTuple 的数据集项列表

pal*_*asb 6 python tensorflow tensorflow2.0

typing.NamedTupletf.data.Dataset. 下面是一个例子。

# You can run all the code in this question by pasting all
# the code blocks consecutively into a Python file

import tensorflow as tf
from typing import *
from random import *
from pprint import *

class Coord(NamedTuple):
    x: float
    y: float

    @classmethod
    def random(cls): return cls(gauss(10., 1.), gauss(10., 1.))

class Box(NamedTuple):
    min: Coord
    max: Coord

    @classmethod
    def random(cls): return cls(Coord.random(), Coord.random())

class Boxes(NamedTuple):
    boxes: List[Box]

    @classmethod
    def random(cls): return cls([Box.random() for _ in range(randint(3, 5))])

def test_dataset():
    for _ in range(randint(3, 5)): yield Boxes.random()

tf_dataset = tf.data.Dataset.from_generator(test_dataset, output_types=(tf.float32,))
Run Code Online (Sandbox Code Playgroud)

如您所知,tf.data.Dataset.from_generator()将数据集元素(最初具有Boxes类型)转换为tf.Tensor具有(None, 2, 2)形状的单元素元组。例如,数据集的一个元素可能是以下项目:

(<tf.Tensor: shape=(4, 2, 2), dtype=float32, numpy=
array([[[11.642379,  9.937152],
        [ 8.998009,  8.387287]],

       [[10.649337, 10.028358],
        [ 8.507834,  9.84779 ]],

       [[11.10263 , 11.3706  ],
        [ 9.20623 , 10.44905 ]],

       [[ 9.591406,  9.560486],
        [ 9.461394,  9.256082]]], dtype=float32)>,)
Run Code Online (Sandbox Code Playgroud)

我有非@tf.function注释的常规 Python 函数,可以将数据转换为原始类型,例如以下函数:

def flip_boxes(boxes: Boxes):
    def flip_coord(c: Coord): return Coord(-c.x, c.y)
    def flip_box(b: Box): return Box(flip_coord(b.min), flip_coord(b.max))
    return Boxes(boxes=list(map(flip_box, boxes.boxes)))
Run Code Online (Sandbox Code Playgroud)

我想通过该函数将这个 Python 函数(以及其他类似tf.data.Datasettf.data.Dataset.map(map_func)函数)应用于此。Dataset.map期望map_func是一个函数,以它们的格式获取数据集元素类型的成员tf.Tensor。原始元素类型是Boxes它有一个成员,最初是boxes: List[Box](4, 2, 2)创建数据集时,该列表将转换为上面的-shape 张量。tf.data.Dataset.map()调用时不回变map_func,Tensor 直接作为第一个参数传递给map_func. (如果Boxes有更多成员,这些成员将作为单独的参数传递给,map_func并且它们不会作为单个元组传递。)

问题:我要实现什么适配器函数才能使常规 Python 函数(如flip_boxes)可用于tf.data.Dataset.map()

我尝试迭代并使用从输入中tf.split恢复 a但我遇到了下面作为注释列出的错误消息。List[Boxes]tf.Tensor

# Question: How do I implement this function?
def to_tf_mappable_function(fn: Callable) -> Callable:

    def function(tensor: tf.Tensor):
        boxes: List[Box] = [Box(Coord(10.0, 0.0), Coord(10.0, 0.0)), Box(Coord(10.0, 0.0), Coord(10.0, 0.0))]
        # TODO calculate `boxes` from `tensor`, not use this dummy constant above

        # Trivial Python code does not work, it results in this error on the commented-out line:
        #   OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:
        #   AutoGraph is disabled in this function. Try decorating it directly with @tf.function.
        # boxes = [Box(Coord(row[0][0], row[0][1]), Coord(row[1][0], row[1][1])) for row in tensor]
        # Decorating any of flip_boxes, to_tf_mappable_function and to_tf_mappable_function.function
        # does not eliminate the error.

        # I thought tf.split might help, but it results in this error on the commented-out line:
        #   ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split.
        #   Argument provided: Tensor("cond/Identity:0", shape=(), dtype=int32)
        # boxes = tf.split(tensor, len(tensor))

        return fn(Boxes(boxes))

    return function

tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
# The line above should be morally equivalent to `dataset = map(flip_boxes, dataset)`,
# given a `dataset: Iterable[Boxes]` and the builtin `map` function in Python.
Run Code Online (Sandbox Code Playgroud)

也许我没有问正确的问题,但请给我一些懈怠。* 高级任务是以有效的方式将flip_boxes类似的函数应用到 a tf.data.Dataset* 我被卡住的地方是List[Box]tf.Tensor形状与框坐标列表完全一样的 a 中恢复 a ,所以也许我的问题应该仅限于此问题。

Ale*_*rov 0

我不确定您是否正在寻找更通用的东西,但对于您在这里提出的确切问题,这似乎是实现它的可能方法之一:

# Helper function to translate from tensor back to Boxes type
def boxes_from_tensor(t: tf.Tensor) -> Boxes:
    n_boxes = t.shape[0]
    t = t.numpy()
    boxes = Boxes(boxes=[Box(Coord(t[i,0,0], t[i,0,1]), Coord(t[i,1,0], t[i,1,1])) for i in range(n_boxes)])
    return boxes

def to_tf_mappable_function(fn: Callable) -> Callable:
    def function(tensor: tf.Tensor):
        return tf.py_function(lambda t: fn(boxes_from_tensor(t)), [tensor], tensor.dtype)
    return function

tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
list(tf_dataset)
Run Code Online (Sandbox Code Playgroud)