Skip to content

Commit

Permalink
Add Flower Datasets v0.0.1 functionality (adap#2195)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
adam-narozniak and danieljanes authored Sep 4, 2023
1 parent e31a974 commit 26454f8
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 0 deletions.
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Flower Datasets main package."""


from .federated_dataset import FederatedDataset

__all__ = ["FederatedDataset"]
148 changes: 148 additions & 0 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FederatedDataset."""


from typing import Dict, Optional

import datasets
from datasets import Dataset, DatasetDict
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.utils import _check_if_dataset_supported, _instantiate_partitioners


class FederatedDataset:
"""Representation of a dataset for federated learning/evaluation/analytics.
Download, partition data among clients (edge devices), or load full dataset.
Partitions are created using IidPartitioner. Support for different partitioners
specification and types will come in future releases.
Parameters
----------
dataset: str
The name of the dataset in the Hugging Face Hub.
partitioners: Dict[str, int]
Dataset split to the number of IID partitions.
Examples
--------
Use MNIST dataset for Federated Learning with 100 clients (edge devices):
>>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
Load partition for client with ID 10.
>>> partition = mnist_fds.load_partition(10, "train")
Use test split for centralized evaluation.
>>> centralized = mnist_fds.load_full("test")
"""

def __init__(self, *, dataset: str, partitioners: Dict[str, int]) -> None:
_check_if_dataset_supported(dataset)
self._dataset_name: str = dataset
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
# Init (download) lazily on the first call to `load_partition` or `load_full`
self._dataset: Optional[DatasetDict] = None

def load_partition(self, idx: int, split: str) -> Dataset:
"""Load the partition specified by the idx in the selected split.
The dataset is downloaded only when the first call to `load_partition` or
`load_full` is made.
Parameters
----------
idx: int
Partition index for the selected split, idx in {0, ..., num_partitions - 1}.
split: str
Name of the (partitioned) split (e.g. "train", "test").
Returns
-------
partition: Dataset
Single partition from the dataset split.
"""
self._download_dataset_if_none()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
self._check_if_split_possible_to_federate(split)
partitioner: Partitioner = self._partitioners[split]
self._assign_dataset_to_partitioner(split)
return partitioner.load_partition(idx)

def load_full(self, split: str) -> Dataset:
"""Load the full split of the dataset.
The dataset is downloaded only when the first call to `load_partition` or
`load_full` is made.
Parameters
----------
split: str
Split name of the downloaded dataset (e.g. "train", "test").
Returns
-------
dataset_split: Dataset
Part of the dataset identified by its split name.
"""
self._download_dataset_if_none()
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
self._check_if_split_present(split)
return self._dataset[split]

def _download_dataset_if_none(self) -> None:
"""Lazily load (and potentially download) the Dataset instance into memory."""
if self._dataset is None:
self._dataset = datasets.load_dataset(self._dataset_name)

def _check_if_split_present(self, split: str) -> None:
"""Check if the split (for partitioning or full return) is in the dataset."""
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
available_splits = list(self._dataset.keys())
if split not in available_splits:
raise ValueError(
f"The given split: '{split}' is not present in the dataset's splits: "
f"'{available_splits}'."
)

def _check_if_split_possible_to_federate(self, split: str) -> None:
"""Check if the split has corresponding partitioner."""
partitioners_keys = list(self._partitioners.keys())
if split not in partitioners_keys:
raise ValueError(
f"The given split: '{split}' does not have a partitioner to perform "
f"partitioning. Partitioners were specified for the following splits:"
f"'{partitioners_keys}'."
)

def _assign_dataset_to_partitioner(self, split: str) -> None:
"""Assign the corresponding split of the dataset to the partitioner.
Assign only if the dataset is not assigned yet.
"""
if self._dataset is None:
raise ValueError("Dataset is not loaded yet.")
if not self._partitioners[split].is_dataset_assigned():
self._partitioners[split].dataset = self._dataset[split]
120 changes: 120 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated Dataset tests."""


import unittest

import pytest
from parameterized import parameterized, parameterized_class

import datasets
from flwr_datasets.federated_dataset import FederatedDataset


@parameterized_class(
[
{"dataset_name": "mnist", "test_split": "test"},
{"dataset_name": "cifar10", "test_split": "test"},
{"dataset_name": "fashion_mnist", "test_split": "test"},
{"dataset_name": "sasha/dog-food", "test_split": "test"},
{"dataset_name": "zh-plus/tiny-imagenet", "test_split": "valid"},
]
)
class RealDatasetsFederatedDatasetsTrainTest(unittest.TestCase):
"""Test Real Dataset (MNIST, CIFAR10) in FederatedDatasets."""

dataset_name = ""
test_split = ""

@parameterized.expand( # type: ignore
[
(
"10",
10,
),
(
"100",
100,
),
]
)
def test_load_partition_size(self, _: str, train_num_partitions: int) -> None:
"""Test if the partition size is correct based on the number of partitions."""
dataset_fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": train_num_partitions}
)
dataset_partition0 = dataset_fds.load_partition(0, "train")
dataset = datasets.load_dataset(self.dataset_name)
self.assertEqual(
len(dataset_partition0), len(dataset["train"]) // train_num_partitions
)

def test_load_full(self) -> None:
"""Test if the load_full works with the correct split name."""
dataset_fds = FederatedDataset(
dataset=self.dataset_name, partitioners={"train": 100}
)
dataset_fds_test = dataset_fds.load_full(self.test_split)
dataset_test = datasets.load_dataset(self.dataset_name)[self.test_split]
self.assertEqual(len(dataset_fds_test), len(dataset_test))

def test_multiple_partitioners(self) -> None:
"""Test if the dataset works when multiple partitioners are specified."""
num_train_partitions = 100
num_test_partitions = 100
dataset_fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={
"train": num_train_partitions,
self.test_split: num_test_partitions,
},
)
dataset_test_partition0 = dataset_fds.load_partition(0, self.test_split)

dataset = datasets.load_dataset(self.dataset_name)
self.assertEqual(
len(dataset_test_partition0),
len(dataset[self.test_split]) // num_test_partitions,
)


class IncorrectUsageFederatedDatasets(unittest.TestCase):
"""Test incorrect usages in FederatedDatasets."""

def test_no_partitioner_for_split(self) -> None: # pylint: disable=R0201
"""Test using load_partition with missing partitioner."""
dataset_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})

with pytest.raises(ValueError):
dataset_fds.load_partition(0, "test")

def test_no_split_in_the_dataset(self) -> None: # pylint: disable=R0201
"""Test using load_partition with non-existent split name."""
dataset_fds = FederatedDataset(
dataset="mnist", partitioners={"non-existent-split": 100}
)

with pytest.raises(ValueError):
dataset_fds.load_partition(0, "non-existent-split")

def test_unsupported_dataset(self) -> None: # pylint: disable=R0201
"""Test creating FederatedDataset for unsupported dataset."""
with pytest.raises(ValueError):
FederatedDataset(dataset="food101", partitioners={"train": 100})


if __name__ == "__main__":
unittest.main()
57 changes: 57 additions & 0 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for FederatedDataset."""


from typing import Dict

from flwr_datasets.partitioner import IidPartitioner, Partitioner


def _instantiate_partitioners(partitioners: Dict[str, int]) -> Dict[str, Partitioner]:
"""Transform the partitioners from the initial format to instantiated objects.
Parameters
----------
partitioners: Dict[str, int]
Partitioners specified as split to the number of partitions format.
Returns
-------
partitioners: Dict[str, Partitioner]
Partitioners specified as split to Partitioner object.
"""
instantiated_partitioners: Dict[str, Partitioner] = {}
for split_name, num_partitions in partitioners.items():
instantiated_partitioners[split_name] = IidPartitioner(
num_partitions=num_partitions
)
return instantiated_partitioners


def _check_if_dataset_supported(dataset: str) -> None:
"""Check if the dataset is in the narrowed down list of the tested datasets."""
supported_datasets = [
"mnist",
"cifar10",
"fashion_mnist",
"sasha/dog-food",
"zh-plus/tiny-imagenet",
]
if dataset not in supported_datasets:
raise ValueError(
f"The currently tested and supported dataset are {supported_datasets}. "
f"Given: {dataset}"
)
2 changes: 2 additions & 0 deletions datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ exclude = [
[tool.poetry.dependencies]
python = "^3.8"
numpy = "^1.21.0"
datasets = "^2.14.3"

[tool.poetry.dev-dependencies]
isort = "==5.11.5"
Expand All @@ -62,6 +63,7 @@ docformatter = "==1.7.1"
mypy = "==1.4.0"
pylint = "==2.13.9"
flake8 = "==3.9.2"
parameterized = "==0.9.0"
pytest = "==7.1.2"
pytest-watch = "==4.2.0"
ruff = "==0.0.277"
Expand Down

0 comments on commit 26454f8

Please sign in to comment.