Dataset API does not pass dimensionality information for its output tensor when using py_func

Jia*_*nbo 6 python tensorflow tensorflow-datasets

To reproduce my problem, try this first (mapping with py_func):

import tensorflow as tf
import numpy as np
def image_parser(image_name):
    a = np.array([1.0,2.0,3.0], dtype=np.float32)
    return a

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser, [image], [tf.float32])), num_parallel_calls = 2)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)
Run Code Online (Sandbox Code Playgroud)

It will give you (TensorShape(None),)

However, if you try this (using direct tensorflow mapping instead of py_func):

import tensorflow as tf
import numpy as np

def image_parser(image_name)
    return image_name

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(image_parser)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)
Run Code Online (Sandbox Code Playgroud)

It will give you the exact tensor dimension (3,)

Oli*_*rot 7

tf.py_func由于TensorFlow无法推断输出形状本身,因此这是一个通用问题,例如,参见此答案

您可以根据需要自己设置形状,方法是移动tf.py_funcparse函数内部:

def parser(x):
    a = np.array([1.0,2.0,3.0])
    y = tf.py_func(lambda: a, [], tf.float32)
    y.set_shape((3,))
    return y

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(parser)
print(dataset.output_shapes)  # will correctly print (3,)
Run Code Online (Sandbox Code Playgroud)