Skip to content

Commit

Permalink
Merge pull request #204 from VectorInstitute/progress_bar
Browse files Browse the repository at this point in the history
Added progress bar and fixed bugs
  • Loading branch information
scarere authored Aug 16, 2024
2 parents dc574f7 + 70e0f7d commit 0e0a6ed
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 113 deletions.
1 change: 1 addition & 0 deletions examples/nnunet_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def main(
device=DEVICE,
metrics=[dice],
data_path=dataset_path, # Argument not actually used by nnUNetClient
progress_bar=True,
)

start_client(server_address=server_address, client=client.to_client())
Expand Down
159 changes: 122 additions & 37 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from enum import Enum
from logging import INFO
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from flwr.client import NumPyClient
from flwr.common.logger import log
from flwr.common.logger import LOG_COLORS, log
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
metrics_reporter: Optional[MetricsReporter] = None,
progress_bar: bool = False,
) -> None:
"""
Base FL Client with functionality to train, evaluate, log, report and checkpoint.
Expand All @@ -59,12 +61,16 @@ def __init__(
None.
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
during the execution. Defaults to an instance of MetricsReporter with default init parameters.
progress_bar (bool): Whether or not to display a progress bar
during client training and validation. Uses tqdm. Defaults to
False
"""

self.data_path = data_path
self.device = device
self.metrics = metrics
self.checkpointer = checkpointer
self.progress_bar = progress_bar

self.client_name = generate_hash()

Expand Down Expand Up @@ -363,7 +369,35 @@ def _should_evaluate_after_fit(self, evaluate_after_fit: bool) -> bool:
)
return evaluate_after_fit or pre_aggregation_checkpointing_enabled

