We propose a Weight Diffusion (W-Diff) approach, which is specialized for the evolving domain generalization (EDG) in the domain-incremental setting. W-Diff capitalizes on the strong modeling ability of diffusion models to capture the evolving pattern of optimized classifiers across domains.
-
The code is implemented with
Python 3.7.16
and run onNVIDIA GeForce RTX 4090
. To try out this project, it is recommended to set up a virtual environment first.# Step-by-step installation conda create --name wdiff python=3.7.16 conda activate wdiff # this installs the right pip and dependencies for the fresh python conda install -y ipython pip # install torch, torchvision and torchaudio pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html # this installs required packages pip install -r requirements.txt
- Download yearbook.pkl
- Download fmow.pkl and fmow_v1.1.tar.gz
- Download huffpost.pkl
- Download arxiv.pkl
- ONP and 2-Moons are provided in the "datasets" folder.
- rmnist will automatically download while running the code.
The data folder should be structured as follows:
```
├── datasets/
│ ├── yearbook/
| | ├── yearbook.pkl
│ ├── rmnist/
| | ├── MNIST/
| | ├── rmnist.pkl
│ ├── ONP/
| | ├── processed/
│ ├── Moons/
| | ├── processed/
│ ├── huffpost/
| | ├── huffpost.pkl
│ ├── fMoW/
| | ├── fmow_v1.1/
| | | |── images/
| | |—— fmow.pkl
│ ├── arxiv/
| | ├── arxiv.pkl
```
-
Training and testing together:
# running for yearbook dataset: python3 main.py --cfg ./configs/eval_fix/cfg_yearbook.yaml device 0 # running for rmnist dataset: python3 main.py --cfg ./configs/eval_fix/cfg_rmnist.yaml device 1 # running for fmow dataset: python3 main.py --cfg ./configs/eval_fix/cfg_fmow.yaml device 2 # running for 2-Moons dataset: python3 main.py --cfg ./configs/eval_fix/cfg_moons.yaml device 3 # running for ONP dataset: python3 main.py --cfg ./configs/eval_fix/cfg_onp.yaml device 4 # running for huffpost dataset: python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5 # running for arxiv dataset: python3 main.py --cfg ./configs/eval_fix/cfg_arxiv.yaml device 6
If you meet the "OSError: Can't load tokenizer for 'bert-base-uncased'."
when running code on the Huffpost and Arxiv datasets, you can try to add HF_ENDPOINT=https://hf-mirror.com
before the python commands. For example,
HF_ENDPOINT=https://hf-mirror.com python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5
-
Testing with saved model checkpoints:
You can download the models trained by W-Diff here and put them into
<root_dir>/checkpoints/
.# evaluating on ONP dataset python3 main_test_only.py --cfg ./configs/eval_fix/cfg_onp.yaml --model_path 'abs_path_of_onp_model.pkl' device 5
This project is mainly based on the open-source project: Wild-Time, EvoS and LDM. We thank the authors for making the source code publicly available.
If you find this work helpful to your research, please consider citing the paper:
@inproceedings{xie2024wdiff,
title={Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments},
author={Mixue Xie, Shuang Li, Binhui Xie, Chi Harold Liu, Jian Liang, Zixun Sun, Ke Feng, Chengwei Zhu},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2024}
}