Skip to content

Commit

Permalink
Fix issues in test_evaluator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NickyHavoc committed Nov 23, 2023
1 parent 0d4bb68 commit 923f03c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 43 deletions.
5 changes: 2 additions & 3 deletions src/intelligence_layer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from .evaluator import Dataset as Dataset
from .evaluator import EvaluationException as EvaluationException
from .evaluator import EvaluationRepository as EvaluationRepository
from .evaluator import Evaluator as Evaluator
from .evaluator import ExampleResult as ExampleResult
from .evaluator import EvaluationException as EvaluationException
from .evaluator import EvaluationRunOverview as EvaluationRunOverview
from .evaluator import Evaluator as Evaluator
from .evaluator import Example as Example
from .evaluator import ExampleResult as ExampleResult
from .evaluator import InMemoryEvaluationRepository as InMemoryEvaluationRepository
from .evaluator import LogTrace as LogTrace
from .evaluator import SequenceDataset as SequenceDataset
Expand Down
40 changes: 21 additions & 19 deletions src/intelligence_layer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def examples(self) -> Iterable[Example[Input, ExpectedOutput]]:

class SequenceDataset(BaseModel, Generic[Input, ExpectedOutput]):
"""A :class:`Dataset` that contains all examples in a sequence.
We recommend using this when it is certain that all examples
fit in memory.
Expand All @@ -88,11 +88,11 @@ class SequenceDataset(BaseModel, Generic[Input, ExpectedOutput]):

class EvaluationException(BaseModel):
"""Captures an exception raised during evaluating a :class:`Task`.
Attributes:
error_message: String-representation of the exception.
"""

error_message: str


Expand All @@ -101,12 +101,13 @@ class EvaluationException(BaseModel):

class SpanTrace(BaseModel):
"""Represents traces contained by :class:`Span`
Attributes:
traces: The child traces.
start: Start time of the span.
end: End time of the span.
"""

model_config = ConfigDict(frozen=True)

traces: Sequence[Trace]
Expand Down Expand Up @@ -135,6 +136,7 @@ class TaskSpanTrace(SpanTrace):
input: Input from the traced :class:`Task`.
output: Output of the traced :class:`Task`.
"""

model_config = ConfigDict(frozen=True)

input: SerializeAsAny[JsonSerializable]
Expand Down Expand Up @@ -169,7 +171,7 @@ def _ipython_display_(self) -> None:

class LogTrace(BaseModel):
"""Represents a :class:`LogEntry`.
Attributes:
message: A description of the value that is being logged, such as the step in the
:class:`Task` this is related to.
Expand Down Expand Up @@ -213,9 +215,9 @@ def _to_trace_entry(entry: InMemoryTaskSpan | InMemorySpan | LogEntry) -> Trace:

class ExampleResult(BaseModel, Generic[Evaluation]):
"""Result of a single evaluated :class:`Example`
Created to persist the evaluation result in the repository.
Attributes:
example_id: Identifier of the :class:`Example` evaluated.
result: If the evaluation was successful, evaluation's result,
Expand All @@ -231,12 +233,12 @@ class ExampleResult(BaseModel, Generic[Evaluation]):

class EvaluationRunOverview(BaseModel, Generic[AggregatedEvaluation]):
"""Overview of the results of evaluating a :class:`Task` on a :class:`Dataset`.
Created when running :func:`Evaluator.evaluate_dataset`. Contains high-level information and statistics.
Attributes:
id: Identifier of the run.
statistics: Aggregated statistics of the run.
statistics: Aggregated statistics of the run.
"""

id: str
Expand All @@ -250,7 +252,7 @@ class EvaluationRunOverview(BaseModel, Generic[AggregatedEvaluation]):

