小编Yil*_* L.的帖子

使用 .detach() 的 Pytorch DQN、DDQN 造成了非常大的损失(呈指数级增长)并且根本不学习

这是我对 CartPole-v0 的 DQN 和 DDQN 的实现,我认为这是正确的。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.optim as optim
import random
import os
import time


class NETWORK(torch.nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int) -> None:

        super(NETWORK, self).__init__()

        self.layer1 = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU()
        )

        self.layer2 = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU()
        )

        self.final = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.final(x) …
Run Code Online (Sandbox Code Playgroud)

reinforcement-learning q-learning pytorch dqn

2
推荐指数
1
解决办法
1081
查看次数

标签 统计

dqn ×1

pytorch ×1

q-learning ×1

reinforcement-learning ×1