聚集函数在外行术语中对pytorch有什么作用?

ami*_*ion 26 pytorch

我已经通过官方文件这个,但很难理解发生了什么.

我试图理解DQN源代码,它使用第197行的收集功能.

有人可以用简单的术语解释聚集函数的作用吗?这个功能的目的是什么?

Rit*_*esh 59

torch.gather通过沿输入维度获取每一行的值,从输入张量创建一个新的张量dim.torch.LongTensor传入的值in ,index指定从每个'row'中取出的值.输出张量的维数与索引张量的维度相同.以下官方文档的说明更清楚地解释了: 来自文档的Pictoral表示

(注意:在图示中,索引从1而不是0开始).

在第一个例子中,给定尺寸为沿行(从上到下),所以对(1,1)的位置result,它需要从该行值indexsrc1.在(1,1)源值是1这样,输出1在(1,1)in result.类似地,对于(2,2)从用于索引的行值src3.在(3,2)中的值src8,因此输出8等.

类似地,对于第二个例子,索引是沿着列,并因此在所述的(2,2)位置result从索引,列值src3,所以在(2,3)从src,6取并输出到result在(2,2 )

  • 谢谢。这确实是“一幅画值得一千个单词”的例子 (5认同)
  • 这是最好的答案。谢谢。 (2认同)

cle*_*ros 45

torch.gather功能(或torch.Tensor.gather)是一种多指数选择方法.从官方文档中查看以下示例:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])
Run Code Online (Sandbox Code Playgroud)

让我们从遍历不同参数的语义开始:第一个参数input是我们想要从中选择元素的源张量.第二个dim是我们想要收集的尺寸(或tensorflow/numpy中的轴).最后,index是索引的指数input.至于操作的语义,这是官方文档解释它的方式:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
Run Code Online (Sandbox Code Playgroud)

让我们来看看这个例子.

输入张量是[[1, 2], [3, 4]],而昏暗的参数是1,即我们想从第二维收集.第二维的指数以[0, 0]和表示[1, 0].

当我们"跳过"第一个维度(我们想要收集的维度是1)时,结果的第一个维度被隐式地给出为第一个维度index.这意味着索引包含第二维或列索引,但不包含行索引.这些是由index张量本身的指数给出的.对于该示例,这意味着输出将在其第一行中具有input张量的第一行的元素的选择,如张量的第一行的第一行所给出的index.由于列索引由下式给出[0, 0],因此我们选择输入的第一行的第一个元素两次,得到[1, 1].类似地,结果的第二行input的元素是通过张量的第二行的元素索引张量的第二行的index结果,导致[4, 3].

为了进一步说明这一点,让我们在示例中交换维度:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])
Run Code Online (Sandbox Code Playgroud)

如您所见,现在沿第一维收集指数.

对于你提到的例子,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
Run Code Online (Sandbox Code Playgroud)

gather将通过批处理动作列表索引q值的行(即,一批q值中的每样本q值).结果将与您完成以下操作相同(尽管它比循环快得多):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
Run Code Online (Sandbox Code Playgroud)


Mar*_*mer 30

@Ritesh 和 @cleros 给出了很好的答案(有很多赞成票),但在阅读之后我仍然有点困惑,我知道为什么。这篇文章可能会帮助像我这样的人。

对于这些各种各样的行和列的练习,我认为这真的有助于使用非方形物体,让我们有更大的4x3的开始sourcetorch.Size([4, 3])使用)source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])。这会给我们

\\ This is the source tensor
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
Run Code Online (Sandbox Code Playgroud)

现在让我们开始沿列 ( dim=1) 和 create 建立索引index = torch.tensor([[0,0],[1,1],[2,2],[0,1]]),它是一个列表列表。这里的关键:因为我们的尺寸为列,而声源4行,则index必须包含4列表!我们需要每一行的列表。跑步source.gather(dim=1, index=index)会给我们

tensor([[ 1,  1],
        [ 5,  5],
        [ 9,  9],
        [10, 11]])
Run Code Online (Sandbox Code Playgroud)

