Skip to content

Commit

Permalink
Merge pull request #254 from VectorInstitute/base_reporter
Browse files Browse the repository at this point in the history
Reporting restructure
  • Loading branch information
scarere authored Oct 21, 2024
2 parents 851d18c + 02e44fe commit 739f4e0
Show file tree
Hide file tree
Showing 95 changed files with 2,376 additions and 1,217 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ dmypy.json
# vscode
launch.json
settings.json
.devcontainer*

#mac
.DS_Store
Expand Down Expand Up @@ -169,6 +170,7 @@ settings.json
**/*.pkl
**/*.png
**/*.pt
**/*.ckpt
/metrics/

# dev
Expand Down
1 change: 0 additions & 1 deletion examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
1 change: 0 additions & 1 deletion examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
6 changes: 3 additions & 3 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -58,7 +59,6 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistApflClient(data_path, [Accuracy()], DEVICE)
client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

client.metrics_reporter.dump()
client.shutdown() # This will tell the JsonReporter to dump data
4 changes: 2 additions & 2 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.server.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
Expand Down Expand Up @@ -59,15 +60,14 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager, strategy, reporters=[JsonReporter()])

fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

server.metrics_reporter.dump()
server.shutdown()


Expand Down
1 change: 0 additions & 1 deletion examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointers,
)
Expand Down
5 changes: 2 additions & 3 deletions examples/ditto_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from examples.models.cnn_model import MnistNet
from fl4health.clients.ditto_client import DittoClient
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -68,10 +69,8 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistDittoClient(data_path, [Accuracy()], DEVICE)
client = MnistDittoClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
client.shutdown()

