Skip to content

Commit

Permalink
created EvaluationLogicBase to fix type magic for async case
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesWesch committed May 7, 2024
1 parent a5dd176 commit f2f1426
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from itertools import combinations
from typing import Generic, Mapping, Optional, Sequence
from typing import Mapping, Optional, Sequence

from pydantic import BaseModel

Expand Down Expand Up @@ -30,11 +30,14 @@
from intelligence_layer.evaluation.evaluation.evaluation_repository import (
EvaluationRepository,
)
from intelligence_layer.evaluation.evaluation.evaluator import EvaluationLogicBase
from intelligence_layer.evaluation.run.domain import SuccessfulExampleOutput
from intelligence_layer.evaluation.run.run_repository import RunRepository


class ArgillaEvaluationLogic(Generic[Input, Output, ExpectedOutput, Evaluation], ABC):
class ArgillaEvaluationLogic(
EvaluationLogicBase[Input, Output, ExpectedOutput, Evaluation], ABC
):
def __init__(self, fields: Mapping[str, Field], questions: Sequence[Question]):
self.fields = fields
self.questions = questions
Expand Down Expand Up @@ -156,7 +159,7 @@ def submit(
):
record_sequence = self._evaluation_logic._to_record(example, *outputs)
for record in record_sequence.records:
self._client.add_record(self._workspace_id, record)
self._client.add_record(argilla_dataset_id, record)

