Skip to content

Commit

Permalink
feat: Add get_benchmark_lineage method for better testing
Browse files Browse the repository at this point in the history
  • Loading branch information
MerlinKallenbornAA committed Nov 29, 2024
1 parent b934c37 commit d6822f6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 87 deletions.
63 changes: 48 additions & 15 deletions src/intelligence_layer/connectors/studio/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ class PostBenchmarkLineagesResponse(RootModel[Sequence[str]]):
pass


class GetBenchmarkLineageResponse(BaseModel):
id: str
trace_id: str
benchmark_execution_id: str
input: Any
expected_output: Any
example_metadata: Optional[dict[str, Any]] = None
output: Any
evaluation: Any
run_latency: int
run_tokens: int


class StudioClient:
"""Client for communicating with Studio.
Expand Down Expand Up @@ -491,12 +504,24 @@ def submit_benchmark_lineages(
benchmark_lineages: Sequence[BenchmarkLineage],
benchmark_id: str,
execution_id: str,
max_payload_size: int = 50 * 1024 * 1024,
max_payload_size: int = 50
* 1024
* 1024, # Maximum request size handled by Studio
) -> PostBenchmarkLineagesResponse:
"""Submit benchmark lineages in batches to avoid exceeding the maximum payload size.
Args:
benchmark_lineages: List of :class: `BenchmarkLineages` to submit.
benchmark_id: ID of the benchmark.
execution_id: ID of the execution.
max_payload_size: Maximum size of the payload in bytes. Defaults to 50MB.
Returns:
Response containing the results of the submissions.
"""
all_responses = []
remaining_lineages = list(benchmark_lineages)

converted_lineage_sizes = [
lineage_sizes = [
len(lineage.model_dump_json().encode("utf-8"))
for lineage in benchmark_lineages
]
Expand All @@ -505,9 +530,7 @@ def submit_benchmark_lineages(
batch = []
current_size = 0
# Build batch while checking size
for lineage, size in zip(
remaining_lineages, converted_lineage_sizes, strict=True
):
for lineage, size in zip(remaining_lineages, lineage_sizes, strict=True):
if current_size + size <= max_payload_size:
batch.append(lineage)
current_size += size
Expand All @@ -524,12 +547,28 @@ def submit_benchmark_lineages(
else: # Only reached if a lineage is too big for the request
print("Lineage exceeds maximum of upload size", lineage)
batch.append(lineage)

remaining_lineages = remaining_lineages[len(batch) :]
converted_lineage_sizes = converted_lineage_sizes[len(batch) :]
lineage_sizes = lineage_sizes[len(batch) :]

return PostBenchmarkLineagesResponse(all_responses)

def get_benchmark_lineage(
self, benchmark_id: str, execution_id: str, lineage_id: str
) -> GetBenchmarkLineageResponse | None:
url = urljoin(
self.url,
f"/api/projects/{self.project_id}/evaluation/benchmarks/{benchmark_id}/executions/{execution_id}/lineages/{lineage_id}",
)
response = requests.get(
url,
headers=self._headers,
)
self._raise_for_status(response)
response_text = response.json()
if response_text is None:
return None
return GetBenchmarkLineageResponse.model_validate(response_text)

def _send_compressed_batch(
self, batch: list[BenchmarkLineage], benchmark_id: str, execution_id: str
) -> list[str]:
Expand All @@ -538,8 +577,7 @@ def _send_compressed_batch(
f"/api/projects/{self.project_id}/evaluation/benchmarks/{benchmark_id}/executions/{execution_id}/lineages",
)

request_data = self._create_post_bechnmark_lineages_request(batch)
json_data = request_data.model_dump_json()
json_data = PostBenchmarkLineagesRequest(root=batch).model_dump_json()
compressed_data = gzip.compress(json_data.encode("utf-8"))

headers = {**self._headers, "Content-Encoding": "gzip"}
Expand All @@ -553,11 +591,6 @@ def _send_compressed_batch(
self._raise_for_status(response)
return response.json()

def _create_post_bechnmark_lineages_request(
self, benchmark_lineages: Sequence[BenchmarkLineage]
) -> PostBenchmarkLineagesRequest:
return PostBenchmarkLineagesRequest(root=benchmark_lineages)

def _raise_for_status(self, response: requests.Response) -> None:
try:
response.raise_for_status()
Expand Down
102 changes: 30 additions & 72 deletions tests/connectors/studio/test_studio_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,21 @@ def with_uploaded_benchmark_execution(
return benchmark_execution_id


def dummy_lineage(
trace_id: str, input: str = "input", output: str = "output"
) -> DummyBenchmarkLineage:
return DummyBenchmarkLineage(
trace_id=trace_id,
input=input,
expected_output="output",
example_metadata={"key3": "value3"},
output=output,
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
)


def test_create_benchmark(
studio_client: StudioClient,
studio_dataset: str,
Expand Down Expand Up @@ -235,41 +250,7 @@ def test_can_create_benchmark_execution(
assert UUID(benchmark_execution_id)


def test_submit_benchmark_lineage_uploads_single_lineage(
studio_client: StudioClient,
with_uploaded_test_trace: str,
with_uploaded_benchmark: str,
with_uploaded_benchmark_execution: str,
) -> None:
trace_id = with_uploaded_test_trace
benchmark_id = with_uploaded_benchmark
benchmark_execution_id = with_uploaded_benchmark_execution

lineages = [
DummyBenchmarkLineage(
trace_id=trace_id,
input="input",
expected_output="output",
example_metadata={"key3": "value3"},
output="output",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
]

lineage_ids = studio_client.submit_benchmark_lineages(
benchmark_lineages=lineages,
benchmark_id=benchmark_id,
execution_id=benchmark_execution_id,
)

assert len(lineage_ids.root) == len(lineages)
for lineage_id in lineage_ids.root:
assert UUID(lineage_id)


def test_batch_upload_sends_multiple_requests(
def test_can_submit_lineages(
studio_client: StudioClient,
with_uploaded_test_trace: str,
with_uploaded_benchmark: str,
Expand All @@ -281,26 +262,10 @@ def test_batch_upload_sends_multiple_requests(
benchmark_execution_id = with_uploaded_benchmark_execution

lineages = [
DummyBenchmarkLineage(
trace_id=trace_id,
input="input",
expected_output="output",
example_metadata={"key3": "value3"},
output="output",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
DummyBenchmarkLineage(
trace_id=trace_id,
input="input2",
expected_output="output2",
example_metadata={"key4": "value4"},
output="output2",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
dummy_lineage(
trace_id,
),
dummy_lineage(trace_id, "slightly longer input", "slightly_longer_output"),
]

lineage_ids = studio_client.submit_benchmark_lineages(
Expand All @@ -327,25 +292,11 @@ def test_submit_lineage_skips_lineages_exceeding_request_size(
benchmark_execution_id = with_uploaded_benchmark_execution

lineages = [
DummyBenchmarkLineage(
trace_id=trace_id,
input="input",
expected_output="output",
example_metadata={"key3": "value3"},
output="output",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
DummyBenchmarkLineage(
trace_id=trace_id,
dummy_lineage(trace_id),
dummy_lineage(
trace_id,
input="input input2 input3 input4 input5",
expected_output="output output output output",
example_metadata={"key3": "value3"},
output="output output output output",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
]

Expand All @@ -354,7 +305,14 @@ def test_submit_lineage_skips_lineages_exceeding_request_size(
benchmark_id=benchmark_id,
execution_id=benchmark_execution_id,
max_payload_size=len(lineages[0].model_dump_json().encode("utf-8"))
+ 1, # to enforce only upload of first lineage
+ 1, # to enforce second lineage exceeds
)

fetched_lineage = studio_client.get_benchmark_lineage(
benchmark_id=benchmark_id,
execution_id=benchmark_execution_id,
lineage_id=lineage_ids.root[0],
)
assert len(lineage_ids.root) == 1
assert fetched_lineage
assert fetched_lineage.input == lineages[0].input

0 comments on commit d6822f6

Please sign in to comment.