ash*_*ids 5 deep-learning tensorflow
我在 TensorFlow 代码中使用tf.data.Dataset
for input_fn
。我需要分别读取所有频道,因为它们存储在不同的文件中。下面的代码显示了加载和处理功能。
def load_images_as_tensor(image_paths, dtype=np.uint8):
n_channels = 6
image_paths = image_paths
data = np.ndarray(shape=(512, 512, n_channels), dtype=dtype)
for ix, img_path in enumerate(image_paths):
data[:, :, ix] = load_image(img_path)
return(data)
def process(img, pixel_stats=GLOBAL_PIXEL_STATS, use_bfloat16 = True):
if pixel_stats is not None:
mean, std = pixel_stats
img = (tf.cast(img, tf.float32) - mean) / std
if use_bfloat16:
img = tf.image.convert_image_dtype(img, dtype=tf.bfloat16)
img = img.set_shape([512, 512, 6])
return(img)
def input_fn(params):
data = tf.data.Dataset.from_tensor_slices(tmp2)
data = data.map(lambda x: tf.py_func(load_images_as_tensor,[x], tf.uint8))
data = data.map(lambda x: process(x, GLOBAL_PIXEL_STATS, True))
return(data)
Run Code Online (Sandbox Code Playgroud)
由于使用了tf.py_func
,TensorFlow 无法获得输入的形状。在 github 上经历了一个 TensorFlow 问题后,我认为手动指定形状可以解决错误。我仍然收到错误
类型错误:传递给 Dataset.map() 的函数不支持返回值:无。
归档时间: |
|
查看次数: |
1172 次 |
最近记录: |