diff --git a/CHANGELOG.md b/CHANGELOG.md
index b3d11a644..ad1970fa8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,8 @@
- `PostgresInstructionFinetuningDataRepository` to work with data stored in a Postgres database.
- `FileInstructionFinetuningDataRepository` to work with data stored in the local file-system.
- Compute precision, recall and f1-score by class in `SingleLabelClassifyAggregationLogic`
+- Add submit_dataset function to StudioClient
+ - Add `how_to_upload_existing_datasets_to_studio.ipynb` to how-tos
### Fixes
...
diff --git a/README.md b/README.md
index 566219b34..120875bc4 100644
--- a/README.md
+++ b/README.md
@@ -150,7 +150,7 @@ The how-tos are quick lookups about how to do things. Compared to the tutorials,
| [...define a task](./src/documentation/how_tos/how_to_define_a_task.ipynb) | How to come up with a new task and formulate it |
| [...implement a task](./src/documentation/how_tos/how_to_implement_a_task.ipynb) | Implement a formulated task and make it run with the Intelligence Layer |
| [...debug and log a task](./src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb) | Tools for logging and debugging in tasks |
-| [...use Studio with traces](./src/documentation/how_tos/how_to_use_studio_with_traces.ipynb) | Submitting Traces to Studio for debugging |
+| [...use Studio with traces](./src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb) | Submitting Traces to Studio for debugging |
| **Analysis Pipeline** | |
| [...implement a simple evaluation and aggregation logic](./src/documentation/how_tos/how_to_implement_a_simple_evaluation_and_aggregation_logic.ipynb) | Basic examples of evaluation and aggregation logic |
| [...create a dataset](./src/documentation/how_tos/how_to_create_a_dataset.ipynb) | Create a dataset used for running a task |
diff --git a/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb b/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb
index d64431a19..873861633 100644
--- a/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb
+++ b/src/documentation/how_tos/how_to_aggregate_evaluations.ipynb
@@ -70,7 +70,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "intelligence-layer-d3iSWYpm-py3.10",
+ "display_name": "intelligence-layer-aL2cXmJM-py3.11",
"language": "python",
"name": "python3"
},
diff --git a/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb b/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb
index a8a7c939e..1fd88c4ab 100644
--- a/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb
+++ b/src/documentation/how_tos/how_to_log_and_debug_a_task.ipynb
@@ -90,7 +90,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "intelligence-layer-d3iSWYpm-py3.10",
+ "display_name": ".venv",
"language": "python",
"name": "python3"
},
diff --git a/src/documentation/how_tos/studio/how_to_upload_existing_datasets_to_studio.ipynb b/src/documentation/how_tos/studio/how_to_upload_existing_datasets_to_studio.ipynb
new file mode 100644
index 000000000..83eed05dd
--- /dev/null
+++ b/src/documentation/how_tos/studio/how_to_upload_existing_datasets_to_studio.ipynb
@@ -0,0 +1,102 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from uuid import uuid4\n",
+ "\n",
+ "from documentation.how_tos.example_data import example_data\n",
+ "from intelligence_layer.connectors import StudioClient\n",
+ "from intelligence_layer.evaluation.dataset.studio_dataset_repository import (\n",
+ " StudioDatasetRepository,\n",
+ ")\n",
+ "\n",
+ "my_example_data = example_data()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# How to upload (existing) datasets to Studio\n",
+ "
\n",
+ "\n",
+ "Make sure your account has permissions to use the Studio application.\n",
+ "\n",
+ "For an on-prem or local installation, please contact the corresponding team.\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "0. Extract `Dataset` and `Examples` from your `DatasetRepository`.\n",
+ "\n",
+ "1. Initialize a `StudioClient` with a project.\n",
+ " - Use an existing project or create a new one with the `StudioClient.create_project` function.\n",
+ " \n",
+ "2. Create a `StudioDatasetRepository` and create a new `Dataset` via `StudioDatasetRepository.create_dataset`, which will automatically upload this new `Dataset` to Studio.\n",
+ "\n",
+ "### Example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Step 0\n",
+ "existing_dataset_repo = my_example_data.dataset_repository\n",
+ "\n",
+ "existing_dataset = existing_dataset_repo.dataset(dataset_id=my_example_data.dataset.id)\n",
+ "assert existing_dataset, \"Make sure your dataset still exists.\"\n",
+ "\n",
+ "existing_examples = existing_dataset_repo.examples(\n",
+ " existing_dataset.id, input_type=str, expected_output_type=str\n",
+ ")\n",
+ "\n",
+ "# Step 1\n",
+ "project_name = str(uuid4())\n",
+ "studio_client = StudioClient(project=project_name)\n",
+ "my_project = studio_client.create_project(project=project_name)\n",
+ "\n",
+ "# Step 2\n",
+ "studio_dataset_repo = StudioDatasetRepository(studio_client=studio_client)\n",
+ "\n",
+ "studio_dataset = studio_dataset_repo.create_dataset(\n",
+ " examples=existing_examples,\n",
+ " dataset_name=existing_dataset.name,\n",
+ " labels=existing_dataset.labels,\n",
+ " metadata=existing_dataset.metadata,\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "intelligence-layer-aL2cXmJM-py3.11",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/documentation/how_tos/how_to_use_studio_with_traces.ipynb b/src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb
similarity index 99%
rename from src/documentation/how_tos/how_to_use_studio_with_traces.ipynb
rename to src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb
index 401895c50..b9d8695c8 100644
--- a/src/documentation/how_tos/how_to_use_studio_with_traces.ipynb
+++ b/src/documentation/how_tos/studio/how_to_use_studio_with_traces.ipynb
@@ -88,7 +88,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.2"
+ "version": "3.11.8"
}
},
"nbformat": 4,
diff --git a/src/intelligence_layer/connectors/studio/studio.py b/src/intelligence_layer/connectors/studio/studio.py
index 6a7848caa..29e3aaa54 100644
--- a/src/intelligence_layer/connectors/studio/studio.py
+++ b/src/intelligence_layer/connectors/studio/studio.py
@@ -1,25 +1,71 @@
+import json
import os
from collections import defaultdict
-from collections.abc import Sequence
-from typing import Optional
+from collections.abc import Iterable, Sequence
+from typing import Generic, Optional, TypeVar
from urllib.parse import urljoin
+from uuid import uuid4
import requests
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from requests.exceptions import ConnectionError, MissingSchema
+from intelligence_layer.connectors.base.json_serializable import (
+ SerializableDict,
+)
from intelligence_layer.core.tracer.tracer import ( # Import to be fixed with PHS-731
ExportedSpan,
ExportedSpanList,
+ PydanticSerializable,
Tracer,
)
+Input = TypeVar("Input", bound=PydanticSerializable)
+ExpectedOutput = TypeVar("ExpectedOutput", bound=PydanticSerializable)
+
class StudioProject(BaseModel):
name: str
description: Optional[str]
+class StudioExample(BaseModel, Generic[Input, ExpectedOutput]):
+ """Represents an instance of :class:`Example`as sent to Studio.
+
+ Attributes:
+ input: Input for the :class:`Task`. Has to be same type as the input for the task used.
+ expected_output: The expected output from a given example run.
+ This will be used by the evaluator to compare the received output with.
+ id: Identifier for the example, defaults to uuid.
+ metadata: Optional dictionary of custom key-value pairs.
+
+ Generics:
+ Input: Interface to be passed to the :class:`Task` that shall be evaluated.
+ ExpectedOutput: Output that is expected from the run with the supplied input.
+ """
+
+ input: Input
+ expected_output: ExpectedOutput
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ metadata: Optional[SerializableDict] = None
+
+
+class StudioDataset(BaseModel):
+ """Represents a :class:`Dataset` linked to multiple examples as sent to Studio.
+
+ Attributes:
+ id: Dataset ID.
+ name: A short name of the dataset.
+ label: Labels for filtering datasets. Defaults to empty list.
+ metadata: Additional information about the dataset. Defaults to empty dict.
+ """
+
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ name: str
+ labels: set[str] = set()
+ metadata: SerializableDict = dict()
+
+
class StudioClient:
"""Client for communicating with Studio.
@@ -50,7 +96,6 @@ def __init__(
"'AA_TOKEN' is not set and auth_token is not given as a parameter. Please provide one or the other."
)
self._headers = {
- "Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {self._token}",
}
@@ -148,7 +193,7 @@ def submit_trace(self, data: Sequence[ExportedSpan]) -> str:
spans belong to multiple traces.
Args:
- data: Spans to create the trace from. Created by exporting from a `Tracer`.
+ data: :class:`Spans` to create the trace from. Created by exporting from a :class:`Tracer`.
Returns:
The ID of the created trace.
@@ -161,7 +206,7 @@ def submit_from_tracer(self, tracer: Tracer) -> list[str]:
"""Sends all trace data from the Tracer to Studio.
Args:
- tracer: Tracer to extract data from.
+ tracer: :class:`Tracer` to extract data from.
Returns:
List of created trace IDs.
@@ -191,3 +236,53 @@ def _upload_trace(self, trace: ExportedSpanList) -> str:
case _:
response.raise_for_status()
return str(response.json())
+
+ def submit_dataset(
+ self,
+ dataset: StudioDataset,
+ examples: Iterable[StudioExample[Input, ExpectedOutput]],
+ ) -> str:
+ """Submits a dataset to Studio.
+
+ Args:
+ dataset: :class:`Dataset` to be uploaded
+ examples: :class:`Examples` of the :class:`Dataset`
+
+ Returns:
+ ID of the created dataset
+ """
+ url = urljoin(self.url, f"/api/projects/{self.project_id}/evaluation/datasets")
+ source_data_list = [
+ example.model_dump_json()
+ for example in sorted(examples, key=lambda x: x.id)
+ ]
+
+ source_data_file = "\n".join(source_data_list).encode()
+
+ data = {
+ "name": dataset.name,
+ "labels": list(dataset.labels) if dataset.labels is not None else [],
+ "total_datapoints": len(source_data_list),
+ }
+
+ if dataset.metadata:
+ data["metadata"] = json.dumps(dataset.metadata)
+
+ response = requests.post(
+ url,
+ files={"source_data": source_data_file},
+ data=data,
+ headers=self._headers,
+ )
+
+ self._raise_for_status(response)
+ return str(response.text)
+
+ def _raise_for_status(self, response: requests.Response) -> None:
+ try:
+ response.raise_for_status()
+ except requests.HTTPError as e:
+ print(
+ f"The following error has been raised via execution {e.response.text}"
+ )
+ raise e
diff --git a/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py b/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py
index e1423175e..bb0817d8a 100644
--- a/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py
+++ b/src/intelligence_layer/evaluation/dataset/studio_dataset_repository.py
@@ -1,16 +1,19 @@
-import json
+import warnings
from collections.abc import Iterable
from typing import Optional
-from intelligence_layer.connectors.base.json_serializable import (
+from intelligence_layer.connectors import (
SerializableDict,
+ StudioClient,
+)
+from intelligence_layer.connectors.studio.studio import (
+ StudioDataset,
+ StudioExample,
)
-from intelligence_layer.connectors.data import DataClient
-from intelligence_layer.connectors.data.models import DatasetCreate
from intelligence_layer.core import Input
-from intelligence_layer.evaluation.dataset.dataset_repository import DatasetRepository
-from intelligence_layer.evaluation.dataset.domain import (
+from intelligence_layer.evaluation import (
Dataset,
+ DatasetRepository,
Example,
ExpectedOutput,
)
@@ -19,15 +22,16 @@
class StudioDatasetRepository(DatasetRepository):
"""Dataset repository interface with Data Platform."""
- def __init__(self, repository_id: str, data_client: DataClient) -> None:
+ def __init__(self, studio_client: StudioClient) -> None:
"""Initializes the StudioDatasetRepository.
Args:
- data_client: Data client to interact with the Data Platform API.
- repository_id: Repository ID that identifies the repository(group of datasets).
+ studio_client: Client to interact with the Studio API.
"""
- self.data_client = data_client
- self.repository_id = repository_id
+ self.studio_client = studio_client
+ warnings.warn(
+ "The StudioDatasetRepository is currently in beta and only supports create_dataset."
+ )
def create_dataset(
self,
@@ -53,29 +57,23 @@ def create_dataset(
raise NotImplementedError(
"Custom dataset IDs are not supported by the StudioDataRepository"
)
- source_data_list = [
- example.model_dump_json()
- for example in sorted(examples, key=lambda x: x.id)
- ]
- remote_dataset = self.data_client.create_dataset(
- repository_id=self.repository_id,
- dataset=DatasetCreate(
- source_data="\n".join(source_data_list).encode(),
- name=dataset_name,
- labels=list(labels) if labels is not None else [],
- total_datapoints=len(source_data_list),
- metadata=metadata,
- ),
+
+ created_dataset = Dataset(
+ name=dataset_name,
+ labels=labels or set(),
+ metadata=metadata or dict(),
)
- return Dataset(
- id=remote_dataset.dataset_id,
- name=remote_dataset.name or "",
- labels=set(remote_dataset.labels)
- if remote_dataset.labels is not None
- else set(),
- metadata=remote_dataset.metadata or dict(),
+
+ studio_dataset = self.map_to_studio_dataset(created_dataset)
+ studio_examples = self.map_to_many_studio_example(examples)
+
+ studio_dataset_id = self.studio_client.submit_dataset(
+ dataset=studio_dataset, examples=studio_examples
)
+ created_dataset.id = studio_dataset_id
+ return created_dataset
+
def delete_dataset(self, dataset_id: str) -> None:
"""Deletes a dataset identified by the given dataset ID.
@@ -83,9 +81,7 @@ def delete_dataset(self, dataset_id: str) -> None:
dataset_id: Dataset ID of the dataset to delete.
"""
- self.data_client.delete_dataset(
- repository_id=self.repository_id, dataset_id=dataset_id
- )
+ raise NotImplementedError()
def dataset(self, dataset_id: str) -> Optional[Dataset]:
"""Returns a dataset identified by the given dataset ID.
@@ -96,17 +92,7 @@ def dataset(self, dataset_id: str) -> Optional[Dataset]:
Returns:
:class:`Dataset` if it was not, `None` otherwise.
"""
- remote_dataset = self.data_client.get_dataset(
- repository_id=self.repository_id, dataset_id=dataset_id
- )
- return Dataset(
- id=remote_dataset.dataset_id,
- name=remote_dataset.name or "",
- labels=set(remote_dataset.labels)
- if remote_dataset.labels is not None
- else set(),
- metadata=remote_dataset.metadata or dict(),
- )
+ raise NotImplementedError()
def datasets(self) -> Iterable[Dataset]:
"""Returns all :class:`Dataset`s. Sorting is not guaranteed.
@@ -114,17 +100,7 @@ def datasets(self) -> Iterable[Dataset]:
Returns:
:class:`Sequence` of :class:`Dataset`s.
"""
- for remote_dataset in self.data_client.list_datasets(
- repository_id=self.repository_id
- ):
- yield Dataset(
- id=remote_dataset.dataset_id,
- name=remote_dataset.name or "",
- labels=set(remote_dataset.labels)
- if remote_dataset.labels is not None
- else set(),
- metadata=remote_dataset.metadata or dict(),
- )
+ raise NotImplementedError()
def dataset_ids(self) -> Iterable[str]:
"""Returns all sorted dataset IDs.
@@ -132,8 +108,7 @@ def dataset_ids(self) -> Iterable[str]:
Returns:
:class:`Iterable` of dataset IDs.
"""
- datasets = self.data_client.list_datasets(repository_id=self.repository_id)
- return (dataset.dataset_id for dataset in datasets)
+ raise NotImplementedError()
def example(
self,
@@ -153,14 +128,7 @@ def example(
Returns:
:class:`Example` if it was found, `None` otherwise.
"""
- stream = self.data_client.stream_dataset(
- repository_id=self.repository_id, dataset_id=dataset_id
- )
- for item in stream:
- data = json.loads(item.decode())
- if data["id"] == example_id:
- return Example[input_type, expected_output_type].model_validate(data) # type: ignore
- return None
+ raise NotImplementedError()
def examples(
self,
@@ -180,11 +148,17 @@ def examples(
Returns:
:class:`Iterable` of :class`Example`s.
"""
- stream = self.data_client.stream_dataset(
- repository_id=self.repository_id, dataset_id=dataset_id
- )
- for item in stream:
- data = json.loads(item.decode())
- if examples_to_skip is not None and data["id"] in examples_to_skip:
- continue
- yield Example[input_type, expected_output_type].model_validate(data) # type: ignore
+ raise NotImplementedError()
+
+ def map_to_studio_example(
+ self, example_to_map: Example[Input, ExpectedOutput]
+ ) -> StudioExample[Input, ExpectedOutput]:
+ return StudioExample(**example_to_map.model_dump())
+
+ def map_to_many_studio_example(
+ self, examples_to_map: Iterable[Example[Input, ExpectedOutput]]
+ ) -> Iterable[StudioExample[Input, ExpectedOutput]]:
+ return (self.map_to_studio_example(example) for example in examples_to_map)
+
+ def map_to_studio_dataset(self, dataset_to_map: Dataset) -> StudioDataset:
+ return StudioDataset(**dataset_to_map.model_dump())
diff --git a/tests/connectors/studio/test_studio.py b/tests/connectors/studio/test_studio.py
index 7bfff9682..7f2e04bd1 100644
--- a/tests/connectors/studio/test_studio.py
+++ b/tests/connectors/studio/test_studio.py
@@ -2,7 +2,7 @@
import time
from collections.abc import Sequence
from typing import Any
-from unittest.mock import patch
+from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
@@ -11,6 +11,13 @@
from intelligence_layer.connectors import StudioClient
from intelligence_layer.core import ExportedSpan, InMemoryTracer, Task, TaskSpan
+from intelligence_layer.evaluation.dataset.domain import Example
+from intelligence_layer.evaluation.dataset.in_memory_dataset_repository import (
+ InMemoryDatasetRepository,
+)
+from intelligence_layer.evaluation.dataset.studio_dataset_repository import (
+ StudioDatasetRepository,
+)
class TracerTestSubTask(Task[None, None]):
@@ -56,6 +63,29 @@ def studio_client() -> StudioClient:
return client
+@pytest.fixture
+def mock_studio_client() -> Mock:
+ return Mock(spec=StudioClient)
+
+
+@fixture
+def examples() -> Sequence[Example[str, str]]:
+ return [
+ Example(input="input_str", expected_output="output_str"),
+ Example(input="input_str2", expected_output="output_str2"),
+ ]
+
+
+@fixture
+def labels() -> set[str]:
+ return {"label1", "label2"}
+
+
+@fixture
+def metadata() -> dict[str, Any]:
+ return {"key": "value"}
+
+
def test_cannot_connect_to_non_existing_project() -> None:
project_name = "non-existing-project"
with pytest.raises(ValueError, match=project_name):
@@ -151,3 +181,47 @@ def test_submit_from_tracer_works_with_empty_tracer(
empty_trace_id_list = studio_client.submit_from_tracer(tracer)
assert len(empty_trace_id_list) == 0
+
+
+def test_can_upload_dataset_with_minimal_request_body(
+ studio_client: StudioClient,
+ examples: Sequence[Example[str, str]],
+) -> None:
+ dataset_repo = InMemoryDatasetRepository()
+ dataset = dataset_repo.create_dataset(examples, "my_dataset")
+
+ studio_dataset = StudioDatasetRepository(studio_client).map_to_studio_dataset(
+ dataset
+ )
+ studio_examples = StudioDatasetRepository(studio_client).map_to_many_studio_example(
+ examples
+ )
+
+ result = studio_client.submit_dataset(
+ dataset=studio_dataset, examples=studio_examples
+ )
+ assert result
+
+
+def test_can_upload_dataset_with_complete_request_body(
+ studio_client: StudioClient,
+ examples: Sequence[Example[str, str]],
+ labels: set[str],
+ metadata: dict[str, Any],
+) -> None:
+ dataset_repo = InMemoryDatasetRepository()
+ dataset = dataset_repo.create_dataset(
+ examples, "my_dataset", labels=labels, metadata=metadata
+ )
+
+ studio_dataset = StudioDatasetRepository(studio_client).map_to_studio_dataset(
+ dataset
+ )
+ studio_examples = StudioDatasetRepository(studio_client).map_to_many_studio_example(
+ examples
+ )
+
+ result = studio_client.submit_dataset(
+ dataset=studio_dataset, examples=studio_examples
+ )
+ assert result
diff --git a/tests/evaluation/dataset/test_studio_data_repository.py b/tests/evaluation/dataset/test_studio_data_repository.py
index 0d90986b5..0a33b2352 100644
--- a/tests/evaluation/dataset/test_studio_data_repository.py
+++ b/tests/evaluation/dataset/test_studio_data_repository.py
@@ -3,7 +3,8 @@
import pytest
from pydantic import BaseModel
-from intelligence_layer.connectors import DataClient, DataDataset, DatasetCreate
+from intelligence_layer.connectors import DataClient, StudioClient
+from intelligence_layer.connectors.studio.studio import StudioDataset
from intelligence_layer.evaluation.dataset.domain import (
Dataset,
Example,
@@ -19,8 +20,15 @@ def mock_data_client() -> Mock:
@pytest.fixture
-def studio_dataset_repository(mock_data_client: Mock) -> StudioDatasetRepository:
- return StudioDatasetRepository(repository_id="repo1", data_client=mock_data_client)
+def mock_studio_client() -> Mock:
+ return Mock(spec=StudioClient)
+
+
+@pytest.fixture
+def studio_dataset_repository(mock_studio_client: Mock) -> StudioDatasetRepository:
+ return StudioDatasetRepository(
+ studio_client=mock_studio_client,
+ )
class InputExample(BaseModel):
@@ -31,15 +39,14 @@ class ExpectedOutputExample(BaseModel):
data: str
-def test_create_dataset(
- studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock
-) -> None:
- return_dataset_mock = Mock(spec=DataDataset)
- return_dataset_mock.dataset_id = "dataset1"
- return_dataset_mock.labels = ["label"]
- return_dataset_mock.metadata = {}
- return_dataset_mock.name = "Dataset 1"
- mock_data_client.create_dataset.return_value = return_dataset_mock
+def test_create_dataset(studio_dataset_repository: StudioDatasetRepository) -> None:
+ expected_dataset_id = "dataset1"
+ studio_dataset_repository.studio_client.submit_dataset.return_value = ( # type: ignore
+ expected_dataset_id
+ )
+
+ dataset_name = "Dataset 1"
+ dataset_labels = {"label"}
examples = [
Example(
@@ -55,178 +62,32 @@ def test_create_dataset(
]
dataset = studio_dataset_repository.create_dataset(
- examples=examples, dataset_name="Dataset 1", labels={"label"}, metadata={}
+ examples=examples,
+ dataset_name=dataset_name,
+ labels=dataset_labels,
+ metadata={},
)
assert isinstance(dataset, Dataset)
- assert dataset.id == "dataset1"
- assert dataset.name == "Dataset 1"
- assert dataset.labels == {"label"}
+ assert dataset.id == expected_dataset_id
+ assert dataset.name == dataset_name
+ assert dataset.labels == dataset_labels
assert dataset.metadata == {}
- mock_data_client.create_dataset.assert_called_once_with(
- repository_id="repo1",
- dataset=DatasetCreate.model_validate(
- {
- "source_data": b'{"input":{"data":"input1"},"expected_output":{"data":"output1"},"id":"example1","metadata":null}\n{"input":{"data":"input2"},"expected_output":{"data":"output2"},"id":"example2","metadata":null}',
- "labels": ["label"],
- "name": "Dataset 1",
- "total_datapoints": 2,
- "metadata": {},
- }
- ),
- )
+ studio_dataset_repository.studio_client.submit_dataset.assert_called_once() # type: ignore
+ actual_call = studio_dataset_repository.studio_client.submit_dataset.call_args # type: ignore
+ submitted_dataset = actual_call[1]["dataset"]
+ submitted_examples = list(actual_call[1]["examples"])
-def test_delete_dataset(
- studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock
-) -> None:
- studio_dataset_repository.delete_dataset(dataset_id="dataset1")
+ expected_dataset = StudioDataset(name=dataset_name, labels=dataset_labels)
- mock_data_client.delete_dataset.assert_called_once_with(
- repository_id="repo1", dataset_id="dataset1"
+ studio_examples = list(
+ studio_dataset_repository.map_to_many_studio_example(examples)
)
-
-def test_dataset(
- studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock
-) -> None:
- return_dataset_mock = Mock(spec=DataDataset)
- return_dataset_mock.dataset_id = "dataset1"
- return_dataset_mock.labels = []
- return_dataset_mock.metadata = {}
- return_dataset_mock.name = "Dataset 1"
- mock_data_client.get_dataset.return_value = return_dataset_mock
- dataset = studio_dataset_repository.dataset(dataset_id="dataset1")
-
- assert isinstance(dataset, Dataset)
- assert dataset.id == "dataset1"
- assert dataset.name == "Dataset 1"
- assert dataset.labels == set()
- assert dataset.metadata == {}
-
- mock_data_client.get_dataset.assert_called_once_with(
- repository_id="repo1", dataset_id="dataset1"
- )
-
-
-def test_datasets(
- studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock
-) -> None:
- return_dataset_mock = Mock(spec=DataDataset)
- return_dataset_mock.dataset_id = "dataset1"
- return_dataset_mock.labels = []
- return_dataset_mock.metadata = {}
- return_dataset_mock.name = "Dataset 1"
-
- return_dataset_mock_2 = Mock(spec=DataDataset)
- return_dataset_mock_2.dataset_id = "dataset2"
- return_dataset_mock_2.labels = []
- return_dataset_mock_2.metadata = {}
- return_dataset_mock_2.name = "Dataset 2"
-
- mock_data_client.list_datasets.return_value = [
- return_dataset_mock,
- return_dataset_mock_2,
- ]
-
- datasets = list(studio_dataset_repository.datasets())
-
- assert len(datasets) == 2
- assert isinstance(datasets[0], Dataset)
- assert datasets[0].id == "dataset1"
- assert datasets[0].name == "Dataset 1"
- assert datasets[0].labels == set()
- assert datasets[0].metadata == {}
- assert isinstance(datasets[1], Dataset)
- assert datasets[1].id == "dataset2"
- assert datasets[1].name == "Dataset 2"
- assert datasets[1].labels == set()
- assert datasets[1].metadata == {}
-
- mock_data_client.list_datasets.assert_called_once_with(repository_id="repo1")
-
-
-def test_dataset_ids(
- studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock
-) -> None:
- return_dataset_mock = Mock(spec=DataDataset)
- return_dataset_mock.dataset_id = "dataset1"
- return_dataset_mock.labels = ["label"]
- return_dataset_mock.metadata = {}
- return_dataset_mock.name = "Dataset 1"
-
- return_dataset_mock_2 = Mock(spec=DataDataset)
- return_dataset_mock_2.dataset_id = "dataset2"
- return_dataset_mock_2.labels = ["label"]
- return_dataset_mock_2.metadata = {}
- return_dataset_mock_2.name = "Dataset 2"
-
- mock_data_client.list_datasets.return_value = [
- return_dataset_mock,
- return_dataset_mock_2,
- ]
-
- dataset_ids = list(studio_dataset_repository.dataset_ids())
-
- assert len(dataset_ids) == 2
- assert dataset_ids[0] == "dataset1"
- assert dataset_ids[1] == "dataset2"
-
- mock_data_client.list_datasets.assert_called_once_with(repository_id="repo1")
-
-
-def test_example(
- studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock
-) -> None:
- mock_data_client.stream_dataset.return_value = [
- b'{"input": {"data": "input1"}, "expected_output": {"data": "output1"}, "id": "example1"}',
- b'{"input": {"data": "input2"}, "expected_output": {"data": "output2"}, "id": "example2"}',
- ]
-
- example = studio_dataset_repository.example(
- dataset_id="dataset1",
- example_id="example1",
- input_type=InputExample,
- expected_output_type=ExpectedOutputExample,
- )
-
- assert isinstance(example, Example)
- assert example.input.data == "input1"
- assert example.expected_output.data == "output1"
- assert example.id == "example1"
-
- mock_data_client.stream_dataset.assert_called_once_with(
- repository_id="repo1", dataset_id="dataset1"
- )
-
-
-def test_examples(
- studio_dataset_repository: StudioDatasetRepository, mock_data_client: Mock
-) -> None:
- mock_data_client.stream_dataset.return_value = [
- b'{"input": {"data": "input1"}, "expected_output": {"data": "output1"}, "id": "example1"}',
- b'{"input": {"data": "input2"}, "expected_output": {"data": "output2"}, "id": "example2"}',
- ]
-
- examples = list(
- studio_dataset_repository.examples(
- dataset_id="dataset1",
- input_type=InputExample,
- expected_output_type=ExpectedOutputExample,
- )
- )
-
- assert len(examples) == 2
- assert isinstance(examples[0], Example)
- assert examples[0].input.data == "input1"
- assert examples[0].expected_output.data == "output1"
- assert examples[0].id == "example1"
- assert isinstance(examples[1], Example)
- assert examples[1].input.data == "input2"
- assert examples[1].expected_output.data == "output2"
- assert examples[1].id == "example2"
-
- mock_data_client.stream_dataset.assert_called_once_with(
- repository_id="repo1", dataset_id="dataset1"
- )
+ # Assertions
+ assert submitted_dataset.labels == expected_dataset.labels
+ assert submitted_dataset.name == expected_dataset.name
+ assert submitted_dataset.metadata == expected_dataset.metadata
+ assert submitted_examples == studio_examples