Tensorflow get_single_element 不适用于 tf.data.TFRecordDataset.batch()

tak*_*lex 9 python-3.x tensorflow tfrecord tensorflow-datasets

我正在尝试对 Tensorflow 数据集执行 ZCA 白化。为了做到这一点,我尝试从数据集中提取数据作为张量,执行白化,然后创建另一个数据集。

我按照此处的示例从 TFRecordDataset 获取数据集作为 numpy 数组,不包括评估张量的点。

get_single_element 抛出此错误:

Traceback (most recent call last):
  File "/Users/takeoffs/Code/takeoffs_ai/test_pipeline_local.py", line 239, in <module>
    validation_steps=val_steps, callbacks=callbacks)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 780, in fit
    steps_name='steps_per_epoch')
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 198, in model_iteration
    val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 517, in _get_iterator
    return training_utils.get_iterator(inputs)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_utils.py", line 1315, in get_iterator
    initialize_iterator(iterator)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_utils.py", line 1322, in initialize_iterator
    K.get_session((init_op,)).run(init_op)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dataset had more than one element.
     [[node DatasetToSingleElement_1 (defined at /test_pipeline_local.py:88) ]]
Run Code Online (Sandbox Code Playgroud)

奇怪的是,根据上面链接的帖子,batch() 应该返回一个包含单个元素的数据集。

这是我正在运行的代码。为了本地测试目的,我将批量大小硬编码为 20。

def _tfrec_ds(tfrec_path, restore_shape, dtype):
    """Reads in a tf record dataset

    Args:
        tfrec_path (str): Str for path to a tfrecord file
        restore_shape (tuple(int)): shape to transform data to
        dtype (TF type): datatype to cast to

    Returns:
        ds: a dataset
    """
    ds = tf.data.TFRecordDataset(tfrec_path)
    def parse(x):
        result = tf.parse_tensor(x, out_type=dtype)
        result = tf.reshape(result, restore_shape)
        result = tf.cast(result, tf.float32)
        return result
    ds = ds.map(parse, num_parallel_calls=tf.contrib.data.AUTOTUNE)
    return ds

def get_data_zip(in_dir,
             num_samples_fname,
             x_shape,
             y_shape,
             batch_size=5,
             dtype=tf.float32,
             X_fname="X.tfrec",
             y_fname="y.tfrec",
             augment=True):
    #Get number of samples
    with FileIO(in_dir + num_samples_fname, "r") as f:
        N = int(f.readlines()[0])
    #Load in TFRecordDatasets
    if in_dir[len(in_dir)-1] != "/":
        in_dir += "/"
    N = 20
    def zca(x):
        '''Returns tf Dataset X with ZCA whitened pixels.'''
        flat_x = tf.reshape(x, (N, (x_shape[0] * x_shape[1] * x_shape[2])))
        sigma = tf.tensordot(tf.transpose(flat_x), flat_x, axes=1) / 20
        u, s, _ = tf.linalg.svd(sigma)
        s_inv = 1. / tf.math.sqrt(s + 1e-6)
        a = tf.tensordot(u, s_inv, axes=1)
        principal_components = tf.tensordot(a, tf.transpose(u), axes=1)
        whitex = flat_x*principal_components
        batch_shape = [N] + list(x_shape)
        x = tf.reshape(whitex, batch_shape)
        return x
    X_path = in_dir + X_fname
    y_path = in_dir + y_fname
    X = _tfrec_ds(X_path, x_shape, dtype)
    y = _tfrec_ds(y_path, y_shape, dtype)
    buffer_size = 500
    shuffle_seed = 8
    #Perform ZCA whitening
    dataset = X.batch(N)
    whole_dataset_tensors = tf.data.experimental.get_single_element(dataset)
    whole_dataset_tensors = zca(whole_dataset_tensors)
    X = tf.data.Dataset.from_tensor_slices(whole_dataset_tensors)
    #Shuffle, repeat and batch
    Xy = tf.data.Dataset.zip((X, y))
    Xy = Xy.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=buffer_size, seed=shuffle_seed))\
        .batch(batch_size).prefetch(tf.contrib.data.AUTOTUNE)
    return Xy, N
Run Code Online (Sandbox Code Playgroud)