return PartialEvaluationOverview(
run_overviews=frozenset(run_overviews),
Expand Down
17 changes: 2 additions & 15 deletions src/intelligence_layer/evaluation/evaluation/async_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from abc import ABC, abstractmethod
from typing import Generic, Optional
from typing import Optional

from intelligence_layer.core.task import Input, Output
from intelligence_layer.evaluation.dataset.domain import Example, ExpectedOutput
from intelligence_layer.evaluation.dataset.domain import ExpectedOutput
from intelligence_layer.evaluation.evaluation.domain import (
Evaluation,
EvaluationOverview,
PartialEvaluationOverview,
)
from intelligence_layer.evaluation.evaluation.evaluator import Evaluator
from intelligence_layer.evaluation.run.domain import SuccessfulExampleOutput


class AsyncEvaluator(Evaluator[Input, Output, ExpectedOutput, Evaluation], ABC):
Expand All @@ -23,15 +22,3 @@ def submit(

@abstractmethod
def retrieve(self, id: str) -> EvaluationOverview: ...


class AsyncEvaluationLogic(ABC, Generic[Input, Output, ExpectedOutput, Evaluation]):
@abstractmethod
def submit(
self,
example: Example[Input, ExpectedOutput],
*output: SuccessfulExampleOutput[Output],
) -> None: ...

@abstractmethod
def retrieve(self, eval_id: str) -> EvaluationOverview: ...
14 changes: 10 additions & 4 deletions src/intelligence_layer/evaluation/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
from intelligence_layer.evaluation.run.run_repository import RunRepository


class EvaluationLogic(ABC, Generic[Input, Output, ExpectedOutput, Evaluation]):
class EvaluationLogicBase(Generic[Input, Output, ExpectedOutput, Evaluation]):
pass


class EvaluationLogic(
ABC, EvaluationLogicBase[Input, Output, ExpectedOutput, Evaluation]
):
@abstractmethod
def do_evaluate(
self,
Expand Down Expand Up @@ -130,7 +136,7 @@ def _get_types(self) -> Mapping[str, type]:

def is_eligible_subclass(parent: type) -> bool:
return hasattr(parent, "__orig_bases__") and issubclass(
parent, EvaluationLogic
parent, EvaluationLogicBase
)

def update_types() -> None:
Expand All @@ -148,7 +154,7 @@ def update_types() -> None:
num_types_set += 1

# mypy does not know __orig_bases__
base_types = EvaluationLogic.__orig_bases__[1] # type: ignore
base_types = EvaluationLogicBase.__orig_bases__[0] # type: ignore
type_list: list[type | TypeVar] = list(get_args(base_types))
possible_parent_classes = [
p
Expand All @@ -159,7 +165,7 @@ def update_types() -> None:
# mypy does not know __orig_bases__
for base in parent.__orig_bases__: # type: ignore
origin = get_origin(base)
if origin is None or not issubclass(origin, EvaluationLogic):
if origin is None or not issubclass(origin, EvaluationLogicBase):
continue
current_types = list(get_args(base))
update_types()
Expand Down
19 changes: 6 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from aleph_alpha_client import Client, Image
from dotenv import load_dotenv
from faker import Faker
from pydantic import BaseModel
from pytest import fixture

Expand Down Expand Up @@ -110,26 +109,20 @@ def to_document(document_chunk: DocumentChunk) -> Document:


class DummyStringInput(BaseModel):
input: str

@classmethod
def any(cls) -> "DummyStringInput":
fake = Faker()
return cls(input=fake.text())
input: str = "dummy-input"


class DummyStringOutput(BaseModel):
output: str
output: str = "dummy-output"


@classmethod
def any(cls) -> "DummyStringOutput":
fake = Faker()
return cls(output=fake.text())
class DummyStringEvaluation(BaseModel):
evaluation: str = "dummy-evaluation"


class DummyStringTask(Task[DummyStringInput, DummyStringOutput]):
def do_run(self, input: DummyStringInput, task_span: TaskSpan) -> DummyStringOutput:
return DummyStringOutput.any()
return DummyStringOutput()


@fixture
Expand Down
4 changes: 1 addition & 3 deletions tests/evaluation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ def aggregation_overview(

@fixture
def dummy_string_example() -> Example[DummyStringInput, DummyStringOutput]:
return Example(
input=DummyStringInput.any(), expected_output=DummyStringOutput.any()
)
return Example(input=DummyStringInput(), expected_output=DummyStringOutput())


@fixture
Expand Down
77 changes: 14 additions & 63 deletions tests/evaluation/test_argilla_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
Runner,
SuccessfulExampleOutput,
)
from tests.conftest import DummyStringInput, DummyStringOutput, DummyStringTask
from tests.conftest import (
DummyStringEvaluation,
DummyStringInput,
DummyStringOutput,
DummyStringTask,
)
from tests.evaluation.conftest import DummyAggregatedEvaluation, StubArgillaClient


Expand Down Expand Up @@ -50,7 +55,7 @@ class DummyStringTaskArgillaEvaluationLogic(
DummyStringInput,
DummyStringOutput,
DummyStringOutput,
DummyStringOutput,
DummyStringEvaluation,
]
):
def __init__(self) -> None:
Expand Down Expand Up @@ -85,8 +90,10 @@ def _to_record(
]
)

def _from_record(self, argilla_evaluation: ArgillaEvaluation) -> DummyStringOutput:
return DummyStringOutput(output="test")
def _from_record(
self, argilla_evaluation: ArgillaEvaluation
) -> DummyStringEvaluation:
return DummyStringEvaluation()


class DummyArgillaClient(ArgillaClient):
Expand All @@ -105,7 +112,7 @@ def ensure_dataset_exists(
return dataset_id

def add_record(self, dataset_id: str, record: RecordData) -> None:
if dataset_id not in self._datasets:
if dataset_id not in self._datasets.keys():
raise Exception("Add record: dataset not found")
self._datasets[dataset_id].append(record)

Expand Down Expand Up @@ -182,7 +189,7 @@ def string_argilla_evaluator(
DummyStringInput,
DummyStringOutput,
DummyStringOutput,
DummyStringOutput,
DummyStringEvaluation,
]:
evaluator = ArgillaEvaluator(
in_memory_dataset_repository,
Expand Down Expand Up @@ -242,7 +249,7 @@ def test_argilla_evaluator_can_submit_evals_to_argilla(
"dummy-string-task",
DummyStringTaskArgillaEvaluationLogic(),
DummyArgillaClient(),
workspace_id="1",
workspace_id="workspace-id",
)

run_overview = string_argilla_runner.run_dataset(string_dataset_id)
Expand All @@ -268,62 +275,6 @@ def test_argilla_evaluator_can_submit_evals_to_argilla(
assert len(DummyArgillaClient()._datasets[partial_evaluation_overview.id]) == 1


def test_argilla_evaluator_can_do_sync_evaluation(
string_argilla_evaluator: ArgillaEvaluator[
DummyStringInput,
DummyStringOutput,
DummyStringOutput,
DummyStringOutput,
],
string_argilla_runner: Runner[DummyStringInput, DummyStringOutput],
string_dataset_id: str,
) -> None:
argilla_client = cast(
StubArgillaClient,
string_argilla_evaluator._evaluation_repository._client, # type: ignore
)

run_overview = string_argilla_runner.run_dataset(string_dataset_id)
eval_overview = string_argilla_evaluator.evaluate_runs(run_overview.id)
examples_iter = string_argilla_evaluator._dataset_repository.examples(
string_dataset_id, DummyStringInput, DummyStringOutput
)
assert examples_iter is not None

assert eval_overview.id in argilla_client._datasets
saved_dataset = argilla_client._datasets[eval_overview.id]
examples = list(examples_iter)
assert len(saved_dataset) == len(examples)
assert saved_dataset[0].example_id == examples[0].id
assert saved_dataset[0].content["input"] == examples[0].input.input


def test_argilla_evaluator_can_aggregate_evaluation(
string_argilla_evaluator: ArgillaEvaluator[
DummyStringInput,
DummyStringOutput,
DummyStringOutput,
DummyStringOutput,
],
string_argilla_runner: Runner[DummyStringInput, DummyStringOutput],
string_dataset_id: str,
string_argilla_aggregator: ArgillaAggregator[DummyAggregatedEvaluation],
) -> None:
# given
argilla_client = cast(
StubArgillaClient,
string_argilla_evaluator._evaluation_repository._client, # type: ignore
)
# when
run_overview = string_argilla_runner.run_dataset(string_dataset_id)
eval_overview = string_argilla_evaluator.evaluate_runs(run_overview.id)
aggregated_eval_overview = string_argilla_aggregator.aggregate_evaluation(
eval_overview.id
)
# then
assert aggregated_eval_overview.statistics.score == argilla_client._score


def test_argilla_aggregation_logic_works() -> None:
argilla_aggregation_logic = InstructComparisonArgillaAggregationLogic()
evaluations = (
Expand Down
4 changes: 2 additions & 2 deletions tests/evaluation/test_dataset_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def test_examples_returns_all_examples_sorted_by_their_id(
dataset_repository: DatasetRepository = request.getfixturevalue(repository_fixture)
examples = [
Example(
input=DummyStringInput.any(),
expected_output=DummyStringOutput.any(),
input=DummyStringInput(),
expected_output=DummyStringOutput(),
)
for i in range(0, 10)
]
Expand Down

0 comments on commit f2f1426

Please sign in to comment.