我在下面编写了以下代码,但出现此错误:
TypeError: backward() got an unexpected keyword argument 'retain_variables'
Run Code Online (Sandbox Code Playgroud)
我的代码是:
def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
next_outputs = self.model(batch_next_state).detach().max(1)[0]
target = self.gamma*next_outputs + batch_reward
td_loss = F.smooth_l1_loss(outputs, target)
self.optimizer.zero_grad()
td_loss.backward(retain_variables = True)
self.optimizer.step()
Run Code Online (Sandbox Code Playgroud)
小智 10
我遇到了同样的问题。这个解决方案对我有用。
td_loss.backward(retain_graph = True)
Run Code Online (Sandbox Code Playgroud)
有效。
归档时间: |
|
查看次数: |
3075 次 |
最近记录: |