小编Alm*_*evi的帖子

如何从 DataLoader 获取样本的文件名?

我需要用我训练的卷积神经网络的数据测试结果编写一个文件。数据包括语音数据收集。文件格式需要为“文件名,预测”,但我很难提取文件名。我像这样加载数据:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
Run Code Online (Sandbox Code Playgroud)

我正在尝试按如下方式写入文件:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()
Run Code Online (Sandbox Code Playgroud)

问题os.listdir(TESTH_DATA_PATH + "/all")[i]在于它与加载的文件顺序不同步test_loader。我能做什么?

python machine-learning pytorch torchvision

5
推荐指数
1
解决办法
1万
查看次数

标签 统计

machine-learning ×1

python ×1

pytorch ×1

torchvision ×1