diff --git a/CHANGELOG.md b/CHANGELOG.md index b3d11a644..ad1970fa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ - `PostgresInstructionFinetuningDataRepository` to work with data stored in a Postgres database. - `FileInstructionFinetuningDataRepository` to work with data stored in the local file-system. - Compute precision, recall and f1-score by class in `SingleLabelClassifyAggregationLogic` +- Add submit_dataset function to StudioClient + - Add `how_to_upload_existing_datasets_to_studio.ipynb` to how-tos ### Fixes ... diff --git a/README.md b/README.md index 566219b34..120875bc4 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ The how-tos are quick lookups about how to do things. Compared to the tutorials, | [...define a task](./src/documentation/how_tos/how_to_define_a_task.ipynb) | How to come up with a new task and formulate it | | [...implement a task](./src/documentation/how_tos/how_to_implement_a_task.ipynb) | Implement a formulated task and make it run with the Intelligence Layer | | [...debug and log a task](./src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb) | Tools for logging and debugging in tasks | -| [...use Studio with traces](./src/documentation/how_tos/how_to_use_studio_with_traces.ipynb) | Submitting Traces to Studio for debugging | +| [...use Studio with traces](./src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb) | Submitting Traces to Studio for debugging | | **Analysis Pipeline** | | | [...implement a simple evaluation and aggregation logic](./src/documentation/how_tos/how_to_implement_a_simple_evaluation_and_aggregation_logic.ipynb) | Basic examples of evaluation and aggregation logic | | [...create a dataset](./src/documentation/how_tos/how_to_create_a_dataset.ipynb) | Create a dataset used for running a task | diff --git a/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb b/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb index d64431a19..873861633 100644 --- a/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb +++ b/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb @@ -70,7 +70,7 @@ ], "metadata": { "kernelspec": { - "display_name": "intelligence-layer-d3iSWYpm-py3.10", + "display_name": "intelligence-layer-aL2cXmJM-py3.11", "language": "python", "name": "python3" }, diff --git a/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb b/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb index a8a7c939e..1fd88c4ab 100644 --- a/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb +++ b/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb @@ -90,7 +90,7 @@ ], "metadata": { "kernelspec": { - "display_name": "intelligence-layer-d3iSWYpm-py3.10", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/src/documentation/how_tos/studio/how_to_upload_existing_datasets_to_studio.ipynb b/src/documentation/how_tos/studio/how_to_upload_existing_datasets_to_studio.ipynb new file mode 100644 index 000000000..83eed05dd --- /dev/null +++ b/src/documentation/how_tos/studio/how_to_upload_existing_datasets_to_studio.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from uuid import uuid4\n", + "\n", + "from documentation.how_tos.example_data import example_data\n", + "from intelligence_layer.connectors import StudioClient\n", + "from intelligence_layer.evaluation.dataset.studio_dataset_repository import (\n", + " StudioDatasetRepository,\n", + ")\n", + "\n", + "my_example_data = example_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to upload (existing) datasets to Studio\n", + "
\n", + "\n", + "Make sure your account has permissions to use the Studio application.\n", + "\n", + "For an on-prem or local installation, please contact the corresponding team.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "0. Extract `Dataset` and `Examples` from your `DatasetRepository`.\n", + "\n", + "1. Initialize a `StudioClient` with a project.\n", + " - Use an existing project or create a new one with the `StudioClient.create_project` function.\n", + " \n", + "2. Create a `StudioDatasetRepository` and create a new `Dataset` via `StudioDatasetRepository.create_dataset`, which will automatically upload this new `Dataset` to Studio.\n", + "\n", + "### Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 0\n", + "existing_dataset_repo = my_example_data.dataset_repository\n", + "\n", + "existing_dataset = existing_dataset_repo.dataset(dataset_id=my_example_data.dataset.id)\n", + "assert existing_dataset, \"Make sure your dataset still exists.\"\n", + "\n", + "existing_examples = existing_dataset_repo.examples(\n", + " existing_dataset.id, input_type=str, expected_output_type=str\n", + ")\n", + "\n", + "# Step 1\n", + "project_name = str(uuid4())\n", + "studio_client = StudioClient(project=project_name)\n", + "my_project = studio_client.create_project(project=project_name)\n", + "\n", + "# Step 2\n", + "studio_dataset_repo = StudioDatasetRepository(studio_client=studio_client)\n", + "\n", + "studio_dataset = studio_dataset_repo.create_dataset(\n", + " examples=existing_examples,\n", + " dataset_name=existing_dataset.name,\n", + " labels=existing_dataset.labels,\n", + " metadata=existing_dataset.metadata,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "intelligence-layer-aL2cXmJM-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/documentation/how_tos/how_to_use_studio_with_traces.ipynb b/src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb similarity index 99% rename from src/documentation/how_tos/how_to_use_studio_with_traces.ipynb rename to src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb index 401895c50..b9d8695c8 100644 --- a/src/documentation/how_tos/how_to_use_studio_with_traces.ipynb +++ b/src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb @@ -88,7 +88,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/src/intelligence_layer/connectors/studio/studio.py b/src/intelligence_layer/connectors/studio/studio.py index 6a7848caa..29e3aaa54 100644 --- a/src/intelligence_layer/connectors/studio/studio.py +++ b/src/intelligence_layer/connectors/studio/studio.py @@ -1,25 +1,71 @@ +import json import os from collections import defaultdict -from collections.abc import Sequence -from typing import Optional +from collections.abc import Iterable, Sequence +from typing import Generic, Optional, TypeVar from urllib.parse import urljoin +from uuid import uuid4 import requests -from pydantic import BaseModel +from pydantic import BaseModel, Field from requests.exceptions import ConnectionError, MissingSchema +from intelligence_layer.connectors.base.json_serializable import ( + SerializableDict, +) from intelligence_layer.core.tracer.tracer import ( # Import to be fixed with PHS-731 ExportedSpan, ExportedSpanList, + PydanticSerializable, Tracer, ) +Input = TypeVar("Input", bound=PydanticSerializable) +ExpectedOutput = TypeVar("ExpectedOutput", bound=PydanticSerializable) + class StudioProject(BaseModel): name: str description: Optional[str] +class StudioExample(BaseModel, Generic[Input, ExpectedOutput]): + """Represents an instance of :class:`Example`as sent to Studio. + + Attributes: + input: Input for the :class:`Task`. Has to be same type as the input for the task used. + expected_output: The expected output from a given example run. + This will be used by the evaluator to compare the received output with. + id: Identifier for the example, defaults to uuid. + metadata: Optional dictionary of custom key-value pairs. + + Generics: + Input: Interface to be passed to the :class:`Task` that shall be evaluated. + ExpectedOutput: Output that is expected from the run with the supplied input. + """ + + input: Input + expected_output: ExpectedOutput + id: str = Field(default_factory=lambda: str(uuid4())) + metadata: Optional[SerializableDict] = None + + +class StudioDataset(BaseModel): + """Represents a :class:`Dataset` linked to multiple examples as sent to Studio. + + Attributes: + id: Dataset ID. + name: A short name of the dataset. + label: Labels for filtering datasets. Defaults to empty list. + metadata: Additional information about the dataset. Defaults to empty dict. + """ + + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + labels: set[str] = set() + metadata: SerializableDict = dict() + + class StudioClient: """Client for communicating with Studio. @@ -50,7 +96,6 @@ def __init__( "'AA_TOKEN' is not set and auth_token is not given as a parameter. Please provide one or the other." ) self._headers = { - "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {self._token}", } @@ -148,7 +193,7 @@ def submit_trace(self, data: Sequence[ExportedSpan]) -> str: spans belong to multiple traces. Args: - data: Spans to create the trace from. Created by exporting from a `Tracer`. + data: :class:`Spans` to create the trace from. Created by exporting from a :class:`Tracer`. Returns: The ID of the created trace. @@ -161,7 +206,7 @@ def submit_from_tracer(self, tracer: Tracer) -> list[str]: """Sends all trace data from the Tracer to Studio. Args: - tracer: Tracer to extract data from. + tracer: :class:`Tracer` to extract data from. Returns: List of created trace IDs. @@ -191,3 +236,53 @@ def _upload_trace(self, trace: ExportedSpanList) -> str: case _: response.raise_for_status() return str(response.json()) + + def submit_dataset( + self, + dataset: StudioDataset, + examples: Iterable[StudioExample[Input, ExpectedOutput]], + ) -> str: + """Submits a dataset to Studio. + + Args: + dataset: :class:`Dataset` to be uploaded + examples: :class:`Examples` of the :class:`Dataset` + + Returns: + ID of the created dataset + """ + url = urljoin(self.url, f"/api/projects/{self.project_id}/evaluation/datasets") + source_data_list = [ + example.model_dump_json() + for example in sorted(examples, key=lambda x: x.id) + ] + + source_data_file = "\n".join(source_data_list).encode() + + data = { + "name": dataset.name, + "labels": list(dataset.labels) if dataset.labels is not None else [], + "total_datapoints": len(source_data_list), + } + + if dataset.metadata: + data["metadata"] = json.dumps(dataset.metadata) + + response = requests.post( + url, + files={"source_data": source_data_file}, + data=data, + headers=self._headers, + ) + + self._raise_for_status(response) + return str(response.text) + + def _raise_for_status(self, response: requests.Response) -> None: + try: + response.raise_for_status() + except requests.HTTPError as e: + print( + f"The following error has been raised via execution {e.response.text}" + ) + raise e diff --git a/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py b/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py index e1423175e..bb0817d8a 100644 --- a/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py +++ b/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py @@ -1,16 +1,19 @@ -import json +import warnings from collections.abc import Iterable from typing import Optional -from intelligence_layer.connectors.base.json_serializable import ( +from intelligence_layer.connectors import ( SerializableDict, + StudioClient, +) +from intelligence_layer.connectors.studio.studio import ( + StudioDataset, + StudioExample, ) -from intelligence_layer.connectors.data import DataClient -from intelligence_layer.connectors.data.models import DatasetCreate from intelligence_layer.core import Input -from intelligence_layer.evaluation.dataset.dataset_repository import DatasetRepository -from intelligence_layer.evaluation.dataset.domain import ( +from intelligence_layer.evaluation import ( Dataset, + DatasetRepository, Example, ExpectedOutput, ) @@ -19,15 +22,16 @@ class StudioDatasetRepository(DatasetRepository): """Dataset repository interface with Data Platform.""" - def __init__(self, repository_id: str, data_client: DataClient) -> None: + def __init__(self, studio_client: StudioClient) -> None: """Initializes the StudioDatasetRepository. Args: - data_client: Data client to interact with the Data Platform API. - repository_id: Repository ID that identifies the repository(group of datasets). + studio_client: Client to interact with the Studio API. """ - self.data_client = data_client - self.repository_id = repository_id + self.studio_client = studio_client + warnings.warn( + "The StudioDatasetRepository is currently in beta and only supports create_dataset." + ) def create_dataset( self, @@ -53,29 +57,23 @@ def create_dataset( raise NotImplementedError( "Custom dataset IDs are not supported by the StudioDataRepository" ) - source_data_list = [ - example.model_dump_json() - for example in sorted(examples, key=lambda x: x.id) - ] - remote_dataset = self.data_client.create_dataset( - repository_id=self.repository_id, - dataset=DatasetCreate( - source_data="\n".join(source_data_list).encode(), - name=dataset_name, - labels=list(labels) if labels is not None else [], - total_datapoints=len(source_data_list), - metadata=metadata, - ), + + created_dataset = Dataset( + name=dataset_name, + labels=labels or set(), + metadata=metadata or dict(), ) - return Dataset( - id=remote_dataset.dataset_id, - name=remote_dataset.name or "", - labels=set(remote_dataset.labels) - if remote_dataset.labels is not None - else set(), - metadata=remote_dataset.metadata or dict(), + + studio_dataset = self.map_to_studio_dataset(created_dataset) + studio_examples = self.map_to_many_studio_example(examples) + + studio_dataset_id = self.studio_client.submit_dataset( + dataset=studio_dataset, examples=studio_examples ) + created_dataset.id = studio_dataset_id + return created_dataset + def delete_dataset(self, dataset_id: str) -> None: """Deletes a dataset identified by the given dataset ID. @@ -83,9 +81,7 @@ def delete_dataset(self, dataset_id: str) -> None: dataset_id: Dataset ID of the dataset to delete. """ - self.data_client.delete_dataset( - repository_id=self.repository_id, dataset_id=dataset_id - ) + raise NotImplementedError() def dataset(self, dataset_id: str) -> Optional[Dataset]: """Returns a dataset identified by the given dataset ID. @@ -96,17 +92,7 @@ def dataset(self, dataset_id: str) -> Optional[Dataset]: Returns: :class:`Dataset` if it was not, `None` otherwise. """ - remote_dataset = self.data_client.get_dataset( - repository_id=self.repository_id, dataset_id=dataset_id - ) - return Dataset( - id=remote_dataset.dataset_id, - name=remote_dataset.name or "", - labels=set(remote_dataset.labels) - if remote_dataset.labels is not None - else set(), - metadata=remote_dataset.metadata or dict(), - ) + raise NotImplementedError() def datasets(self) -> Iterable[Dataset]: """Returns all :class:`Dataset`s. Sorting is not guaranteed. @@ -114,17 +100,7 @@ def datasets(self) -> Iterable[Dataset]: Returns: :class:`Sequence` of :class:`Dataset`s. """ - for remote_dataset in self.data_client.list_datasets( - repository_id=self.repository_id - ): - yield Dataset( - id=remote_dataset.dataset_id, - name=remote_dataset.name or "", - labels=set(remote_dataset.labels) - if remote_dataset.labels is not None - else set(), - metadata=remote_dataset.metadata or dict(), - ) + raise NotImplementedError() def dataset_ids(self) -> Iterable[str]: """Returns all sorted dataset IDs. @@ -132,8 +108,7 @@ def dataset_ids(self) -> Iterable[str]: Returns: :class:`Iterable` of dataset IDs. """ - datasets = self.data_client.list_datasets(repository_id=self.repository_id) - return (dataset.dataset_id for dataset in datasets) + raise NotImplementedError() def example( self, @@ -153,14 +128,7 @@ def example( Returns: :class:`Example` if it was found, `None` otherwise. """ - stream = self.data_client.stream_dataset( - repository_id=self.repository_id, dataset_id=dataset_id - ) - for item in stream: - data = json.loads(item.decode()) - if data["id"] == example_id: - return Example[input_type, expected_output_type].model_validate(data) # type: ignore - return None + raise NotImplementedError() def examples( self, @@ -180,11 +148,17 @@ def examples( Returns: :class:`Iterable` of :class`Example`s. """ - stream = self.data_client.stream_dataset( - repository_id=self.repository_id, dataset_id=dataset_id - ) - for item in stream: - data = json.loads(item.decode()) - if examples_to_skip is not None and data["id"] in examples_to_skip: - continue - yield Example[input_type, expected_output_type].model_validate(data) # type: ignore + raise NotImplementedError() + + def map_to_studio_example( + self, example_to_map: Example[Input, ExpectedOutput] + ) -> StudioExample[Input, ExpectedOutput]: + return StudioExample(**example_to_map.model_dump()) + + def map_to_many_studio_example( + self, examples_to_map: Iterable[Example[Input, ExpectedOutput]] + ) -> Iterable[StudioExample[Input, ExpectedOutput]]: + return (self.map_to_studio_example(example) for example in examples_to_map) + + def map_to_studio_dataset(self, dataset_to_map: Dataset) -> StudioDataset: + return StudioDataset(**dataset_to_map.model_dump()) diff --git a/tests/connectors/studio/test_studio.py b/tests/connectors/studio/test_studio.py index 7bfff9682..7f2e04bd1 100644 --- a/tests/connectors/studio/test_studio.py +++ b/tests/connectors/studio/test_studio.py @@ -2,7 +2,7 @@ import time from collections.abc import Sequence from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch from uuid import uuid4 import pytest @@ -11,6 +11,13 @@ from intelligence_layer.connectors import StudioClient from intelligence_layer.core import ExportedSpan, InMemoryTracer, Task, TaskSpan +from intelligence_layer.evaluation.dataset.domain import Example +from intelligence_layer.evaluation.dataset.in_memory_dataset_repository import ( + InMemoryDatasetRepository, +) +from intelligence_layer.evaluation.dataset.studio_dataset_repository import ( + StudioDatasetRepository, +) class TracerTestSubTask(Task[None, None]): @@ -56,6 +63,29 @@ def studio_client() -> StudioClient: return client +@pytest.fixture +def mock_studio_client() -> Mock: + return Mock(spec=StudioClient) + + +@fixture +def examples() -> Sequence[Example[str, str]]: + return [ + Example(input="input_str", expected_output="output_str"), + Example(input="input_str2", expected_output="output_str2"), + ] + + +@fixture +def labels() -> set[str]: + return {"label1", "label2"} + + +@fixture +def metadata() -> dict[str, Any]: + return {"key": "value"} + + def test_cannot_connect_to_non_existing_project() -> None: project_name = "non-existing-project" with pytest.raises(ValueError, match=project_name): @@ -151,3 +181,47 @@ def test_submit_from_tracer_works_with_empty_tracer( empty_trace_id_list = studio_client.submit_from_tracer(tracer) assert len(empty_trace_id_list) == 0 + + +def test_can_upload_dataset_with_minimal_request_body( + studio_client: StudioClient, + examples: Sequence[Example[str, str]], +) -> None: + dataset_repo = InMemoryDatasetRepository() + dataset = dataset_repo.create_dataset(examples, "my_dataset") + + studio_dataset = StudioDatasetRepository(studio_client).map_to_studio_dataset( + dataset + ) + studio_examples = StudioDatasetRepository(studio_client).map_to_many_studio_example( + examples + ) + + result = studio_client.submit_dataset( + dataset=studio_dataset, examples=studio_examples + ) + assert result + + +def test_can_upload_dataset_with_complete_request_body( + studio_client: StudioClient, + examples: Sequence[Example[str, str]], + labels: set[str], + metadata: dict[str, Any], +) -> None: + dataset_repo = InMemoryDatasetRepository() + dataset = dataset_repo.create_dataset( + examples, "my_dataset", labels=labels, metadata=metadata + ) + + studio_dataset = StudioDatasetRepository(studio_client).map_to_studio_dataset( + dataset + ) + studio_examples = StudioDatasetRepository(studio_client).map_to_many_studio_example( + examples + ) + + result = studio_client.submit_dataset( + dataset=studio_dataset, examples=studio_examples + ) + assert result diff --git a/tests/evaluation/dataset/test_studio_data_repository.py b/tests/evaluation/dataset/test_studio_data_repository.py index 0d90986b5..0a33b2352 100644 --- a/tests/evaluation/dataset/test_studio_data_repository.py +++ b/tests/evaluation/dataset/test_studio_data_repository.py @@ -3,7 +3,8 @@ import pytest from pydantic import BaseModel -from intelligence_layer.connectors import DataClient, DataDataset, DatasetCreate +from intelligence_layer.connectors import DataClient, StudioClient +from intelligence_layer.connectors.studio.studio import StudioDataset from intelligence_layer.evaluation.dataset.domain import ( Dataset, Example, @@ -19,8 +20,15 @@ def mock_data_client() -> Mock: @pytest.fixture -def studio_dataset_repository(mock_data_client: Mock) -> StudioDatasetRepository: - return StudioDatasetRepository(repository_id="repo1", data_client=mock_data_client) +def mock_studio_client() -> Mock: + return Mock(spec=StudioClient) + + +@pytest.fixture +def studio_dataset_repository(mock_studio_client: Mock) -> StudioDatasetRepository: + return StudioDatasetRepository( + studio_client=mock_studio_client, + ) class InputExample(BaseModel): @@ -31,15 +39,14 @@ class ExpectedOutputExample(BaseModel): data: str -def test_create_dataset( - studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock -) -> None: - return_dataset_mock = Mock(spec=DataDataset) - return_dataset_mock.dataset_id = "dataset1" - return_dataset_mock.labels = ["label"] - return_dataset_mock.metadata = {} - return_dataset_mock.name = "Dataset 1" - mock_data_client.create_dataset.return_value = return_dataset_mock +def test_create_dataset(studio_dataset_repository: StudioDatasetRepository) -> None: + expected_dataset_id = "dataset1" + studio_dataset_repository.studio_client.submit_dataset.return_value = ( # type: ignore + expected_dataset_id + ) + + dataset_name = "Dataset 1" + dataset_labels = {"label"} examples = [ Example( @@ -55,178 +62,32 @@ def test_create_dataset( ] dataset = studio_dataset_repository.create_dataset( - examples=examples, dataset_name="Dataset 1", labels={"label"}, metadata={} + examples=examples, + dataset_name=dataset_name, + labels=dataset_labels, + metadata={}, ) assert isinstance(dataset, Dataset) - assert dataset.id == "dataset1" - assert dataset.name == "Dataset 1" - assert dataset.labels == {"label"} + assert dataset.id == expected_dataset_id + assert dataset.name == dataset_name + assert dataset.labels == dataset_labels assert dataset.metadata == {} - mock_data_client.create_dataset.assert_called_once_with( - repository_id="repo1", - dataset=DatasetCreate.model_validate( - { - "source_data": b'{"input":{"data":"input1"},"expected_output":{"data":"output1"},"id":"example1","metadata":null}\n{"input":{"data":"input2"},"expected_output":{"data":"output2"},"id":"example2","metadata":null}', - "labels": ["label"], - "name": "Dataset 1", - "total_datapoints": 2, - "metadata": {}, - } - ), - ) + studio_dataset_repository.studio_client.submit_dataset.assert_called_once() # type: ignore + actual_call = studio_dataset_repository.studio_client.submit_dataset.call_args # type: ignore + submitted_dataset = actual_call[1]["dataset"] + submitted_examples = list(actual_call[1]["examples"]) -def test_delete_dataset( - studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock -) -> None: - studio_dataset_repository.delete_dataset(dataset_id="dataset1") + expected_dataset = StudioDataset(name=dataset_name, labels=dataset_labels) - mock_data_client.delete_dataset.assert_called_once_with( - repository_id="repo1", dataset_id="dataset1" + studio_examples = list( + studio_dataset_repository.map_to_many_studio_example(examples) ) - -def test_dataset( - studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock -) -> None: - return_dataset_mock = Mock(spec=DataDataset) - return_dataset_mock.dataset_id = "dataset1" - return_dataset_mock.labels = [] - return_dataset_mock.metadata = {} - return_dataset_mock.name = "Dataset 1" - mock_data_client.get_dataset.return_value = return_dataset_mock - dataset = studio_dataset_repository.dataset(dataset_id="dataset1") - - assert isinstance(dataset, Dataset) - assert dataset.id == "dataset1" - assert dataset.name == "Dataset 1" - assert dataset.labels == set() - assert dataset.metadata == {} - - mock_data_client.get_dataset.assert_called_once_with( - repository_id="repo1", dataset_id="dataset1" - ) - - -def test_datasets( - studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock -) -> None: - return_dataset_mock = Mock(spec=DataDataset) - return_dataset_mock.dataset_id = "dataset1" - return_dataset_mock.labels = [] - return_dataset_mock.metadata = {} - return_dataset_mock.name = "Dataset 1" - - return_dataset_mock_2 = Mock(spec=DataDataset) - return_dataset_mock_2.dataset_id = "dataset2" - return_dataset_mock_2.labels = [] - return_dataset_mock_2.metadata = {} - return_dataset_mock_2.name = "Dataset 2" - - mock_data_client.list_datasets.return_value = [ - return_dataset_mock, - return_dataset_mock_2, - ] - - datasets = list(studio_dataset_repository.datasets()) - - assert len(datasets) == 2 - assert isinstance(datasets[0], Dataset) - assert datasets[0].id == "dataset1" - assert datasets[0].name == "Dataset 1" - assert datasets[0].labels == set() - assert datasets[0].metadata == {} - assert isinstance(datasets[1], Dataset) - assert datasets[1].id == "dataset2" - assert datasets[1].name == "Dataset 2" - assert datasets[1].labels == set() - assert datasets[1].metadata == {} - - mock_data_client.list_datasets.assert_called_once_with(repository_id="repo1") - - -def test_dataset_ids( - studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock -) -> None: - return_dataset_mock = Mock(spec=DataDataset) - return_dataset_mock.dataset_id = "dataset1" - return_dataset_mock.labels = ["label"] - return_dataset_mock.metadata = {} - return_dataset_mock.name = "Dataset 1" - - return_dataset_mock_2 = Mock(spec=DataDataset) - return_dataset_mock_2.dataset_id = "dataset2" - return_dataset_mock_2.labels = ["label"] - return_dataset_mock_2.metadata = {} - return_dataset_mock_2.name = "Dataset 2" - - mock_data_client.list_datasets.return_value = [ - return_dataset_mock, - return_dataset_mock_2, - ] - - dataset_ids = list(studio_dataset_repository.dataset_ids()) - - assert len(dataset_ids) == 2 - assert dataset_ids[0] == "dataset1" - assert dataset_ids[1] == "dataset2" - - mock_data_client.list_datasets.assert_called_once_with(repository_id="repo1") - - -def test_example( - studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock -) -> None: - mock_data_client.stream_dataset.return_value = [ - b'{"input": {"data": "input1"}, "expected_output": {"data": "output1"}, "id": "example1"}', - b'{"input": {"data": "input2"}, "expected_output": {"data": "output2"}, "id": "example2"}', - ] - - example = studio_dataset_repository.example( - dataset_id="dataset1", - example_id="example1", - input_type=InputExample, - expected_output_type=ExpectedOutputExample, - ) - - assert isinstance(example, Example) - assert example.input.data == "input1" - assert example.expected_output.data == "output1" - assert example.id == "example1" - - mock_data_client.stream_dataset.assert_called_once_with( - repository_id="repo1", dataset_id="dataset1" - ) - - -def test_examples( - studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock -) -> None: - mock_data_client.stream_dataset.return_value = [ - b'{"input": {"data": "input1"}, "expected_output": {"data": "output1"}, "id": "example1"}', - b'{"input": {"data": "input2"}, "expected_output": {"data": "output2"}, "id": "example2"}', - ] - - examples = list( - studio_dataset_repository.examples( - dataset_id="dataset1", - input_type=InputExample, - expected_output_type=ExpectedOutputExample, - ) - ) - - assert len(examples) == 2 - assert isinstance(examples[0], Example) - assert examples[0].input.data == "input1" - assert examples[0].expected_output.data == "output1" - assert examples[0].id == "example1" - assert isinstance(examples[1], Example) - assert examples[1].input.data == "input2" - assert examples[1].expected_output.data == "output2" - assert examples[1].id == "example2" - - mock_data_client.stream_dataset.assert_called_once_with( - repository_id="repo1", dataset_id="dataset1" - ) + # Assertions + assert submitted_dataset.labels == expected_dataset.labels + assert submitted_dataset.name == expected_dataset.name + assert submitted_dataset.metadata == expected_dataset.metadata + assert submitted_examples == studio_examples