tak*_*lex 9 python-3.x tensorflow tfrecord tensorflow-datasets
我正在尝试对 Tensorflow 数据集执行 ZCA 白化。为了做到这一点,我尝试从数据集中提取数据作为张量,执行白化,然后创建另一个数据集。
我按照此处的示例从 TFRecordDataset 获取数据集作为 numpy 数组,不包括评估张量的点。
get_single_element 抛出此错误:
Traceback (most recent call last):
File "/Users/takeoffs/Code/takeoffs_ai/test_pipeline_local.py", line 239, in <module>
validation_steps=val_steps, callbacks=callbacks)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 780, in fit
steps_name='steps_per_epoch')
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 198, in model_iteration
val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 517, in _get_iterator
return training_utils.get_iterator(inputs)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_utils.py", line 1315, in get_iterator
initialize_iterator(iterator)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_utils.py", line 1322, in initialize_iterator
K.get_session((init_op,)).run(init_op)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
run_metadata)
File "/Users/takeoffs/Code/takeoffs_ai/venv/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dataset had more than one element.
[[node DatasetToSingleElement_1 (defined at /test_pipeline_local.py:88) ]]
Run Code Online (Sandbox Code Playgroud)
奇怪的是,根据上面链接的帖子,batch() 应该返回一个包含单个元素的数据集。
这是我正在运行的代码。为了本地测试目的,我将批量大小硬编码为 20。
def _tfrec_ds(tfrec_path, restore_shape, dtype):
"""Reads in a tf record dataset
Args:
tfrec_path (str): Str for path to a tfrecord file
restore_shape (tuple(int)): shape to transform data to
dtype (TF type): datatype to cast to
Returns:
ds: a dataset
"""
ds = tf.data.TFRecordDataset(tfrec_path)
def parse(x):
result = tf.parse_tensor(x, out_type=dtype)
result = tf.reshape(result, restore_shape)
result = tf.cast(result, tf.float32)
return result
ds = ds.map(parse, num_parallel_calls=tf.contrib.data.AUTOTUNE)
return ds
def get_data_zip(in_dir,
num_samples_fname,
x_shape,
y_shape,
batch_size=5,
dtype=tf.float32,
X_fname="X.tfrec",
y_fname="y.tfrec",
augment=True):
#Get number of samples
with FileIO(in_dir + num_samples_fname, "r") as f:
N = int(f.readlines()[0])
#Load in TFRecordDatasets
if in_dir[len(in_dir)-1] != "/":
in_dir += "/"
N = 20
def zca(x):
'''Returns tf Dataset X with ZCA whitened pixels.'''
flat_x = tf.reshape(x, (N, (x_shape[0] * x_shape[1] * x_shape[2])))
sigma = tf.tensordot(tf.transpose(flat_x), flat_x, axes=1) / 20
u, s, _ = tf.linalg.svd(sigma)
s_inv = 1. / tf.math.sqrt(s + 1e-6)
a = tf.tensordot(u, s_inv, axes=1)
principal_components = tf.tensordot(a, tf.transpose(u), axes=1)
whitex = flat_x*principal_components
batch_shape = [N] + list(x_shape)
x = tf.reshape(whitex, batch_shape)
return x
X_path = in_dir + X_fname
y_path = in_dir + y_fname
X = _tfrec_ds(X_path, x_shape, dtype)
y = _tfrec_ds(y_path, y_shape, dtype)
buffer_size = 500
shuffle_seed = 8
#Perform ZCA whitening
dataset = X.batch(N)
whole_dataset_tensors = tf.data.experimental.get_single_element(dataset)
whole_dataset_tensors = zca(whole_dataset_tensors)
X = tf.data.Dataset.from_tensor_slices(whole_dataset_tensors)
#Shuffle, repeat and batch
Xy = tf.data.Dataset.zip((X, y))
Xy = Xy.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=buffer_size, seed=shuffle_seed))\
.batch(batch_size).prefetch(tf.contrib.data.AUTOTUNE)
return Xy, N
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1090 次 |
| 最近记录: |