diff --git a/doc/source/ref-api-flwr.rst b/doc/source/ref-api-flwr.rst index 07936f117444..e1983cd92c90 100644 --- a/doc/source/ref-api-flwr.rst +++ b/doc/source/ref-api-flwr.rst @@ -214,6 +214,16 @@ server.strategy.Krum .. automethod:: __init__ +.. _flwr-server-strategy-Bulyan-apiref: + +server.strategy.Bulyan +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: flwr.server.strategy.Bulyan + :members: + + .. automethod:: __init__ + .. _flwr-server-strategy-FedXgbNnAvg-apiref: diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 299228170b80..10112088b88c 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -58,6 +58,10 @@ Flower received many improvements under the hood, too many to list here. +- **Add new** `Bulyan` **strategy** ([#1817](https://github.com/adap/flower/pull/1817), [#1891](https://github.com/adap/flower/pull/1891)) + + The new `Bulyan` strategy implements Bulyan by [El Mhamdi et al., 2018](https://arxiv.org/abs/1802.07927) + ### Incompatible changes - **Remove support for Python 3.7** ([#2280](https://github.com/adap/flower/pull/2280), [#2299](https://github.com/adap/flower/pull/2299), [#2304](https://github.com/adap/flower/pull/2304), [#2306](https://github.com/adap/flower/pull/2306), [#2355](https://github.com/adap/flower/pull/2355), [#2356](https://github.com/adap/flower/pull/2356)) diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index 72429694bfe7..908267d04b3f 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -15,6 +15,7 @@ """Contains the strategy abstraction and different implementations.""" +from .bulyan import Bulyan as Bulyan from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed from .fault_tolerant_fedavg import FaultTolerantFedAvg as FaultTolerantFedAvg @@ -48,6 +49,7 @@ "FedMedian", "FedTrimmedAvg", "Krum", + "Bulyan", "DPFedAvgAdaptive", "DPFedAvgFixed", "Strategy", diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 42390a08a110..63926f2eaa51 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -13,10 +13,10 @@ # limitations under the License. # ============================================================================== """Aggregation functions for strategy implementations.""" - +# mypy: disallow_untyped_calls=False from functools import reduce -from typing import List, Tuple +from typing import Any, Callable, List, Tuple import numpy as np @@ -56,7 +56,7 @@ def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: def aggregate_krum( results: List[Tuple[NDArrays, int]], num_malicious: int, to_keep: int ) -> NDArrays: - """Choose one parameter vector according to the Krum fucntion. + """Choose one parameter vector according to the Krum function. If to_keep is not None, then MultiKrum is applied. """ @@ -91,6 +91,89 @@ def aggregate_krum( return weights[np.argmin(scores)] +# pylint: disable=too-many-locals +def aggregate_bulyan( + results: List[Tuple[NDArrays, int]], + num_malicious: int, + aggregation_rule: Callable, # type: ignore + **aggregation_rule_kwargs: Any, +) -> NDArrays: + """Perform Bulyan aggregation. + + Parameters + ---------- + results: List[Tuple[NDArrays, int]] + Weights and number of samples for each of the client. + num_malicious: int + The maximum number of malicious clients. + aggregation_rule: Callable + Byzantine resilient aggregation rule used as the first step of the Bulyan + aggregation_rule_kwargs: Any + The arguments to the aggregation rule. + + Returns + ------- + aggregated_parameters: NDArrays + Aggregated parameters according to the Bulyan strategy. + """ + byzantine_resilient_single_ret_model_aggregation = [aggregate_krum] + # also GeoMed (but not implemented yet) + byzantine_resilient_many_return_models_aggregation = [] # type: ignore + # Brute, Medoid (but not implemented yet) + + num_clients = len(results) + if num_clients < 4 * num_malicious + 3: + raise ValueError( + "The Bulyan aggregation requires then number of clients to be greater or " + "equal to the 4 * num_malicious + 3. This is the assumption of this method." + "It is needed to ensure that the method reduces the attacker's leeway to " + "the one proved in the paper." + ) + selected_models_set: List[Tuple[NDArrays, int]] = [] + + theta = len(results) - 2 * num_malicious + beta = theta - 2 * num_malicious + + for _ in range(theta): + best_model = aggregation_rule( + results=results, num_malicious=num_malicious, **aggregation_rule_kwargs + ) + list_of_weights = [weights for weights, num_samples in results] + # This group gives exact result + if aggregation_rule in byzantine_resilient_single_ret_model_aggregation: + best_idx = _find_reference_weights(best_model, list_of_weights) + # This group requires finding the closest model to the returned one + # (weights distance wise) + elif aggregation_rule in byzantine_resilient_many_return_models_aggregation: + # when different aggregation strategies available + # write a function to find the closest model + raise NotImplementedError( + "aggregate_bulyan currently does not support the aggregation rules that" + " return many models as results. " + "Such aggregation rules are currently not available in Flower." + ) + else: + raise ValueError( + "The given aggregation rule is not added as Byzantine resilient. " + "Please choose from Byzantine resilient rules." + ) + + selected_models_set.append(results[best_idx]) + + # remove idx from tracker and weights_results + results.pop(best_idx) + + # Compute median parameter vector across selected_models_set + median_vect = aggregate_median(selected_models_set) + + # Take the averaged beta parameters of the closest distance to the median + # (coordinate-wise) + parameters_aggregated = _aggregate_n_closest_weights( + median_vect, selected_models_set, beta_closest=beta + ) + return parameters_aggregated + + def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" num_total_evaluation_examples = sum([num_examples for num_examples, _ in results]) @@ -168,3 +251,90 @@ def aggregate_trimmed_avg( ] return trimmed_w + + +def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool: + """Check if weights are the same.""" + if len(weights1) != len(weights2): + return False + return all( + np.array_equal(layer_weights1, layer_weights2) + for layer_weights1, layer_weights2 in zip(weights1, weights2) + ) + + +def _find_reference_weights( + reference_weights: NDArrays, list_of_weights: List[NDArrays] +) -> int: + """Find the reference weights by looping through the `list_of_weights`. + + Raise Error if the reference weights is not found. + + Parameters + ---------- + reference_weights: NDArrays + Weights that will be searched for. + list_of_weights: List[NDArrays] + List of weights that will be searched through. + + Returns + ------- + index: int + The index of `reference_weights` in the `list_of_weights`. + + Raises + ------ + ValueError + If `reference_weights` is not found in `list_of_weights`. + """ + for idx, weights in enumerate(list_of_weights): + if _check_weights_equality(reference_weights, weights): + return idx + raise ValueError("The reference weights not found in list_of_weights.") + + +def _aggregate_n_closest_weights( + reference_weights: NDArrays, results: List[Tuple[NDArrays, int]], beta_closest: int +) -> NDArrays: + """Calculate element-wise mean of the `N` closest values. + + Note, each i-th coordinate of the result weight is the average of the beta_closest + -ith coordinates to the reference weights + + + Parameters + ---------- + reference_weights: NDArrays + The weights from which the distances will be computed + results: List[Tuple[NDArrays, int]] + The weights from models + beta_closest: int + The number of the closest distance weights that will be averaged + + Returns + ------- + aggregated_weights: NDArrays + Averaged (element-wise) beta weights that have the closest distance to + reference weights + """ + list_of_weights = [weights for weights, num_examples in results] + aggregated_weights = [] + + for layer_id, layer_weights in enumerate(reference_weights): + other_weights_layer_list = [] + for other_w in list_of_weights: + other_weights_layer = other_w[layer_id] + other_weights_layer_list.append(other_weights_layer) + other_weights_layer_np = np.array(other_weights_layer_list) + diff_np = np.abs(layer_weights - other_weights_layer_np) + # Create indices of the smallest differences + # We do not need the exact order but just the beta closest weights + # therefore np.argpartition is used instead of np.argsort + indices = np.argpartition(diff_np, kth=beta_closest - 1, axis=0) + # Take the weights (coordinate-wise) corresponding to the beta of the + # closest distances + beta_closest_weights = np.take_along_axis( + other_weights_layer_np, indices=indices, axis=0 + )[:beta_closest] + aggregated_weights.append(np.mean(beta_closest_weights, axis=0)) + return aggregated_weights diff --git a/src/py/flwr/server/strategy/aggregate_test.py b/src/py/flwr/server/strategy/aggregate_test.py index 81cc7189b14d..f8b4e3c03b50 100644 --- a/src/py/flwr/server/strategy/aggregate_test.py +++ b/src/py/flwr/server/strategy/aggregate_test.py @@ -19,7 +19,13 @@ import numpy as np -from .aggregate import aggregate, weighted_loss_avg +from .aggregate import ( + _aggregate_n_closest_weights, + _check_weights_equality, + _find_reference_weights, + aggregate, + weighted_loss_avg, +) def test_aggregate() -> None: @@ -64,3 +70,72 @@ def test_weighted_loss_avg_multiple_values() -> None: # Assert assert expected == actual + + +def test_check_weights_equality_true() -> None: + """Check weights equality - the same weights.""" + weights1 = [np.array([1, 2]), np.array([[1, 2], [3, 4]])] + weights2 = [np.array([1, 2]), np.array([[1, 2], [3, 4]])] + results = _check_weights_equality(weights1, weights2) + expected = True + assert expected == results + + +def test_check_weights_equality_numeric_false() -> None: + """Check weights equality - different weights, same length.""" + weights1 = [np.array([1, 2]), np.array([[1, 2], [3, 4]])] + weights2 = [np.array([2, 2]), np.array([[1, 2], [3, 4]])] + results = _check_weights_equality(weights1, weights2) + expected = False + assert expected == results + + +def test_check_weights_equality_various_length_false() -> None: + """Check weights equality - the same first layer weights, different length.""" + weights1 = [np.array([1, 2]), np.array([[1, 2], [3, 4]])] + weights2 = [np.array([1, 2])] + results = _check_weights_equality(weights1, weights2) + expected = False + assert expected == results + + +def test_find_reference_weights() -> None: + """Check if the finding weights from list of weigths work.""" + reference_weights = [np.array([1, 2]), np.array([[1, 2], [3, 4]])] + list_of_weights = [ + [np.array([2, 2]), np.array([[1, 2], [3, 4]])], + [np.array([3, 2]), np.array([[1, 2], [3, 4]])], + [np.array([3, 2]), np.array([[1, 2], [10, 4]])], + [np.array([1, 2]), np.array([[1, 2], [3, 4]])], + ] + + result = _find_reference_weights(reference_weights, list_of_weights) + + expected = 3 + assert result == expected + + +def test_aggregate_n_closest_weights_mean() -> None: + """Check if aggregation of n closest weights to the reference works.""" + beta_closest = 2 + reference_weights = [np.array([1, 2]), np.array([[1, 2], [3, 4]])] + + list_of_weights = [ + [np.array([1, 2]), np.array([[1, 2], [3, 4]])], + [np.array([1.1, 2.1]), np.array([[1.1, 2.1], [3.1, 4.1]])], + [np.array([1.2, 2.2]), np.array([[1.2, 2.2], [3.2, 4.2]])], + [np.array([1.3, 2.3]), np.array([[0.9, 2.5], [3.4, 3.8]])], + ] + list_of_weights_and_samples = [(weights, 0) for weights in list_of_weights] + + beta_closest_weights = _aggregate_n_closest_weights( + reference_weights, list_of_weights_and_samples, beta_closest=beta_closest + ) + expected_averaged = [np.array([1.05, 2.05]), np.array([[0.95, 2.05], [3.05, 4.05]])] + + assert all( + ( + np.array_equal(expected, result) + for expected, result in zip(expected_averaged, beta_closest_weights) + ) + ) diff --git a/src/py/flwr/server/strategy/bulyan.py b/src/py/flwr/server/strategy/bulyan.py new file mode 100644 index 000000000000..0243f4e6546f --- /dev/null +++ b/src/py/flwr/server/strategy/bulyan.py @@ -0,0 +1,162 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bulyan [El Mhamdi et al., 2018] strategy. + +Paper: arxiv.org/abs/1802.07927 +""" + + +from logging import WARNING +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from flwr.common import ( + FitRes, + MetricsAggregationFn, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy + +from .aggregate import aggregate_bulyan, aggregate_krum +from .fedavg import FedAvg + + +# flake8: noqa: E501 +class Bulyan(FedAvg): + """Bulyan strategy implementation.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long, too-many-locals + def __init__( + self, + *, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + num_malicious_clients: int = 0, + evaluate_fn: Optional[ + Callable[ + [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + accept_failures: bool = True, + initial_parameters: Optional[Parameters] = None, + fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + first_aggregation_rule: Callable = aggregate_krum, # type: ignore + **aggregation_rule_kwargs: Any, + ) -> None: + """Bulyan strategy. + + Implementation based on https://arxiv.org/abs/1802.07927. + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + num_malicious_clients : int, optional + Number of malicious clients in the system. Defaults to 0. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + first_aggregation_rule: Callable + Byzantine resilient aggregation rule that is used as the first step of the Bulyan (e.g., Krum) + **aggregation_rule_kwargs: Any + arguments to the first_aggregation rule + """ + super().__init__( + fraction_fit=fraction_fit, + fraction_evaluate=fraction_evaluate, + min_fit_clients=min_fit_clients, + min_evaluate_clients=min_evaluate_clients, + min_available_clients=min_available_clients, + evaluate_fn=evaluate_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + ) + self.num_malicious_clients = num_malicious_clients + self.first_aggregation_rule = first_aggregation_rule + self.aggregation_rule_kwargs = aggregation_rule_kwargs + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"Bulyan(accept_failures={self.accept_failures})" + return rep + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using Bulyan.""" + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + + # Aggregate weights + parameters_aggregated = ndarrays_to_parameters( + aggregate_bulyan( + weights_results, + self.num_malicious_clients, + self.first_aggregation_rule, + **self.aggregation_rule_kwargs, + ) + ) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + return parameters_aggregated, metrics_aggregated diff --git a/src/py/flwr/server/strategy/bulyan_test.py b/src/py/flwr/server/strategy/bulyan_test.py new file mode 100644 index 000000000000..299ed49066fb --- /dev/null +++ b/src/py/flwr/server/strategy/bulyan_test.py @@ -0,0 +1,131 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bulyan tests.""" + + +from typing import List, Tuple +from unittest.mock import MagicMock + +from numpy import array, float32 + +from flwr.common import ( + Code, + FitRes, + NDArrays, + Parameters, + Status, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.server.client_proxy import ClientProxy + +from .bulyan import Bulyan + + +# pylint: disable=too-many-locals +def test_aggregate_fit() -> None: + """Tests if Bulyan is aggregating correctly.""" + # Prepare + previous_weights: NDArrays = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] + strategy = Bulyan( + initial_parameters=ndarrays_to_parameters(previous_weights), + num_malicious_clients=0, + to_keep=0, + ) + param_0: Parameters = ndarrays_to_parameters( + [array([0.2, 0.2, 0.2, 0.2], dtype=float32)] + ) + param_1: Parameters = ndarrays_to_parameters( + [array([0.5, 0.5, 0.5, 0.5], dtype=float32)] + ) + param_2: Parameters = ndarrays_to_parameters( + [array([0.7, 0.7, 0.7, 0.7], dtype=float32)] + ) + param_3: Parameters = ndarrays_to_parameters( + [array([12.0, 12.0, 12.0, 12.0], dtype=float32)] + ) + param_4: Parameters = ndarrays_to_parameters( + [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] + ) + param_5: Parameters = ndarrays_to_parameters( + [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] + ) + results: List[Tuple[ClientProxy, FitRes]] = [ + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=param_0, + num_examples=5, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=param_1, + num_examples=5, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=param_2, + num_examples=5, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=param_3, + num_examples=5, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=param_4, + num_examples=5, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=param_5, + num_examples=5, + metrics={}, + ), + ), + ] + coordinate = (0.2 + 0.5 + 0.7 + 12.0 + 0.1 + 0.1) / 6 + expected: NDArrays = [array([coordinate] * 4, dtype=float32)] + + # Execute + actual_aggregated, _ = strategy.aggregate_fit( + server_round=1, results=results, failures=[] + ) + if actual_aggregated: + actual_list = parameters_to_ndarrays(actual_aggregated) + actual = actual_list[0] + assert (actual == expected[0]).all()