Pytorch 中的批量矩阵乘法 - 与输出维度的处理相混淆

Bil*_*Kid 7 python vectorization batch-processing matrix-multiplication pytorch

我有两个数组:

A
B
Run Code Online (Sandbox Code Playgroud)

数组A包含一批 RGB 图像,形状为:

[batch, Width, Height, 3]
Run Code Online (Sandbox Code Playgroud)

而 ArrayB包含对图像进行“类转换”操作所需的系数,其形状为:

[batch, 4, 4, 3]
Run Code Online (Sandbox Code Playgroud)

简单来说,对单个图像的运算是乘法,输出环境图(normalMap * Coefficients)。

我想要的输出应该保持形状:

[batch, Width, Height, 3]
Run Code Online (Sandbox Code Playgroud)

我尝试使用torch.bmm但失败了。这有可能吗?

pro*_*sti 5

我认为你需要计算 PyTorch 与

BxCxHxW : number of mini-batches, channels, height, width
Run Code Online (Sandbox Code Playgroud)

格式,也可以使用matmul,因为bmm适用于张量或 ndim/dim/rank =3。

我知道你可能会在网上找到这个,但无论如何:

batch1 = torch.randn(10, 3, 20, 10)
batch2 = torch.randn(10, 3, 10, 30)
res = torch.matmul(batch1, batch2)
res.size() # torch.Size([10, 3, 20, 30])
Run Code Online (Sandbox Code Playgroud)