diff --git a/src/intelligence_layer/core/task.py b/src/intelligence_layer/core/task.py index 885079187..401d5e69d 100644 --- a/src/intelligence_layer/core/task.py +++ b/src/intelligence_layer/core/task.py @@ -87,7 +87,6 @@ def run_concurrently( inputs: Iterable[Input], tracer: Tracer, concurrency_limit: int = MAX_CONCURRENCY, - trace_id: Optional[str] = None, ) -> Sequence[Output]: """Executes multiple processes of this task concurrently. @@ -107,9 +106,7 @@ def run_concurrently( The order of Outputs corresponds to the order of the Inputs. """ - with tracer.span( - f"Concurrent {type(self).__name__} tasks", trace_id=trace_id - ) as span: + with tracer.span(f"Concurrent {type(self).__name__} tasks") as span: with ThreadPoolExecutor( max_workers=min(concurrency_limit, MAX_CONCURRENCY) ) as executor: diff --git a/src/intelligence_layer/core/tracer/in_memory_tracer.py b/src/intelligence_layer/core/tracer/in_memory_tracer.py index fbfda4b23..09a8d23d0 100644 --- a/src/intelligence_layer/core/tracer/in_memory_tracer.py +++ b/src/intelligence_layer/core/tracer/in_memory_tracer.py @@ -1,7 +1,7 @@ import json import os from datetime import datetime -from typing import Optional, Union +from typing import Optional, Sequence, Union from uuid import UUID import requests @@ -11,13 +11,18 @@ from rich.tree import Tree from intelligence_layer.core.tracer.tracer import ( + Context, EndSpan, EndTask, + Event, + ExportedSpan, LogEntry, LogLine, PlainEntry, PydanticSerializable, Span, + SpanAttributes, + SpanStatus, StartSpan, StartTask, TaskSpan, @@ -45,12 +50,11 @@ def span( self, name: str, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, ) -> "InMemorySpan": child = InMemorySpan( name=name, start_timestamp=timestamp or utc_now(), - trace_id=self.ensure_id(trace_id), + context=self.context, ) self.entries.append(child) return child @@ -60,13 +64,12 @@ def task_span( task_name: str, input: PydanticSerializable, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, ) -> "InMemoryTaskSpan": child = InMemoryTaskSpan( name=task_name, input=input, start_timestamp=timestamp or utc_now(), - trace_id=self.ensure_id(trace_id), + context=self.context, ) self.entries.append(child) return child @@ -106,12 +109,22 @@ def submit_to_trace_viewer(self) -> bool: ) return False + def export_for_viewing(self) -> Sequence[ExportedSpan]: + exported_root_spans: list[ExportedSpan] = [] + for entry in self.entries: + if isinstance(entry, LogEntry): + raise Exception( + "Found a log outside of a span. Logs can only be part of a span." + ) + else: + exported_root_spans.extend(entry.export_for_viewing()) + return exported_root_spans + class InMemorySpan(InMemoryTracer, Span): name: str start_timestamp: datetime = Field(default_factory=datetime.utcnow) end_timestamp: Optional[datetime] = None - trace_id: str def id(self) -> str: return self.trace_id @@ -144,6 +157,36 @@ def _rich_render_(self) -> Tree: return tree + def export_for_viewing(self) -> Sequence[ExportedSpan]: + logs: list[LogEntry] = [] + exported_spans: list[ExportedSpan] = [] + for entry in self.entries: + if isinstance(entry, LogEntry): + logs.append(entry) + else: + exported_spans.extend(entry.export_for_viewing()) + exported_spans.append( + ExportedSpan( + context=Context(trace_id=self.id(), span_id="?"), + name=self.name, + parent_id=self.parent_id, + start_time=self.start_timestamp, + end_time=self.end_timestamp, + attributes=SpanAttributes(), + events=[ + Event( + name="log", + body=log.value, + message=log.message, + timestamp=log.timestamp, + ) + for log in logs + ], + status=SpanStatus.OK, + ) + ) + return exported_spans + class InMemoryTaskSpan(InMemorySpan, TaskSpan): input: SerializeAsAny[PydanticSerializable] diff --git a/src/intelligence_layer/core/tracer/tracer.py b/src/intelligence_layer/core/tracer/tracer.py index cb9182e03..764389fad 100644 --- a/src/intelligence_layer/core/tracer/tracer.py +++ b/src/intelligence_layer/core/tracer/tracer.py @@ -78,9 +78,13 @@ class SpanStatus(Enum): ERROR = "ERROR" +class Context(BaseModel): + trace_id: str + span_id: str + + class ExportedSpan: - id: str - # we ignore context as we only need the id from it + context: Context name: str | None parent_id: str | None start_time: datetime @@ -102,13 +106,14 @@ class Tracer(ABC): documentation of each implementation to see how to use the resulting tracer. """ + context: Context | None = None + @abstractmethod def span( self, name: str, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, - ) -> "Span": + ) -> "Span": # TODO """Generate a span from the current span or logging instance. Allows for grouping multiple logs and duration together as a single, logical step in the @@ -134,8 +139,7 @@ def task_span( task_name: str, input: PydanticSerializable, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, - ) -> "TaskSpan": + ) -> "TaskSpan": # TODO """Generate a task-specific span from the current span or logging instance. Allows for grouping multiple logs together, as well as the task's specific input, output, @@ -167,6 +171,18 @@ def ensure_id(self, id: Optional[str]) -> str: return id if id is not None else str(uuid4()) + @abstractmethod + def export_for_viewing(self) -> Sequence[ExportedSpan]: + """Converts the trace to a format that can be read by the trace viewer. + + The format is inspired by the OpenTelemetry Format, but does not abide by it, + because it is too complex for our use-case. + + Returns: + A list of spans which includes the current span and all its child spans. + """ + ... + class ErrorValue(BaseModel): error_type: str @@ -183,9 +199,13 @@ class Span(Tracer, AbstractContextManager["Span"]): span only in scope while it is active. """ - @abstractmethod - def id(self) -> str: - pass + def __init__(self, context: Optional[Context] = None): + if context is None: + trace_id = str(uuid4()) + else: + trace_id = self.context.trace_id + span_id = str(uuid4()) + self.context = Context(trace_id=trace_id, span_id=span_id) def __enter__(self) -> Self: return self @@ -246,18 +266,6 @@ def __exit__( self.log(error_value.message, error_value) self.end() - @abstractmethod - def export_for_viewing(self) -> Sequence[ExportedSpan]: - """Converts the span to a format that can be read by the trace viewer. - - The format is inspired by the OpenTelemetry Format, but does not abide by it, - because it is too complex for our use-case. - - Returns: - A list of spans which includes the current span and all its child spans. - """ - ... - class TaskSpan(Span): """Specialized span for instrumenting :class:`Task` input, output, and nested spans and logs. diff --git a/tests/core/tracer/test_tracer.py b/tests/core/tracer/test_tracer.py index 9f8e3d63c..b207acadd 100644 --- a/tests/core/tracer/test_tracer.py +++ b/tests/core/tracer/test_tracer.py @@ -54,6 +54,7 @@ def test_tracer_exports_task_spans_to_unified_format() -> None: assert span.attributes.input == "input" assert span.attributes.output == "output" assert span.status == SpanStatus.OK + assert span.context.trace_id == span.context.span_id def test_tracer_exports_error_correctly() -> None: @@ -93,3 +94,21 @@ def test_tracer_export_nests_correctly() -> None: assert child.name == "name-2" assert child.parent_id == parent.id assert len(child.events) == 0 + assert child.context.trace_id == parent.context.trace_id + assert child.context.span_id != parent.context.span_id + + +def test_tracer_exports_unrelated_spans_correctly() -> None: + tracer = InMemoryTracer() + tracer.span("name") + tracer.span("name-2") + + unified_format = tracer.export_for_viewing() + + assert len(unified_format) == 2 + span_1, span_2 = unified_format[0], unified_format[1] + + assert span_1.parent_id is None + assert span_2.parent_id is None + + assert span_1.context.trace_id != span_2.context.trace_id