Skip to content

Commit

Permalink
[OPIK-133] [SDK] Implement batching for trace creation (#790)
Browse files Browse the repository at this point in the history
* refactor span/trace convertion during creation processing

* add batch support for create trace message

* add trace batcher

* add trace batch handling to backend emulator

* add tests

* rename `batcher` file

* rename `as_dict` messages method

* fix as_payload_dict

* enable batching for all e2e tests

* add more tests

* fix flaky e2e test
  • Loading branch information
japdubengsub authored Dec 4, 2024
1 parent 3925650 commit 1d0d5ea
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,35 @@
from .. import messages

from . import base_batcher
from . import create_span_message_batcher
from . import batchers
from . import batch_manager

CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0
CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000


def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManager:
create_span_message_batcher_ = create_span_message_batcher.CreateSpanMessageBatcher(
create_span_message_batcher_ = batchers.CreateSpanMessageBatcher(
flush_interval_seconds=CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS,
max_batch_size=CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE,
flush_callback=message_queue.put,
)

MESSAGE_TO_BATCHER_MAPPING: Dict[
create_trace_message_batcher_ = batchers.CreateTraceMessageBatcher(
flush_interval_seconds=CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS,
max_batch_size=CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE,
flush_callback=message_queue.put,
)

message_to_batcher_mapping: Dict[
Type[messages.BaseMessage], base_batcher.BaseBatcher
] = {messages.CreateSpanMessage: create_span_message_batcher_}
] = {
messages.CreateSpanMessage: create_span_message_batcher_,
messages.CreateTraceMessage: create_trace_message_batcher_,
}

batch_manager_ = batch_manager.BatchManager(
message_to_batcher_mapping=MESSAGE_TO_BATCHER_MAPPING
message_to_batcher_mapping=message_to_batcher_mapping
)

return batch_manager_
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,10 @@ def _create_batch_from_accumulated_messages(
self,
) -> messages.CreateSpansBatchMessage:
return messages.CreateSpansBatchMessage(batch=self._accumulated_messages) # type: ignore


class CreateTraceMessageBatcher(base_batcher.BaseBatcher):
def _create_batch_from_accumulated_messages(
self,
) -> messages.CreateTraceBatchMessage:
return messages.CreateTraceBatchMessage(batch=self._accumulated_messages) # type: ignore
76 changes: 28 additions & 48 deletions sdks/python/src/opik/message_processing/message_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import messages
from ..jsonable_encoder import jsonable_encoder
from .. import dict_utils
from ..rest_api.types import feedback_score_batch_item
from ..rest_api.types import feedback_score_batch_item, trace_write
from ..rest_api.types import span_write
from ..rest_api import core as rest_api_core
from ..rest_api import client as rest_api_client
Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(self, rest_client: rest_api_client.OpikApi):
messages.AddTraceFeedbackScoresBatchMessage: self._process_add_trace_feedback_scores_batch_message, # type: ignore
messages.AddSpanFeedbackScoresBatchMessage: self._process_add_span_feedback_scores_batch_message, # type: ignore
messages.CreateSpansBatchMessage: self._process_create_span_batch_message, # type: ignore
messages.CreateTraceBatchMessage: self._process_create_trace_batch_message, # type: ignore
}

def process(self, message: messages.BaseMessage) -> None:
Expand Down Expand Up @@ -65,24 +66,7 @@ def process(self, message: messages.BaseMessage) -> None:
)

def _process_create_span_message(self, message: messages.CreateSpanMessage) -> None:
create_span_kwargs = {
"id": message.span_id,
"trace_id": message.trace_id,
"project_name": message.project_name,
"parent_span_id": message.parent_span_id,
"name": message.name,
"start_time": message.start_time,
"end_time": message.end_time,
"type": message.type,
"input": message.input,
"output": message.output,
"metadata": message.metadata,
"tags": message.tags,
"usage": message.usage,
"model": message.model,
"provider": message.provider,
}

create_span_kwargs = message.as_payload_dict()
cleaned_create_span_kwargs = dict_utils.remove_none_from_dict(
create_span_kwargs
)
Expand All @@ -93,18 +77,7 @@ def _process_create_span_message(self, message: messages.CreateSpanMessage) -> N
def _process_create_trace_message(
self, message: messages.CreateTraceMessage
) -> None:
create_trace_kwargs = {
"id": message.trace_id,
"name": message.name,
"project_name": message.project_name,
"start_time": message.start_time,
"end_time": message.end_time,
"input": message.input,
"output": message.output,
"metadata": message.metadata,
"tags": message.tags,
}

