pal*_*asb 6 python tensorflow tensorflow2.0
我typing.NamedTuple在tf.data.Dataset. 下面是一个例子。
# You can run all the code in this question by pasting all
# the code blocks consecutively into a Python file
import tensorflow as tf
from typing import *
from random import *
from pprint import *
class Coord(NamedTuple):
x: float
y: float
@classmethod
def random(cls): return cls(gauss(10., 1.), gauss(10., 1.))
class Box(NamedTuple):
min: Coord
max: Coord
@classmethod
def random(cls): return cls(Coord.random(), Coord.random())
class Boxes(NamedTuple):
boxes: List[Box]
@classmethod
def random(cls): return cls([Box.random() for _ in range(randint(3, 5))])
def test_dataset():
for _ in range(randint(3, 5)): yield Boxes.random()
tf_dataset = tf.data.Dataset.from_generator(test_dataset, output_types=(tf.float32,))
Run Code Online (Sandbox Code Playgroud)
如您所知,tf.data.Dataset.from_generator()将数据集元素(最初具有Boxes类型)转换为tf.Tensor具有(None, 2, 2)形状的单元素元组。例如,数据集的一个元素可能是以下项目:
(<tf.Tensor: shape=(4, 2, 2), dtype=float32, numpy=
array([[[11.642379, 9.937152],
[ 8.998009, 8.387287]],
[[10.649337, 10.028358],
[ 8.507834, 9.84779 ]],
[[11.10263 , 11.3706 ],
[ 9.20623 , 10.44905 ]],
[[ 9.591406, 9.560486],
[ 9.461394, 9.256082]]], dtype=float32)>,)
Run Code Online (Sandbox Code Playgroud)
我有非@tf.function注释的常规 Python 函数,可以将数据转换为原始类型,例如以下函数:
def flip_boxes(boxes: Boxes):
def flip_coord(c: Coord): return Coord(-c.x, c.y)
def flip_box(b: Box): return Box(flip_coord(b.min), flip_coord(b.max))
return Boxes(boxes=list(map(flip_box, boxes.boxes)))
Run Code Online (Sandbox Code Playgroud)
我想通过该函数将这个 Python 函数(以及其他类似tf.data.Dataset的tf.data.Dataset.map(map_func)函数)应用于此。Dataset.map期望map_func是一个函数,以它们的格式获取数据集元素类型的成员tf.Tensor。原始元素类型是Boxes它有一个成员,最初是boxes: List[Box]。(4, 2, 2)创建数据集时,该列表将转换为上面的-shape 张量。tf.data.Dataset.map()调用时不回变map_func,Tensor 直接作为第一个参数传递给map_func. (如果Boxes有更多成员,这些成员将作为单独的参数传递给,map_func并且它们不会作为单个元组传递。)
问题:我要实现什么适配器函数才能使常规 Python 函数(如flip_boxes)可用于tf.data.Dataset.map()?
我尝试迭代并使用从输入中tf.split恢复 a但我遇到了下面作为注释列出的错误消息。List[Boxes]tf.Tensor
# Question: How do I implement this function?
def to_tf_mappable_function(fn: Callable) -> Callable:
def function(tensor: tf.Tensor):
boxes: List[Box] = [Box(Coord(10.0, 0.0), Coord(10.0, 0.0)), Box(Coord(10.0, 0.0), Coord(10.0, 0.0))]
# TODO calculate `boxes` from `tensor`, not use this dummy constant above
# Trivial Python code does not work, it results in this error on the commented-out line:
# OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:
# AutoGraph is disabled in this function. Try decorating it directly with @tf.function.
# boxes = [Box(Coord(row[0][0], row[0][1]), Coord(row[1][0], row[1][1])) for row in tensor]
# Decorating any of flip_boxes, to_tf_mappable_function and to_tf_mappable_function.function
# does not eliminate the error.
# I thought tf.split might help, but it results in this error on the commented-out line:
# ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split.
# Argument provided: Tensor("cond/Identity:0", shape=(), dtype=int32)
# boxes = tf.split(tensor, len(tensor))
return fn(Boxes(boxes))
return function
tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
# The line above should be morally equivalent to `dataset = map(flip_boxes, dataset)`,
# given a `dataset: Iterable[Boxes]` and the builtin `map` function in Python.
Run Code Online (Sandbox Code Playgroud)
也许我没有问正确的问题,但请给我一些懈怠。* 高级任务是以有效的方式将flip_boxes类似的函数应用到 a tf.data.Dataset* 我被卡住的地方是List[Box]从tf.Tensor形状与框坐标列表完全一样的 a 中恢复 a ,所以也许我的问题应该仅限于此问题。
我不确定您是否正在寻找更通用的东西,但对于您在这里提出的确切问题,这似乎是实现它的可能方法之一:
# Helper function to translate from tensor back to Boxes type
def boxes_from_tensor(t: tf.Tensor) -> Boxes:
n_boxes = t.shape[0]
t = t.numpy()
boxes = Boxes(boxes=[Box(Coord(t[i,0,0], t[i,0,1]), Coord(t[i,1,0], t[i,1,1])) for i in range(n_boxes)])
return boxes
def to_tf_mappable_function(fn: Callable) -> Callable:
def function(tensor: tf.Tensor):
return tf.py_function(lambda t: fn(boxes_from_tensor(t)), [tensor], tensor.dtype)
return function
tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
list(tf_dataset)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
770 次 |
| 最近记录: |