我正在尝试将我们的输入管道移动到 tensorflow 数据集 api。为此,我们已将图像和标签转换为 tfrecords。然后我们通过dataset api读取tfrecords,比较初始数据和读取的数据是否相同。到现在为止还挺好。下面是将 tfrecords 读入数据集的代码
def _parse_function2(proto):
# define your tfrecord again. Remember that you saved your image as a string.
keys_to_features = {"im_path": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
"im_shape": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
"score_shape": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
"geo_shape": tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
"im_patches": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
"score_patches": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
"geo_patches": tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
}
# Load one example
parsed_features = tf.parse_single_example(serialized=proto, features=keys_to_features)
parsed_features['im_patches'] = parsed_features['im_patches'][0]
parsed_features['score_patches'] = parsed_features['score_patches'][0]
parsed_features['geo_patches'] = parsed_features['geo_patches'][0]
parsed_features['im_patches'] = tf.decode_raw(parsed_features['im_patches'], tf.uint8)
parsed_features['im_patches'] = tf.reshape(parsed_features['im_patches'], parsed_features['im_shape']) …Run Code Online (Sandbox Code Playgroud)