Locally Estimated Global Perturbations are Better than Local Perturbations for Federated Sharpness-aware Minimization
| 📑 Paper | 🐱 Github Repo |
Ziqing Fan1,2 , Shengchao Hu1,2 , Jiangchao Yao1,2, Gang Niu3, Ya Zhang1,2, Masashi Sugiyama3,4, Yanfeng Wang1,2
1 Shanghai Jiao Tong University, 2 Shanghai AI Laboratory, 3 RIKEN AIP, 4 The University of Tokyo.
In federated learning (FL), the multi-step update and data heterogeneity among clients often lead to a loss landscape with sharper minima, degenerating the performance of the resulted global model. Prevalent federated approaches incorporate sharpness-aware minimization (SAM) into local training to mitigate this problem. However, the local loss landscapes may not accurately reflect the flatness of global loss landscape in heterogeneous environments; as a result, minimizing local sharpness and calculating perturbations on client data might not align the efficacy of SAM in FL with centralized training. To overcome this challenge, we propose FedLESAM, a novel algorithm that locally estimates the direction of global perturbation on client side as the difference between global models received in the previous active and current rounds. Besides the improved quality, FedLESAM also speed up federated SAM-based approaches since it only performs once backpropagation in each iteration.
Here we provide the implementation on Cifar-10 and Cifar100 datasets of following methods:
FedAvg: Communication-Efficient Learning of Deep Networks from Decentralized Data
FedProx: Federated Optimization in Heterogeneous Networks
FedAdam: Adaptive Federated Optimization
SCAFFOLD: SCAFFOLD: Stochastic Controlled Averaging for Federated Learning
FedDyn: Federated Learning Based on Dynamic Regularization
FedCM: FedCM: Federated Learning with Client-level Momentum
FedSAM/MoFedSAM: Generalized Federated Learning via Sharpness Aware Minimization
FedSkip(coming soon)
FedMR(coming soon)
FedGELA(coming soon)
FedLESAM, FedLESAM-S, FedLESAM-D: Locally Estimated Global Perturbations are Better than Local Perturbations for Federated Sharpness-aware Minimization
Here we provide a command to start the training of one algorithm:
CUDA_VISIBLE_DEVICES=0 python train.py --non-iid --dataset CIFAR10 --model ResNet18 --split-rule Dirichlet --split-coef 0.6 --active-ratio 0.1 --total-client 100 --batchsize 50 --rho 0.5 --method FedLESAM-S --local-epochs 5 --comm-rounds 800
For the best results, you might need to tune the parameter of rho.
As for FedSMOO and FedGAMMA, the authors just make their codes open source. Please refer to the repo FedSMOO, which might be more accurate for their algorithms. Notably, we try to implement our previous works FedSkip(ICDM22), FedMR(TMLR23) and FedGELA(NeurIPS23) in this repo. Feel free to use these methods for heterogeneous data in federated learning.
If you find this work is relevant with your research or applications, please feel free to cite our work!
@inproceedings{FedLESAM,
title={Locally Estimated Global Perturbations are Better than Local Perturbations for Federated Sharpness-aware Minimization},
author={Fan, Ziqing and Hu, Shengchao and Yao, Jiangchao and Niu, Gang and Zhang, Ya and Sugiyama, Masashi and Wang, Yanfeng},
booktitle={International Conference on Machine Learning},
year={2024},
}
This repo benefits from FedSMOO. Thanks for their wonderful works!