PyTorch Implementation of DeepMind's Human-Level Control through Deep Reinforcement Learning, Minh et al.
This repository is an implementation of the DeepMind DQN Algorithm for the OpenAI gym ATARI environment from Minh et al.
To train the model, run python dqn.py --weights [pretrained weights]
. Various hyperparameters can be set in dqn.py
. Good pretrained weights are provided in the weights directory, but you can also train from scratch. Experiments are saved in the experiments
directory by default.
The details of the DQN implementation make a huge difference on performance. The following guidelines were helpful in achieving good error:
- The original paper was ambiguous about frame skipping in the environment. I originally thought the Q-network was fed the past 4 frames (ignoring frame skipping), but in fact it is fed the past 4 observed frames on top of frame skipping, so essentially the current frame, T, T-4, T-8, and T-12 with the default skip size of 4.
- Considering loss of life to be the end of an episode for rewards was helpful (i.e. mark any frame on which loss of life occurs as terminal, but don't reset the environment).
- The large replay memory can be fit into a reasonable GPU or CPU memory (8GB) by storing frames as unsigned integers and by storing the original and subsequent state in the same cache. Essentially, just save all frames once in an array and sample from it as needed. This implementation also supports a memory cache split over multiple devices, but this was ultimately not needed.
- Using the Adam optimizer was perfectly fine, instead of RMSProp. See the hyperparameters used in this implementation for reference. Faster learning rates worked better for easy tasks like Pong.
- I personally annealed epsilon from 1 to 0.1 in 1 million frames, and then to 0.01 over the next 30 million frames. This worked fine, but other methods anneal to 0.01 much faster.
Here is the metrics plot for a long training run, with orange showing the total average unclipped reward:
Here is the clipped reward over time: