log_prob 有什么作用?

cer*_*rou 15 probability-distribution pytorch

在一些(例如机器学习)库中,我们可以找到log_prob函数。它有什么作用,它与常规服用有何不同log有何?

例如,这段代码的目的是什么:

dist = Normal(mean, std)
sample = dist.sample()
logprob = dist.log_prob(sample)
Run Code Online (Sandbox Code Playgroud)

随后,为什么我们要先取一个日志,然后对结果值取幂,而不是直接评估它:

prob = torch.exp(dist.log_prob(sample))
Run Code Online (Sandbox Code Playgroud)

lum*_*uri 14

log_prob取(某些动作)概率的对数。例子:

action_logits = torch.rand(5)
action_probs = F.softmax(action_logits, dim=-1)
action_probs
Run Code Online (Sandbox Code Playgroud)

返回:

张量([0.1457, 0.2831, 0.1569, 0.2221, 0.1922])

然后:

dist = Categorical(action_probs)
action = dist.sample()
print(dist.log_prob(action), torch.log(action_probs[action]))
Run Code Online (Sandbox Code Playgroud)

返回:

张量(-1.8519) 张量(-1.8519)


use*_*967 6

正如您自己的答案所提到的,log_prob返回密度或概率的对数。在这里,我将解决您问题中的其余要点:

  • 这与 有log什么不同?分发没有方法log。如果他们这样做了,最接近的可能解释确实是这样的,log_prob但它不会是一个非常精确的名称,因为如果引出了“什么的日志”的问题?分布具有多个数字属性(例如其均值、期望等),概率或密度只是其中之一,因此名称会不明确。

这同样不适用于该Tensor.log()方法(这可能是您想到的),因为Tensor它本身就是一个我们可以取对数的数学量。

  • 为什么只取概率的对数来对它取幂?您以后可能不需要对它求幂。例如,如果您有概率日志pq,那么您可以直接计算log(p * q)log(p) + log(q),避免中间取幂。这在数值上更稳定(避免下溢),因为概率可能会变得非常接近于零,而它们的日志却没有。一般来说,加法也比乘法更有效,它的导数也更简单。https://en.wikipedia.org/wiki/Log_probability 上有一篇关于这些主题的好文章。


Lib*_*ang 6

回答

logprob = dist.log_prob(sample)表示得到一个实验样本( )在特定分布( )下的对数概率( )。logprobsampledist

(理解起来很困难,需要一段时间才能理解下面的解释。)

解释

(我们用一个简单的例子来理解它的作用是什么log_prob?)

正向测试

首先,使用 中的均匀分布生成概率a[0, 1]

import torch.distributions as D
import torch

a = torch.empty(1).uniform_(0, 1)
a
# OUTPUT: tensor([0.3291])
Run Code Online (Sandbox Code Playgroud)

基于这个概率 和D.Bernoulli,我们可以实例化伯努利分布 b=D.Bernoulli(a)(这意味着每个伯努利实验的结果 ,b.sample()要么1具有概率a=0.3291 ,要么0具有概率1-a=0.6709),

b = D.Bernoulli(a)
b
# OUTPUT: Bernoulli()
Run Code Online (Sandbox Code Playgroud)

我们可以通过一个伯努利实验来验证这一点,以获得一个样本 c(保持c概率0.32911,而0.6709概率为0),

c = b.sample()
c
# OUTPUT: tensor([0.])
Run Code Online (Sandbox Code Playgroud)

利用伯努利分布b和样本c,我们可以得到(伯努利实验样本)在分布(0.3291 为 TRUE 的特定伯努利分布)下的对数概率cb,(或者正式地,计算的概率密度/质量函数的对数)值( c) )

b.log_prob(c)
b 
# OUTPUT: tensor([-0.3991])
Run Code Online (Sandbox Code Playgroud)

向后验证

我们已经知道,每个样本的概率0(对于一个实验,概率可以简单地视为其概率密度/质量函数)为0.6709,因此我们可以log_prob通过以下方式验证结果:

torch.log(torch.tensor(0.6709))
# OUTPUT: tensor(-0.3991)
Run Code Online (Sandbox Code Playgroud)

c它等于under的对数概率b。(完成的!)

希望它对您有用。


cer*_*rou 5

部分答案是log_prob返回在给定样本值处评估的概率密度/质量函数的对数。