client.metrics_reporter.dump()
7 changes: 4 additions & 3 deletions examples/dp_fed_examples/instance_level_dp/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import string
from collections.abc import Sequence
from functools import partial
from random import choices
from typing import Any, Dict, Optional
Expand All @@ -15,7 +16,7 @@
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer, OpacusCheckpointer
from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.reporting.fl_wandb import ServerWandBReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.server.instance_level_dp_server import InstanceLevelDpServer
from fl4health.strategies.basic_fedavg import OpacusBasicFedAvg
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -71,8 +72,8 @@ def __init__(
strategy: OpacusBasicFedAvg,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
wandb_reporter: Optional[ServerWandBReporter] = None,
checkpointer: Optional[OpacusCheckpointer] = None,
reporters: Sequence[BaseReporter] | None = None,
delta: Optional[float] = None,
) -> None:
super().__init__(
Expand All @@ -83,8 +84,8 @@ def __init__(
strategy,
local_epochs,
local_steps,
wandb_reporter,
checkpointer,
reporters,
delta,
)
self.parameter_exchanger = FullParameterExchanger()
Expand Down
6 changes: 3 additions & 3 deletions examples/feddg_ga_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -58,7 +59,6 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistApflClient(data_path, [Accuracy()], DEVICE)
client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

client.metrics_reporter.dump()
client.shutdown()
5 changes: 3 additions & 2 deletions examples/feddg_ga_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.server.base_server import FlServer
from fl4health.strategies.feddg_ga_strategy import FedDgGaStrategy
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -66,15 +67,15 @@ def main(config: Dict[str, Any]) -> None:
# will return the same sampling until it is told to reset, which in FedDgGaStrategy
# is done right before fit_round.
client_manager = FixedSamplingClientManager()
server = FlServer(strategy=strategy, client_manager=client_manager)
server = FlServer(strategy=strategy, client_manager=client_manager, reporters=[JsonReporter()])

fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

server.metrics_reporter.dump()
server.shutdown()


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/federated_eval_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from examples.models.cnn_model import Net
from fl4health.clients.evaluate_client import EvaluateClient
from fl4health.reporting.metrics import MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_test_data
from fl4health.utils.losses import LossMeterType
Expand All @@ -25,15 +25,15 @@ def __init__(
metrics: Sequence[Metric],
device: torch.device,
model_checkpoint_path: Optional[Path],
metrics_reporter: Optional[MetricsReporter] = None,
reporters: Sequence[BaseReporter] | None = None,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
model_checkpoint_path=model_checkpoint_path,
loss_meter_type=LossMeterType.AVERAGE,
metrics_reporter=metrics_reporter,
reporters=reporters,
)

def initialize_global_model(self, config: Config) -> Optional[nn.Module]:
Expand Down
1 change: 0 additions & 1 deletion examples/fedpca_examples/dim_reduction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
5 changes: 2 additions & 3 deletions examples/fedprox_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from examples.models.cnn_model import MnistNet
from fl4health.clients.fed_prox_client import FedProxClient
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -64,10 +65,8 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistFedProxClient(data_path, [Accuracy()], DEVICE)
client = MnistFedProxClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
client.shutdown()

client.metrics_reporter.dump()
7 changes: 3 additions & 4 deletions examples/fedprox_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ local_epochs: 1 # The number of epochs to complete for client
batch_size: 128 # The batch size for client training

reporting_config:
enabled: False
project_name: FL4Health # Name of the project under which everything should be logged
run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored
project: FL4Health # Name of the project under which everything should be logged
name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group: "FedProx Experiment" # Group under which each of the FL run logging will be stored
entity: "your_entity_here" # WandB user name
notes: "Testing WB reporting"
tags: ["Test", "FedProx"]
36 changes: 18 additions & 18 deletions examples/fedprox_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from examples.models.cnn_model import MnistNet
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.reporting.fl_wandb import ServerWandBReporter
from fl4health.reporting import JsonReporter, WandBReporter
from fl4health.server.adaptive_constraint_servers.fedprox_server import FedProxServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
Expand All @@ -22,24 +22,24 @@
def fit_config(
batch_size: int,
n_server_rounds: int,
reporting_enabled: bool,
project_name: str,
group_name: str,
entity: str,
current_round: int,
reporting_config: Optional[Dict[str, str]] = None,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
) -> Config:
return {
base_config: Config = {
**make_dict_with_epochs_or_steps(local_epochs, local_steps),
"batch_size": batch_size,
"n_server_rounds": n_server_rounds,
"current_server_round": current_round,
"reporting_enabled": reporting_enabled,
"project_name": project_name,
"group_name": group_name,
"entity": entity,
}
if reporting_config is not None:
# NOTE: that name is not included, it will be set in the clients
base_config["project"] = reporting_config.get("project", "")
base_config["group"] = reporting_config.get("group", "")
base_config["entity"] = reporting_config.get("entity", "")

return base_config


def main(config: Dict[str, Any], server_address: str) -> None:
Expand All @@ -48,11 +48,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
fit_config,
config["batch_size"],
config["n_server_rounds"],
config["reporting_config"].get("enabled", False),
# Note that run name is not included, it will be set in the clients
config["reporting_config"].get("project_name", ""),
config["reporting_config"].get("group_name", ""),
config["reporting_config"].get("entity", ""),
reporting_config=config.get("reporting_config"),
local_epochs=config.get("local_epochs"),
local_steps=config.get("local_steps"),
)
Expand All @@ -78,9 +74,14 @@ def main(config: Dict[str, Any], server_address: str) -> None:
loss_weight_patience=config["proximal_weight_patience"],
)

wandb_reporter = ServerWandBReporter.from_config(config)
json_reporter = JsonReporter()
client_manager = SimpleClientManager()
server = FedProxServer(client_manager=client_manager, strategy=strategy, model=None, wandb_reporter=wandb_reporter)
if "reporting_config" in config:
wandb_reporter = WandBReporter("round", **config["reporting_config"])
reporters = [wandb_reporter, json_reporter]
else:
reporters = [json_reporter]
server = FedProxServer(client_manager=client_manager, strategy=strategy, model=None, reporters=reporters)

fl.server.start_server(
server=server,
Expand All @@ -89,7 +90,6 @@ def main(config: Dict[str, Any], server_address: str) -> None:
)
# Shutdown the server gracefully
server.shutdown()
server.metrics_reporter.dump()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def main(config: Dict[str, Any]) -> None:

# Initializing the model on the server side
model: nn.Module = FedSimClrModel(
CifarSslEncoder(), CifarSslProjectionHead(), CifarSslPredictionHead(), pretrain=True
CifarSslEncoder(),
CifarSslProjectionHead(),
CifarSslPredictionHead(),
pretrain=True,
)
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
Expand All @@ -67,7 +70,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
Loading

0 comments on commit 739f4e0

Please sign in to comment.