sample() 和 rsample() 有什么区别?

vai*_*ijr 15 python random pytorch

当我在PyTorch分布样,都samplersample似乎给了类似的结果:

import torch, seaborn as sns

x = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
Run Code Online (Sandbox Code Playgroud)
在此处输入图片说明 在此处输入图片说明
sns.distplot(x.sample((100000,))) sns.distplot(x.rsample((100000,)))

什么时候用sample(),什么时候用rsample()

sta*_*iet 23

sample():从概率分布中随机抽样。所以,我们不能反向传播,因为它是随机的!(计算图被截断)。

sample查看源代码torch.distributions.normal.Normal

def sample(self, sample_shape=torch.Size()):
    shape = self._extended_shape(sample_shape)
    with torch.no_grad():
        return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
Run Code Online (Sandbox Code Playgroud)

torch.normal返回随机数张量。此外,torch.no_grad()上下文可以防止计算图进一步增长。

你看,我们不能反向传播。返回的张量sample()仅包含一些数字,而不是整个计算图。


那么,什么是rsample()

通过使用rsample,我们可以反向传播,因为它使计算图保持活动状态。

如何?通过将随机性放在单独的参数中。这称为“重新参数化技巧”。

r样本:使用r参数化技巧进行采样。

eps源码中有:

def rsample(self, sample_shape=torch.Size()):
    shape = self._extended_shape(sample_shape)
    eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
    return self.loc + eps * self.scale

    # `self.loc` is the mean and `self.scale` is the standard deviation.
Run Code Online (Sandbox Code Playgroud)

eps是负责采样随机性的单独参数。

看返回值:平均值+ eps*标准差

eps依赖于您想要区分的参数。

所以,现在你可以自由地反向传播(=微分),因为eps参数改变时不会改变。

(如果我们改变参数,重新参数化样本的分布因为self.locself.scale改变而改变,但 的分布eps不会改变。)

请注意,采样的随机性来自于 的随机采样eps。计算图本身存在随机性。一旦eps选定,就固定下来了。eps(采样后,元素的分布是固定的。)

例如,在强化学习中的 SAC(Soft Actor-Critic)算法的实现中,eps可能由与单个小批量动作相对应的元素组成(并且一个动作可能由许多元素组成)。


Sha*_*hai 11

使用rsample允许路径导数

实现这些随机/策略梯度的另一种方法是使用该rsample()方法中的重新参数化技巧,其中参数化随机变量可以通过无参数随机变量的参数化确定性函数构建。因此,重新参数化的样本变得可微。

  • 变分自动编码器使用重新参数化技巧来允许通过均值和标准差层进行反向传播。 (8认同)
  • 你能举一个不涉及强化学习的例子吗? (3认同)