小编Jon*_*n R的帖子

使用 tensorflow image_dataset_from_directory 时从数据集中获取标签

我在 python (v3.8.3) 中使用 tensorflow (v2.4) + keras 编写了一个简单的 CNN。我正在尝试优化网络,我想要更多关于它无法预测的信息。我正在尝试添加一个混淆矩阵,我需要提供 tensorflow.math.confusion_matrix() 测试标签。

我的问题是我无法弄清楚如何从 tf.keras.preprocessing.image_dataset_from_directory() 创建的数据集对象访问标签

我的图像组织在以标签为名称的目录中。文档说该函数返回一个 tf.data.Dataset 对象。

If label_mode is None, it yields float32 tensors of shape (batch_size, image_size[0], image_size[1], num_channels), encoding
Run Code Online (Sandbox Code Playgroud)

图像(有关 num_channels 的规则,请参见下文)。否则,它会生成一个元组(图像、标签),其中图像具有形状(batch_size、image_size[0]、image_size[1]、num_channels),标签遵循下面描述的格式。

这是代码:

import tensorflow as tf
from tensorflow.keras import layers
#import matplotlib.pyplot as plt
import numpy as np
import random

import PIL
import PIL.Image

import os
import pathlib

#load the IMAGES
dataDirectory = '/p/home/username/tensorflow/newBirds'

dataDirectory = pathlib.Path(dataDirectory)
imageCount = len(list(dataDirectory.glob('*/*.jpg')))
print('Image count: {0}\n'.format(imageCount))

#test display …
Run Code Online (Sandbox Code Playgroud)

python keras tensorflow

7
推荐指数
2
解决办法
6296
查看次数

标签 统计

keras ×1

python ×1

tensorflow ×1