我正在研究一种被提议用于视频分类的变压器模型。我的输入张量的形状为 [batch=16 ,channels=3 ,frames=16, H=224, W=224] ,为了在输入张量上应用补丁嵌入,它使用以下场景:
patch_dim = in_channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim), ***** (Root of the error)******
)
Run Code Online (Sandbox Code Playgroud)
我使用的参数如下:
patch_size =16
dim = 192
in_channels = 3
Run Code Online (Sandbox Code Playgroud)
不幸的是,我收到与代码中显示的行相对应的以下错误:
Exception has occured: RuntimeError
mat1 and mat2 shapes cannot be multiplied (9408x4096 and 768x192)
Run Code Online (Sandbox Code Playgroud)
我想了很多错误的原因,但我无法找出原因是什么。我该如何解决这个问题?