将模型从 3 通道 (RGB) 重新训练为 4 通道 (RGBA),我可以使用 3 通道权重吗?

Ric*_* De 1 python rgba pytorch

我需要将模型从 RGB 扩展到 RGBA。我可以处理模型上的代码重写,但我不想从头开始重新训练整个模型,而是希望从 3 个通道权重 + 零开始。

有没有简单的方法可以将手电筒的 3 个通道权重保存更改为 4 个?

jod*_*dag 5

是的,你可以做一点“模型手术”。假设模型的输入仅由卷积层直接处理,那么您可以将该卷积层替换为另一个设置为 的in_channels4。然后,您可以将权重设置为零并从原始转换层复制旧权重(以及偏差,如果适用)。

例如,假设我们有一个如下所示的简单模型

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
        self.linear = nn.Linear(125, 1)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return self.linear(x.flatten(start_dim=1))

model = SimpleModel()
Run Code Online (Sandbox Code Playgroud)

假设此时模型已完成训练,我们可以按如下方式执行手术

y_rgb = torch.randn(1, 3, 5, 5)

# get performance on initial z_rgb
z_rgb = model(y_rgb)

# perform model surgery
with torch.no_grad():
    new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
    new_conv1.weight.zero_()
    new_conv1.weight[:,:3,...]=model.conv1.weight
    new_conv1.bias.copy_(model.conv1.bias)
    model.conv1 = new_conv1

# add a random alpha channel to y_rgba
y_alpha = torch.randn(1,1,5,5)
y_rgba = torch.cat([y_rgb, y_alpha], dim=1)

# get results on rgba model
z_rgba = model(y_rgba)

# compare z_rgb and z_rgba, print mean-square difference
z_err = ((z_rgba-z_rgb)**2).mean().item()
print('Err:', z_err)

# save results to a new file
torch.save(model.state_dict(), 'checkpoint_rgba.pt')
Run Code Online (Sandbox Code Playgroud)

这应该给你零或非常接近于零的误差。

当然,如果您的第一个转换层中没有bias术语,则无需复制该术语。

假设您已经保存了新的状态字典,那么您可能需要更新模型类定义,以便您的输入卷积层采用 4 个通道输入而不是 3 个。然后下次您可以直接加载新的状态字典,而无需额外的步骤。


现在,直接在模型上进行手术并不是绝对必要的。尽管我更喜欢它,因为我发现它更容易验证正确性。

假设您保存了 RGB 模型的状态字典,您也可以直接修改状态字典。

# assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
state_dict = torch.load('checkpoint_rgb.pt')
old_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = torch.zeros(
    old_weight.shape[0],
    old_weight.shape[1]+1,
    old_weight.shape[2],
    old_weight.shape[3]
).type_as(old_weight)
state_dict['conv1.weight'][:,:3,...] = old_weight
torch.save(state_dict, 'checkpoint_rgba.pt')
Run Code Online (Sandbox Code Playgroud)