计算 rmse 时 Pytorch 掩码缺失值

raz*_*c92 4 pytorch

我正在尝试计算两个火炬张量的均方根误差。我想忽略/屏蔽标签为 0(缺失值)的行。我如何修改这一行以考虑该限制?

torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()
Run Code Online (Sandbox Code Playgroud)

先感谢您。

prl*_*900 7

这可以通过定义自定义 MSE 损失函数* 来解决,该函数从输入张量和目标张量中屏蔽掉缺失值(在您的情况下为 0):

def mse_loss_with_nans(input, target):

    # Missing data are nan's
    # mask = torch.isnan(target)

    # Missing data are 0's
    mask = target == 0

    out = (input[~mask]-target[~mask])**2
    loss = out.mean()

    return loss
Run Code Online (Sandbox Code Playgroud)

(*)从优化的角度来看,计算 MSE 相当于 RMSE —— 优点是计算速度更快。