Skip to content

Commit

Permalink
Reduce calls to HuggingFace in tests(#730)
Browse files Browse the repository at this point in the history
* refactor: move huggingface inits to parent class

* refactor: Add huggingFaceXXXRepo to the standart repository tests with a mock implementation

* docs: update changelog

Task: IL-423
  • Loading branch information
NiklasKoehneckeAA authored Apr 11, 2024
1 parent b829383 commit fd8814a
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 225 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

## Unreleased

### Breaking Changes
- The implementation of the HuggingFace repository creation and deletion got moved to `HuggingFaceRepository`
### New Features
...
- feature: HuggingFaceDataset- & AggregationRepositories now have an explicit `create_repository` function.
### Fixes
...


## 0.8.2

### New Features
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from pathlib import Path

import huggingface_hub # type: ignore
from huggingface_hub import HfFileSystem, create_repo

from intelligence_layer.evaluation.aggregation.file_aggregation_repository import (
FileSystemAggregationRepository,
)
Expand All @@ -14,30 +9,5 @@
class HuggingFaceAggregationRepository(
FileSystemAggregationRepository, HuggingFaceRepository
):
def __init__(self, repository_id: str, token: str, private: bool) -> None:
assert repository_id[-1] != "/"

create_repo(
repo_id=repository_id,
token=token,
repo_type=HuggingFaceRepository._REPO_TYPE,
private=private,
exist_ok=True,
)

file_system = HfFileSystem(token=token)
root_directory = Path(
f"{HuggingFaceRepository._ROOT_DIRECTORY_PREFIX_}/{repository_id}"
)
super().__init__(file_system, root_directory)

self._repository_id = repository_id
self._file_system = file_system # for better type checks

def delete_repository(self) -> None:
huggingface_hub.delete_repo(
repo_id=self._repository_id,
token=self._file_system.token,
repo_type=HuggingFaceRepository._REPO_TYPE,
missing_ok=True,
)
# this class inherits all its behavior from its parents
pass
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def create_dataset(

dataset_path = self._dataset_path(dataset.id)
examples_path = self._dataset_examples_path(dataset.id)
if self._file_system.exists(dataset_path) or self._file_system.exists(
examples_path
):
if self.exists(dataset_path) or self.exists(examples_path):
raise ValueError(
f"One of the dataset files already exist for dataset {dataset}. This should not happen. Files: {dataset_path}, {examples_path}."
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from pathlib import Path
from typing import Optional

import huggingface_hub # type: ignore
from huggingface_hub import HfFileSystem, create_repo

from intelligence_layer.evaluation.dataset.domain import Dataset
from intelligence_layer.evaluation.dataset.file_dataset_repository import (
FileSystemDatasetRepository,
Expand All @@ -14,41 +10,6 @@


class HuggingFaceDatasetRepository(HuggingFaceRepository, FileSystemDatasetRepository):
def __init__(self, repository_id: str, token: str, private: bool) -> None:
"""Create a HuggingFace dataset repository
Args:
repository_id: The HuggingFace namespace and repository name separated by a "/".
token: The HuggingFace token.
private: Whether the dataset repository should be private.
"""
assert repository_id[-1] != "/"

create_repo(
repo_id=repository_id,
token=token,
repo_type=HuggingFaceDatasetRepository._REPO_TYPE,
private=private,
exist_ok=True,
)

file_system = HfFileSystem(token=token)
root_directory = Path(
f"{HuggingFaceRepository._ROOT_DIRECTORY_PREFIX_}/{repository_id}"
)
super().__init__(file_system, root_directory)

self._repository_id = repository_id
self._file_system = file_system # for better type checks

def delete_repository(self) -> None:
huggingface_hub.delete_repo(
repo_id=self._repository_id,
token=self._file_system.token,
repo_type=HuggingFaceDatasetRepository._REPO_TYPE,
missing_ok=True,
)

def delete_dataset(self, dataset_id: str) -> None:
"""Deletes a dataset identified by the given dataset ID.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path

import huggingface_hub # type: ignore

from intelligence_layer.evaluation.dataset.file_dataset_repository import (
FileSystemDatasetRepository,
)
Expand All @@ -14,3 +16,41 @@ class HuggingFaceRepository(FileSystemDatasetRepository):
@staticmethod
def path_to_str(path: Path) -> str:
return path.as_posix()

def __init__(self, repository_id: str, token: str, private: bool) -> None:
"""Create a HuggingFace repository.
Creates a corresponding repository and initializes the file system.
Args:
repository_id: The HuggingFace namespace and repository name, separated by a "/".
token: The HuggingFace authentication token.
private: Whether the dataset repository should be private.
"""
assert repository_id[-1] != "/"
self.create_repository(repository_id, token, private)

file_system = huggingface_hub.HfFileSystem(token=token)
root_directory = Path(f"{self._ROOT_DIRECTORY_PREFIX_}/{repository_id}")

super().__init__(file_system, root_directory)
self._repository_id = repository_id
# the file system is assigned in super init but this fixes the typing
self._file_system: huggingface_hub.HfFileSystem

def create_repository(self, repository_id: str, token: str, private: bool) -> None:
huggingface_hub.create_repo(
repo_id=repository_id,
token=token,
repo_type=self._REPO_TYPE,
private=private,
exist_ok=True,
)

def delete_repository(self) -> None:
huggingface_hub.delete_repo(
repo_id=self._repository_id,
token=self._file_system.token,
repo_type=self._REPO_TYPE,
missing_ok=True,
)
16 changes: 16 additions & 0 deletions tests/evaluation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import uuid4

from dotenv import load_dotenv
from fsspec.implementations.memory import MemoryFileSystem # type: ignore
from pydantic import BaseModel
from pytest import fixture

Expand Down Expand Up @@ -241,6 +242,21 @@ def stub_argilla_client() -> StubArgillaClient:
return StubArgillaClient()


@fixture()
def temp_file_system() -> Iterable[MemoryFileSystem]:
mfs = MemoryFileSystem()

try:
yield mfs
finally:
mfs.store.clear()


@fixture(scope="session")
def hugging_face_test_repository_id() -> str:
return f"Aleph-Alpha/test-{str(uuid4())}"


@fixture(scope="session")
def hugging_face_token() -> str:
load_dotenv()
Expand Down
29 changes: 24 additions & 5 deletions tests/evaluation/test_aggregation_repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Iterable
from unittest.mock import patch
from uuid import uuid4

from _pytest.fixtures import FixtureRequest
from fsspec.implementations.memory import MemoryFileSystem # type: ignore
from pytest import fixture, mark

from intelligence_layer.core import utc_now
Expand All @@ -10,14 +12,36 @@
AggregationRepository,
EvaluationOverview,
)
from intelligence_layer.evaluation.aggregation.hugging_face_aggregation_repository import (
HuggingFaceAggregationRepository,
)
from tests.evaluation.conftest import DummyAggregatedEvaluation

test_repository_fixtures = [
"file_aggregation_repository",
"in_memory_aggregation_repository",
"mocked_hugging_face_aggregation_repository",
]


@fixture
def mocked_hugging_face_aggregation_repository(
temp_file_system: MemoryFileSystem,
) -> Iterable[HuggingFaceAggregationRepository]:
class_to_patch = "intelligence_layer.evaluation.aggregation.hugging_face_aggregation_repository.HuggingFaceAggregationRepository"
with patch(f"{class_to_patch}.create_repository", autospec=True), patch(
f"{class_to_patch}.delete_repository",
autospec=True,
):
repo = HuggingFaceAggregationRepository(
repository_id="doesn't-matter",
token="non-existing-token",
private=True,
)
repo._file_system = temp_file_system
yield repo


@fixture
def aggregation_overviews(
evaluation_overview: EvaluationOverview,
Expand Down Expand Up @@ -70,7 +94,6 @@ def test_aggregation_repository_stores_and_returns_an_aggregation_overview(
def test_aggregation_overview_returns_none_for_not_existing_id(
repository_fixture: str,
request: FixtureRequest,
aggregation_overview: AggregationOverview[DummyAggregatedEvaluation],
) -> None:
aggregation_repository: AggregationRepository = request.getfixturevalue(
repository_fixture
Expand All @@ -90,9 +113,7 @@ def test_aggregation_overview_returns_none_for_not_existing_id(
def test_aggregation_overviews_returns_all_aggregation_overviews(
repository_fixture: str,
request: FixtureRequest,
evaluation_overview: EvaluationOverview,
aggregation_overviews: Iterable[AggregationOverview[DummyAggregatedEvaluation]],
dummy_aggregated_evaluation: DummyAggregatedEvaluation,
) -> None:
aggregation_repository: AggregationRepository = request.getfixturevalue(
repository_fixture
Expand All @@ -116,9 +137,7 @@ def test_aggregation_overviews_returns_all_aggregation_overviews(
def test_aggregation_overview_ids_returns_sorted_ids(
repository_fixture: str,
request: FixtureRequest,
evaluation_overview: EvaluationOverview,
aggregation_overviews: Iterable[AggregationOverview[DummyAggregatedEvaluation]],
dummy_aggregated_evaluation: DummyAggregatedEvaluation,
) -> None:
aggregation_repository: AggregationRepository = request.getfixturevalue(
repository_fixture
Expand Down
42 changes: 38 additions & 4 deletions tests/evaluation/test_dataset_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from unittest.mock import patch

import pytest
from fsspec.implementations.memory import MemoryFileSystem # type: ignore
from pytest import FixtureRequest, fixture, mark, raises

from intelligence_layer.evaluation import (
DatasetRepository,
Example,
FileDatasetRepository,
)
from intelligence_layer.evaluation.dataset.hugging_face_dataset_repository import (
HuggingFaceDatasetRepository,
)
from tests.conftest import DummyStringInput, DummyStringOutput


Expand All @@ -18,9 +22,28 @@ def file_dataset_repository(tmp_path: Path) -> FileDatasetRepository:
return FileDatasetRepository(tmp_path)


@fixture
def mocked_hugging_face_dataset_repository(
temp_file_system: MemoryFileSystem,
) -> Iterable[HuggingFaceDatasetRepository]:
class_to_patch = "intelligence_layer.evaluation.dataset.hugging_face_dataset_repository.HuggingFaceDatasetRepository"
with patch(f"{class_to_patch}.create_repository", autospec=True), patch(
f"{class_to_patch}.delete_repository",
autospec=True,
):
repo = HuggingFaceDatasetRepository(
repository_id="doesn't-matter",
token="non-existing-token",
private=True,
)
repo._file_system = temp_file_system
yield repo


test_repository_fixtures = [
"file_dataset_repository",
"in_memory_dataset_repository",
"mocked_hugging_face_dataset_repository",
]


Expand Down Expand Up @@ -72,15 +95,13 @@ def test_dataset_repository_can_create_and_store_a_dataset(
assert stored_examples[0] == dummy_string_example


@patch(
target="intelligence_layer.evaluation.dataset.domain.uuid4", return_value="12345"
)
@patch(target="intelligence_layer.evaluation.dataset.domain.uuid4", return_value="1234")
@mark.parametrize(
"repository_fixture",
test_repository_fixtures,
)
def test_dataset_repository_ensures_unique_dataset_ids(
_mock_uuid4: Any,
_mock_uuid4: Any, # this is necessary as otherwise the other fixtures aren't found
repository_fixture: str,
request: FixtureRequest,
dummy_string_example: Example[DummyStringInput, DummyStringOutput],
Expand Down Expand Up @@ -153,6 +174,19 @@ def test_delete_dataset_deletes_a_dataset(
) # tests whether function is idempotent


@mark.parametrize(
"repository_fixture",
test_repository_fixtures,
)
def test_deleting_a_nonexistant_repo_does_not_cause_an_exception(
repository_fixture: str,
request: FixtureRequest,
) -> None:
dataset_repository: DatasetRepository = request.getfixturevalue(repository_fixture)

dataset_repository.delete_dataset("non-existant-id")


@mark.parametrize(
"repository_fixture",
test_repository_fixtures,
Expand Down
Loading

0 comments on commit fd8814a

Please sign in to comment.