Skip to content

Commit

Permalink
Rename trace names
Browse files Browse the repository at this point in the history
  • Loading branch information
volkerstampa committed Nov 23, 2023
1 parent 19fa376 commit d44fb02
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/intelligence_layer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from .evaluator import Evaluator as Evaluator
from .evaluator import Example as Example
from .evaluator import InMemoryEvaluationRepository as InMemoryEvaluationRepository
from .evaluator import LogTrace as LogTrace
from .evaluator import SequenceDataset as SequenceDataset
from .evaluator import SpanTrace as SpanTrace
from .evaluator import TaskTrace as TaskTrace
from .evaluator import TraceLog as TraceLog
from .evaluator import TaskSpanTrace as TaskSpanTrace
from .explain import Explain, ExplainInput, ExplainOutput
from .graders import BleuGrader, RougeScores
from .prompt_template import (
Expand Down
61 changes: 44 additions & 17 deletions src/intelligence_layer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ class EvaluationException(BaseModel):
error_message: str


TraceEntry = Union["TaskTrace", "SpanTrace", "TraceLog"]
Trace = Union["TaskSpanTrace", "SpanTrace", "LogTrace"]


class SpanTrace(BaseModel):
model_config = ConfigDict(frozen=True)

traces: Sequence[TraceEntry]
traces: Sequence[Trace]
start: datetime
end: Optional[datetime]

Expand All @@ -105,15 +105,15 @@ def _rich_render_(self) -> Tree:
return tree


class TaskTrace(SpanTrace):
class TaskSpanTrace(SpanTrace):
model_config = ConfigDict(frozen=True)

input: SerializeAsAny[JsonSerializable]
output: SerializeAsAny[JsonSerializable]

@staticmethod
def from_task_span(task_span: InMemoryTaskSpan) -> "TaskTrace":
return TaskTrace(
def from_task_span(task_span: InMemoryTaskSpan) -> "TaskSpanTrace":
return TaskSpanTrace(
traces=[to_trace_entry(t) for t in task_span.entries],
start=task_span.start_timestamp,
end=task_span.end_timestamp,
Expand All @@ -138,15 +138,15 @@ def _ipython_display_(self) -> None:
print(self._rich_render_())


class TraceLog(BaseModel):
class LogTrace(BaseModel):
model_config = ConfigDict(frozen=True)

message: str
value: SerializeAsAny[JsonSerializable]

@staticmethod
def from_log_entry(entry: LogEntry) -> "TraceLog":
return TraceLog(
def from_log_entry(entry: LogEntry) -> "LogTrace":
return LogTrace(
message=entry.message,
# RootModel.model_dump is declared to return the type of root, but actually returns
# a JSON-like structure that fits to the JsonSerializable type
Expand All @@ -164,19 +164,29 @@ def _render_log_value(value: JsonSerializable, title: str) -> Panel:
)


def to_trace_entry(entry: InMemoryTaskSpan | InMemorySpan | LogEntry) -> TraceEntry:
def to_trace_entry(entry: InMemoryTaskSpan | InMemorySpan | LogEntry) -> Trace:
if isinstance(entry, InMemoryTaskSpan):
return TaskTrace.from_task_span(entry)
return TaskSpanTrace.from_task_span(entry)
elif isinstance(entry, InMemorySpan):
return SpanTrace.from_span(entry)
else:
return TraceLog.from_log_entry(entry)
return LogTrace.from_log_entry(entry)


class ExampleResult(BaseModel, Generic[Evaluation]):
example_id: str
result: SerializeAsAny[Evaluation | EvaluationException]
trace: TaskTrace
trace: TaskSpanTrace


class EvaluationRunOverview(BaseModel, Generic[AggregatedEvaluation]):
id: str
# dataset_id: str
# failed_evaluation_count: int
# successful_evaluation_count: int
# start: datetime
# end: datetime
statistics: SerializeAsAny[AggregatedEvaluation]


class EvaluationRepository(ABC):
Expand All @@ -198,13 +208,25 @@ def store_example_result(
) -> None:
...

@abstractmethod
def store_evaluation_run_overview(
self, overview: EvaluationRunOverview[AggregatedEvaluation]
) -> None:
...

@abstractmethod
def evaluation_run_overview(
self, run_id: str, aggregation_type: type[AggregatedEvaluation]
) -> Optional[EvaluationRunOverview[AggregatedEvaluation]]:
...


class InMemoryEvaluationRepository(EvaluationRepository):
class SerializedExampleResult(BaseModel):
example_id: str
is_exception: bool
json_result: str
trace: TaskTrace
trace: TaskSpanTrace

_example_results: dict[str, list[str]] = defaultdict(list)

Expand Down Expand Up @@ -263,10 +285,15 @@ def store_example_result(
)
self._example_results[run_id].append(json_result.model_dump_json())

def store_evaluation_run_overview(
self, overview: EvaluationRunOverview[AggregatedEvaluation]
) -> None:
pass

class EvaluationRunOverview(BaseModel, Generic[AggregatedEvaluation]):
id: str
statistics: SerializeAsAny[AggregatedEvaluation]
def evaluation_run_overview(
self, run_id: str, aggregation_type: type[AggregatedEvaluation]
) -> EvaluationRunOverview[AggregatedEvaluation] | None:
return None


class Evaluator(
Expand Down Expand Up @@ -322,7 +349,7 @@ def _evaluate(
example_result = ExampleResult(
example_id=example.id,
result=result,
trace=TaskTrace.from_task_span(
trace=TaskSpanTrace.from_task_span(
cast(InMemoryTaskSpan, eval_tracer.entries[0])
),
)
Expand Down
20 changes: 15 additions & 5 deletions tests/core/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
InMemoryEvaluationRepository,
NoOpTracer,
SequenceDataset,
TaskTrace,
TaskSpanTrace,
Tracer,
)
from intelligence_layer.core.evaluator import SpanTrace, TraceLog, to_trace_entry
from intelligence_layer.core.evaluator import LogTrace, SpanTrace, to_trace_entry
from intelligence_layer.core.task import Task
from intelligence_layer.core.tracer import InMemorySpan, InMemoryTaskSpan, LogEntry

Expand Down Expand Up @@ -109,24 +109,34 @@ def test_to_trace_entry() -> None:
)
)

assert entry == TaskTrace(
assert entry == TaskSpanTrace(
input="input",
output="output",
start=now,
end=now,
traces=[
TraceLog(message="message", value="value"),
LogTrace(message="message", value="value"),
SpanTrace(traces=[], start=now, end=now),
],
)


def test_deserialize_task_trace() -> None:
trace = TaskTrace(
trace = TaskSpanTrace(
start=datetime.utcnow(),
end=datetime.utcnow(),
traces=[],
input=[{"a": "b"}],
output=["c"],
)
assert trace.model_validate_json(trace.model_dump_json()) == trace


# TODO

# - check ci problem
# - add documentation to evaluator
# - store aggregation result
# - file bases repo
# - refactor _rich_render_ (reuse in tracer and evaluator?)
# - introduce MappingTask (to remove redundancy in ClassifyEvaluators)

0 comments on commit d44fb02

Please sign in to comment.