如何在序列特征上应用tf.map_fn?收到错误消息:TensorArray dtype是字符串,但是Op试图写dtype uint8

Sye*_*med 4 tensorflow

我正在编写一个序列到将视频映射到文本的序列模型。我在SequenceExample原型的序列功能中将视频的帧编码为JPEG字符串。构建输入管道时,我正在执行以下操作以获取解码的jpeg数组:

encoded_video, caption = parse_sequence_example(
                    serialized_sequence_example,
                    video_feature="video/frames",
                    caption_feature="video/caption_ids")
decoded_video = tf.map_fn(lambda x: tf.image.decode_jpeg(x, channels=3), encoded_video)
Run Code Online (Sandbox Code Playgroud)

但是,我收到以下错误:

InvalidArgumentError (see above for traceback): TensorArray dtype is string but Op is trying to write dtype uint8.
Run Code Online (Sandbox Code Playgroud)

我的目标是image = tf.image.convert_image_dtype(image, dtype=tf.float32)在解码后应用它,以使uint8的像素值介于[0,255]之间并在[0,1]之间浮动。

我尝试执行以下操作:

decoded_video = tf.map_fn(lambda x: tf.image.decode_jpeg(x, channels=3), encoded_video, dtype=tf.uint8)
converted_video = tf.map_fn(lambda x: tf.image.convert_image_dtype(x, dtype=tf.float32), decoded_video)
Run Code Online (Sandbox Code Playgroud)

但是,我仍然遇到相同的错误。任何人都不知道会有什么问题吗?提前致谢。

Sye*_*med 5

没关系。只需在以下行中显式添加tf.float32的dtype:

converted_video = tf.map_fn(lambda x: tf.image.convert_image_dtype(x, dtype=tf.float32), decoded_video, dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)