class EvaluationRepository(ABC):
"""Base evaluation repository interface.
Provides methods to store and load evaluation results for individual examples
of a run and the aggregated evaluation of said run.
"""
Expand All @@ -260,12 +262,12 @@ def evaluation_run_results(
self, run_id: str, evaluation_type: type[Evaluation]
) -> Sequence[ExampleResult[Evaluation]]:
"""Returns all :class:`ExampleResult` instances of a given run
Args:
run_id: Identifier of the run to obtain the results for.
evaluation_type: Type of evaluations that the :class:`Evaluator` returned
in :func:`Evaluator.do_evaluate`
Returns:
All :class:`ExampleResult` of the run. Will return an empty list if there's none.
"""
Expand All @@ -276,13 +278,13 @@ def evaluation_example_result(
self, run_id: str, example_id: str, evaluation_type: type[Evaluation]
) -> Optional[ExampleResult[Evaluation]]:
"""Returns an :class:`ExampleResult` of a given run by its id.
Args:
run_id: Identifier of the run to obtain the results for.
example_id: Identifier of the :class:`ExampleResult` to be retrieved.
evaluation_type: Type of evaluations that the `Evaluator` returned
in :func:`Evaluator.do_evaluate`
Returns:
:class:`ExampleResult` if one was found, `None` otherwise.
"""
Expand All @@ -293,7 +295,7 @@ def store_example_result(
self, run_id: str, result: ExampleResult[Evaluation]
) -> None:
"""Stores an :class:`ExampleResult` for a run in the repository.
Args:
run_id: Identifier of the run.
result: The result to be persisted.
Expand All @@ -305,12 +307,12 @@ def evaluation_run_overview(
self, run_id: str, aggregation_type: type[AggregatedEvaluation]
) -> Optional[EvaluationRunOverview[AggregatedEvaluation]]:
"""Returns an :class:`EvaluationRunOverview` of a given run by its id.
Args:
run_id: Identifier of the run to obtain the overview for.
aggregation_type: Type of aggregations that the :class:`Evaluator` returned
in :func:`Evaluator.aggregate`
Returns:
:class:`EvaluationRunOverview` if one was found, `None` otherwise.
"""
Expand All @@ -321,7 +323,7 @@ def store_evaluation_run_overview(
self, overview: EvaluationRunOverview[AggregatedEvaluation]
) -> None:
"""Stores an :class:`EvaluationRunOverview` in the repository.
Args:
overview: The overview to be persisted.
"""
Expand Down
38 changes: 17 additions & 21 deletions tests/core/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from datetime import datetime
from typing import Iterable, Literal, Optional, Sequence
from typing import Iterable, Literal, Sequence

from pydantic import BaseModel
from pytest import fixture

from intelligence_layer.core import (
Dataset,
EvaluationException,
Evaluator,
Example,
Expand All @@ -19,7 +18,6 @@
from intelligence_layer.core.task import Task
from intelligence_layer.core.tracer import InMemorySpan, InMemoryTaskSpan, LogEntry


DummyTaskInput = Literal["success", "fail in task", "fail in eval"]
DummyTaskOutput = DummyTaskInput

Expand All @@ -34,15 +32,19 @@ class AggregatedDummyEvaluation(BaseModel):

class DummyEvaluator(
Evaluator[
DummyTaskInput, DummyTaskOutput, None, DummyEvaluation, AggregatedDummyEvaluation
DummyTaskInput,
DummyTaskOutput,
None,
DummyEvaluation,
AggregatedDummyEvaluation,
]
):
def do_evaluate(
self, input: DummyTaskInput, output: DummyTaskOutput, expected_output: None
) -> DummyEvaluation:
if output == "fail in eval":
raise RuntimeError(output)
return DummyEvaluation(result="pass")
return DummyEvaluation(result="pass")

def aggregate(
self, evaluations: Iterable[DummyEvaluation]
Expand All @@ -51,10 +53,10 @@ def aggregate(


class DummyTask(Task[DummyTaskInput, DummyTaskOutput]):
def do_run(self, input: str | None, tracer: Tracer) -> str | None:
def do_run(self, input: DummyTaskInput, tracer: Tracer) -> DummyTaskOutput:
if input == "fail in task":
raise RuntimeError(input)
return input
return input


@fixture
Expand All @@ -69,16 +71,6 @@ def dummy_evaluator(
return DummyEvaluator(DummyTask(), evaluation_repository)


def test_evaluate_dataset_does_not_throw_an_exception_for_failure(
dummy_evaluator: DummyEvaluator,
) -> None:
dataset: Dataset[Optional[str], None] = SequenceDataset(
name="test",
examples=[Example(input="fail", expected_output=None)],
)
dummy_evaluator.evaluate_dataset(dataset, NoOpTracer())


def test_evaluate_dataset_stores_example_results(
dummy_evaluator: DummyEvaluator,
) -> None:
Expand All @@ -89,7 +81,7 @@ def test_evaluate_dataset_stores_example_results(
Example(input="fail in eval", expected_output=None),
]

dataset: SequenceDataset[str | None, None] = SequenceDataset(
dataset: SequenceDataset[DummyTaskInput, None] = SequenceDataset(
name="test",
examples=examples,
)
Expand All @@ -106,8 +98,12 @@ def test_evaluate_dataset_stores_example_results(
)

assert success_result and isinstance(success_result.result, DummyEvaluation)
assert failure_result_task and isinstance(failure_result_task.result, EvaluationException)
assert failure_result_eval and isinstance(failure_result_eval.result, EvaluationException)
assert failure_result_task and isinstance(
failure_result_task.result, EvaluationException
)
assert failure_result_eval and isinstance(
failure_result_eval.result, EvaluationException
)
assert success_result.trace.input == "success"
assert failure_result_task.trace.input == "fail in task"
assert failure_result_eval.trace.input == "fail in eval"
Expand All @@ -118,7 +114,7 @@ def test_evaluate_dataset_stores_aggregated_results(
) -> None:
evaluation_repository = dummy_evaluator.repository

dataset: SequenceDataset[str | None, None] = SequenceDataset(
dataset: SequenceDataset[DummyTaskInput, None] = SequenceDataset(
name="test",
examples=[],
)
Expand Down

0 comments on commit 923f03c

Please sign in to comment.