Deep Deterministic Policy Gradient (DDPG)
Overview
DDPG is a popular DRL algorithm for continuous control. It extends DQN to work with the continuous action space by introducing a deterministic actor that directly outputs continuous actions. DDPG also combines techniques from DQN, such as the replay buffer and target network.
Original paper:
Reference resources:
Implemented Variants
| Variants Implemented | Description | 
|---|---|
| ddpg_continuous_action.py,  docs | For continuous action space | 
Below is our single-file implementation of DDPG:
ddpg_continuous_action.py
The ddpg_continuous_action.py has the following features:
- For continuous action space
- Works with the Boxobservation space of low-level features
- Works with the Box(continuous) action space
Usage
poetry install
poetry install -E pybullet
python cleanrl/ddpg_continuous_action.py --help
python cleanrl/ddpg_continuous_action.py --env-id HopperBulletEnv-v0
poetry install -E mujoco # only works in Linux
python cleanrl/ddpg_continuous_action.py --env-id Hopper-v3
Explanation of the logged metrics
Running python cleanrl/ddpg_continuous_action.py will automatically record various metrics such as actor or value losses in Tensorboard. Below is the documentation for these metrics:
- charts/episodic_return: episodic return of the game
- charts/SPS: number of steps per second
- 
losses/qf1_loss: the mean squared error (MSE) between the Q values at timestep \(t\) and the Bellman update target estimated using the reward \(r_t\) and the Q values at timestep \(t+1\), thus minimizing the one-step temporal difference. Formally, it can be expressed by the equation below. $$ J(\theta^{Q}) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \big[ (Q(s, a) - y)^2 \big], $$ with the Bellman update target \(y = r + \gamma \, Q^{'}(s', a')\), where \(a' \sim \mu^{'}(s')\), and the replay buffer \(\mathcal{D}\).
- 
losses/actor_loss: implemented as-qf1(data.observations, actor(data.observations)).mean(); it is the negative average Q values calculated based on the 1) observations and the 2) actions computed by the actor based on these observations. By minimizingactor_loss, the optimizer updates the actors parameter using the following gradient (Lillicrap et al., 2016, Algorithm 1)1:
- losses/qf1_values: implemented as- qf1(data.observations, data.actions).view(-1), it is the average Q values of the sampled data in the replay buffer; useful when gauging if under or over estimation happens.
Implementation details
Our ddpg_continuous_action.py is based on the OurDDPG.py from  sfujim/TD3, which presents the the following implementation difference from (Lillicrap et al., 2016)1:
- 
ddpg_continuous_action.pyuses a gaussian exploration noise \(\mathcal{N}(0, 0.1)\), while (Lillicrap et al., 2016)1 uses Ornstein-Uhlenbeck process with \(\theta=0.15\) and \(\sigma=0.2\).
- 
ddpg_continuous_action.pyruns the experiments using theopenai/gymMuJoCo environments, while (Lillicrap et al., 2016)1 uses their proprietary MuJoCo environments.
- 
ddpg_continuous_action.pyuses the following architecture:while (Lillicrap et al., 2016, see Appendix 7 EXPERIMENT DETAILS)1 uses the following architecture (difference highlighted):class QNetwork(nn.Module): def __init__(self, env): super(QNetwork, self).__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, 1) def forward(self, x, a): x = torch.cat([x, a], 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x class Actor(nn.Module): def __init__(self, env): super(Actor, self).__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) self.fc2 = nn.Linear(256, 256) self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape)) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return torch.tanh(self.fc_mu(x))class QNetwork(nn.Module): def __init__(self, env): super(QNetwork, self).__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 400) self.fc2 = nn.Linear(400 + np.prod(env.single_action_space.shape), 300) self.fc3 = nn.Linear(300, 1) def forward(self, x, a): x = F.relu(self.fc1(x)) x = torch.cat([x, a], 1) x = F.relu(self.fc2(x)) x = self.fc3(x) return x class Actor(nn.Module): def __init__(self, env): super(Actor, self).__init__() self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 400) self.fc2 = nn.Linear(400, 300) self.fc_mu = nn.Linear(300, np.prod(env.single_action_space.shape)) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return torch.tanh(self.fc_mu(x))
- 
ddpg_continuous_action.pyuses the following learning rates:while (Lillicrap et al., 2016, see Appendix 7 EXPERIMENT DETAILS)1 uses the following learning rates:q_optimizer = optim.Adam(list(qf1.parameters()), lr=3e-4) actor_optimizer = optim.Adam(list(actor.parameters()), lr=3e-4)q_optimizer = optim.Adam(list(qf1.parameters()), lr=1e-4) actor_optimizer = optim.Adam(list(actor.parameters()), lr=1e-3)
- 
ddpg_continuous_action.pyuses--batch-size=256 --tau=0.005, while (Lillicrap et al., 2016, see Appendix 7 EXPERIMENT DETAILS)1 uses--batch-size=64 --tau=0.001
- 
ddpg_continuous_action.pyalso adds support for handling continuous environments where the lower and higher bounds of the action space are not \([-1,1]\), or are asymmetric. The case where the bounds are not \([-1,1]\) is handled inDDPG.py(Fujimoto et al., 2018)2 as follows:On the other hand, inclass Actor(nn.Module): ... def forward(self, state): a = F.relu(self.l1(state)) a = F.relu(self.l2(a)) return self.max_action * torch.tanh(self.l3(a)) # Scale from [-1,1] to [-action_high, action_high]CleanRL's ddpg_continuous_action.py, the mean and the scale of the the action space are computed asaction_biasandaction_scalerespectively. Those scalars are in turn used to scale the output of atanhactivation function in the actor to the original action space range:class Actor(nn.Module): def __init__(self, env): ... # action rescaling self.register_buffer("action_scale", torch.FloatTensor((env.action_space.high - env.action_space.low) / 2.0)) self.register_buffer("action_bias", torch.FloatTensor((env.action_space.high + env.action_space.low) / 2.0)) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = torch.tanh(self.fc_mu(x)) return x * self.action_scale + self.action_bias # Scale from [-1,1] to [-action_low, action_high]
Additionally, when drawing exploration noise that is added to the actions produced by the actor, CleanRL's ddpg_continuous_action.py centers the distribution the sampled from at action_bias, and the scale of the distribution is set to action_scale * exploration_noise.
Info
Note that Humanoid-v2, InvertedPendulum-v2, Pusher-v2 have action space bounds that are not the standard [-1, 1]. See below.
Ant-v2 Observation space: Box(-inf, inf, (111,), float64) Action space: Box(-1.0, 1.0, (8,), float32)
HalfCheetah-v2 Observation space: Box(-inf, inf, (17,), float64) Action space: Box(-1.0, 1.0, (6,), float32)
Hopper-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (3,), float32)
Humanoid-v2 Observation space: Box(-inf, inf, (376,), float64) Action space: Box(-0.4, 0.4, (17,), float32)
InvertedDoublePendulum-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (1,), float32)
InvertedPendulum-v2 Observation space: Box(-inf, inf, (4,), float64) Action space: Box(-3.0, 3.0, (1,), float32)
Pusher-v2 Observation space: Box(-inf, inf, (23,), float64) Action space: Box(-2.0, 2.0, (7,), float32)
Reacher-v2 Observation space: Box(-inf, inf, (11,), float64) Action space: Box(-1.0, 1.0, (2,), float32)
Swimmer-v2 Observation space: Box(-inf, inf, (8,), float64) Action space: Box(-1.0, 1.0, (2,), float32)
Walker2d-v2 Observation space: Box(-inf, inf, (17,), float64) Action space: Box(-1.0, 1.0, (6,), float32)
Experiment results
To run benchmark experiments, see benchmark/ddpg.sh. Specifically, execute the following command:
Below are the average episodic returns for ddpg_continuous_action.py (3 random seeds). To ensure the quality of the implementation, we compared the results against (Fujimoto et al., 2018)2.
| Environment | ddpg_continuous_action.py | OurDDPG.py(Fujimoto et al., 2018, Table 1)2 | DDPG.pyusing settings from (Lillicrap et al., 2016)1 in (Fujimoto et al., 2018, Table 1)2 | 
|---|---|---|---|
| HalfCheetah | 9382.32 ± 1395.52 | 8577.29 | 3305.60 | 
| Walker2d | 1598.35 ± 862.66 | 3098.11 | 1843.85 | 
| Hopper | 1313.43 ± 684.46 | 1860.02 | 2020.46 | 
| Humanoid | 897.74 ± 281.87 | not available | |
| Pusher | -34.45 ± 4.47 | not available | |
| InvertedPendulum | 645.67 ± 270.31 | 1000.00 ± 0.00 | 
Info
Note that ddpg_continuous_action.py uses gym MuJoCo v2 environments while OurDDPG.py (Fujimoto et al., 2018)2 uses the gym MuJoCo v1 environments. According to the  openai/gym#834, gym MuJoCo v2 environments should be equivalent to the gym MuJoCo v1 environments.
Also note the performance of our ddpg_continuous_action.py seems to be worse than the reference implementation on Walker2d and Hopper. This is likely due to  openai/gym#938. We would have a hard time reproducing gym MuJoCo v1 environments because they have been long deprecated.
One other thing could cause the performance difference: the original code reported the average episodic return using determinisitc evaluation (i.e., without exploration noise), see sfujim/TD3/main.py#L15-L32, whereas we reported the episodic return during training and the policy gets updated between environments steps.
Learning curves:
 
 
 
 
 
 
ddpg_continuous_action_jax.py
The ddpg_continuous_action_jax.py has the following features:
- Uses Jax, Flax, and Optax instead of torch. ddpg_continuous_action_jax.py is roughly 2.5-4x faster than ddpg_continuous_action.py
- For continuous action space
- Works with the Boxobservation space of low-level features
- Works with the Box(continuous) action space
Usage
poetry install -E "mujoco jax"
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -c "import mujoco_py"
python cleanrl/ddpg_continuous_action_jax.py --help
poetry install -E mujoco # only works in Linux
python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v3
Explanation of the logged metrics
See related docs for ddpg_continuous_action.py.
Implementation details
See related docs for ddpg_continuous_action.py.
Experiment results
To run benchmark experiments, see benchmark/ddpg.sh. Specifically, execute the following command:
Below are the average episodic returns for ddpg_continuous_action_jax.py (3 random seeds). To ensure the quality of the implementation, we compared the results against (Fujimoto et al., 2018)2.
| Environment | ddpg_continuous_action_jax.py | ddpg_continuous_action.py | OurDDPG.py(Fujimoto et al., 2018, Table 1)2 | 
|---|---|---|---|
| HalfCheetah | 9910.53 ± 673.49 | 9382.32 ± 1395.52 | 8577.29 | 
| Walker2d | 1397.60 ± 677.12 | 1598.35 ± 862 | 3098.11 | 
| Hopper | 1603.5 ± 727.281 | 1313.43 ± 684.46 | 1860.02 | 
Info
Note that we ran the ddpg_continuous_action_jax.py experiments with RTX 3060 Ti (~810 SPS) and ddpg_continuous_action.py experiments with RTX 2060 (~241 SPS). Using RTX 3060 Ti w/ ddpg_continuous_action.py brings the SPS from 241 to 325, meaning that under the same hardware, ddpg_continuous_action_jax.py would be roughly 810/241=2.5x faster.  However, because of the overhead of --capture-video that both scripts suffer, we suspect ddpg_continuous_action_jax.py would be 3x-4x faster when --capture-video is disabled.
Learning curves:
 
 
 
 
 
 
Tracked experiments and game play videos:
- 
Lillicrap, T.P., Hunt, J.J., Pritzel, A., Heess, N.M., Erez, T., Tassa, Y., Silver, D., & Wierstra, D. (2016). Continuous control with deep reinforcement learning. CoRR, abs/1509.02971. https://arxiv.org/abs/1509.02971 ↩↩↩↩↩↩↩↩ 
- 
Fujimoto, S., Hoof, H.V., & Meger, D. (2018). Addressing Function Approximation Error in Actor-Critic Methods. ArXiv, abs/1802.09477. https://arxiv.org/abs/1802.09477 ↩↩↩↩↩↩↩