Rit*_*esh 59
torch.gather
通过沿输入维度获取每一行的值,从输入张量创建一个新的张量dim
.torch.LongTensor
传入的值in ,index
指定从每个'row'中取出的值.输出张量的维数与索引张量的维度相同.以下官方文档的说明更清楚地解释了:
(注意:在图示中,索引从1而不是0开始).
在第一个例子中,给定尺寸为沿行(从上到下),所以对(1,1)的位置result
,它需要从该行值index
的src
是1
.在(1,1)源值是1
这样,输出1
在(1,1)in result
.类似地,对于(2,2)从用于索引的行值src
是 3
.在(3,2)中的值src
是8
,因此输出8
等.
类似地,对于第二个例子,索引是沿着列,并因此在所述的(2,2)位置result
从索引,列值src
是3
,所以在(2,3)从src
,6
取并输出到result
在(2,2 )
cle*_*ros 45
的torch.gather
功能(或torch.Tensor.gather
)是一种多指数选择方法.从官方文档中查看以下示例:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 1],
# [ 4, 3]])
Run Code Online (Sandbox Code Playgroud)
让我们从遍历不同参数的语义开始:第一个参数input
是我们想要从中选择元素的源张量.第二个dim
是我们想要收集的尺寸(或tensorflow/numpy中的轴).最后,index
是索引的指数input
.至于操作的语义,这是官方文档解释它的方式:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Run Code Online (Sandbox Code Playgroud)
让我们来看看这个例子.
输入张量是[[1, 2], [3, 4]]
,而昏暗的参数是1
,即我们想从第二维收集.第二维的指数以[0, 0]
和表示[1, 0]
.
当我们"跳过"第一个维度(我们想要收集的维度是1
)时,结果的第一个维度被隐式地给出为第一个维度index
.这意味着索引包含第二维或列索引,但不包含行索引.这些是由index
张量本身的指数给出的.对于该示例,这意味着输出将在其第一行中具有input
张量的第一行的元素的选择,如张量的第一行的第一行所给出的index
.由于列索引由下式给出[0, 0]
,因此我们选择输入的第一行的第一个元素两次,得到[1, 1]
.类似地,结果的第二行input
的元素是通过张量的第二行的元素索引张量的第二行的index
结果,导致[4, 3]
.
为了进一步说明这一点,让我们在示例中交换维度:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 2],
# [ 3, 2]])
Run Code Online (Sandbox Code Playgroud)
如您所见,现在沿第一维收集指数.
对于你提到的例子,
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
Run Code Online (Sandbox Code Playgroud)
gather
将通过批处理动作列表索引q值的行(即,一批q值中的每样本q值).结果将与您完成以下操作相同(尽管它比循环快得多):
q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
Run Code Online (Sandbox Code Playgroud)
Mar*_*mer 30
@Ritesh 和 @cleros 给出了很好的答案(有很多赞成票),但在阅读之后我仍然有点困惑,我知道为什么。这篇文章可能会帮助像我这样的人。
对于这些各种各样的行和列的练习,我认为这真的有助于使用非方形物体,让我们有更大的4x3的开始source
(torch.Size([4, 3])
使用)source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
。这会给我们
\\ This is the source tensor
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
Run Code Online (Sandbox Code Playgroud)
现在让我们开始沿列 ( dim=1
) 和 create 建立索引index = torch.tensor([[0,0],[1,1],[2,2],[0,1]])
,它是一个列表列表。这里的关键:因为我们的尺寸为列,而声源4
行,则index
必须包含4
列表!我们需要每一行的列表。跑步source.gather(dim=1, index=index)
会给我们
tensor([[ 1, 1],
[ 5, 5],
[ 9, 9],
[10, 11]])
Run Code Online (Sandbox Code Playgroud)
因此,其中的每个列表都index
为我们提供了从中提取值的列。index
( [0,0]
)的第一个列表告诉我们要查看 的第一行source
并取该行的第一列(它是零索引)两次,即[1,1]
. index
( [1,1]
)的第二个列表告诉我们要查看 的第二行source
并取该行的第二列两次,即[5,5]
。跳转到index
( [0,1]
)的第 4 个列表,它要求我们查看 的第 4 行和最后一行source
,要求我们取第 1 列 ( 10
),然后取第 2 列 ( 11
),它给出了[10,11]
。
这里有一件很棒的事情:您的每个列表index
都必须具有相同的长度,但它们可以随您的需要而长!例如, with index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])
,source.gather(dim=1, index=index)
会给我们
tensor([[ 1, 2, 3, 2, 1],
[ 6, 5, 4, 5, 6],
[ 8, 9, 7, 9, 8],
[11, 10, 12, 10, 11]])
Run Code Online (Sandbox Code Playgroud)
输出将始终具有与 相同的行数source
,但列数将等于 中每个列表的长度index
。例如,index
( [2,1,0,1,2]
)的第 2 个列表去 和 的第 2 行source
,分别拉动第 3、2、1、2 和 3 项,即[6,5,4,5,6]
。请注意,中每个元素的值index
必须小于source
(在本例中3
)的列数,否则会出现out of bounds
错误。
切换到dim=0
,我们现在将使用行而不是列。使用相同的source
,我们现在需要一个index
其中每个列表的长度等于source
. 为什么?因为列表中的每个元素都代表source
我们逐列移动时的行。
因此,index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])
将有然后source.gather(dim=0, index=index)
给我们
tensor([[ 1, 2, 3],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
Run Code Online (Sandbox Code Playgroud)
查看index
( [0,0,0]
)中的第一个列表,我们可以看到我们正在 3 列中移动,source
选择每列的第一个元素(它的零索引),即[1,2,3]
. index
( [0,1,2]
) 中的第二个列表告诉我们在分别取第 1、第 2 和第 3 项的列中移动,即[1,5,9]
。等等。
随着dim=1
我们index
必须有一批名单等于行数source
,但只要或短,随你喜欢每个列表可能。使用dim=0
,我们中的每个列表index
的长度必须与 中的列数相同source
,但我们现在可以拥有任意数量的列表。index
但是, 中的每个值都需要小于中的行数source
(在本例中4
)。
例如,index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])
就会source.gather(dim=0, index=index)
给我们
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
Run Code Online (Sandbox Code Playgroud)
与dim=1
输出总是具有相同的行数为的source
,虽然列数将等于在列表的长度index
。中的列表数index
必须等于 中的行数source
。index
但是, 中的每个值都需要小于 中的列数source
。
随着dim=0
输出始终具有相同数量的列source
,但行数将等于列表的数量index
。中每个列表的长度index
必须等于 中的列数source
。index
但是, 中的每个值都需要小于 中的行数source
。
二维就是这样。超越这一点将遵循相同的模式。
Hon*_*iao 22
这是基于@Ritesh 的回答(感谢@Ritesh!)和一些真实的代码。
\nAPItorch.gather
是
torch.gather(input, dim, index, *, sparse_grad=False, out=None) \xe2\x86\x92 Tensor
实施例1
\n什么时候dim = 0
,
dim = 0\ninput = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])\nindex = torch.tensor([[0, 1, 2], [1, 2, 0]])\n\noutput = torch.gather(input, dim, index)\n# tensor([[10, 14, 18],\n# [13, 17, 12]])\n
Run Code Online (Sandbox Code Playgroud)\n实施例2
\n什么时候dim = 1
,
dim = 1\ninput = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])\nindex = torch.tensor([[0, 1], [1, 2], [2, 0]])\n\noutput = torch.gather(input, dim, index)\n# tensor([[10, 11],\n# [14, 15],\n# [18, 16]])\n
Run Code Online (Sandbox Code Playgroud)\n
归档时间: |
|
查看次数: |
15543 次 |
最近记录: |