Skip to content

Commit

Permalink
Merge pull request #285 from VectorInstitute/sa_add_deep_mmd_loss
Browse files Browse the repository at this point in the history
Update deep mmd client and loss implementation
  • Loading branch information
sanaAyrml authored Nov 21, 2024
2 parents c8e967f + a6573eb commit 2509952
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 19 deletions.
116 changes: 112 additions & 4 deletions fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fl4health.utils.client import clone_and_freeze_model
from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses
from fl4health.utils.metrics import Metric
from fl4health.utils.random import restore_random_state, save_random_state
from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType


Expand All @@ -27,6 +28,8 @@ def __init__(
checkpointer: Optional[ClientCheckpointModule] = None,
deep_mmd_loss_weight: float = 10.0,
feature_extraction_layers_with_size: Optional[Dict[str, int]] = None,
mmd_kernel_train_interval: int = 20,
num_accumulating_batches: Optional[int] = None,
) -> None:
"""
This client implements the Deep MMD loss function in the Ditto framework. The Deep MMD loss is a measure of
Expand All @@ -47,6 +50,13 @@ def __init__(
deep_mmd_loss_weight (float, optional): weight applied to the Deep MMD loss. Defaults to 10.0.
feature_extraction_layers_with_size (Optional[Dict[str, int]], optional): Dictionary of layers to extract
features from them and their respective feature size. Defaults to None.
mmd_kernel_update_interval (int, optional): interval at which to train and update the Deep MMD kernel. If
set to above 0, the kernel will be train based on whole distribution of latent features of data with
the given train interval. If set to 0, the kernal will not be trained. If set to -1, the kernel will
be trained after each individual batch based on only that individual batch. Defaults to 20.
num_accumulating_batches (int, optional): Number of batches to accumulate features to approximate the whole
distribution of the latent features for updating Deep MMD kernel. This parameter is only used
if mmd_kernel_train_interval is set to larger than 0. Defaults to None.
"""
super().__init__(
data_path=data_path,
Expand All @@ -67,14 +77,21 @@ def __init__(
feature_extraction_layers_with_size = {}
self.flatten_feature_extraction_layers = {layer: True for layer in feature_extraction_layers_with_size.keys()}
self.deep_mmd_losses: Dict[str, DeepMmdLoss] = {}
# Save the random state to be restored after initializing the Deep MMD loss layers.
random_state, numpy_state, torch_state = save_random_state()
for layer, feature_size in feature_extraction_layers_with_size.items():
self.deep_mmd_losses[layer] = DeepMmdLoss(
device=self.device,
input_size=feature_size,
).to(self.device)
# Restore the random state after initializing the Deep MMD loss layers. This is to ensure that the random state
# would not change after initializing the Deep MMD loss.
restore_random_state(random_state, numpy_state, torch_state)
self.initial_global_model: nn.Module
self.local_feature_extractor: FeatureExtractorBuffer
self.initial_global_feature_extractor: FeatureExtractorBuffer
self.num_accumulating_batches = num_accumulating_batches
self.mmd_kernel_train_interval = mmd_kernel_train_interval

def setup_client(self, config: Config) -> None:
super().setup_client(config)
Expand All @@ -97,9 +114,97 @@ def update_before_train(self, current_server_round: int) -> None:
)
# Register hooks to extract features from the initial global model if not already registered
self.initial_global_feature_extractor._maybe_register_hooks()
# Enable training of Deep MMD loss layers
for layer in self.flatten_feature_extraction_layers.keys():
self.deep_mmd_losses[layer].training = True
# Enable training of Deep MMD loss layers if the mmd_kernel_train_interval is set to -1
# meaning that the betas will be updated after each individual batch based on only that
# individual batch
if self.mmd_kernel_train_interval == -1:
for layer in self.flatten_feature_extraction_layers.keys():
self.deep_mmd_losses[layer].training = True

def _should_optimize_betas(self, step: int) -> bool:
step_at_interval = (step - 1) % self.mmd_kernel_train_interval == 0
valid_components_present = self.initial_global_model is not None
# If the Deep MMD loss doesn't matter, we don't bother optimizing betas
weighted_deep_mmd_loss = self.deep_mmd_loss_weight != 0
return step_at_interval and valid_components_present and weighted_deep_mmd_loss

def update_after_step(self, step: int, current_round: Optional[int] = None) -> None:
if self.mmd_kernel_train_interval > 0 and self._should_optimize_betas(step):
# Get the feature distribution of the local and initial global features with evaluation
# mode
local_distributions, initial_global_distributions = self.update_buffers(
self.model, self.initial_global_model
)
# As we set the training mode of the Deep MMD loss layers to True, we train the
# kernel of the Deep MMD loss based on gathered features in the buffer and compute the
# Deep MMD loss
for layer, layer_deep_mmd_loss in self.deep_mmd_losses.items():
layer_deep_mmd_loss.training = True
layer_deep_mmd_loss(local_distributions[layer], initial_global_distributions[layer])
layer_deep_mmd_loss.training = False
super().update_after_step(step)

def update_buffers(
self, local_model: torch.nn.Module, initial_global_model: torch.nn.Module
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Update the feature buffer of the local and global features.
Args:
local_model (torch.nn.Module): Local model to extract features from.
initial_global_model (torch.nn.Module): Initial global model to extract features from.
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple containing the extracted
features using the local and initial global models.
"""

self.local_feature_extractor.clear_buffers()
self.initial_global_feature_extractor.clear_buffers()

self.local_feature_extractor.enable_accumulating_features()
self.initial_global_feature_extractor.enable_accumulating_features()

# Save the initial state of the local model to restore it after the buffer is populated,
# however as initial global model is already cloned and frozen, we don't need to save its state.
initial_state_local_model = local_model.training

# Set local model to evaluation mode, as we don't want to create a computational graph
# for the local model when populating the local buffer with features to train Deep MMD
# kernel
local_model.eval()

# Make sure the local model is in evaluation mode before populating the local buffer
assert not local_model.training

# Make sure the initial global model is in evaluation mode before populating the global buffer
# as it is already cloned and frozen from the global model
assert not initial_global_model.training

with torch.no_grad():
for i, (input, _) in enumerate(self.train_loader):
input = input.to(self.device)
# Pass the input through the local model to populate the local_feature_extractor buffer
local_model(input)
# Pass the input through the initial global model to populate the initial_global_feature_extractor
# buffer
initial_global_model(input)
# Break if the number of accumulating batches is reached to avoid memory issues
if i == self.num_accumulating_batches:
break
local_distributions = self.local_feature_extractor.get_extracted_features()
initial_global_distributions = self.initial_global_feature_extractor.get_extracted_features()
# Restore the initial state of the local model
if initial_state_local_model:
local_model.train()

self.local_feature_extractor.disable_accumulating_features()
self.initial_global_feature_extractor.disable_accumulating_features()

self.local_feature_extractor.clear_buffers()
self.initial_global_feature_extractor.clear_buffers()

return local_distributions, initial_global_distributions

def predict(
self,
Expand Down Expand Up @@ -185,7 +290,10 @@ def compute_training_loss(
loss tensor.
"""
for layer_loss_module in self.deep_mmd_losses.values():
assert layer_loss_module.training
if self.mmd_kernel_train_interval == -1:
assert layer_loss_module.training
else:
assert not layer_loss_module.training
# Check that both models are in training mode
assert self.global_model.training and self.model.training

Expand Down
21 changes: 14 additions & 7 deletions fl4health/losses/deep_mmd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
hidden_size: int = 10,
output_size: int = 50,
lr: float = 0.001,
training: bool = True,
is_unbiased: bool = True,
gaussian_degree: int = 1,
optimization_steps: int = 5,
Expand All @@ -69,7 +68,6 @@ def __init__(
output_size (int, optional): The output size of the deep network as the deep kernel used to compute
the MMD loss. Defaults to 50.
lr (float, optional): Learning rate for training the Deep Kernel. Defaults to 0.001.
training (bool, optional): Whether the Deep Kernel is in training mode. Defaults to True.
is_unbiased (bool, optional): Whether to use the unbiased estimator for the MMD loss. Defaults to True.
gaussian_degree (int, optional): The degree of the generalized Gaussian kernel. Defaults to 1.
optimization_steps (int, optional): The number of optimization steps to train the Deep Kernel in each
Expand All @@ -79,28 +77,32 @@ def __init__(
super().__init__()
self.device = device
self.lr = lr
self.training = training
self.is_unbiased = is_unbiased
self.gaussian_degree = gaussian_degree # generalized Gaussian (if L>1)
self.optimization_steps = optimization_steps

# Initialize the model
self.featurizer = ModelLatentF(input_size, hidden_size, output_size).to(self.device)
# Set the model to evaluation mode as default
self.featurizer.eval()

# Initialize parameters
self.epsilon_opt: torch.Tensor = torch.log(torch.from_numpy(np.random.rand(1) * 10 ** (-10)).to(self.device))
self.epsilon_opt.requires_grad = self.training
self.epsilon_opt.requires_grad = False
self.sigma_q_opt: torch.Tensor = torch.sqrt(torch.tensor(2 * 32 * 32, dtype=torch.float).to(self.device))
self.sigma_q_opt.requires_grad = self.training
self.sigma_q_opt.requires_grad = False
self.sigma_phi_opt: torch.Tensor = torch.sqrt(torch.tensor(0.005, dtype=torch.float).to(self.device))
self.sigma_phi_opt.requires_grad = self.training
self.sigma_phi_opt.requires_grad = False

# Initialize optimizers
self.optimizer_F = torch.optim.AdamW(
list(self.featurizer.parameters()) + [self.epsilon_opt] + [self.sigma_q_opt] + [self.sigma_phi_opt],
lr=self.lr,
)

# Set the model to training mode if required to train the Deep Kernel
self.training = False

def pairwise_distiance_squared(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""
Compute the paired distance between x and y.
Expand Down Expand Up @@ -240,7 +242,12 @@ def train_kernel(self, X: torch.Tensor, Y: torch.Tensor) -> None:
self.sigma_phi_opt.requires_grad = True
self.epsilon_opt.requires_grad = True

features = torch.cat([X, Y], 0)
# Shuffle the data to ensure they are not always presented in the same order for training
# which might lead to overfitting
indices = torch.randperm(Y.size(0))
Y_shuffled = Y[indices]

features = torch.cat([X, Y_shuffled], 0)

# ------------------------------
# Train deep network for MMD-D
Expand Down
36 changes: 35 additions & 1 deletion fl4health/utils/random.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import uuid
from logging import INFO
from typing import Optional
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -60,6 +60,40 @@ def unset_all_random_seeds() -> None:
torch.use_deterministic_algorithms(False)


def save_random_state() -> Tuple[Tuple[Any, ...], Dict[str, Any], torch.Tensor]:
"""
Save the state of the random number generators for Python, NumPy, and PyTorch. This will allow you to restore the
state of the random number generators at a later time.
Returns:
Tuple[Tuple[Any, ...], Dict[str, Any], torch.Tensor]: A tuple containing the state of the random number
generators for Python, NumPy, and
"""
log(INFO, "Saving random state.")
random_state = random.getstate()
numpy_state = np.random.get_state()
torch_state = torch.get_rng_state()
return random_state, numpy_state, torch_state


def restore_random_state(
random_state: Tuple[Any, ...], numpy_state: Dict[str, Any], torch_state: torch.Tensor
) -> None:
"""
Restore the state of the random number generators for Python, NumPy, and PyTorch. This will allow you to restore
the state of the random number generators to a previously saved state.
Args:
random_state (Tuple[Any, ...]): The state of the Python random number generator
numpy_state (Dict[str, Any]): The state of the NumPy random number generator
torch_state (torch.Tensor): The state of the PyTorch random number generator
"""
log(INFO, "Restoring random state.")
random.setstate(random_state)
np.random.set_state(numpy_state)
torch.set_rng_state(torch_state)


def generate_hash(length: int = 8) -> str:
"""
Generates unique hash used as id for client.
Expand Down
16 changes: 9 additions & 7 deletions tests/losses/test_deep_mmd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@

def test_forward() -> None:
torch.manual_seed(42)
deep_mmd_loss_1 = DeepMmdLoss(device=DEVICE, input_size=3, training=True, optimization_steps=1)
deep_mmd_loss_1 = DeepMmdLoss(device=DEVICE, input_size=3, optimization_steps=1)
deep_mmd_loss_1.training = True
train_outputs_1 = []
val_outputs_1 = []
for i in range(5):
Expand All @@ -54,11 +55,11 @@ def test_forward() -> None:

# The output of the DeepMmdLoss in training mode should be different for each optimization step
# as values are updated in each step
assert pytest.approx(train_outputs_1[0].item(), abs=0.001) == 0.0584
assert pytest.approx(train_outputs_1[1].item(), abs=0.001) == 0.0682
assert pytest.approx(train_outputs_1[2].item(), abs=0.001) == 0.0773
assert pytest.approx(train_outputs_1[3].item(), abs=0.001) == 0.0850
assert pytest.approx(train_outputs_1[4].item(), abs=0.001) == 0.0914
assert pytest.approx(train_outputs_1[0].item(), abs=0.001) == 0.0573
assert pytest.approx(train_outputs_1[1].item(), abs=0.001) == 0.0670
assert pytest.approx(train_outputs_1[2].item(), abs=0.001) == 0.0767
assert pytest.approx(train_outputs_1[3].item(), abs=0.001) == 0.0848
assert pytest.approx(train_outputs_1[4].item(), abs=0.001) == 0.0927

for i in range(len(val_outputs_1)):
# The output of the DeepMmdLoss in evaluation mode should be the same as the output of the DeepMmdLoss in
Expand All @@ -67,7 +68,8 @@ def test_forward() -> None:

# Reset the seed for the second DeepMmdLoss
torch.manual_seed(42)
deep_mmd_loss_2 = DeepMmdLoss(device=DEVICE, input_size=3, training=True, optimization_steps=5)
deep_mmd_loss_2 = DeepMmdLoss(device=DEVICE, input_size=3, optimization_steps=5)
deep_mmd_loss_2.training = True
train_output = deep_mmd_loss_2(X, Y)
deep_mmd_loss_2.training = False
val_output = deep_mmd_loss_2(X, Y)
Expand Down

0 comments on commit 2509952

Please sign in to comment.