Skip to content

a continual learning optimizer mitigating catastrophic forgetting and loss of plasticity

License

Notifications You must be signed in to change notification settings

yumikim381/upgd-dl-project

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Is Utility all you need?

The repo for the ETH Deep Learning project, based on the original paper "Addressing Loss of Plasticity and Catastrophic Forgetting in Continual Learning" from here.

Todo (change / delete this sentence): Here we describe how to reproduce the results. If you only want the implementation of the UPGD algorithm you can find it here:

import torch

class UPGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-5, weight_decay=0.001, beta_utility=0.999, sigma=0.001):
        defaults = dict(lr=lr, weight_decay=weight_decay, beta_utility=beta_utility, sigma=sigma)
        super(UPGD, self).__init__(params, defaults)
    def step(self):
        global_max_util = torch.tensor(-torch.inf)
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    state["avg_utility"] = torch.zeros_like(p.data)
                state["step"] += 1
                avg_utility = state["avg_utility"]
                avg_utility.mul_(group["beta_utility"]).add_(
                    -p.grad.data * p.data, alpha=1 - group["beta_utility"]
                )
                current_util_max = avg_utility.max()
                if current_util_max > global_max_util:
                    global_max_util = current_util_max
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                bias_correction_utility = 1 - group["beta_utility"] ** state["step"]
                noise = torch.randn_like(p.grad) * group["sigma"]
                scaled_utility = torch.sigmoid_((state["avg_utility"] / bias_correction_utility) / global_max_util)
                p.data.mul_(1 - group["lr"] * group["weight_decay"]).add_(
                    (p.grad.data + noise) * (1-scaled_utility),
                    alpha=-2.0*group["lr"],
                )

Installation:

1. You need to have environemnt with python 3.7:

git clone  --recursive [email protected]:mohmdelsayed/upgd.git
python3.7 -m venv .upgd
source .upgd/bin/activate

2. Install Dependencies:

python -m pip install --upgrade pip
pip install -r requirements.txt 
pip install HesScale/.
pip install .

3. Run experiment:

Label-permuted CIFAR-10/EMNIST/miniImageNet (Figure 6):

You first need to define the grid search of each method then you generate then python cmds using:

python experiments/label_permuted_cifar10.py

This would generate a list of python cmds you need to run them. After they are done, the results would be saved in logs/ in a JSON format. To plot, use the following after choosing what to plot:

python core/plot/plotter.py

Input/Label-permuted Tasks Diagnostic Statistics (Figure 5):

You first need to choose the method and the hyperparameter setting you want to run the statistics on from:

python experiments/statistics_output_permuted_cifar10.py

This would generate a list of python cmds you need to run them. After they are done, the results would be saved in logs/ in a JSON format.

Policy collapse experiment (Figure 8):

You need to choose the environment id and the seed number. In the paper, we averaged over 30 different seeds.

python core/run/rl/ppo_continuous_action_adam.py --seed 0 --env_id HalfCheetah-v4
python core/run/rl/ppo_continuous_action_upgd.py --seed 0 --env_id HalfCheetah-v4

About

a continual learning optimizer mitigating catastrophic forgetting and loss of plasticity

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 60.1%
  • Python 38.6%
  • Shell 1.3%