我正在尝试计算两个火炬张量的均方根误差。我想忽略/屏蔽标签为 0(缺失值)的行。我如何修改这一行以考虑该限制?
torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()
先感谢您。
pytorch
pytorch ×1