This is the official implementation for SFR-on on ImageNet using DiT. The code is based on the DiT official implementation.
Figure 1: Samples from the original DiT-XL-2-256x256. |
Figure 2: Samples from the unlearned DiT. We successly remove the knowledge of Golden retriever, while preserving the knowledge of other classes. |
Install the requirements using a conda environment:
conda env create -f environment.yml
conda activate DiT
- Dataset ImageNet: You should download the ImageNet dataset and place it at
data_path
. - Pre-trained DiT checkpoints: You can either download the pre-trained DiT model (XL-2-256x256 in the paper) on advance, or the weights for the pre-trained DiT model will be automatically downloaded depending on the model you use.
- First, we need to generate fisher diagonal for saliency map.
python generate_fisher.py --model DiT-XL/2 --data-path $data_path --batch-size $bz --ckpt $ckpt --n-iters $iters --forget-class $cls --mask-path $mask_path
- data_path: path to the imagenet dataset.
- bz: batch size. (default: 1, can be increased for better results but costs more memory)
- ckpt: path to the pre-trained DiT checkpoint.
- iters: number of iterations to generate the fisher diagonal. (default: 2000, can be increased for better results but costs more time)
- cls: the class to be forgotten. (such as: 207)
- mask_path: path to the saliency mask. (default: ./mask)
- Next, we need to generate saliency map with threshold (
$\gamma$ in paper) for unlearning.
python generate_mask.py --mask-path $mask_path --forget-class $cls --thresholds $thresholds
- mask_path: path to the saliency mask. (default: ./mask)
- cls: the class to be forgotten. (such as: 207)
- thresholds: a list of threshold values to apply for hard-coding the mask. (default: "0.5 1 3 5 10")
- Forgetting training with SFR-on.
python forget.py --model DiT-XL/2 --data-path $d --batch-size $bz --ckpt $ckpt --n-iters 600 --snapshot-every 50 --lr 1e-4 --forget-class $cls --method ron --unlearn-loss ga \
--forget-alpha 1e-3 --decay-forget-alpha --remain-alpha 1.0 --mask-path $mask_path
- d: path to the imagenet dataset.
- bz: batch size. (default: 1, can be increased for better results but costs more memory)
- ckpt: path to the pre-trained DiT checkpoint.
- n-iters: number of iterations for unlearning. (recommended range: 500 to 1000)
- snapshot-every: interval at which to visualize results. (default: 50)
- lr: learning rate. (default: 1e-4)
- forget-class: the class to be forgotten. (such as: 207)
- method: the method to generate the saliency map. (default: ron)
- unlearn-loss: the loss function to use for unlearning. (default: ga)