Skip to content

Commit

Permalink
refactor: split evaluation and aggregation
Browse files Browse the repository at this point in the history
Task: IL-259
  • Loading branch information
Valentina Galata committed Feb 21, 2024
1 parent 6724e9f commit a28ab3d
Show file tree
Hide file tree
Showing 11 changed files with 598 additions and 281 deletions.
1 change: 1 addition & 0 deletions src/intelligence_layer/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .accumulator import MeanAccumulator as MeanAccumulator
from .aggregator import Aggregator as Aggregator
from .argilla import ArgillaEvaluationLogic as ArgillaEvaluationLogic
from .argilla import ArgillaEvaluator as ArgillaEvaluator
from .argilla import (
Expand Down
223 changes: 223 additions & 0 deletions src/intelligence_layer/evaluation/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from functools import lru_cache
from typing import (
Callable,
Generic,
Iterable,
Iterator,
Mapping,
TypeVar,
cast,
final,
get_args,
get_origin,
)
from uuid import uuid4

from intelligence_layer.core.tracer import utc_now
from intelligence_layer.evaluation.base_logic import AggregationLogic
from intelligence_layer.evaluation.data_storage.aggregation_repository import (
AggregationRepository,
)
from intelligence_layer.evaluation.data_storage.evaluation_repository import (
EvaluationRepository,
)
from intelligence_layer.evaluation.domain import (
AggregatedEvaluation,
AggregationOverview,
Evaluation,
EvaluationOverview,
FailedExampleEvaluation,
)

T = TypeVar("T")


class CountingFilterIterable(Iterable[T]):
def __init__(
self, wrapped_iterable: Iterable[T], filter: Callable[[T], bool]
) -> None:
self._wrapped_iterator = iter(wrapped_iterable)
self._filter = filter
self._included_count = 0
self._excluded_count = 0

def __next__(self) -> T:
e = next(self._wrapped_iterator)
while not self._filter(e):
self._excluded_count += 1
e = next(self._wrapped_iterator)
self._included_count += 1
return e

def __iter__(self) -> Iterator[T]:
return self

def included_count(self) -> int:
return self._included_count

def excluded_count(self) -> int:
return self._excluded_count


class Aggregator(Generic[Evaluation, AggregatedEvaluation]):
"""Aggregator that can handle automatic aggregation of evaluation scenarios.
This aggregator should be used for automatic eval. A user still has to implement
:class: `AggregationLogic`.
Arguments:
evaluation_repository: The repository that will be used to store evaluation results.
aggregation_repository: The repository that will be used to store aggregation results.
description: Human-readable description for the evaluator.
aggregation_logic: The logic to aggregate the evaluations.
Generics:
Evaluation: Interface of the metrics that come from the evaluated :class:`Task`.
AggregatedEvaluation: The aggregated results of an evaluation run with a :class:`Dataset`.
"""

def __init__(
self,
evaluation_repository: EvaluationRepository,
aggregation_repository: AggregationRepository,
description: str,
aggregation_logic: AggregationLogic[Evaluation, AggregatedEvaluation],
) -> None:
self._evaluation_repository = evaluation_repository
self._aggregation_repository = aggregation_repository
self._aggregation_logic = aggregation_logic
self.description = description

@lru_cache(maxsize=1)
def _get_types(self) -> Mapping[str, type]:
"""Type magic function that gets the actual types of the generic parameters.
Traverses the inheritance history of `BaseEvaluator`-subclass to find an actual type every time a TypeVar is replaced.
Returns:
Name of generic parameter to the type found.
"""

def is_eligible_subclass(parent: type) -> bool:
return hasattr(parent, "__orig_bases__") and issubclass(
parent, AggregationLogic
)

def update_types() -> None:
num_types_set = 0
for current_index, current_type in enumerate(current_types):
if type(current_type) is not TypeVar:
type_var_count = num_types_set - 1
for element_index, element in enumerate(type_list):
if type(element) is TypeVar:
type_var_count += 1
if type_var_count == current_index:
break
assert type_var_count == current_index
type_list[element_index] = current_type
num_types_set += 1

# mypy does not know __orig_bases__
base_types = AggregationLogic.__orig_bases__[1] # type: ignore
type_list: list[type | TypeVar] = list(get_args(base_types))

possible_parent_classes = [
p
for p in reversed(type(self._aggregation_logic).__mro__)
if is_eligible_subclass(p)
]
for parent in possible_parent_classes:
# mypy does not know __orig_bases__
for base in parent.__orig_bases__: # type: ignore
origin = get_origin(base)
if origin is None or not issubclass(origin, AggregationLogic):
continue
current_types = list(get_args(base))
update_types()

return {
name: param_type
for name, param_type in zip(
(a.__name__ for a in get_args(base_types)), type_list
)
if type(param_type) is not TypeVar
}

def evaluation_type(self) -> type[Evaluation]:
"""Returns the type of the evaluation result of an example.
This can be used to retrieve properly typed evaluations of an evaluation run
from a :class:`EvaluationRepository`
Returns:
Returns the type of the evaluation result of an example.
"""
try:
evaluation_type = self._get_types()["Evaluation"]
except KeyError:
raise TypeError(
f"Alternatively overwrite evaluation_type() in {type(self)}"
)
return cast(type[Evaluation], evaluation_type)

@final
def aggregate_evaluation(
self, *eval_ids: str
) -> AggregationOverview[AggregatedEvaluation]:
"""Aggregates all evaluations into an overview that includes high-level statistics.
Aggregates :class:`Evaluation`s according to the implementation of :func:`BaseEvaluator.aggregate`.
Args:
evaluation_overview: An overview of the evaluation to be aggregated. Does not include
actual evaluations as these will be retrieved from the repository.
Returns:
An overview of the aggregated evaluation.
"""

def load_eval_overview(eval_id: str) -> EvaluationOverview:
evaluation_overview = self._evaluation_repository.evaluation_overview(
eval_id
)
if not evaluation_overview:
raise ValueError(
f"No PartialEvaluationOverview found for eval-id: {eval_id}"
)
return evaluation_overview

evaluation_overviews = frozenset(load_eval_overview(id) for id in set(eval_ids))

nested_evaluations = [
self._evaluation_repository.example_evaluations(
overview.id, self.evaluation_type()
)
for overview in evaluation_overviews
]
example_evaluations = [
eval for sublist in nested_evaluations for eval in sublist
]

successful_evaluations = CountingFilterIterable(
(example_eval.result for example_eval in example_evaluations),
lambda evaluation: not isinstance(evaluation, FailedExampleEvaluation),
)
id = str(uuid4())
start = utc_now()
statistics = self._aggregation_logic.aggregate(
cast(Iterable[Evaluation], successful_evaluations)
)

aggregation_overview = AggregationOverview(
evaluation_overviews=frozenset(evaluation_overviews),
id=id,
start=start,
end=utc_now(),
successful_evaluation_count=successful_evaluations.included_count(),
crashed_during_eval_count=successful_evaluations.excluded_count(),
description=self.description,
statistics=statistics,
)
self._aggregation_repository.store_aggregation_overview(aggregation_overview)
return aggregation_overview
48 changes: 39 additions & 9 deletions src/intelligence_layer/evaluation/argilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
RecordData,
)
from intelligence_layer.core import Input, InstructInput, Output, PromptOutput
from intelligence_layer.evaluation import Aggregator
from intelligence_layer.evaluation.accumulator import MeanAccumulator
from intelligence_layer.evaluation.base_logic import AggregationLogic, EvaluationLogic
from intelligence_layer.evaluation.data_storage.aggregation_repository import (
Expand Down Expand Up @@ -71,56 +72,85 @@ def _to_record(


class ArgillaEvaluator(
Evaluator[Input, Output, ExpectedOutput, ArgillaEvaluation, AggregatedEvaluation],
Evaluator[Input, Output, ExpectedOutput, ArgillaEvaluation],
ABC,
):
"""Evaluator used to integrate with Argilla (https://github.com/argilla-io/argilla).
Use this evaluator if you would like to easily do human eval.
This evaluator runs a dataset and sends the input and output to Argilla to be evaluated.
After they have been evaluated, you can fetch the results by using the `aggregate_evaluation` method.
Arguments:
dataset_repository: The repository with the examples that will be taken for the evaluation.
run_repository: The repository of the runs to evaluate.
evaluation_repository: The repository that will be used to store evaluation results.
aggregation_repository: The repository that will be used to store aggregation results.
description: Human-readable description for the evaluator.
evaluation_logic: The logic to use for evaluation.
aggregation_logic: The logic to aggregate the evaluations.
Generics:
Input: Interface to be passed to the :class:`Task` that shall be evaluated.
Output: Type of the output of the :class:`Task` to be evaluated.
ExpectedOutput: Output that is expected from the run with the supplied input.
ArgillaEvaluation: Interface of the metrics that come from the Argilla task`.
AggregatedEvaluation: The aggregated results of an evaluation run with a :class:`Dataset`.
"""

def __init__(
self,
dataset_repository: DatasetRepository,
run_repository: RunRepository,
evaluation_repository: ArgillaEvaluationRepository,
aggregation_repository: AggregationRepository,
description: str,
evaluation_logic: ArgillaEvaluationLogic[Input, Output, ExpectedOutput],
aggregation_logic: AggregationLogic[ArgillaEvaluation, AggregatedEvaluation],
) -> None:
super().__init__(
dataset_repository,
run_repository,
evaluation_repository,
aggregation_repository,
description,
evaluation_logic, # type: ignore
aggregation_logic, # TODO: check if the non-matching types of the evaluation logic and aggregation logic (in the line above) are a problem
)

def evaluation_type(self) -> type[ArgillaEvaluation]: # type: ignore
return ArgillaEvaluation


class ArgillaAggregator(
Aggregator[ArgillaEvaluation, AggregatedEvaluation],
ABC,
):
"""Aggregator used to aggregate Argilla (https://github.com/argilla-io/argilla) evaluations.
You can fetch the results by using the `aggregate_evaluation` method.
Arguments:
evaluation_repository: The repository that will be used to store evaluation results.
aggregation_repository: The repository that will be used to store aggregation results.
description: Human-readable description for the evaluator.
aggregation_logic: The logic to aggregate the evaluations.
Generics:
ArgillaEvaluation: Interface of the metrics that come from the Argilla task`.
AggregatedEvaluation: The aggregated results of an evaluation run with a :class:`Dataset`.
"""

def evaluation_type(self) -> type[ArgillaEvaluation]: # type: ignore
return ArgillaEvaluation

def __init__(
self,
evaluation_repository: ArgillaEvaluationRepository,
aggregation_repository: AggregationRepository,
description: str,
aggregation_logic: AggregationLogic[ArgillaEvaluation, AggregatedEvaluation],
) -> None:
super().__init__(
evaluation_repository,
aggregation_repository,
description,
aggregation_logic,
)


class AggregatedInstructComparison(BaseModel):
scores: Mapping[str, PlayerScore]

Expand Down
Loading

0 comments on commit a28ab3d

Please sign in to comment.