Skip to content

Commit

Permalink
feat: Include batch upload of benchmark lineages
Browse files Browse the repository at this point in the history
TASK: PHS-885
  • Loading branch information
MerlinKallenbornAA committed Nov 28, 2024
1 parent 922bee6 commit 10afe8b
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 51 deletions.
57 changes: 50 additions & 7 deletions src/intelligence_layer/connectors/studio/studio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gzip
import json
import os
from collections import defaultdict, deque
Expand Down Expand Up @@ -161,9 +162,6 @@ class PostBenchmarkLineagesResponse(RootModel[Sequence[str]]):
pass


# class PostBenchmarkLineageRequest


class StudioClient:
"""Client for communicating with Studio.
Expand Down Expand Up @@ -493,22 +491,67 @@ def submit_benchmark_lineages(
benchmark_lineages: Sequence[BenchmarkLineage],
benchmark_id: str,
execution_id: str,
max_payload_size: int = 50 * 1024 * 1024,
) -> PostBenchmarkLineagesResponse:
all_responses = []
remaining_lineages = list(benchmark_lineages)

converted_lineage_sizes = [
len(lineage.model_dump_json().encode("utf-8"))
for lineage in benchmark_lineages
]

while remaining_lineages:
batch = []
current_size = 0
# Build batch while checking size
for lineage, size in zip(
remaining_lineages, converted_lineage_sizes, strict=True
):
if current_size + size <= max_payload_size:
batch.append(lineage)
current_size += size
else:
break

if batch:
# Send batch
response = self._send_compressed_batch(
batch, benchmark_id, execution_id
)
all_responses.extend(response)

else: # Only reached if a lineage is too big for the request
print("Lineage exceeds maximum of upload size", lineage)
batch.append(lineage)

remaining_lineages = remaining_lineages[len(batch) :]
converted_lineage_sizes = converted_lineage_sizes[len(batch) :]

return PostBenchmarkLineagesResponse(all_responses)

def _send_compressed_batch(
self, batch: list[BenchmarkLineage], benchmark_id: str, execution_id: str
) -> list[str]:
url = urljoin(
self.url,
f"/api/projects/{self.project_id}/evaluation/benchmarks/{benchmark_id}/executions/{execution_id}/lineages",
)

request_data = self._create_post_bechnmark_lineages_request(benchmark_lineages)
request_data = self._create_post_bechnmark_lineages_request(batch)
json_data = request_data.model_dump_json()
compressed_data = gzip.compress(json_data.encode("utf-8"))

headers = {**self._headers, "Content-Encoding": "gzip"}

response = requests.post(
url,
headers=self._headers,
data=request_data.model_dump_json(),
headers=headers,
data=compressed_data,
)

self._raise_for_status(response)
return PostBenchmarkLineagesResponse(response.json())
return response.json()

def _create_post_bechnmark_lineages_request(
self, benchmark_lineages: Sequence[BenchmarkLineage]
Expand Down
177 changes: 133 additions & 44 deletions tests/connectors/studio/test_studio_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,31 @@ def studio_dataset(
return studio_client.submit_dataset(StudioDataset(name="dataset_name"), examples)


@fixture
def post_benchmark_execution() -> PostBenchmarkExecution:
return PostBenchmarkExecution(
name="name",
description="Test benchmark execution",
labels={"performance", "testing"},
metadata={"project": "AI Testing", "team": "QA"},
start=datetime.now(),
end=datetime.now(),
run_start=datetime.now(),
run_end=datetime.now(),
run_successful_count=10,
run_failed_count=2,
run_success_avg_latency=120,
run_success_avg_token_count=300,
eval_start=datetime.now(),
eval_end=datetime.now(),
eval_successful_count=8,
eval_failed_count=1,
aggregation_start=datetime.now(),
aggregation_end=datetime.now(),
statistics=DummyAggregatedEvaluation(score=1.0).model_dump_json(),
)


@fixture
def evaluation_logic_identifier() -> EvaluationLogicIdentifier:
return create_evaluation_logic_identifier(DummyEvaluationLogic())
Expand Down Expand Up @@ -126,6 +151,21 @@ def with_uploaded_benchmark(
return benchmark_id


@fixture
def with_uploaded_benchmark_execution(
studio_client: StudioClient,
studio_dataset: str,
evaluation_logic_identifier: EvaluationLogicIdentifier,
aggregation_logic_identifier: AggregationLogicIdentifier,
with_uploaded_benchmark: str,
post_benchmark_execution: PostBenchmarkExecution,
) -> str:
benchmark_execution_id = studio_client.submit_benchmark_execution(
benchmark_id=with_uploaded_benchmark, data=post_benchmark_execution
)
return benchmark_execution_id


def test_create_benchmark(
studio_client: StudioClient,
studio_dataset: str,
Expand Down Expand Up @@ -182,30 +222,11 @@ def test_get_non_existing_benchmark(studio_client: StudioClient) -> None:
def test_can_create_benchmark_execution(
studio_client: StudioClient,
with_uploaded_benchmark: str,
post_benchmark_execution: PostBenchmarkExecution,
) -> None:
benchmark_id = with_uploaded_benchmark

example_request = PostBenchmarkExecution(
name="name",
description="Test benchmark execution",
labels={"performance", "testing"},
metadata={"project": "AI Testing", "team": "QA"},
start=datetime.now(),
end=datetime.now(),
run_start=datetime.now(),
run_end=datetime.now(),
run_successful_count=10,
run_failed_count=2,
run_success_avg_latency=120,
run_success_avg_token_count=300,
eval_start=datetime.now(),
eval_end=datetime.now(),
eval_successful_count=8,
eval_failed_count=1,
aggregation_start=datetime.now(),
aggregation_end=datetime.now(),
statistics=DummyAggregatedEvaluation(score=1.0).model_dump_json(),
)
example_request = post_benchmark_execution

benchmark_execution_id = studio_client.submit_benchmark_execution(
benchmark_id=benchmark_id, data=example_request
Expand All @@ -218,35 +239,47 @@ def test_submit_benchmark_lineage_uploads_single_lineage(
studio_client: StudioClient,
with_uploaded_test_trace: str,
with_uploaded_benchmark: str,
with_uploaded_benchmark_execution: str,
) -> None:
trace_id = with_uploaded_test_trace
benchmark_id = with_uploaded_benchmark
benchmark_execution_id = with_uploaded_benchmark_execution

example_request = PostBenchmarkExecution(
name="name",
description="Test benchmark execution",
labels={"performance", "testing"},
metadata={"project": "AI Testing", "team": "QA"},
start=datetime.now(),
end=datetime.now(),
run_start=datetime.now(),
run_end=datetime.now(),
run_successful_count=10,
run_failed_count=2,
run_success_avg_latency=120,
run_success_avg_token_count=300,
eval_start=datetime.now(),
eval_end=datetime.now(),
eval_successful_count=8,
eval_failed_count=1,
aggregation_start=datetime.now(),
aggregation_end=datetime.now(),
statistics=DummyAggregatedEvaluation(score=1.0).model_dump_json(),
)
lineages = [
DummyBenchmarkLineage(
trace_id=trace_id,
input="input",
expected_output="output",
example_metadata={"key3": "value3"},
output="output",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
]

benchmark_execution_id = studio_client.submit_benchmark_execution(
benchmark_id=benchmark_id, data=example_request
lineage_ids = studio_client.submit_benchmark_lineages(
benchmark_lineages=lineages,
benchmark_id=benchmark_id,
execution_id=benchmark_execution_id,
)

assert len(lineage_ids.root) == len(lineages)
for lineage_id in lineage_ids.root:
assert UUID(lineage_id)


def test_batch_upload_sends_multiple_requests(
studio_client: StudioClient,
with_uploaded_test_trace: str,
with_uploaded_benchmark: str,
with_uploaded_benchmark_execution: str,
post_benchmark_execution: PostBenchmarkExecution,
) -> None:
trace_id = with_uploaded_test_trace
benchmark_id = with_uploaded_benchmark
benchmark_execution_id = with_uploaded_benchmark_execution

lineages = [
DummyBenchmarkLineage(
trace_id=trace_id,
Expand All @@ -258,14 +291,70 @@ def test_submit_benchmark_lineage_uploads_single_lineage(
run_latency=1,
run_tokens=3,
),
DummyBenchmarkLineage(
trace_id=trace_id,
input="input2",
expected_output="output2",
example_metadata={"key4": "value4"},
output="output2",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
]

lineage_ids = studio_client.submit_benchmark_lineages(
benchmark_lineages=lineages,
benchmark_id=benchmark_id,
execution_id=benchmark_execution_id,
max_payload_size=len(lineages[1].model_dump_json().encode("utf-8"))
+ 1, # to enforce making to requests for the lineages
)

assert len(lineage_ids.root) == len(lineages)
for lineage_id in lineage_ids.root:
assert UUID(lineage_id)


def test_submit_lineage_skips_lineages_exceeding_request_size(
studio_client: StudioClient,
with_uploaded_test_trace: str,
with_uploaded_benchmark: str,
with_uploaded_benchmark_execution: str,
) -> None:
trace_id = with_uploaded_test_trace
benchmark_id = with_uploaded_benchmark
benchmark_execution_id = with_uploaded_benchmark_execution

lineages = [
DummyBenchmarkLineage(
trace_id=trace_id,
input="input",
expected_output="output",
example_metadata={"key3": "value3"},
output="output",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
DummyBenchmarkLineage(
trace_id=trace_id,
input="input input2 input3 input4 input5",
expected_output="output output output output",
example_metadata={"key3": "value3"},
output="output output output output",
evaluation={"key5": "value5"},
run_latency=1,
run_tokens=3,
),
]

lineage_ids = studio_client.submit_benchmark_lineages(
benchmark_lineages=lineages,
benchmark_id=benchmark_id,
execution_id=benchmark_execution_id,
max_payload_size=len(lineages[0].model_dump_json().encode("utf-8"))
+ 1, # to enforce only upload of first lineage
)

assert len(lineage_ids.root) == 1

0 comments on commit 10afe8b

Please sign in to comment.