diff --git a/src/intelligence_layer/evaluation/dataset/file_dataset_repository.py b/src/intelligence_layer/evaluation/dataset/file_dataset_repository.py index 6fd7cbf66..10a147160 100644 --- a/src/intelligence_layer/evaluation/dataset/file_dataset_repository.py +++ b/src/intelligence_layer/evaluation/dataset/file_dataset_repository.py @@ -22,6 +22,8 @@ class FileSystemDatasetRepository(DatasetRepository, FileSystemBasedRepository): def __init__(self, filesystem: AbstractFileSystem, root_directory: Path) -> None: super().__init__(file_system=filesystem, root_directory=root_directory) + # this is a local lru cache per repository instance, instead of a global one for all classes + self.examples = lru_cache(maxsize=2)(self.examples) # type: ignore def create_dataset( self, @@ -89,13 +91,11 @@ def example( ) if not self.exists(example_path): return None - for example in self.examples(dataset_id, input_type, expected_output_type): if example.id == example_id: return example return None - @lru_cache(maxsize=1) def examples( self, dataset_id: str, diff --git a/tests/evaluation/test_dataset_repository.py b/tests/evaluation/test_dataset_repository.py index 96ffb1ac8..2b94b42ac 100644 --- a/tests/evaluation/test_dataset_repository.py +++ b/tests/evaluation/test_dataset_repository.py @@ -1,6 +1,6 @@ from pathlib import Path from typing import Any, Iterable -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest from fsspec.implementations.memory import MemoryFileSystem # type: ignore @@ -11,7 +11,6 @@ Example, FileDatasetRepository, ) -from intelligence_layer.evaluation.dataset.file_dataset_repository import FileSystemDatasetRepository from intelligence_layer.evaluation.dataset.hugging_face_dataset_repository import ( HuggingFaceDatasetRepository, ) @@ -336,26 +335,3 @@ def test_example_raises_error_for_not_existing_dataset_id( DummyStringInput, DummyStringOutput, ) - -@mark.parametrize("repository_fixture", test_repository_fixtures) -def test_example_raises_error_for_not_existing_dataset_id( - repository_fixture: str, - request: FixtureRequest, - dummy_string_example: Example[DummyStringInput, DummyStringOutput], -) -> None: - dataset_repository: FileSystemDatasetRepository = request.getfixturevalue(repository_fixture) - - dataset = dataset_repository.create_dataset([], "temp") - - dataset_repository._file_system = Mock() - try: - dataset_repository.example(dataset.id, "", str, str) - except Exception as e: - pass - try: - dataset_repository.example(dataset.id, "", str, str) - except Exception as e: - pass - - assert dataset_repository._file_system.open.call_count == 1 - diff --git a/tests/evaluation/test_file_dataset_repository.py b/tests/evaluation/test_file_dataset_repository.py new file mode 100644 index 000000000..184396a54 --- /dev/null +++ b/tests/evaluation/test_file_dataset_repository.py @@ -0,0 +1,34 @@ +from pathlib import Path +from typing import Iterable + +from fsspec import AbstractFileSystem # type: ignore +from pytest import fixture +from intelligence_layer.core.task import Input +from intelligence_layer.evaluation.dataset.domain import Example, ExpectedOutput +from intelligence_layer.evaluation.dataset.file_dataset_repository import FileSystemDatasetRepository + + +class FileDatasetRepositoryTestWrapper(FileSystemDatasetRepository): + def __init__(self, filesystem: AbstractFileSystem, root_directory: Path) -> None: + super().__init__(filesystem, root_directory) + self.counter = 0 + + def examples(self, dataset_id: str, input_type: type[Input], expected_output_type: type[ExpectedOutput]) -> Iterable[Example[Input, ExpectedOutput]]: + self.counter += 1 + return super().examples(dataset_id, input_type, expected_output_type) + + +@fixture +def file_data_repo_stub(temp_file_system: AbstractFileSystem) -> FileDatasetRepositoryTestWrapper: + return FileDatasetRepositoryTestWrapper(temp_file_system, Path("Root")) + + +def test_opens_files_only_once_when_reading_multiple_examples( + file_data_repo_stub: FileDatasetRepositoryTestWrapper, +) -> None: + dataset = file_data_repo_stub.create_dataset([], "temp") + + file_data_repo_stub.example(dataset.id, "", str, str) + file_data_repo_stub.example(dataset.id, "", str, str) + + assert file_data_repo_stub.counter == 1