tf.contrib.data.Dataset似乎不支持SparseTensor

aur*_*oua 5 tensorflow tensorflow-datasets

我使用tensorflow对象检测API中的代码生成了pascal voc 2007 tfrecords文件。我使用tf.contrib.data.DatasetAPI从tfrecords中读取数据。我在没有tf.contrib.data.DatasetAPI的情况下尝试了mehtod ,并且代码可以正常运行,但是更改为tf.contrib.data.DatasetAPI后无法正常工作。

没有tf.contrib.data.Dataset以下代码:

import tensorflow as tf

if __name__ == '__main__':
    slim_example_decoder = tf.contrib.slim.tfexample_decoder

    features = {"image/height": tf.FixedLenFeature((), tf.int64, default_value=1),
                "image/width": tf.FixedLenFeature((), tf.int64, default_value=1),
                "image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
                "image/source_id": tf.FixedLenFeature((), tf.string, default_value=""),
                "image/key/sha256": tf.FixedLenFeature((), tf.string, default_value=""),
                "image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
                "image/format": tf.FixedLenFeature((), tf.string, default_value="jpeg"),
                "image/object/bbox/xmin": tf.VarLenFeature(tf.float32),
                "image/object/bbox/xmax": tf.VarLenFeature(tf.float32),
                "image/object/bbox/ymin": tf.VarLenFeature(tf.float32),
                "image/object/bbox/ymax": tf.VarLenFeature(tf.float32),
                "image/object/class/text": tf.VarLenFeature(tf.string),
                "image/object/class/label": tf.VarLenFeature(tf.int64),
                "image/object/difficult": tf.VarLenFeature(tf.int64),
                "image/object/truncated": tf.VarLenFeature(tf.int64),
                "image/object/view": tf.VarLenFeature(tf.int64)}
    items_to_handlers = {
        'image': slim_example_decoder.Image(
            image_key='image/encoded', format_key='image/format', channels=3),
        'source_id': (
            slim_example_decoder.Tensor('image/source_id')),
        'key': (
            slim_example_decoder.Tensor('image/key/sha256')),
        'filename': (
            slim_example_decoder.Tensor('image/filename')),
        # Object boxes and classes.
        'groundtruth_boxes': (
            slim_example_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
        'groundtruth_classes': (
            slim_example_decoder.Tensor('image/object/class/label')),
        'groundtruth_difficult': (
            slim_example_decoder.Tensor('image/object/difficult')),
        'image/object/truncated': (
            slim_example_decoder.Tensor('image/object/truncated')),
        'image/object/view': (
            slim_example_decoder.Tensor('image/object/view')),
    }
    decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
    keys = decoder.list_items()
    for example in tf.python_io.tf_record_iterator(
          "/home/aurora/workspaces/data/tfrecords_data/oxford_pet/pet_train.record"):
        serialized_example = tf.reshape(example, shape=[])
        tensors = decoder.decode(serialized_example, items=keys)
        tensor_dict = dict(zip(keys, tensors))
        tensor_dict['image'].set_shape([None, None, 3])
        print(tensor_dict)
Run Code Online (Sandbox Code Playgroud)

上面代码的输出是:

{'image': <tf.Tensor 'case/If_1/Merge:0' shape=(?, ?, 3) dtype=uint8>,
 'filename': <tf.Tensor 'Reshape_2:0' shape=() dtype=string>,
 'groundtruth_boxes': <tf.Tensor 'transpose:0' shape=(?, 4) dtype=float32>,
 'key': <tf.Tensor 'Reshape_5:0' shape=() dtype=string>,
 'image/object/truncated': <tf.Tensor 'SparseToDense:0' shape=(?,) dtype=int64>,
 'groundtruth_classes': <tf.Tensor 'SparseToDense_2:0' shape=(?,) dtype=int64>,
 'image/object/view': <tf.Tensor 'SparseToDense_1:0' shape=(?,) dtype=int64>,
 'source_id': <tf.Tensor 'Reshape_6:0' shape=() dtype=string>,
 'groundtruth_difficult': <tf.Tensor 'SparseToDense_3:0' shape=(?,) dtype=int64>}
...
Run Code Online (Sandbox Code Playgroud)

代码tf.contrib.data.Dataset

import tensorflow as tf
from tensorflow.contrib.data import Iterator

slim_example_decoder = tf.contrib.slim.tfexample_decoder

flags = tf.app.flags
flags.DEFINE_string('data_dir',
  '/home/aurora/workspaces/data/tfrecords_data/voc_dataset/trainval.tfrecords',
  'tfrecords file output path')
flags.DEFINE_integer('batch_size', 32, 'training batch size')
flags.DEFINE_integer('capacity', 10000, 'training batch size')
FLAGS = flags.FLAGS

features = {"image/height": tf.FixedLenFeature((), tf.int64, default_value=1),
        "image/width": tf.FixedLenFeature((), tf.int64, default_value=1),
        "image/filename": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/source_id": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/key/sha256": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
        "image/format": tf.FixedLenFeature((), tf.string, default_value="jpeg"),
        "image/object/bbox/xmin": tf.VarLenFeature(tf.float32),
        "image/object/bbox/xmax": tf.VarLenFeature(tf.float32),
        "image/object/bbox/ymin": tf.VarLenFeature(tf.float32),
        "image/object/bbox/ymax": tf.VarLenFeature(tf.float32),
        "image/object/class/text": tf.VarLenFeature(tf.string),
        "image/object/class/label": tf.VarLenFeature(tf.int64),
        "image/object/difficult": tf.VarLenFeature(tf.int64),
        "image/object/truncated": tf.VarLenFeature(tf.int64),
        "image/object/view": tf.VarLenFeature(tf.int64)
      }

items_to_handlers = {
    'image': slim_example_decoder.Image(
        image_key='image/encoded', format_key='image/format', channels=3),
    'source_id': (
        slim_example_decoder.Tensor('image/source_id')),
    'key': (
        slim_example_decoder.Tensor('image/key/sha256')),
    'filename': (
        slim_example_decoder.Tensor('image/filename')),
    # Object boxes and classes.
    'groundtruth_boxes': (
        slim_example_decoder.BoundingBox(
            ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
    'groundtruth_classes': (
        slim_example_decoder.Tensor('image/object/class/label')),
    'groundtruth_difficult': (
        slim_example_decoder.Tensor('image/object/difficult')),
    'image/object/truncated': (
        slim_example_decoder.Tensor('image/object/truncated')),
    'image/object/view': (
        slim_example_decoder.Tensor('image/object/view')),
    }
decoder = slim_example_decoder.TFExampleDecoder(features, items_to_handlers)
keys = decoder.list_items()


def _parse_function_train(example):
  serialized_example = tf.reshape(example, shape=[])
  tensors = decoder.decode(serialized_example, items=keys)
  tensor_dict = dict(zip(keys, tensors))
  tensor_dict['image'].set_shape([None, None, 3])
  print(tensor_dict)
  return tensor_dict


if __name__ == '__main__':
    train_dataset = tf.contrib.data.TFRecordDataset(FLAGS.data_dir)
    train_dataset = train_dataset.map(_parse_function_train)
    train_dataset = train_dataset.repeat(1)
    train_dataset = train_dataset.batch(FLAGS.batch_size)
    train_dataset = train_dataset.shuffle(buffer_size=FLAGS.capacity)
    iterator = Iterator.from_structure(train_dataset.output_types,
                                   train_dataset.output_shapes)
    next_element = iterator.get_next()
    training_init_op = iterator.make_initializer(train_dataset)

    sess = tf.Session()
    sess.run(training_init_op)
    counter = 0
    while True:
        try:
            sess.run(next_element)
            counter += 1
        except tf.errors.OutOfRangeError:
            print('End of training data in step %d' %counter)
            break
Run Code Online (Sandbox Code Playgroud)

运行上面的代码时,它将报告以下错误:

2017-10-09 23:41:43.488439: W tensorflow/core/framework/op_kernel.cc:1192]     Invalid argument: Name: <unknown>, Key: image/object/view, Index: 0.  Data types don't match. Expected type: int64
2017-10-09 23:41:43.488554: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Name: <unknown>, Key: image/object/view, Index: 0.  Data types don't match. Expected type: int64
 [[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
Traceback (most recent call last):
  File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
  File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
  File "/usr/software/anaconda3/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
  File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: <unknown>, Key: image/object/view, Index: 0.  Data types don't match. Expected type: int64
 [[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?,?,?,3], [?,?], [?,?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_INT64, DT_STRING, DT_STRING], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/aurora/workspaces/PycharmProjects/object_detection_models/datasets/voc_dataset/voc_tfrecords_decode_test.py", line 83, in <module>
sess.run(next_element)
  File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
  File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
  File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
  File "/usr/software/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Name: <unknown>, Key: image/object/view, Index: 0.  Data types don't match. Expected type: int64
 [[Node: ParseSingleExample/ParseExample/ParseExample = ParseExample[Ndense=7, Nsparse=9, Tdense=[DT_STRING, DT_STRING, DT_STRING, DT_INT64, DT_STRING, DT_STRING, DT_INT64], dense_shapes=[[], [], [], [], [], [], []], sparse_types=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64]](ParseSingleExample/ExpandDims, ParseSingleExample/ParseExample/ParseExample/names, ParseSingleExample/ParseExample/ParseExample/sparse_keys_0, ParseSingleExample/ParseExample/ParseExample/sparse_keys_1, ParseSingleExample/ParseExample/ParseExample/sparse_keys_2, ParseSingleExample/ParseExample/ParseExample/sparse_keys_3, ParseSingleExample/ParseExample/ParseExample/sparse_keys_4, ParseSingleExample/ParseExample/ParseExample/sparse_keys_5, ParseSingleExample/ParseExample/ParseExample/sparse_keys_6, ParseSingleExample/ParseExample/ParseExample/sparse_keys_7, ParseSingleExample/ParseExample/ParseExample/sparse_keys_8, ParseSingleExample/ParseExample/ParseExample/dense_keys_0, ParseSingleExample/ParseExample/ParseExample/dense_keys_1, ParseSingleExample/ParseExample/ParseExample/dense_keys_2, ParseSingleExample/ParseExample/ParseExample/dense_keys_3, ParseSingleExample/ParseExample/ParseExample/dense_keys_4, ParseSingleExample/ParseExample/ParseExample/dense_keys_5, ParseSingleExample/ParseExample/ParseExample/dense_keys_6, ParseSingleExample/ParseExample/Reshape, ParseSingleExample/ParseExample/Reshape_1, ParseSingleExample/ParseExample/Reshape_2, ParseSingleExample/ParseExample/Reshape_3, ParseSingleExample/ParseExample/Reshape_4, ParseSingleExample/ParseExample/Reshape_5, ParseSingleExample/ParseExample/Reshape_6)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?,?,4], [?,?], [?,?], [?,?,?,3], [?,?], [?,?], [?], [?]], output_types=[DT_STRING, DT_FLOAT, DT_INT64, DT_INT64, DT_UINT8, DT_INT64, DT_INT64, DT_STRING, DT_STRING], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
Run Code Online (Sandbox Code Playgroud)

生成tfrecords文件的代码可以引用create_pascal_tf_record.py

mrr*_*rry 5

编辑 (2018/01/25): 在 TensorFlow 1.5 中tf.SparseTensor添加了支持tf.data。问题中的代码应该适用于 TensorFlow 1.5 或更高版本。


在 TF 1.4 之前,tf.contrib.dataAPI 不支持tf.SparseTensor数据集元素中的对象。有几种解决方法:

  1. (更难,但可能更快。)如果一个tf.SparseTensor对象st表示一个可变长度的特征列表,您可能能够返回st.values而不是stmap()函数中返回。请注意,您可能需要使用Dataset.padded_batch()而不是填充结果Dataset.batch()

  2. (更简单,但可能更慢。)_parse_function_train()函数中,迭代tensor_dict并生成一个新版本,其中任何tf.SparseTensor对象都已转换为tf.Tensorusing tf.serialize_sparse()。当你

    # NOTE: You could probably infer these from `keys`.
    sparse_keys = set()
    
    def _parse_function_train(example):
      serialized_example = tf.reshape(example, shape=[])
      tensors = decoder.decode(serialized_example, items=keys)
      tensor_dict = dict(zip(keys, tensors))
      tensor_dict['image'].set_shape([None, None, 3])
    
      rewritten_tensor_dict = {}
      for key, value in tensor_dict.items():
        if isinstance(value, tf.SparseTensor):
          rewritten_tensor_dict[key] = tf.serialize_sparse(value)
          sparse_keys.add(key)
        else:
          rewritten_tensor_dict[key] = value
      return rewritten_tensor_dict
    
    Run Code Online (Sandbox Code Playgroud)

    然后,next_element从获得字典后iterator.get_next(),您可以使用tf.deserialize_many_sparse()以下方法反转此转换:

    next_element = iterator.get_next()
    
    for key in sparse_keys:
      next_element[key] = tf.deserialize_many_sparse(key)
    
    Run Code Online (Sandbox Code Playgroud)