embedding_bag 在 PyTorch 中的具体工作原理

nil*_*rif 1 embedding python-embedding neural-network torch pytorch

在 PyTorch 中,torch.nn.function.embedding_bag 似乎是负责执行嵌入查找的实际工作的主要函数。在 PyTorch 的文档中,提到 embedding_bag 可以完成其工作 > 无需实例化中间嵌入。这究竟意味着什么?这是否意味着例如当模式为“sum”时它会进行就地求和?或者它只是意味着在调用 embedding_bag 时不会产生额外的张量,但仍然从系统的角度来看,所有中间行向量已经被提取到处理器中以用于计算最终张量?

小智 6

在最简单的情况下,torch.nn.functional.embedding_bag从概念上讲是一个两步过程。第一步是创建嵌入,第二步是减少(总和/平均值/最大值,根据“模式”参数)跨维度 0 的嵌入输出。因此,您可以通过调用 得到与 embedding_bag 给出的相同结果torch.nn.functional.embedding,其次是torch.sum/mean/max。在下面的示例中,embedding_bag_resembedding_mean_res是相等的。

>>> weight = torch.randn(3, 4)
>>> weight
tensor([[ 0.3987,  1.6173,  0.4912,  1.5001],
        [ 0.2418,  1.5810, -1.3191,  0.0081],
        [ 0.0931,  0.4102,  0.3003,  0.2288]])
>>> indices = torch.tensor([2, 1])
>>> embedding_res = torch.nn.functional.embedding(indices, weight)
>>> embedding_res
tensor([[ 0.0931,  0.4102,  0.3003,  0.2288],
        [ 0.2418,  1.5810, -1.3191,  0.0081]])
>>> embedding_mean_res = embedding_res.mean(dim=0, keepdim=True)
>>> embedding_mean_res
tensor([[ 0.1674,  0.9956, -0.5094,  0.1185]])
>>> embedding_bag_res = torch.nn.functional.embedding_bag(indices, weight, torch.tensor([0]), mode='mean')
>>> embedding_bag_res
tensor([[ 0.1674,  0.9956, -0.5094,  0.1185]])
Run Code Online (Sandbox Code Playgroud)

然而,概念性​​的两步过程并不能反映它的实际实施方式。由于embedding_bag不需要返回中间结果,因此它实际上不会生成用于嵌入的 Tensor 对象。它只是直接计算归约,weight根据input参数中的索引从参数中提取适当的数据。避免创建嵌入张量可以实现更好的性能。

所以你的问题的答案(如果我理解正确的话)

这只是意味着在调用 embedding_bag 时不会产生额外的张量,但从系统的角度来看,所有中间行向量都已经被提取到处理器中以用于计算最终张量?

是是的。