因此,其中的每个列表都index为我们提供了从中提取值的列。index( [0,0])的第一个列表告诉我们要查看 的第一行source并取该行的第一列(它是零索引)两次,即[1,1]. index( [1,1])的第二个列表告诉我们要查看 的第二行source并取该行的第二列两次,即[5,5]。跳转到index( [0,1])的第 4 个列表,它要求我们查看 的第 4 行和最后一行source,要求我们取第 1 列 ( 10),然后取第 2 列 ( 11),它给出了[10,11]

这里有一件很棒的事情:您的每个列表index都必须具有相同的长度,但它们可以随您的需要而长!例如, with index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])source.gather(dim=1, index=index)会给我们

tensor([[ 1,  2,  3,  2,  1],
        [ 6,  5,  4,  5,  6],
        [ 8,  9,  7,  9,  8],
        [11, 10, 12, 10, 11]])
Run Code Online (Sandbox Code Playgroud)

输出将始终具有与 相同的行数source,但列数将等于 中每个列表的长度index。例如,index( [2,1,0,1,2])的第 2 个列表去 和 的第 2 行source,分别拉动第 3、2、1、2 和 3 项,即[6,5,4,5,6]。请注意,中每个元素的值index必须小于source(在本例中3)的列数,否则会出现out of bounds错误。

切换到dim=0,我们现在将使用行而不是列。使用相同的source,我们现在需要一个index其中每个列表的长度等于source. 为什么?因为列表中的每个元素都代表source我们逐列移动时的行。

因此,index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])将有然后source.gather(dim=0, index=index)给我们

tensor([[ 1,  2,  3],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])
Run Code Online (Sandbox Code Playgroud)

查看index( [0,0,0])中的第一个列表,我们可以看到我们正在 3 列中移动,source选择每列的第一个元素(它的零索引),即[1,2,3]. index( [0,1,2]) 中的第二个列表告诉我们在分别取第 1、第 2 和第 3 项的列中移动,即[1,5,9]。等等。

随着dim=1我们index必须有一批名单等于行数source,但只要或短,随你喜欢每个列表可能。使用dim=0,我们中的每个列表index的长度必须与 中的列数相同source,但我们现在可以拥有任意数量的列表。index但是, 中的每个值都需要小于中的行数source(在本例中4)。

例如,index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])就会source.gather(dim=0, index=index)给我们

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])
Run Code Online (Sandbox Code Playgroud)

dim=1输出总是具有相同的行数为的source,虽然列数将等于在列表的长度index。中的列表数index必须等于 中的行数sourceindex但是, 中的每个值都需要小于 中的列数source

随着dim=0输出始终具有相同数量的列source,但行数将等于列表的数量index。中每个列表的长度index必须等于 中的列数sourceindex但是, 中的每个值都需要小于 中的行数source

二维就是这样。超越这一点将遵循相同的模式。

  • 很棒的答案。您对每个维度所需内容的描述帮助我更轻松地可视化操作。 (2认同)
  • 太棒了。对我来说,关键确实是你用粗体指出的。使用非方阵也非常有帮助。非常感激! (2认同)
  • 最佳答案 国际海事组织 (2认同)

Hon*_*iao 22

这是基于@Ritesh 的回答(感谢@Ritesh!)和一些真实的代码。

\n

APItorch.gather

\n

torch.gather(input, dim, index, *, sparse_grad=False, out=None) \xe2\x86\x92 Tensor

\n

实施例1

\n

什么时候dim = 0

\n

在此输入图像描述

\n
dim = 0\ninput = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])\nindex = torch.tensor([[0, 1, 2], [1, 2, 0]])\n\noutput = torch.gather(input, dim, index)\n# tensor([[10, 14, 18],\n#         [13, 17, 12]])\n
Run Code Online (Sandbox Code Playgroud)\n

实施例2

\n

什么时候dim = 1

\n

在此输入图像描述

\n
dim = 1\ninput = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])\nindex = torch.tensor([[0, 1], [1, 2], [2, 0]])\n\noutput = torch.gather(input, dim, index)\n# tensor([[10, 11],\n#         [14, 15],\n#         [18, 16]])\n
Run Code Online (Sandbox Code Playgroud)\n