Iul*_*uli 1 artificial-intelligence python-3.x pytorch
我正在尝试使用 PyTorch 制作 AI,但出现以下错误:
RuntimeError: gather_out_cpu(): Expected dtype int64 for index
Run Code Online (Sandbox Code Playgroud)
这是我的功能:
def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
next_outputs = self.model(batch_next_state).detach().max(1)[0]
target = self.gamma * next_outputs + batch_reward
td_loss = F.smooth_l1_loss(outputs, target)
self.optimizer.zero_grad()
td_loss.backward(retain_variables = True)
self.optimizer.step()
Run Code Online (Sandbox Code Playgroud)
您需要先更改batch_action张量的数据类型,然后再将其传递给torch.gather.
def learn(...):
batch_action = batch_action.type(torch.int64)
outputs = ...
...
# or
outputs = self.model(batch_state).gather(1, batch_action.type(torch.int64).unsqueeze(1)).squeeze(1)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
10130 次 |
| 最近记录: |