我正在尝试计算两个火炬张量的均方根误差。我想忽略/屏蔽标签为 0(缺失值)的行。我如何修改这一行以考虑该限制?
torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()
Run Code Online (Sandbox Code Playgroud)
先感谢您。
这可以通过定义自定义 MSE 损失函数* 来解决,该函数从输入张量和目标张量中屏蔽掉缺失值(在您的情况下为 0):
def mse_loss_with_nans(input, target):
# Missing data are nan's
# mask = torch.isnan(target)
# Missing data are 0's
mask = target == 0
out = (input[~mask]-target[~mask])**2
loss = out.mean()
return loss
Run Code Online (Sandbox Code Playgroud)
(*)从优化的角度来看,计算 MSE 相当于 RMSE —— 优点是计算速度更快。
归档时间: |
|
查看次数: |
1974 次 |
最近记录: |