如何计算参数的谱范数?

vai*_*ijr 3 python pytorch

当我做,

import torch, torch.nn as nn
x = nn.Linear(3, 3)
y = torch.nn.utils.spectral_norm(x)
Run Code Online (Sandbox Code Playgroud)

然后它给出四个不同的权重矩阵,

y.weight_u

tensor([ 0.6534, -0.1644,  0.7390])
Run Code Online (Sandbox Code Playgroud)

y.weight_orig

Parameter containing:
tensor([[ 0.2538,  0.3196,  0.3380],
        [ 0.4946,  0.0519,  0.1022],
        [-0.5549, -0.0401,  0.1654]], requires_grad=True)
Run Code Online (Sandbox Code Playgroud)

y.weight_v

tensor([-0.3650,  0.2870,  0.8857])
Run Code Online (Sandbox Code Playgroud)

y.weight

tensor([[ 0.5556,  0.6997,  0.7399],
        [ 1.0827,  0.1137,  0.2237],
        [-1.2149, -0.0878,  0.3622]], grad_fn=<DivBackward0>)
Run Code Online (Sandbox Code Playgroud)

这四个矩阵是如何计算的?

jod*_*dag 5

我刚刚阅读了有关此方法的论文,该论文可以在 arxiv 上找到。如果您有适当的数学背景,我建议您阅读它。有关描述 u 和 v 是什么的幂算法,请参见附录 A。

也就是说,我将在这里进行总结。

首先,您应该知道矩阵的谱范数是最大奇异值。作者建议找到权重矩阵 的谱范数W,然后除以W它的谱范数以使其接近1(该决定的理由在论文中)。

虽然我们可以torch.svd用来找到奇异值的精确估计,但他们使用一种称为“幂迭代”的快速(但不精确)方法。长话短说,在weight_uweight_v是对应于W的最大奇异值的左,右奇异向量他们是有用的粗略近似值,因为相关的奇异值,即频谱规范,W是等于u.transpose(1,0) @ W @ v,如果uv是实际的左/right 的奇异向量W

  • y.weight_orig 包含图层中的原始值。
  • y.weight_u是 的第一个左奇异向量的近似值y.weight_orig
  • y.weight_v是 的第一个右奇异向量的近似值y.weight_orig
  • y.weight是更新后的权重矩阵,y.weight_orig除以它的近似谱范数。

我们可以通过证明实际的左右奇异向量几乎平行于y.weight_uy.weight_v

import torch
import torch.nn as nn

# pytorch default is 1
n_power_iterations = 1

y = nn.Linear(3,3)
y = nn.utils.spectral_norm(y, n_power_iterations=n_power_iterations)

# spectral normalization is performed during forward pre hook for technical reasons, we
# need to send something through the layer to ensure normalization is applied
# NOTE: After this is performed, x.weight is changed in place!
_ = y(torch.randn(1,3))

# test svd vs. spectral_norm u/v estimates
u,s,v = torch.svd(y.weight_orig)
cos_err_u = 1.0 - torch.abs(torch.dot(y.weight_u, u[:, 0])).item()
cos_err_v = 1.0 - torch.abs(torch.dot(y.weight_v, v[:, 0])).item()
print('u-estimate cosine error:', cos_err_u)
print('v-estimate cosine error:', cos_err_v)

# singular values
actual_orig_sn = s[0].item()
approx_orig_sn = (y.weight_u @ y.weight_orig @ y.weight_v).item()
print('Actual original spectral norm:', actual_orig_sn)
print('Approximate original spectral norm:', approx_orig_sn)

# updated weights singular values
u,s_new,v = torch.svd(y.weight.data, compute_uv=False)
actual_sn = s_new[0].item()
print('Actual updated spectral norm:', actual_sn)
print('Desired updated spectral norm: 1.0')
Run Code Online (Sandbox Code Playgroud)

这导致

u-estimate cosine error: 0.00764310359954834
v-estimate cosine error: 0.034041762351989746
Actual original spectral norm: 0.8086231350898743
Approximate original spectral norm: 0.7871124148368835
Actual updated spectral norm: 1.0273288488388062
Desired updated spectral norm: 1.0
Run Code Online (Sandbox Code Playgroud)

增加n_power_iterations参数将以计算时间为代价增加估计的准确性。