Skip to content
/ W-Diff Public

[NeurIPS 2024] official implementation of Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments

Notifications You must be signed in to change notification settings

BIT-DA/W-Diff

Repository files navigation

Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments [NeurIPS 2024]

openreview    Poster   

Overview

We propose a Weight Diffusion (W-Diff) approach, which is specialized for the evolving domain generalization (EDG) in the domain-incremental setting. W-Diff capitalizes on the strong modeling ability of diffusion models to capture the evolving pattern of optimized classifiers across domains. image

Prerequisites Installation

  • The code is implemented with Python 3.7.16 and run on NVIDIA GeForce RTX 4090. To try out this project, it is recommended to set up a virtual environment first.

    # Step-by-step installation
    conda create --name wdiff python=3.7.16
    conda activate wdiff
    
    # this installs the right pip and dependencies for the fresh python
    conda install -y ipython pip
    
    # install torch, torchvision and torchaudio
    pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
    
    # this installs required packages
    pip install -r requirements.txt

Datasets Preparation

The data folder should be structured as follows:

```
├── datasets/
│   ├── yearbook/     
|   |   ├── yearbook.pkl
│   ├── rmnist/
|   |   ├── MNIST/
|   |   ├── rmnist.pkl
│   ├── ONP/	
|   |   ├── processed/
│   ├── Moons/	
|   |   ├── processed/
│   ├── huffpost/	
|   |   ├── huffpost.pkl
│   ├── fMoW/	
|   |   ├── fmow_v1.1/
|   |   |   |── images/
|   |   |—— fmow.pkl
│   ├── arxiv/	
|   |   ├── arxiv.pkl
```

Code Running

  • Training and testing together:

    # running for yearbook dataset:
    python3 main.py --cfg ./configs/eval_fix/cfg_yearbook.yaml device 0
    
    # running for rmnist dataset:
    python3 main.py --cfg ./configs/eval_fix/cfg_rmnist.yaml device 1
    
    # running for fmow dataset:
    python3 main.py --cfg ./configs/eval_fix/cfg_fmow.yaml device 2
    
    # running for 2-Moons dataset:
    python3 main.py --cfg ./configs/eval_fix/cfg_moons.yaml device 3
    
    # running for ONP dataset:
    python3 main.py --cfg ./configs/eval_fix/cfg_onp.yaml device 4
    
    # running for huffpost dataset:
    python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5
    
    # running for arxiv dataset:
    python3 main.py --cfg ./configs/eval_fix/cfg_arxiv.yaml device 6
    

If you meet the "OSError: Can't load tokenizer for 'bert-base-uncased'." when running code on the Huffpost and Arxiv datasets, you can try to add HF_ENDPOINT=https://hf-mirror.com before the python commands. For example,

HF_ENDPOINT=https://hf-mirror.com python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5
  • Testing with saved model checkpoints:

    You can download the models trained by W-Diff here and put them into <root_dir>/checkpoints/.

    # evaluating on ONP dataset
    python3 main_test_only.py --cfg ./configs/eval_fix/cfg_onp.yaml --model_path 'abs_path_of_onp_model.pkl' device 5

Acknowledgments

This project is mainly based on the open-source project: Wild-Time, EvoS and LDM. We thank the authors for making the source code publicly available.

Citation

If you find this work helpful to your research, please consider citing the paper:

@inproceedings{xie2024wdiff,
  title={Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments},
  author={Mixue Xie, Shuang Li, Binhui Xie, Chi Harold Liu, Jian Liang, Zixun Sun, Ke Feng, Chengwei Zhu},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2024}
}

About

[NeurIPS 2024] official implementation of Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages