Pytorch nn.function.interpolate 使用哪些信息?

Moh*_*mba 6 python interpolation pytorch

img在 PyTorch 中有一个大小的张量bx2xhxw,想使用 对其进行上采样torch.nn.functional.interpolate。但是,在插值时,我不希望通道 1 使用通道 2 中的信息。为此,我应该这样做,

img2 = torch.rand(b,2,2*h,2*w) # create a random torch tensor.
img2[:,0,:,:] = nn.functional.interpolate(img[:,0,:,:], [2*h,2*w], mode='bilinear', align_corners=True)
img2[:,1,:,:] = nn.functional.interpolate(img[:,1,:,:], [2*h,2*w], mode='bilinear', align_corners=True)
img=img2
Run Code Online (Sandbox Code Playgroud)

或者简单地使用

img = nn.functional.interpolate(img, [2*h,2*w], mode='bilinear', align_corners=True)
Run Code Online (Sandbox Code Playgroud)

将解决我的目的。

hkc*_*rex 5

你应该使用(2)。对于所有类型的插值(1D、2D、3D),在第一维和第二维(分别是批次和通道)中没有通信,而它们应该是这样。

简单的例子:

import torch
import torch.nn.functional as F

b = 2
c = 4
h = w = 8

a = torch.randn((b, c, h, w))
a_upsample = F.interpolate(a, [h*2, w*2], mode='bilinear', align_corners=True)

a_mod = a.clone()
a_mod[:, 0] *= 1000
a_mod_upsample = F.interpolate(a_mod, [h*2, w*2], mode='bilinear', align_corners=True)

print(torch.isclose(a_upsample[:,0], a_mod_upsample[:,0]).all())
print(torch.isclose(a_upsample[:,1], a_mod_upsample[:,1]).all())
print(torch.isclose(a_upsample[:,2], a_mod_upsample[:,2]).all())
print(torch.isclose(a_upsample[:,3], a_mod_upsample[:,3]).all())
Run Code Online (Sandbox Code Playgroud)

输出:

tensor(False)
tensor(True)
tensor(True)
tensor(True)
Run Code Online (Sandbox Code Playgroud)

可以看出,第一个通道中的较大变化对其他通道没有影响。