def _handle_logging(
def _log_header_str(
self,
current_round: Optional[int] = None,
current_epoch: Optional[int] = None,
logging_mode: LoggingMode = LoggingMode.TRAIN,
) -> None:
"""
Logs a header string. By default this is logged at the beginning of each local
epoch or at the beginning of the round if training by steps
Args:
current_round (Optional[int], optional): The current FL round. (Ie current
server round). Defaults to None.
current_epoch (Optional[int], optional): The current epoch of local
training. Defaults to None.
"""

log_str = f"Current FL Round: {str(current_round)}\t" if current_round is not None else ""
log_str += f"Current Epoch: {str(current_epoch)}" if current_epoch is not None else ""

# Maybe add client specific info to initial log string
client_str, _ = self.get_client_specific_logs(current_round, current_epoch, logging_mode)

log_str += client_str

log(INFO, "") # For aesthetics
log(INFO, log_str)

def _log_results(
self,
loss_dict: Dict[str, float],
metrics_dict: Dict[str, Scalar],
Expand All @@ -382,23 +416,15 @@ def _handle_logging(
current_epoch (Optional[int]): The current epoch of local training.
logging_mode (LoggingMode): The logging mode (Training, Validation, or Testing).
"""
log(INFO, "") # An empty log line for aesthetics
_, client_logs = self.get_client_specific_logs(current_round, current_epoch, logging_mode)

initial_log_str = f"Current FL Round: {str(current_round)}\t" if current_round is not None else ""
initial_log_str += f"Current Epoch: {str(current_epoch)}" if current_epoch is not None else ""

# Maybe add client specific info to initial log string
client_str, client_logs = self.get_client_specific_logs()
initial_log_str += client_str

if initial_log_str != "":
log(INFO, initial_log_str)
self.add_to_initial_log_str = "" # Reset variable

# Log loss/losses
# Get Metric Prefix
metric_prefix = logging_mode.value
log(INFO, f"Client {metric_prefix} Losses:")
[log(INFO, f"\t {key}: {str(val)}") for key, val in loss_dict.items()]

# Log losses if any were provided
if len(loss_dict.keys()) > 0:
log(INFO, f"Client {metric_prefix} Losses:")
[log(INFO, f"\t {key}: {str(val)}") for key, val in loss_dict.items()]

# Log metrics if any
if len(metrics_dict.keys()) > 0:
Expand All @@ -409,21 +435,32 @@ def _handle_logging(
if len(client_logs) > 0:
[log(level.value, msg) for level, msg in client_logs]

def get_client_specific_logs(self) -> Tuple[str, List[Tuple[LogLevel, str]]]:
def get_client_specific_logs(
self, current_round: Optional[int], current_epoch: Optional[int], logging_mode: LoggingMode
) -> Tuple[str, List[Tuple[LogLevel, str]]]:
"""
This function can be overriden to provide any client specific
This function can be overridden to provide any client specific
information to the basic client logging. For example, perhaps a client
uses an LR scheduler and wants the LR to be logged each epoch. The
logging is called at the end of either every epoch for
train_by_epochs, or the end of the server round for train_by_steps
uses an LR scheduler and wants the LR to be logged each epoch. Called at the
beginning and end of each server round or local epoch. Also called at the end
of validation/testing.
Args:
current_round (Optional[int]): The current FL round (i.e., current
server round).
current_epoch (Optional[int]): The current epoch of local training.
logging_mode (LoggingMode): The logging mode (Training,
Validation, or Testing).
Returns:
Optional[str]: A string to append to the initial log string that
typically announces the current server round and current epoch
Optional[str]: A string to append to the header log string that
typically announces the current server round and current epoch at the
beginning of each round or local epoch.
Optional[List[Tuple[LogLevel, str]]]]: A list of tuples where the
first element is a LogLevel as defined in fl4health.utils.
typing and the second element is a string message. Each item
in the list will be logged when self._handle_logging is called
in the list will be logged at the end of each server round or epoch.
Elements will also be logged at the end of validation/testing.
"""
return "", []

Expand Down Expand Up @@ -612,13 +649,16 @@ def train_by_epochs(
Loss is a dictionary of one or more losses that represent the different components of the loss.
"""
self.model.train()
local_step = 0
steps_this_round = 0 # Reset number of steps this round
for local_epoch in range(epochs):
self.train_metric_manager.clear()
self.train_loss_meter.clear()
# Print initial log string on epoch start
self._log_header_str(current_round, local_epoch)
# update before epoch hook
self.update_before_epoch(epoch=local_epoch)
for input, target in self.train_loader:
for input, target in self.maybe_progress_bar(self.train_loader):
self.update_before_step(steps_this_round)
# Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can
# construct empty batches. We skip the iteration if this occurs.
if self.is_empty_batch(input):
Expand All @@ -630,14 +670,14 @@ def train_by_epochs(
losses, preds = self.train_step(input, target)
self.train_loss_meter.update(losses)
self.update_metric_manager(preds, target, self.train_metric_manager)
self.update_after_step(local_step)
self.update_after_step(steps_this_round)
self.total_steps += 1
local_step += 1
steps_this_round += 1
metrics = self.train_metric_manager.compute()
loss_dict = self.train_loss_meter.compute().as_dict()

# Log results and maybe report via WANDB
self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch)
self._log_results(loss_dict, metrics, current_round, local_epoch)
self._handle_reporting(loss_dict, metrics, current_round=current_round)

# Return final training metrics
Expand All @@ -663,7 +703,11 @@ def train_by_steps(

self.train_loss_meter.clear()
self.train_metric_manager.clear()
for step in range(steps):
self._log_header_str(current_round)
for step in self.maybe_progress_bar(range(steps)):

self.update_before_step(step)

try:
input, target = next(train_iterator)
except StopIteration:
Expand All @@ -690,7 +734,7 @@ def train_by_steps(
metrics = self.train_metric_manager.compute()

# Log results and maybe report via WANDB
self._handle_logging(loss_dict, metrics, current_round=current_round)
self._log_results(loss_dict, metrics, current_round)
self._handle_reporting(loss_dict, metrics, current_round=current_round)

return loss_dict, metrics
Expand Down Expand Up @@ -720,7 +764,7 @@ def _validate_or_test(
metric_manager.clear()
loss_meter.clear()
with torch.no_grad():
for input, target in loader:
for input, target in self.maybe_progress_bar(loader):
input = self._move_data_to_device(input)
target = self._move_data_to_device(target)
losses, preds = self.val_step(input, target)
Expand All @@ -730,7 +774,7 @@ def _validate_or_test(
# Compute losses and metrics over validation set
loss_dict = loss_meter.compute().as_dict()
metrics = metric_manager.compute()
self._handle_logging(loss_dict, metrics, logging_mode=logging_mode)
self._log_results(loss_dict, metrics, logging_mode=logging_mode)

return loss_dict["checkpoint"], metrics

Expand Down Expand Up @@ -1074,19 +1118,31 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> N
aggregation.
Args:
local_steps (int): The number of steps in the local training.
local_steps (int): The number of steps so far in the round in the local
training.
loss_dict (Dict[str, float]): A dictionary of losses from local training.
"""
pass

def update_before_step(self, step: int) -> None:
"""
Hook method called before local train step.
Args:
step (int): The local training step that was most recently
completed. Resets only at the end of the round.
"""
pass

def update_after_step(self, step: int) -> None:
"""
Hook method called after local train step on client. step is an integer that represents
the local training step that was most recently completed. For example, used by the APFL
method to update the alpha value after a training a step.
Args:
step (int): The step number in local training that was most recently completed.
step (int): The step number in local training that was most recently
completed. Resets only at the end of the round.
"""
pass

Expand All @@ -1100,3 +1156,32 @@ def update_before_epoch(self, epoch: int) -> None:
epoch (int): Integer representing the epoch about to begin
"""
pass

def maybe_progress_bar(self, iterable: Iterable) -> Iterable:
"""
Used to print progress bars during client training and validation. If
self.progress_bar is false, just returns the original input iterable
wihout modifying it.
Args:
iterable (Iterable): The iterable to wrap
Returns:
Iterable: an iterator which acts exactly like the original
iterable, but prints a dynamically updating progress bar every
time a value is requested. Or the original iterable if
self.progress_bar is False
"""
if not self.progress_bar:
return iterable
else:
# Create a clean looking tqdm instance that matches the flwr logging
kwargs = {
"leave": True,
"ascii": " >=",
# "desc": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']} ",
"unit": "steps",
"dynamic_ncols": True,
"bar_format": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']}" + " : {l_bar}{bar}{r_bar}",
}
return tqdm(iterable, **kwargs)
3 changes: 2 additions & 1 deletion fl4health/clients/flash_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def train_by_epochs(
for local_epoch in range(epochs):
self.train_metric_manager.clear()
self.train_loss_meter.clear()
self._log_header_str(current_round, local_epoch)
for input, target in self.train_loader:
if self.is_empty_batch(input):
log(INFO, "Empty batch generated by data loader. Skipping step.")
Expand All @@ -83,7 +84,7 @@ def train_by_epochs(
loss_dict = self.train_loss_meter.compute().as_dict()
current_loss, _ = self.validate()

self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch)
self._log_results(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch)
self._handle_reporting(loss_dict, metrics, current_round=current_round)

if self.gamma is not None and previous_loss - current_loss < self.gamma / (local_epoch + 1):
Expand Down
Loading

0 comments on commit 0e0a6ed

Please sign in to comment.