在不复制内存的情况下重复 pytorch 张量

mcb*_*mcb 7 memory pytorch tensor

是否pytorch支持在不分配更多内存的情况下重复张量?

假设我们有一个张量

t = torch.ones((1,1000,1000))
t10 = t.repeat(10,1,1)
Run Code Online (Sandbox Code Playgroud)

重复t10 次需要占用 10 倍的内存。有没有办法在t10不分配更多内存的情况下创建张量?

是一个相关的问题,但没有答案。

jod*_*dag 8

您可以使用 torch.expand

t = torch.ones((1, 1000, 1000))
t10 = t.expand(10, 1000, 1000)
Run Code Online (Sandbox Code Playgroud)

请记住,t10只是对 的引用t。例如,对 的更改t10[0,0,0]将导致t[0,0,0]和 的每个成员发生相同的更改t10[:,0,0]

除了直接访问之外,对其执行的大多数操作t10都会导致内存被复制,这将破坏引用并导致使用更多内存。例如:更改设备(.cpu().to(device=...).cuda())、更改数据类型(.float().long().to(dtype=...))或使用.contiguous().