我在 python (v3.8.3) 中使用 tensorflow (v2.4) + keras 编写了一个简单的 CNN。我正在尝试优化网络,我想要更多关于它无法预测的信息。我正在尝试添加一个混淆矩阵,我需要提供 tensorflow.math.confusion_matrix() 测试标签。
我的问题是我无法弄清楚如何从 tf.keras.preprocessing.image_dataset_from_directory() 创建的数据集对象访问标签
我的图像组织在以标签为名称的目录中。文档说该函数返回一个 tf.data.Dataset 对象。
Run Code Online (Sandbox Code Playgroud)If label_mode is None, it yields float32 tensors of shape (batch_size, image_size[0], image_size[1], num_channels), encoding图像(有关 num_channels 的规则,请参见下文)。否则,它会生成一个元组(图像、标签),其中图像具有形状(batch_size、image_size[0]、image_size[1]、num_channels),标签遵循下面描述的格式。
这是代码:
import tensorflow as tf
from tensorflow.keras import layers
#import matplotlib.pyplot as plt
import numpy as np
import random
import PIL
import PIL.Image
import os
import pathlib
#load the IMAGES
dataDirectory = '/p/home/username/tensorflow/newBirds'
dataDirectory = pathlib.Path(dataDirectory)
imageCount = len(list(dataDirectory.glob('*/*.jpg')))
print('Image count: {0}\n'.format(imageCount))
#test display …Run Code Online (Sandbox Code Playgroud)