因此,我将按照有关自定义数据集的文档中的本教程进行操作。我使用的是 MNIST 数据集,而不是教程中的奇特数据集。这是我写的类的扩展Dataset:
class KaggleMNIST(Dataset):
def __init__(self, csv_file, transform=None):
self.pixel_frame = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.pixel_frame)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()
image = self.pixel_frame.iloc[index, 1:]
image = np.array([image])
if self.transform:
image = self.transform(image)
return image
Run Code Online (Sandbox Code Playgroud)
它有效,直到我尝试对其使用转换:
tsf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = KaggleMNIST('train/train.csv', transform=tsf)
image0 = trainset[0]
Run Code Online (Sandbox Code Playgroud)
我查看了堆栈跟踪,看起来规范化正在这行代码中发生:
c:\program files\python38\lib\site-packages\torchvision\transforms\functional.py in normalize(tensor, mean, std, inplace)
--> 218 tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
Run Code Online (Sandbox Code Playgroud)
所以我不明白为什么要除以零,因为std应该是 0.5,远远接近一个小值。
感谢您的帮助!
编辑:
这并没有回答我的问题,但我发现如果我更改这些代码行:
image = self.pixel_frame.iloc[index, 1:]
image = np.array([image])
Run Code Online (Sandbox Code Playgroud)
到
image = self.pixel_frame.iloc[index, 1:].to_numpy(dtype='float64').reshape(1, -1)
Run Code Online (Sandbox Code Playgroud)
本质上,确保数据类型float64解决了问题。我仍然不确定为什么这个问题首先存在,所以我仍然很高兴得到一个解释清楚的答案!
dtype读取的数据为int64
img = np.array([pixel_frame.iloc[0, 1:]])
img.dtype
# output
dtype('int64')
Run Code Online (Sandbox Code Playgroud)
这会强制将平均值和标准差转换为int640.5,当标准差为 0.5 时,它会变为 0,并引发以下错误:
>>> tsf(img)
ValueError: std evaluated to zero after conversion to torch.int64, leading to division by zero.
Run Code Online (Sandbox Code Playgroud)
这是因为平均值和标准差dtype在标准化期间转换为数据集。
def normalize(tensor, mean, std, inplace=False):
...
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
Run Code Online (Sandbox Code Playgroud)
这就是为什么转换数据类型来float修复错误。
| 归档时间: |
|
| 查看次数: |
5067 次 |
| 最近记录: |