如何沿着单个 pytorch 张量的维度连接?

Div*_*oML 0 numpy pytorch tensor

我编写了一个自定义的 pytorch Dataset,该__getitem__()函数返回一个 shape 的张量(250, 150),然后我用来DataLoader生成一批批大小为 10 的数据。我的意图是拥有一个 shape 的批,(2500, 150)作为这 10 个张量沿维度 0 的串联,但输出的DataLoader已有形状(10, 250, 150)。如何将 的输出转换为沿维度 0 串联的DataLoader形状?(2500, 150)

Aya*_*Das 5

PyTorch DataLoader 始终会在第 0 个索引处添加额外的批次维度。所以,如果你得到一个形状的张量(10, 250, 150),你可以简单地重塑它

# x is of shape (10, 250, 150)
x_ = x.view(-1, 150)
# x_ is of shape (2500, 150)
Run Code Online (Sandbox Code Playgroud)

或者,更正确地说,您可以为数据加载器提供自定义整理器

def custom_collate(batch):
    # each item in batch is (250, 150) as returned by __getitem__
    return torch.cat(batch, 0)

dl = DataLoader(dataset, batch_size=10, collate_fn=custom_collate, ...)
Run Code Online (Sandbox Code Playgroud)

这将在数​​据加载器本身中创建适当大小的张量,因此不需要使用 进行任何后处理.view()