forked from adap/flower
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Flower Datasets v0.0.1 functionality (adap#2195)
Co-authored-by: Daniel J. Beutel <[email protected]>
- Loading branch information
1 parent
e31a974
commit 26454f8
Showing
5 changed files
with
332 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""FederatedDataset.""" | ||
|
||
|
||
from typing import Dict, Optional | ||
|
||
import datasets | ||
from datasets import Dataset, DatasetDict | ||
from flwr_datasets.partitioner import Partitioner | ||
from flwr_datasets.utils import _check_if_dataset_supported, _instantiate_partitioners | ||
|
||
|
||
class FederatedDataset: | ||
"""Representation of a dataset for federated learning/evaluation/analytics. | ||
Download, partition data among clients (edge devices), or load full dataset. | ||
Partitions are created using IidPartitioner. Support for different partitioners | ||
specification and types will come in future releases. | ||
Parameters | ||
---------- | ||
dataset: str | ||
The name of the dataset in the Hugging Face Hub. | ||
partitioners: Dict[str, int] | ||
Dataset split to the number of IID partitions. | ||
Examples | ||
-------- | ||
Use MNIST dataset for Federated Learning with 100 clients (edge devices): | ||
>>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) | ||
Load partition for client with ID 10. | ||
>>> partition = mnist_fds.load_partition(10, "train") | ||
Use test split for centralized evaluation. | ||
>>> centralized = mnist_fds.load_full("test") | ||
""" | ||
|
||
def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None: | ||
_check_if_dataset_supported(dataset) | ||
self._dataset_name: str = dataset | ||
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( | ||
partitioners | ||
) | ||
# Init (download) lazily on the first call to `load_partition` or `load_full` | ||
self._dataset: Optional[DatasetDict] = None | ||
|
||
def load_partition(self, idx: int, split: str) -> Dataset: | ||
"""Load the partition specified by the idx in the selected split. | ||
The dataset is downloaded only when the first call to `load_partition` or | ||
`load_full` is made. | ||
Parameters | ||
---------- | ||
idx: int | ||
Partition index for the selected split, idx in {0, ..., num_partitions - 1}. | ||
split: str | ||
Name of the (partitioned) split (e.g. "train", "test"). | ||
Returns | ||
------- | ||
partition: Dataset | ||
Single partition from the dataset split. | ||
""" | ||
self._download_dataset_if_none() | ||
if self._dataset is None: | ||
raise ValueError("Dataset is not loaded yet.") | ||
self._check_if_split_present(split) | ||
self._check_if_split_possible_to_federate(split) | ||
partitioner: Partitioner = self._partitioners[split] | ||
self._assign_dataset_to_partitioner(split) | ||
return partitioner.load_partition(idx) | ||
|
||
def load_full(self, split: str) -> Dataset: | ||
"""Load the full split of the dataset. | ||
The dataset is downloaded only when the first call to `load_partition` or | ||
`load_full` is made. | ||
Parameters | ||
---------- | ||
split: str | ||
Split name of the downloaded dataset (e.g. "train", "test"). | ||
Returns | ||
------- | ||
dataset_split: Dataset | ||
Part of the dataset identified by its split name. | ||
""" | ||
self._download_dataset_if_none() | ||
if self._dataset is None: | ||
raise ValueError("Dataset is not loaded yet.") | ||
self._check_if_split_present(split) | ||
return self._dataset[split] | ||
|
||
def _download_dataset_if_none(self) -> None: | ||
"""Lazily load (and potentially download) the Dataset instance into memory.""" | ||
if self._dataset is None: | ||
self._dataset = datasets.load_dataset(self._dataset_name) | ||
|
||
def _check_if_split_present(self, split: str) -> None: | ||
"""Check if the split (for partitioning or full return) is in the dataset.""" | ||
if self._dataset is None: | ||
raise ValueError("Dataset is not loaded yet.") | ||
available_splits = list(self._dataset.keys()) | ||
if split not in available_splits: | ||
raise ValueError( | ||
f"The given split: '{split}' is not present in the dataset's splits: " | ||
f"'{available_splits}'." | ||
) | ||
|
||
def _check_if_split_possible_to_federate(self, split: str) -> None: | ||
"""Check if the split has corresponding partitioner.""" | ||
partitioners_keys = list(self._partitioners.keys()) | ||
if split not in partitioners_keys: | ||
raise ValueError( | ||
f"The given split: '{split}' does not have a partitioner to perform " | ||
f"partitioning. Partitioners were specified for the following splits:" | ||
f"'{partitioners_keys}'." | ||
) | ||
|
||
def _assign_dataset_to_partitioner(self, split: str) -> None: | ||
"""Assign the corresponding split of the dataset to the partitioner. | ||
Assign only if the dataset is not assigned yet. | ||
""" | ||
if self._dataset is None: | ||
raise ValueError("Dataset is not loaded yet.") | ||
if not self._partitioners[split].is_dataset_assigned(): | ||
self._partitioners[split].dataset = self._dataset[split] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Federated Dataset tests.""" | ||
|
||
|
||
import unittest | ||
|
||
import pytest | ||
from parameterized import parameterized, parameterized_class | ||
|
||
import datasets | ||
from flwr_datasets.federated_dataset import FederatedDataset | ||
|
||
|
||
@parameterized_class( | ||
[ | ||
{"dataset_name": "mnist", "test_split": "test"}, | ||
{"dataset_name": "cifar10", "test_split": "test"}, | ||
{"dataset_name": "fashion_mnist", "test_split": "test"}, | ||
{"dataset_name": "sasha/dog-food", "test_split": "test"}, | ||
{"dataset_name": "zh-plus/tiny-imagenet", "test_split": "valid"}, | ||
] | ||
) | ||
class RealDatasetsFederatedDatasetsTrainTest(unittest.TestCase): | ||
"""Test Real Dataset (MNIST, CIFAR10) in FederatedDatasets.""" | ||
|
||
dataset_name = "" | ||
test_split = "" | ||
|
||
@parameterized.expand( # type: ignore | ||
[ | ||
( | ||
"10", | ||
10, | ||
), | ||
( | ||
"100", | ||
100, | ||
), | ||
] | ||
) | ||
def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: | ||
"""Test if the partition size is correct based on the number of partitions.""" | ||
dataset_fds = FederatedDataset( | ||
dataset=self.dataset_name, partitioners={"train": train_num_partitions} | ||
) | ||
dataset_partition0 = dataset_fds.load_partition(0, "train") | ||
dataset = datasets.load_dataset(self.dataset_name) | ||
self.assertEqual( | ||
len(dataset_partition0), len(dataset["train"]) // train_num_partitions | ||
) | ||
|
||
def test_load_full(self) -> None: | ||
"""Test if the load_full works with the correct split name.""" | ||
dataset_fds = FederatedDataset( | ||
dataset=self.dataset_name, partitioners={"train": 100} | ||
) | ||
dataset_fds_test = dataset_fds.load_full(self.test_split) | ||
dataset_test = datasets.load_dataset(self.dataset_name)[self.test_split] | ||
self.assertEqual(len(dataset_fds_test), len(dataset_test)) | ||
|
||
def test_multiple_partitioners(self) -> None: | ||
"""Test if the dataset works when multiple partitioners are specified.""" | ||
num_train_partitions = 100 | ||
num_test_partitions = 100 | ||
dataset_fds = FederatedDataset( | ||
dataset=self.dataset_name, | ||
partitioners={ | ||
"train": num_train_partitions, | ||
self.test_split: num_test_partitions, | ||
}, | ||
) | ||
dataset_test_partition0 = dataset_fds.load_partition(0, self.test_split) | ||
|
||
dataset = datasets.load_dataset(self.dataset_name) | ||
self.assertEqual( | ||
len(dataset_test_partition0), | ||
len(dataset[self.test_split]) // num_test_partitions, | ||
) | ||
|
||
|
||
class IncorrectUsageFederatedDatasets(unittest.TestCase): | ||
"""Test incorrect usages in FederatedDatasets.""" | ||
|
||
def test_no_partitioner_for_split(self) -> None: # pylint: disable=R0201 | ||
"""Test using load_partition with missing partitioner.""" | ||
dataset_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) | ||
|
||
with pytest.raises(ValueError): | ||
dataset_fds.load_partition(0, "test") | ||
|
||
def test_no_split_in_the_dataset(self) -> None: # pylint: disable=R0201 | ||
"""Test using load_partition with non-existent split name.""" | ||
dataset_fds = FederatedDataset( | ||
dataset="mnist", partitioners={"non-existent-split": 100} | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
dataset_fds.load_partition(0, "non-existent-split") | ||
|
||
def test_unsupported_dataset(self) -> None: # pylint: disable=R0201 | ||
"""Test creating FederatedDataset for unsupported dataset.""" | ||
with pytest.raises(ValueError): | ||
FederatedDataset(dataset="food101", partitioners={"train": 100}) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2023 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Utils for FederatedDataset.""" | ||
|
||
|
||
from typing import Dict | ||
|
||
from flwr_datasets.partitioner import IidPartitioner, Partitioner | ||
|
||
|
||
def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]: | ||
"""Transform the partitioners from the initial format to instantiated objects. | ||
Parameters | ||
---------- | ||
partitioners: Dict[str, int] | ||
Partitioners specified as split to the number of partitions format. | ||
Returns | ||
------- | ||
partitioners: Dict[str, Partitioner] | ||
Partitioners specified as split to Partitioner object. | ||
""" | ||
instantiated_partitioners: Dict[str, Partitioner] = {} | ||
for split_name, num_partitions in partitioners.items(): | ||
instantiated_partitioners[split_name] = IidPartitioner( | ||
num_partitions=num_partitions | ||
) | ||
return instantiated_partitioners | ||
|
||
|
||
def _check_if_dataset_supported(dataset: str) -> None: | ||
"""Check if the dataset is in the narrowed down list of the tested datasets.""" | ||
supported_datasets = [ | ||
"mnist", | ||
"cifar10", | ||
"fashion_mnist", | ||
"sasha/dog-food", | ||
"zh-plus/tiny-imagenet", | ||
] | ||
if dataset not in supported_datasets: | ||
raise ValueError( | ||
f"The currently tested and supported dataset are {supported_datasets}. " | ||
f"Given: {dataset}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters