From ae175752276b46cf0f5d8f85cfdde6bf09eb651b Mon Sep 17 00:00:00 2001 From: Gordon Lichtstein <72274426+generic-account@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:51:41 -0400 Subject: [PATCH] Fixed typos in util files and extended algos type hinting (#79) Co-authored-by: generic-account Co-authored-by: Abhishek Singh --- src/algos/DisPFL.py | 121 +++++++++++++++++++++------------------ src/algos/base_class.py | 42 +++++++------- src/algos/def_kt.py | 38 ++++++------ src/algos/fl_random.py | 1 + src/utils/data_utils.py | 2 +- src/utils/model_utils.py | 1 + 6 files changed, 109 insertions(+), 96 deletions(-) diff --git a/src/algos/DisPFL.py b/src/algos/DisPFL.py index 6c5cc80..310eb92 100644 --- a/src/algos/DisPFL.py +++ b/src/algos/DisPFL.py @@ -6,7 +6,7 @@ import math import random from collections import OrderedDict -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch @@ -21,34 +21,42 @@ class CommProtocol: Communication protocol tags for the server and clients. """ - DONE = 0 # Used to signal the server that the client is done with local training - START = 1 # Used to signal by the server to start the current round - UPDATES = 2 # Used to send the updates from the server to the clients - LAST_ROUND = 3 - SHARE_MASKS = 4 - SHARE_WEIGHTS = 5 - FINISH = 6 # Used to signal the server to finish the current round + DONE: int = 0 # Used to signal the server that the client is done with local training + START: int = 1 # Used to signal by the server to start the current round + UPDATES: int = 2 # Used to send the updates from the server to the clients + LAST_ROUND: int = 3 + SHARE_MASKS: int = 4 + SHARE_WEIGHTS: int = 5 + FINISH: int = 6 # Used to signal the server to finish the current round class DisPFLClient(BaseClient): """ Client class for DisPFL (Distributed Personalized Federated Learning). """ - def __init__(self, config) -> None: + + def __init__(self, config: Dict[str, Any]) -> None: super().__init__(config) + self.params: Optional[Dict[str, Tensor]] = None + self.mask: Optional[OrderedDict[str, Tensor]] = None + self.index: Optional[int] = None + self.repr: Optional[OrderedDict[str, Tensor]] = None self.config = config self.tag = CommProtocol - self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" - self.dense_ratio = self.config["dense_ratio"] - self.anneal_factor = self.config["anneal_factor"] - self.dis_gradient_check = self.config["dis_gradient_check"] - self.server_node = 1 # leader node - self.num_users = config["num_users"] - self.neighbors = list(range(self.num_users)) + self.model_save_path = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) + self.dense_ratio: float = self.config["dense_ratio"] + self.anneal_factor: float = self.config["anneal_factor"] + self.dis_gradient_check: bool = self.config["dis_gradient_check"] + self.server_node: int = 1 # leader node + self.num_users: int = config["num_users"] + self.neighbors: List[int] = list(range(self.num_users)) + if self.node_id == 1: self.clients = list(range(2, self.num_users + 1)) - def local_train(self): + def local_train(self) -> None: """ Train the model locally. """ @@ -57,20 +65,16 @@ def local_train(self): ) print(f"Node{self.node_id} train loss: {loss}, train acc: {acc}") - def local_test(self, **kwargs): + def local_test(self, **kwargs: Any) -> Tuple[float, float]: """ Test the model locally, not to be used in the traditional FedAvg. """ test_loss, acc = self.model_utils.test( self.model, self._test_loader, self.loss_fn, self.device ) - # TODO save the model if the accuracy is better than the best accuracy so far - # if acc > self.best_acc: - # self.best_acc = acc - # self.model_utils.save_model(self.model, self.model_save_path) return test_loss, acc - def get_trainable_params(self): + def get_trainable_params(self) -> Dict[str, Tensor]: param_dict = {} for name, param in self.model.named_parameters(): param_dict[name] = param @@ -82,13 +86,13 @@ def get_representation(self) -> OrderedDict[str, Tensor]: """ return self.model.state_dict() - def set_representation(self, representation: OrderedDict[str, Tensor]): + def set_representation(self, representation: OrderedDict[str, Tensor]) -> None: """ Set the model weights. """ self.model.load_state_dict(representation) - def fire_mask(self, masks, round_num, total_round): + def fire_mask(self, masks: OrderedDict[str, Tensor], round_num: int, total_round: int) -> Tuple[OrderedDict[str, Tensor], Dict[str, int]]: """ Fire mask method for model pruning. """ @@ -97,7 +101,7 @@ def fire_mask(self, masks, round_num, total_round): self.anneal_factor / 2 * (1 + np.cos((round_num * np.pi) / total_round)) ) new_masks = copy.deepcopy(masks) - num_remove = {} + num_remove: Dict[str, int] = {} for name in masks: num_non_zeros = torch.sum(masks[name]) num_remove[name] = math.ceil(drop_ratio * num_non_zeros) @@ -110,7 +114,7 @@ def fire_mask(self, masks, round_num, total_round): new_masks[name].view(-1)[idx[: num_remove[name]]] = 0 return new_masks, num_remove - def regrow_mask(self, masks, num_remove, gradient=None): + def regrow_mask(self, masks: OrderedDict[str, Tensor], num_remove: Dict[str, int], gradient: Optional[Dict[str, Tensor]] = None) -> OrderedDict[str, Tensor]: """ Regrow mask method for model pruning. """ @@ -138,7 +142,7 @@ def regrow_mask(self, masks, num_remove, gradient=None): new_masks[name].view(-1)[idx] = 1 return new_masks - def aggregate(self, nei_indexes, weights_lstrnd, masks_lstrnd): + def aggregate(self, nei_indexes: List[int], weights_lstrnd: List[OrderedDict[str, Tensor]], masks_lstrnd: List[OrderedDict[str, Tensor]]) -> Tuple[OrderedDict[str, Tensor], OrderedDict[str, Tensor]]: """ Aggregate the model weights. """ @@ -171,7 +175,7 @@ def aggregate(self, nei_indexes, weights_lstrnd, masks_lstrnd): w_tmp[name] = w_tmp[name] * self.mask[name].to(self.device) return w_tmp, w_p_g - def send_representations(self, representation): + def send_representations(self, representation: OrderedDict[str, Tensor]) -> None: """ Set the model. """ @@ -179,13 +183,13 @@ def send_representations(self, representation): self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES) print(f"Node 1 sent average weight to {len(self.clients)} nodes") - def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5): + def calculate_sparsities(self, params: Dict[str, Tensor], tabu: Optional[List[str]] = None, distribution: str = "ERK", sparse: float = 0.5) -> Dict[str, float]: """ Calculate sparsities for model pruning. """ if tabu is None: tabu = [] - sparsities = {} + sparsities: Dict[str, float] = {} if distribution == "uniform": for name in params: if name not in tabu: @@ -239,14 +243,14 @@ def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5 sparsities[name] = 1 - epsilon * raw_probabilities[name] return sparsities - def init_masks(self, params, sparsities): + def init_masks(self, params: Dict[str, Tensor], sparsities: Dict[str, float]) -> OrderedDict[str, Tensor]: """ Initialize masks for model pruning. """ masks = OrderedDict() for name in params: masks[name] = zeros_like(params[name]) - dense_numel = int((1 - sparsities[name]) * numel(masks[name])) + dense_numel = int((1 - sparsities[name]) * masks[name].numel()) if dense_numel > 0: temp = masks[name].view(-1) perm = randperm(len(temp)) @@ -254,7 +258,7 @@ def init_masks(self, params, sparsities): temp[perm] = 1 return masks - def screen_gradient(self): + def screen_gradient(self) -> Dict[str, Tensor]: """ Screen gradient method for model pruning. """ @@ -273,7 +277,7 @@ def screen_gradient(self): return gradient - def hamming_distance(self, mask_a, mask_b): + def hamming_distance(self, mask_a: OrderedDict[str, Tensor], mask_b: OrderedDict[str, Tensor]) -> Tuple[int, int]: """ Calculate the Hamming distance between two masks. """ @@ -288,20 +292,20 @@ def hamming_distance(self, mask_a, mask_b): def _benefit_choose( self, - round_idx, - cur_clnt, - client_num_in_total, - client_num_per_round, - dist_local, - total_dist, - cs=False, - active_ths_rnd=None, - ): + round_idx: int, # pylint: disable=unused-argument + cur_clnt: int, + client_num_in_total: int, + client_num_per_round: int, + dist_local: Optional[np.ndarray], # pylint: disable=unused-argument + total_dist: Optional[np.ndarray], # pylint: disable=unused-argument + cs: bool = False, + active_ths_rnd: Optional[np.ndarray] = None, + ) -> np.ndarray: """ Benefit choose method for client selection. """ if client_num_in_total == client_num_per_round: - client_indexes = list(range(client_num_in_total)) + client_indexes = np.array(list(range(client_num_in_total))) return client_indexes if cs == "random": @@ -326,7 +330,7 @@ def _benefit_choose( ) return client_indexes - def model_difference(self, model_a, model_b): + def model_difference(self, model_a: OrderedDict[str, Tensor], model_b: OrderedDict[str, Tensor]) -> Tensor: """ Calculate the difference between two models. """ @@ -335,7 +339,7 @@ def model_difference(self, model_a, model_b): ) return diff - def run_protocol(self): + def run_protocol(self) -> None: """ Runs the entire training protocol. """ @@ -446,14 +450,21 @@ class DisPFLServer(BaseServer): """ Server class for DisPFL (Distributed Personalized Federated Learning). """ - def __init__(self, config) -> None: + + def __init__(self, config: Dict[str, Any]) -> None: super().__init__(config) + self.best_acc: float = 0 + self.round: int = 0 + self.masks: Any = 0 + self.reprs: Any = 0 self.config = config self.set_model_parameters(config) self.tag = CommProtocol - self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" - self.dense_ratio = self.config["dense_ratio"] - self.num_users = self.config["num_users"] + self.model_save_path = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) + self.dense_ratio: float = self.config["dense_ratio"] + self.num_users: int = self.config["num_users"] def get_representation(self) -> OrderedDict[str, Tensor]: """ @@ -461,7 +472,7 @@ def get_representation(self) -> OrderedDict[str, Tensor]: """ return self.model.state_dict() - def send_representations(self, representations): + def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]) -> None: """ Set the model. """ @@ -484,7 +495,7 @@ def test(self) -> float: self.model_utils.save_model(self.model, self.model_save_path) return acc - def single_round(self, epoch, active_ths_rnd): + def single_round(self, epoch: int, active_ths_rnd: np.ndarray) -> None: """ Runs the whole training procedure. """ @@ -509,13 +520,13 @@ def single_round(self, epoch, active_ths_rnd): self.users, self.tag.SHARE_WEIGHTS ) - def get_trainable_params(self): + def get_trainable_params(self) -> Dict[str, Tensor]: param_dict = {} for name, param in self.model.named_parameters(): param_dict[name] = param return param_dict - def run_protocol(self): + def run_protocol(self) -> None: """ Runs the entire training protocol. """ diff --git a/src/algos/base_class.py b/src/algos/base_class.py index f3d98df..a57288c 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -65,10 +65,10 @@ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> self.dset_obj = get_dataset(self.dset, dpath=config["dpath"]) self.set_constants() - def set_constants(self): + def set_constants(self) -> None: self.best_acc = 0.0 - def setup_cuda(self, config: Dict[str, Any]): + def setup_cuda(self, config: Dict[str, Any]) -> None: # Need a mapping from rank to device id device_ids_map = config["device_ids"] node_name = "node_{}".format(self.node_id) @@ -82,7 +82,7 @@ def setup_cuda(self, config: Dict[str, Any]): self.device = torch.device("cpu") print("Using CPU") - def set_model_parameters(self, config: Dict[str, Any]): + def set_model_parameters(self, config: Dict[str, Any]) -> None: # Model related parameters optim_name = config.get("optimizer", "adam") if optim_name == "adam": @@ -112,7 +112,7 @@ def set_model_parameters(self, config: Dict[str, Any]): else: self.loss_fn = torch.nn.CrossEntropyLoss() - def set_shared_exp_parameters(self, config): + def set_shared_exp_parameters(self, config: Dict[str, Any]) -> None: if self.node_id != 0: community_type, number_of_communities = config.get( @@ -158,12 +158,12 @@ class BaseClient(BaseNode): Abstract class for all algorithms """ - def __init__(self, config, comm_utils) -> None: + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: super().__init__(config, comm_utils) self.server_node = 0 self.set_parameters(config) - def set_parameters(self, config): + def set_parameters(self, config: Dict[str, Any]) -> None: """ Set the parameters for the user """ @@ -209,7 +209,7 @@ def set_parameters(self, config): random.seed(seed) numpy.random.seed(seed) - def set_data_parameters(self, config): + def set_data_parameters(self, config: Dict[str, Any]) -> None: # Train set and test set from original dataset train_dset = self.dset_obj.train_dset @@ -388,10 +388,10 @@ def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor] | List[T """ raise NotImplementedError - def run_protocol(self): + def run_protocol(self) -> None: raise NotImplementedError - def print_data_summary(self, train_test, test_dset, val_dset=None): + def print_data_summary(self, train_test: Any, test_dset: Any, val_dset: Optional[Any] = None) -> None: """ Print the data summary """ @@ -435,13 +435,13 @@ class BaseServer(BaseNode): Abstract class for orchestrator """ - def __init__(self, config, comm_utils) -> None: + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: super().__init__(config, comm_utils) self.num_users = config["num_users"] self.users = list(range(1, self.num_users + 1)) self.set_data_parameters(config) - def set_data_parameters(self, config): + def set_data_parameters(self, config: Dict[str, Any]) -> None: test_dset = self.dset_obj.test_dset batch_size = config["batch_size"] self._test_loader = DataLoader(test_dset, batch_size=batch_size) @@ -458,13 +458,13 @@ def test(self, **kwargs: Any) -> List[float]: """ raise NotImplementedError - def get_model(self, **kwargs): + def get_model(self, **kwargs: Any) -> Any: """ Get the model """ raise NotImplementedError - def run_protocol(self): + def run_protocol(self) -> None: raise NotImplementedError @@ -485,7 +485,7 @@ class BaseFedAvgClient(BaseClient): """ Abstract class for FedAvg based algorithms """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol=CommProtocol) -> None: + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol: type[CommProtocol]) -> None: super().__init__(config, comm_utils) self.config = config self.model_save_path = "{}/saved_models/node_{}.pt".format( @@ -500,7 +500,7 @@ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, com keys = self.model_utils.get_last_layer_keys(self.get_model_weights()) self.model_keys_to_ignore.extend(keys) - def local_train(self, epochs): + def local_train(self, epochs: int) -> Tuple[float, float]: """ Train the model locally """ @@ -523,7 +523,7 @@ def local_train(self, epochs): return avg_loss, avg_acc - def local_test(self, **kwargs): + def local_test(self, **kwargs: Any) -> float: """ Test the model locally, not to be used in the traditional FedAvg """ @@ -541,7 +541,7 @@ def get_model_weights(self) -> OrderedDict[str, Tensor]: """ return {k: v.cpu() for k, v in self.model.state_dict().items()} - def set_model_weights(self, model_wts: OrderedDict[str, Tensor], keys_to_ignore=[]): + def set_model_weights(self, model_wts: OrderedDict[str, Tensor], keys_to_ignore: List[str] = []) -> None: """ Set the model weights """ @@ -561,9 +561,9 @@ def weighted_aggregate( self, models_wts: Dict[int, OrderedDict[str, Tensor]], collab_weights_dict: Dict[int, float], - keys_to_ignore=[], + keys_to_ignore: List[str] = [], label_dict: Optional[Dict[int, Dict[str, int]]] = None, - ): + ) -> OrderedDict[str, Tensor]: """ Aggregate the model weights """ @@ -651,7 +651,7 @@ def get_collaborator_weights( ) -> Dict[int, float]: raise NotImplementedError - def run_protocol(self): + def run_protocol(self) -> None: raise NotImplementedError @@ -660,7 +660,7 @@ class BaseFedAvgServer(BaseServer): Abstract class for orchestrator """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol=CommProtocol) -> None: + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol: type[CommProtocol] = CommProtocol) -> None: super().__init__(config, comm_utils) self.tag = comm_protocol diff --git a/src/algos/def_kt.py b/src/algos/def_kt.py index 875e7cc..e721965 100644 --- a/src/algos/def_kt.py +++ b/src/algos/def_kt.py @@ -6,7 +6,7 @@ import copy import random from collections import OrderedDict -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple from torch import Tensor from utils.communication.comm_utils import CommunicationManager import torch.nn as nn @@ -19,10 +19,10 @@ class CommProtocol: Communication protocol tags for the server and clients """ - DONE = 0 # Used to signal the server that the client is done with local training - START = 1 # Used to signal by the server to start the current round - UPDATES = 2 # Used to send the updates from the server to the clients - FINISH = 3 # Used to signal the server to finish the current round + DONE: int = 0 # Used to signal the server that the client is done with local training + START: int = 1 # Used to signal by the server to start the current round + UPDATES: int = 2 # Used to send the updates from the server to the clients + FINISH: int = 3 # Used to signal the server to finish the current round class DefKTClient(BaseClient): @@ -40,7 +40,7 @@ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> self.num_users = config["num_users"] self.clients = list(range(2, self.num_users + 1)) - def local_train(self): + def local_train(self) -> None: """ Train the model locally """ @@ -48,7 +48,7 @@ def local_train(self): self.model, self.optim, self.dloader, self.loss_fn, self.device ) - def local_test(self, **kwargs): + def local_test(self, **kwargs: Any) -> float: """ Test the model locally, not to be used in the traditional FedAvg """ @@ -60,7 +60,7 @@ def local_test(self, **kwargs): self.model_utils.save_model(self.model, self.model_save_path) return acc - def deep_mutual_train(self, teacher_repr): + def deep_mutual_train(self, teacher_repr: OrderedDict[str, Tensor]) -> None: """ Train the model locally with deep mutual learning """ @@ -77,13 +77,13 @@ def get_representation(self) -> OrderedDict[str, Tensor]: """ return self.model.state_dict() - def set_representation(self, representation: OrderedDict[str, Tensor]): + def set_representation(self, representation: OrderedDict[str, Tensor]) -> None: """ Set the model weights """ self.model.load_state_dict(representation) - def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): + def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]) -> OrderedDict[str, Tensor]: """ Federated averaging of model weights """ @@ -101,14 +101,14 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): avgd_wts[key] += coeff * local_wts[key].to(self.device) return avgd_wts - def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]): + def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]) -> OrderedDict[str, Tensor]: """ Aggregate the model weights """ avg_wts = self.fed_avg(representation_list) return avg_wts - def send_representations(self, representation): + def send_representations(self, representation: OrderedDict[str, Tensor]) -> None: """ Send the model representations to the clients """ @@ -116,7 +116,7 @@ def send_representations(self, representation): self.comm_utils.send(client_node, representation, tag=self.tag.UPDATES) print(f"Node 1 sent average weight to {len(self.clients)} nodes") - def single_round(self, self_repr): + def single_round(self, self_repr: OrderedDict[str, Tensor]) -> OrderedDict[str, Tensor]: """ Runs a single training round """ @@ -128,7 +128,7 @@ def single_round(self, self_repr): self.send_representations(avg_wts) return avg_wts - def assign_own_status(self, status): + def assign_own_status(self, status: List[List[int]]) -> None: """ Assign the status (teacher/student) to the client """ @@ -145,7 +145,7 @@ def assign_own_status(self, status): self.pair_id = None print(f"Node {self.node_id} is a {self.status}, pair with {self.pair_id}") - def run_protocol(self): + def run_protocol(self) -> None: """ Runs the entire training protocol """ @@ -182,7 +182,7 @@ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" self.best_acc = 0.0 # Initialize best accuracy attribute - def send_representations(self, representations): + def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]) -> None: """ Send the model representations to the clients """ @@ -204,7 +204,7 @@ def test(self) -> float: self.model_utils.save_model(self.model, self.model_save_path) return acc - def assigns_clients(self): + def assigns_clients(self) -> Optional[Tuple[List[int], List[int]]]: """ Assigns clients as teachers and students """ @@ -218,7 +218,7 @@ def assigns_clients(self): students = selected_elements[num_teachers:] return teachers, students - def single_round(self): + def single_round(self) -> None: """ Runs a single training round """ @@ -233,7 +233,7 @@ def single_round(self): dest=client_node, data=[teachers, students], tag=self.tag.START ) - def run_protocol(self): + def run_protocol(self) -> None: """ Runs the entire training protocol """ diff --git a/src/algos/fl_random.py b/src/algos/fl_random.py index 930f5f6..5a67e2e 100644 --- a/src/algos/fl_random.py +++ b/src/algos/fl_random.py @@ -3,6 +3,7 @@ class RandomTopology: + """Method docstring: Returns selected IDs based on some criteria.""" def get_selected_ids(self, node_id, config, reprs_dict, communities): within_community_sampling = config.get("within_community_sampling", 1) diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index e2910f1..8de66e1 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -1,5 +1,5 @@ import importlib -from typing import Any, List, Sequence, Tuple +from typing import Any, List, Sequence, Tuple, Optional import numpy as np import torch import torchvision.transforms as T diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 384c0a3..b7b2ccf 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -16,6 +16,7 @@ class ModelUtils(): def __init__(self, device: torch.device) -> None: self.device = device + self.dset = None self.models_layers_idx = { "resnet10": {