您能否反转 PyTorch 神经网络并激活输出中的输入?

Rea*_*lar 6 python neural-network pytorch

我们能否激活 NN 的输出以深入了解神经元如何连接到输入特征?

如果我从 PyTorch 教程中获取一个基本的 NN 示例。这是一个f(x,y)训练示例的示例。

import torch

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    model.zero_grad()
    loss.backward()
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad
Run Code Online (Sandbox Code Playgroud)

在我完成训练网络以y根据x输入进行预测之后。是否可以反转经过训练的 NN,以便它现在可以x根据y输入进行预测?

我不希望y匹配训练输出的原始输入y。所以我希望看到模型激活了哪些特征来匹配xy.

如果可能,那么如何在Sequential不破坏所有权重和连接的情况下重新排列模型?

a_g*_*est 7

这是可能的,但仅适用于非常特殊的情况。对于前馈网络 ( Sequential),每一层都需要是可逆的;这意味着以下参数分别适用于每一层。与一层相关的变换是y = activation(W*x + b)其中W是权重矩阵和b偏置向量。为了解决这个问题,x我们需要执行以下步骤:

  1. 撤销activation; 但并非所有激活函数都有反函数。例如,该ReLU函数在 上没有反函数(-inf, 0)tanh另一方面,如果我们使用它,我们可以使用它的逆,即0.5 * log((1 + x) / (1 - x))
  2. 求解;W*x = inverse_activation(y) - bx要存在唯一的解决方案,W必须具有相似的行和列秩并且det(W)必须非零。我们可以通过选择特定的网络架构来控制前者,而后者则取决于训练过程。

因此,对于可逆的神经网络,它必须具有非常具体的架构:所有层必须具有相同数量的输入和输出神经元(即平方权重矩阵),并且激活函数都需要是可逆的。

代码:使用 PyTorch,我们必须手动进行网络反演,无论是求解线性方程组还是找到逆激活函数。考虑以下 1 层神经网络的示例(因为这些步骤分别适用于每一层,因此将其扩展到 1 层以上是微不足道的):

import torch

N = 10  # number of samples
n = 3   # number of neurons per layer

x = torch.randn(N, n)

model = torch.nn.Sequential(
    torch.nn.Linear(n, n), torch.nn.Tanh()
)

y = model(x)

z = y  # use 'z' for the reverse result, start with the model's output 'y'.
for step in list(model.children())[::-1]:
    if isinstance(step, torch.nn.Linear):
        z = z - step.bias[None, ...]
        z = z[..., None]  # 'torch.solve' requires N column vectors (i.e. shape (N, n, 1)).
        z = torch.solve(z, step.weight)[0]
        z = torch.squeeze(z)  # remove the extra dimension that we've added for 'torch.solve'.
    elif isinstance(step, torch.nn.Tanh):
        z = 0.5 * torch.log((1 + z) / (1 - z))

print('Agreement between x and z: ', torch.dist(x, z))
Run Code Online (Sandbox Code Playgroud)