create_trace_kwargs = message.as_payload_dict()
cleaned_create_trace_kwargs = dict_utils.remove_none_from_dict(
create_trace_kwargs
)
Expand Down Expand Up @@ -192,23 +165,7 @@ def _process_create_span_batch_message(
rest_spans: List[span_write.SpanWrite] = []

for item in message.batch:
span_write_kwargs = {
"id": item.span_id,
"trace_id": item.trace_id,
"project_name": item.project_name,
"parent_span_id": item.parent_span_id,
"name": item.name,
"start_time": item.start_time,
"end_time": item.end_time,
"type": item.type,
"input": item.input,
"output": item.output,
"metadata": item.metadata,
"tags": item.tags,
"usage": item.usage,
"model": item.model,
"provider": item.provider,
}
span_write_kwargs = item.as_payload_dict()
cleaned_span_write_kwargs = dict_utils.remove_none_from_dict(
span_write_kwargs
)
Expand All @@ -224,3 +181,26 @@ def _process_create_span_batch_message(
LOGGER.debug("Create spans batch request of size %d", len(batch))
self._rest_client.spans.create_spans(spans=batch)
LOGGER.debug("Sent spans batch of size %d", len(batch))

def _process_create_trace_batch_message(
self, message: messages.CreateTraceBatchMessage
) -> None:
rest_traces: List[trace_write.TraceWrite] = []

for item in message.batch:
trace_write_kwargs = item.as_payload_dict()
cleaned_trace_write_kwargs = dict_utils.remove_none_from_dict(
trace_write_kwargs
)
cleaned_trace_write_kwargs = jsonable_encoder(cleaned_trace_write_kwargs)
rest_traces.append(trace_write.TraceWrite(**cleaned_trace_write_kwargs))

memory_limited_batches = sequence_splitter.split_into_batches(
items=rest_traces,
max_payload_size_MB=BATCH_MEMORY_LIMIT_MB,
)

for batch in memory_limited_batches:
LOGGER.debug("Create trace batch request of size %d", len(batch))
self._rest_client.traces.create_traces(traces=batch)
LOGGER.debug("Sent trace batch of size %d", len(batch))
20 changes: 19 additions & 1 deletion sdks/python/src/opik/message_processing/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

@dataclasses.dataclass
class BaseMessage:
pass
def as_payload_dict(self) -> Dict[str, Any]:
# we are not using dataclasses.as_dict() here
# because it will try to deepcopy all object and will fail if there is non-serializable object
return {**self.__dict__}


@dataclasses.dataclass
Expand All @@ -21,6 +24,11 @@ class CreateTraceMessage(BaseMessage):
metadata: Optional[Dict[str, Any]]
tags: Optional[List[str]]

def as_payload_dict(self) -> Dict[str, Any]:
data = super().as_payload_dict()
data["id"] = data.pop("trace_id")
return data


@dataclasses.dataclass
class UpdateTraceMessage(BaseMessage):
Expand Down Expand Up @@ -51,6 +59,11 @@ class CreateSpanMessage(BaseMessage):
model: Optional[str]
provider: Optional[str]

def as_payload_dict(self) -> Dict[str, Any]:
data = super().as_payload_dict()
data["id"] = data.pop("span_id")
return data


@dataclasses.dataclass
class UpdateSpanMessage(BaseMessage):
Expand Down Expand Up @@ -97,3 +110,8 @@ class AddSpanFeedbackScoresBatchMessage(BaseMessage):
@dataclasses.dataclass
class CreateSpansBatchMessage(BaseMessage):
batch: List[CreateSpanMessage]


@dataclasses.dataclass
class CreateTraceBatchMessage(BaseMessage):
batch: List[CreateTraceMessage]
2 changes: 1 addition & 1 deletion sdks/python/tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def configure_e2e_tests_env():

@pytest.fixture()
def opik_client(configure_e2e_tests_env, shutdown_cached_client_after_test):
opik_client_ = opik.api_objects.opik_client.Opik()
opik_client_ = opik.api_objects.opik_client.Opik(_use_batching=True)

yield opik_client_

Expand Down
79 changes: 67 additions & 12 deletions sdks/python/tests/e2e/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,69 @@ def f_inner(y):
)


def test_tracked_function__two_traces_and_two_spans__happyflow(opik_client):
# Setup
project_name = "e2e-tests-batching-messages"
ID_STORAGE = {}

@opik.track(project_name=project_name)
def f1(x):
ID_STORAGE["f1-trace-id"] = opik_context.get_current_trace_data().id
ID_STORAGE["f1-span-id"] = opik_context.get_current_span_data().id
return "f1-output"

@opik.track(project_name=project_name)
def f2(y):
ID_STORAGE["f2-trace-id"] = opik_context.get_current_trace_data().id
ID_STORAGE["f2-span-id"] = opik_context.get_current_span_data().id
return "f2-output"

# Call
f1("f1-input")
f2("f2-input")
opik.flush_tracker()

# Verify traces
verifiers.verify_trace(
opik_client=opik_client,
trace_id=ID_STORAGE["f1-trace-id"],
name="f1",
input={"x": "f1-input"},
output={"output": "f1-output"},
project_name=project_name,
)
verifiers.verify_trace(
opik_client=opik_client,
trace_id=ID_STORAGE["f2-trace-id"],
name="f2",
input={"y": "f2-input"},
output={"output": "f2-output"},
project_name=project_name,
)

# Verify spans
verifiers.verify_span(
opik_client=opik_client,
span_id=ID_STORAGE["f1-span-id"],
parent_span_id=None,
trace_id=ID_STORAGE["f1-trace-id"],
name="f1",
input={"x": "f1-input"},
output={"output": "f1-output"},
project_name=project_name,
)
verifiers.verify_span(
opik_client=opik_client,
span_id=ID_STORAGE["f2-span-id"],
parent_span_id=None,
trace_id=ID_STORAGE["f2-trace-id"],
name="f2",
input={"y": "f2-input"},
output={"output": "f2-output"},
project_name=project_name,
)


def test_tracked_function__try_different_project_names(opik_client):
"""
In this test we will try to use different project names for outer and inner spans.
Expand Down Expand Up @@ -291,27 +354,19 @@ def test_search_spans__happyflow(opik_client):
)
trace.span(
name="span-name",
input={"input": "Some random input"},
input={"input": "Some random input 1"},
output={"output": "span-output"},
)

# Send a trace that does not match the input filter
trace = opik_client.trace(
id=trace_id,
name="trace-name",
input={"input": "Some random input"},
output={"output": "trace-output"},
project_name=OPIK_E2E_TESTS_PROJECT_NAME,
)
trace.span(
name="span-name",
input={"input": "Some random input"},
input={"input": "Some random input 2"},
output={"output": "span-output"},
)

opik_client.flush()

# Search for the traces - Note that we use a large max_results to ensure that we get all traces, if the project has more than 100000 matching traces it is possible
# Search for the traces - Note that we use a large max_results to ensure that we get all traces,
# if the project has more than 100000 matching traces it is possible
spans = opik_client.search_spans(
project_name=OPIK_E2E_TESTS_PROJECT_NAME,
trace_id=trace_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def process(self, message: messages.BaseMessage) -> None:
elif isinstance(message, messages.CreateSpansBatchMessage):
for item in message.batch:
self.process(item)
elif isinstance(message, messages.CreateTraceBatchMessage):
for item in message.batch:
self.process(item)
elif isinstance(message, messages.UpdateSpanMessage):
span: SpanModel = self._observations[message.span_id]
span.output = message.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from opik.message_processing import messages
from opik.message_processing.batching import batch_manager
from opik.message_processing.batching import create_span_message_batcher
from opik.message_processing.batching import batchers

NOT_USED = None

Expand Down Expand Up @@ -101,7 +101,7 @@ def test_batch_manager__start_and_stop_were_called__accumulated_data_is_flushed(
provider=NOT_USED,
)

example_span_batcher = create_span_message_batcher.CreateSpanMessageBatcher(
example_span_batcher = batchers.CreateSpanMessageBatcher(
flush_callback=flush_callback, max_batch_size=42, flush_interval_seconds=0.1
)
tested = batch_manager.BatchManager(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import time
import mock
from opik.message_processing.batching import (
batchers,
flushing_thread,
create_span_message_batcher,
)


def test_flushing_thread__batcher_is_flushed__every_time_flush_interval_time_passes():
flush_callback = mock.Mock()
FLUSH_INTERVAL = 0.2
very_big_batch_size = float("inf")
batcher = create_span_message_batcher.CreateSpanMessageBatcher(
batcher = batchers.CreateSpanMessageBatcher(
flush_callback=flush_callback,
max_batch_size=very_big_batch_size,
flush_interval_seconds=FLUSH_INTERVAL,
Expand Down
Loading

0 comments on commit 1d0d5ea

Please sign in to comment.