Skip to content

Commit

Permalink
refactored argilla logic to only have to_record and from_record
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesWesch committed May 6, 2024
1 parent 85663cb commit 163ed02
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
52 changes: 32 additions & 20 deletions src/intelligence_layer/evaluation/evaluation/argilla_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from itertools import combinations
from typing import Mapping, Optional, Sequence
from typing import Mapping, Optional, Sequence, cast

from intelligence_layer.connectors.argilla.argilla_client import (
ArgillaClient,
Expand All @@ -18,22 +18,29 @@
ArgillaEvaluationRepository,
RecordDataSequence,
)
from intelligence_layer.evaluation.evaluation.async_evaluation import AsyncEvaluator
from intelligence_layer.evaluation.evaluation.async_evaluation import (
AsyncEvaluationLogic,
AsyncEvaluator,
)
from intelligence_layer.evaluation.evaluation.domain import (
Evaluation,
EvaluationOverview,
ExampleEvaluation,
PartialEvaluationOverview,
)
from intelligence_layer.evaluation.evaluation.evaluation_repository import (
EvaluationRepository,
)
from intelligence_layer.evaluation.evaluation.evaluator import EvaluationLogic
from intelligence_layer.evaluation.run.domain import SuccessfulExampleOutput
from intelligence_layer.evaluation.run.run_repository import RunRepository


class ArgillaEvaluationLogic(
EvaluationLogic[Input, Output, ExpectedOutput, RecordDataSequence], ABC
AsyncEvaluationLogic[Input, Output, ExpectedOutput, Evaluation], ABC
):
def __init__(self, client: ArgillaClient):
self._client = client

def fields(self) -> Sequence[Field]:
return [Field(name="name", title="title")]

Expand All @@ -42,14 +49,6 @@ def questions(self) -> Sequence[Question]:
Question(name="name", title="title", description="description", options=[0])
]

def do_evaluate(
self,
example: Example[Input, ExpectedOutput],
*output: SuccessfulExampleOutput[Output],
) -> RecordDataSequence:
# Hier eher download logic als to-record
return self._to_record(example, *output)

@abstractmethod
def _to_record(
self,
Expand All @@ -63,7 +62,9 @@ def _to_record(
example: The example to be translated.
output: The output of the example that was run.
"""
...
...

def _from_record(argilla_evaluation: ArgillaEvaluation) -> Evaluation: ...


class ArgillaEvaluator(
Expand Down Expand Up @@ -111,8 +112,21 @@ def __init__(

def retrieve(
self,
id: str,
evaluation_id: str,
) -> EvaluationOverview:
example_evaluations = [
ExampleEvaluation(
evaluation_id=evaluation_id,
example_id=example_evaluation.example_id,
# cast to Evaluation because mypy thinks ArgillaEvaluation cannot be Evaluation
result=self._from_record(example_evaluation),
)
for example_evaluation in self._client.evaluations(evaluation_id)
]
evaluations = sorted(example_evaluations, key=lambda i: i.example_id)
for evaluation in evaluations:
self._evaluation_repository.store_example_evaluation(evaluation)

return EvaluationOverview(
run_overviews=frozenset(),
id=id,
Expand Down Expand Up @@ -140,15 +154,17 @@ def submit(
questions=self._evaluation_logic.questions(),
)

run_overviews = self._load_run_overviews(*run_ids)
for example, outputs in self.retrieve_eval_logic_input(
*run_ids, num_examples=num_examples
run_overviews, num_examples=num_examples
):
self._evaluation_logic._to_record(example, outputs)
record_sequence = self._evaluation_logic.submit(example, outputs)
for record in record_sequence:
self._client.add_record(self._workspace_id, record)

return PartialEvaluationOverview(
run_overviews=[],
run_overviews=frozenset(run_overviews),
id=argilla_dataset_id,
start_date=datetime.now(),
description=self.description,
Expand Down Expand Up @@ -212,10 +228,6 @@ def _create_record_data(
},
)

def _do_submit(self) -> None:
# Hier
return


def create_instruct_comparison_argilla_evaluation_classes(
workspace_id: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ def submit(
) -> None: ...

@abstractmethod
def retrieve(self, id: str) -> None: ...
def retrieve(self, eval_id: str) -> EvaluationOverview:
...
2 changes: 2 additions & 0 deletions tests/evaluation/test_argilla_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def test_argilla_evaluator_can_submit_evals_to_argilla(
assert eval_overview.successful_evaluation_count == 1
assert eval_overview.failed_evaluation_count == 0

assert len(in_memory_evaluation_repository.example_evaluations(eval_overview.id, DummyStringOutput)) == 1

assert len(DummyArgillaClient()._datasets[partial_evaluation_overview.id]) == 1


Expand Down

0 comments on commit 163ed02

Please sign in to comment.