zhe*_*ang 5 distribution pytorch
给定一个包含 N 个点的张量,用 [x,y] 表示,我想在每个点周围创建一个 2D 高斯分布,将它们绘制在一个空的特征图上。
例如,左图显示了一个给定点(在特征图上注册为像素,其值设置为 1)。右图在其周围添加了 2D 高斯分布。
我如何为每个点添加这样的分布?pytorch中有相关的API吗?
您可以使用MultivariateNormal多元正态样本进行采样。
>>> h, w = 200, 200
>>> fmap = torch.zeros(h, w)
Run Code Online (Sandbox Code Playgroud)
填充fmap原点:
>>> pts = torch.rand(20, 2)
>>> pts *= torch.tensor([h, w])
>>> x, y = pts.T.long()
>>> x, y = x.clip(0, h), y.clip(0, w)
>>> fmap[x, y] = 1
Run Code Online (Sandbox Code Playgroud)
接下来,我们可以从以下分布中采样(您可以相应地调整协方差矩阵):
>>> sampler = MultivariateNormal(pts.T, 10*torch.eye(len(pts)))
>>> for x in range(10):
... x, y = sampler.sample()
... x, y = x.clip(0, h).long(), y.clip(0, w).long()
... fmap[x, y] = 1
Run Code Online (Sandbox Code Playgroud)
结果,你可能会得到类似的结果:
这没有得到足够好的记录,但您可以将样本形状传递给函数sample。这允许您在每次调用时对多个点进行采样,即您只需要一个点即可填充画布。
这是一个可以从中提取的函数MultivariateNormal:
def multivariate_normal_sampler(mean, cov, k):
sampler = MultivariateNormal(mean, cov)
return sampler.sample((k,)).swapaxes(0,1).flatten(1)
Run Code Online (Sandbox Code Playgroud)
然后你可以将其称为:
>>> x, y = multivariate_normal_sampler(mean=pts.T, cov=50*torch.eye(len(pts)), k=1000)
Run Code Online (Sandbox Code Playgroud)
剪辑样本:
>>> x, y = x.clip(0, h-1).long(), y.clip(0, w-1).long()
Run Code Online (Sandbox Code Playgroud)
最后插入fmap并绘制:
>>> fmap[x, y] += .1
Run Code Online (Sandbox Code Playgroud)
这是一个预览示例:
效用函数可表示为 torch.distributions.multivariate_normal.MultivariateNormal
或者,您可以根据概率密度函数 ( pdf )计算密度值,而不是从正态分布中采样:
原点:
>>> h, w = 50, 50
>>> x0, y0 = torch.rand(2, 20)
>>> origins = torch.stack((x0*h, y0*w)).T
Run Code Online (Sandbox Code Playgroud)
定义高斯二维 pdf:
def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1):
return 1 / (2*math.pi*sx*sy) * \
torch.exp(-((x - mx)**2 / (2*sx**2) + (y - my)**2 / (2*sy**2)))
Run Code Online (Sandbox Code Playgroud)
构建网格并累积每个原点的高斯:
x = torch.linspace(0, h, h)
y = torch.linspace(0, w, w)
x, y = torch.meshgrid(x, y)
z = torch.zeros(h, w)
for x0, y0 in origins:
z += gaussian_2d(x, y, mx=x0, my=y0, sx=h/10, sy=w/10)
Run Code Online (Sandbox Code Playgroud)
绘制值网格的代码只需使用matplotlib.pyplot.pcolormesh: plt.pcolormesh(x, y, z)。
| 归档时间: |
|
| 查看次数: |
5349 次 |
| 最近记录: |