Div*_*oML 0 numpy pytorch tensor
我编写了一个自定义的 pytorch Dataset
,该__getitem__()
函数返回一个 shape 的张量(250, 150)
,然后我用来DataLoader
生成一批批大小为 10 的数据。我的意图是拥有一个 shape 的批,(2500, 150)
作为这 10 个张量沿维度 0 的串联,但输出的DataLoader
已有形状(10, 250, 150)
。如何将 的输出转换为沿维度 0 串联的DataLoader
形状?(2500, 150)
PyTorch DataLoader 始终会在第 0 个索引处添加额外的批次维度。所以,如果你得到一个形状的张量(10, 250, 150)
,你可以简单地重塑它
# x is of shape (10, 250, 150)
x_ = x.view(-1, 150)
# x_ is of shape (2500, 150)
Run Code Online (Sandbox Code Playgroud)
或者,更正确地说,您可以为数据加载器提供自定义整理器
def custom_collate(batch):
# each item in batch is (250, 150) as returned by __getitem__
return torch.cat(batch, 0)
dl = DataLoader(dataset, batch_size=10, collate_fn=custom_collate, ...)
Run Code Online (Sandbox Code Playgroud)
这将在数据加载器本身中创建适当大小的张量,因此不需要使用 进行任何后处理.view()
。
归档时间: |
|
查看次数: |
1672 次 |
最近记录: |