Dan*_*Dan 5 python neural-network pytorch automl
我想在两个损失函数或 nn.Modules 或 python 对象之间动态应用数学运算。在pytorch中生成动态图也可以被视为一个问题。
\n例如:在下面的例子中,我想添加两个损失函数。
\nnn.L1Loss() + nn.CosineEmbeddingLoss()\nRun Code Online (Sandbox Code Playgroud)\n如果我这样做,它会给我一个错误:
\n----> 1 nn.L1Loss() + nn.CosineEmbeddingLoss()\nTypeError: unsupported operand type(s) for +: 'L1Loss' and 'CosineEmbeddingLoss'\nRun Code Online (Sandbox Code Playgroud)\n我还尝试创建一个具有转发功能和火炬操作的包装器,如下所示,但它也不起作用。在下面的情况下x, 和y可以是任何损失函数,也op可以是任何数学运算,例如加法、减法等。
class Execute_Op(nn.Module):\n def __init__(self):\n super().__init__()\n \n def forward(self, x, y, op):\n if op == 'add':\n return torch.add(x, y)\n elif op == 'subtract':\n return torch.subtract(x - y)\n\nexec_op = Execute_Op()\nexec_op(nn.L1Loss(), nn.CosineEmbeddingLoss(), 'add')\nRun Code Online (Sandbox Code Playgroud)\n它给出如下错误:
\nExecute_Op.forward(self, x, y, op)\n 5 def forward(self, x, y, op):\n 6 if op == 'add':\n----> 7 return torch.add(x, y)\n 8 elif op == 'subtract':\n 9 return torch.subtract(x - y)\n\nTypeError: add(): argument 'input' (position 1) must be Tensor, not L1Loss\nRun Code Online (Sandbox Code Playgroud)\n我了解函数式 API 以及将真值和预测值传递给损失函数的一般方法。但在这种情况下,我无法在运行时动态组合损失函数。
\n我不确定具体如何实施。但非常感谢任何帮助。\n此外,如果有一种 Pythonic 方式或 Pytorch 方式来做到这一点,那就太好了。
\n编辑:
\n这里的问题是损失函数没有标准的调用约定。不同的损失函数采用具有不同语义的不同参数。例如,nn.L1Loss()采用两个参数:相同形状的输入和目标。另一方面nn.CosineEmbeddingLoss()需要三个参数:两个相同形状的输入和一个不同形状的目标。
这意味着 的调用约定ExecuteOp将取决于所选的特定损失函数。出于这个原因,我认为将损失函数变得通用并不是一个好主意。
没有额外的泛化级别的简单解决方案是
class ExecuteOp(nn.Module):
def init(self):
super().__init__()
self.l1_loss = nn.L1Loss()
self.emb_loss = nn.CosineEmbeddingLoss()
def forward(self, op, input_l1, target_l1, input1_emb, input2_emb, target_emb):
assert op in ('add', 'subtract')
loss1 = self.l1_loss(input_l1, target_l1)
loss2 = self.emb_loss(input1_emb, input2_emb, target_emb)
if op == 'add':
return loss1 + loss2
return loss1 - loss2
Run Code Online (Sandbox Code Playgroud)
然后你可以按如下方式使用它:
exec_op = ExecuteOp()
...
# at training/val time, assuming inputs and targets are provided from your model and dataloader
loss = exec_op('add', input_l1, target_l1, input1_emb, input2_emb, target_emb)
Run Code Online (Sandbox Code Playgroud)
exec_op考虑一下如果它是一般的 wrt 损失函数,您将如何在运行时调用。你会提供什么论据,以什么顺序?