我正在尝试在 Pytorch 中运行以下代码:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
class H5Dataset(data.Dataset):
def __init__(self, trainx_path, trainy_path):
super(H5Dataset, self).__init__()
x_file = h5py.File(trainx_path)
y_file = h5py.File(trainy_path)
self.data = x_file.get('X')
self.target = y_file.get('y')
def __getitem__(self, size):
permutation1 = list(np.random.permutation(249000))
permutation2 = list(np.random.permutation(np.arange(249000,498000)))
size1 = int(size/2)
index1=list(permutation1[0:size1])
index2=list(permutation2[0:size1])
index = index1+index2
labels=np.array(self.target).reshape(498000,-1)
train_labels=labels[index]
train_batch=[]
for i in range(size):
img=(self.data)[index[i]]
train_batch.append(img)
train_batch=np.array(train_batch)
return (torch.from_numpy(train_batch).float(), torch.from_numpy(train_labels).float())
def __len__(self):
return len(self.data)
dataset = H5Dataset('//content//drive//My …Run Code Online (Sandbox Code Playgroud)