Skip to content

Commit

Permalink
[AIC-py][eval] fix the types pt 1 (#523)
Browse files Browse the repository at this point in the history
[AIC-py][eval] fix the types pt 1

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/523).
* #539
* __->__ #523
* #522
* #513
  • Loading branch information
jonathanlastmileai authored Dec 19, 2023
2 parents adfb206 + e0b9a09 commit 35a77dd
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 118 deletions.
60 changes: 31 additions & 29 deletions python/src/aiconfig/eval/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
import lastmile_utils.lib.core.api as cu
import result
from aiconfig.eval import common
from pydantic import BaseModel, root_validator
from pydantic import BaseModel
from result import Result

T_InputDatum = TypeVar("T_InputDatum", contravariant=True)
T_OutputDatum = TypeVar("T_OutputDatum", contravariant=True)

T_Evaluable = TypeVar("T_Evaluable", contravariant=True)

T_BaseModel = TypeVar("T_BaseModel", bound=BaseModel)

SerializedJSON = NewType("SerializedJSON", str)


@dataclass
@dataclass(eq=False)
class CustomMetricValue(ABC):
"""
Subclass this if you want your metric to return a type not included in MetricValue.
Expand All @@ -28,27 +30,28 @@ class CustomMetricValue(ABC):
"""


T_MetricValue = TypeVar("T_MetricValue", int, float, str, bool, CustomMetricValue, covariant=True)


class CompletionTextToSerializedJSON(Protocol):
@abstractmethod
def __call__(self, output_datum: str) -> Result[common.SerializedJSON, str]:
pass


MetricValue = int | float | str | bool | CustomMetricValue


@dataclass
class CustomMetricPydanticObject(CustomMetricValue, Generic[T_BaseModel]):
data: T_BaseModel


class EvaluationFunction(Protocol, Generic[T_OutputDatum]):
class EvaluationFunction(Protocol, Generic[T_Evaluable, T_MetricValue]):
@abstractmethod
async def __call__(self, output_datum: T_OutputDatum) -> MetricValue:
async def __call__(self, datum: T_Evaluable) -> T_MetricValue:
pass


class EvaluationMetricMetadata(cu.Record, Generic[T_OutputDatum]):
class EvaluationMetricMetadata(cu.Record, Generic[T_Evaluable, T_MetricValue]):

"""A record to tie together metadata about an evaluation metric
to ensure that numbers are interpreted as intended.
Expand Down Expand Up @@ -83,8 +86,8 @@ def _serialize_extra_metadata(self) -> str:

name: str
description: str
best_value: MetricValue | None = None
worst_value: MetricValue | None = None
best_value: T_MetricValue | None = None
worst_value: T_MetricValue | None = None
# e.g. {"substring": "hello", "case_sensitive": False}
extra_metadata: dict[str, Any] = {}

Expand All @@ -95,56 +98,55 @@ def __repr__(self) -> str:
return f"EvaluationMetricMetadata({s_json})"


class SampleMetricValue(cu.Record, Generic[T_OutputDatum]):
value: MetricValue | None
metric_metadata: EvaluationMetricMetadata[T_OutputDatum]
@dataclass
class SampleMetricValue(Generic[T_Evaluable, T_MetricValue]):
value: T_MetricValue | None
metric_metadata: EvaluationMetricMetadata[T_Evaluable, T_MetricValue]

@root_validator(pre=True)
def check_value_range(cls, values: dict[str, Any]) -> dict[str, Any]:
def __post_init__(self) -> None:
metric_metadata = self.metric_metadata
worst_value, best_value = (
values["metric_metadata"].worst_value,
values["metric_metadata"].best_value,
metric_metadata.worst_value,
metric_metadata.best_value,
)
value = values["value"]
value = self.value
if worst_value is None and best_value is None:
# fine
return values
return
elif worst_value is None or best_value is None:
raise ValueError(
f"""
[{values["metric_metadata"].name}]
{values["metric_metadata"].description}
[{metric_metadata.name}]
{metric_metadata.description}
You must define both worst_value and best_value, or neither.
You defined worst_value = {worst_value} and best_value = {best_value}.
"""
)
elif worst_value == best_value:
raise ValueError("best_value and worst_value cannot be equal")
elif value is not None and worst_value < best_value and not worst_value <= value <= best_value:
elif value is not None and worst_value < best_value and not worst_value <= value <= best_value: # type: ignore[fixme]
raise ValueError(
f"""
[{values["metric_metadata"].name}]
{values["metric_metadata"].description}
[{metric_metadata.name}]
{metric_metadata.description}
Value {value} is not in range [{worst_value}, {best_value}].
You defined worst_value = {worst_value} and best_value = {best_value},
but got value outside that range.
"""
)
elif value is not None and worst_value > best_value and not worst_value >= value >= best_value:
elif value is not None and worst_value > best_value and not worst_value >= value >= best_value: # type: ignore[fixme]
raise ValueError(
f"""
[{values["metric_metadata"].name}]
{values["metric_metadata"].description}
[{metric_metadata.name}]
{metric_metadata.description}
Value {value} is not in range [{worst_value}, {best_value}].
You defined worst_value = {worst_value} and best_value = {best_value},
but got value outside that range.
"""
)
else:
return values


class TextRatingsData(cu.Record):
Expand Down
105 changes: 57 additions & 48 deletions python/src/aiconfig/eval/lib.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Generic, NewType, Sequence, Tuple, TypeVar
from typing import Any, Generic, NewType, Sequence, Tuple, TypeVar

import lastmile_utils.lib.core.api as cu
import pandas as pd
from aiconfig.Config import AIConfigRuntime
from aiconfig.eval.common import MetricValue, SampleMetricValue, T_InputDatum, T_OutputDatum
from aiconfig.eval.common import SampleMetricValue, T_InputDatum, T_MetricValue, T_OutputDatum
from aiconfig.eval.metrics import Metric
from result import Err, Ok, Result

logging.basicConfig(format=cu.LOGGER_FMT)
LOGGER = logging.getLogger(__name__)


# TODO: figure out a way to do heterogenous list without Any
# Each test is a (input_datum, Metric) pair
UserTestSuiteWithInputs = Sequence[Tuple[str, Metric[str]]]
UserTestSuiteWithInputs = Sequence[Tuple[str, Metric[str, Any]]]

# Each test is a (output_datum, Metric) pair
UserTestSuiteOutputsOnly = Sequence[Tuple[str, Metric[str]]]
UserTestSuiteOutputsOnly = Sequence[Tuple[str, Metric[str, Any]]]


class TestSuiteWithInputsSettings(cu.Record):
Expand Down Expand Up @@ -62,68 +64,75 @@ class NumericalEvalDataset(cu.Record):
# TODO:
# GenericBeforeBaseModelWarning: Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) for pydantic generics to work properly.
# But swapping the order breaks
class SampleEvaluationResult(Generic[T_InputDatum, T_OutputDatum], cu.Record):
class SampleEvaluationResult(Generic[T_InputDatum, T_OutputDatum, T_MetricValue], cu.Record):
input_datum: T_InputDatum | None
output_datum: T_OutputDatum
metric_value: SampleMetricValue[T_OutputDatum]

Tuple[T_OutputDatum, SampleMetricValue[T_OutputDatum]]


DatasetEvaluationResult = Sequence[SampleEvaluationResult[T_InputDatum, T_OutputDatum]]
metric_value: SampleMetricValue[T_OutputDatum, T_MetricValue]


@dataclass(frozen=True)
class SampleEvaluationParams(Generic[T_InputDatum, T_OutputDatum]):
class SampleEvaluationParams(Generic[T_InputDatum, T_OutputDatum, T_MetricValue]):
# input_sample doesn't _need_ to be here, because we already have
# output_sample ready to input to eval.
# input_sample is here for documentation/debugging.
input_sample: T_InputDatum | None
output_sample: T_OutputDatum
metric: Metric[T_OutputDatum]
metric: Metric[T_OutputDatum, T_MetricValue]

def __str__(self) -> str:
return f"\nSampleEvaluationParams:\n\t{self.output_sample=}\n\t{self.metric=}"


async def evaluate(
evaluation_params_list: Sequence[SampleEvaluationParams[T_InputDatum, T_OutputDatum]],
) -> Result[DatasetEvaluationResult[T_InputDatum, T_OutputDatum], str]:
results: Sequence[SampleEvaluationResult[T_InputDatum, T_OutputDatum]] = []
# TODO: don't use Any.
DatasetEvaluationResult = Sequence[SampleEvaluationResult[T_InputDatum, T_OutputDatum, Any]]
DatasetEvaluationParams = Sequence[SampleEvaluationParams[T_InputDatum, T_OutputDatum, Any]]
MetricList = list[Metric[T_OutputDatum, Any]]

for eval_params in evaluation_params_list:
sample, metric = (
eval_params.output_sample,
eval_params.metric,
)

async def _calculate() -> MetricValue:
return await metric.evaluation_fn(sample)

def _ok_with_log(res_: Result[MetricValue, str]) -> MetricValue | None:
match res_:
case Ok(res):
return res
case Err(e):
LOGGER.error(f"Error evaluating {eval_params=}: {e}")
return None

# TODO: figure out the right timeout
res_ = await cu.run_thunk_safe(_calculate(), timeout=1)
result = SampleEvaluationResult(
input_datum=eval_params.input_sample,
output_datum=sample,
metric_value=SampleMetricValue(value=_ok_with_log(res_), metric_metadata=metric.metric_metadata),
)
results.append(result)
async def _evaluate_for_sample(
eval_params: SampleEvaluationParams[T_InputDatum, T_OutputDatum, T_MetricValue]
) -> SampleEvaluationResult[T_InputDatum, T_OutputDatum, T_MetricValue]:
sample, metric = (
eval_params.output_sample,
eval_params.metric,
)

async def _calculate() -> T_MetricValue:
return await metric.evaluation_fn(sample)

def _ok_with_log(res_: Result[T_MetricValue, str]) -> T_MetricValue | None:
match res_:
case Ok(res):
return res
case Err(e):
LOGGER.error(f"Error evaluating {eval_params=}: {e}")
return None

# TODO: figure out the right timeout
res_ = await cu.run_thunk_safe(_calculate(), timeout=1)
result = SampleEvaluationResult(
input_datum=eval_params.input_sample,
output_datum=sample,
metric_value=SampleMetricValue(
#
value=_ok_with_log(res_),
metric_metadata=metric.metric_metadata,
),
)
return result

return Ok(results)

async def evaluate(
evaluation_params_list: DatasetEvaluationParams[T_InputDatum, T_OutputDatum],
) -> Result[DatasetEvaluationResult[T_InputDatum, T_OutputDatum], str]:
return Ok(await asyncio.gather(*map(_evaluate_for_sample, evaluation_params_list)))


def eval_res_to_df(
eval_res: DatasetEvaluationResult[T_InputDatum, T_OutputDatum],
) -> pd.DataFrame:
records: list[dict[str, None | MetricValue | T_InputDatum | T_OutputDatum]] = []
# TODO: dont use Any
records: list[dict[str, Any]] = []
for sample_res in eval_res:
records.append(
dict(
Expand All @@ -148,17 +157,17 @@ def eval_res_to_df(

async def user_test_suite_with_inputs_to_eval_params_list(
test_suite: UserTestSuiteWithInputs, prompt_name: str, aiconfig: AIConfigRuntime
) -> Result[Sequence[SampleEvaluationParams[TextInput, TextOutput]], str]:
) -> Result[DatasetEvaluationParams[TextInput, TextOutput], str]:
"""
Example in/out:
[("hello", brevity)] -> [SampleEvaluationParams("hello", "output_is_world", brevity)]
"""
out: list[SampleEvaluationParams[TextInput, TextOutput]] = []
out: DatasetEvaluationParams[TextInput, TextOutput] = []

# Group by input so that we only run each input through the AIConfig once.
# This is sort of an optimization because the user can give the same input
# multiple times (with different metrics).
input_to_metrics_mapping: dict[str, list[Metric[str]]] = {}
input_to_metrics_mapping: dict[str, MetricList[TextOutput]] = {}
for input_datum, metric in test_suite:
if input_datum not in input_to_metrics_mapping:
input_to_metrics_mapping[input_datum] = []
Expand Down Expand Up @@ -202,7 +211,7 @@ def _zip_inputs_outputs(outputs: list[TextOutput]):

def user_test_suite_outputs_only_to_eval_params_list(
test_suite: UserTestSuiteOutputsOnly,
) -> Sequence[SampleEvaluationParams[TextInput, TextOutput]]:
) -> DatasetEvaluationParams[TextInput, TextOutput]:
"""
Example: [("the_output_is_world", brevity)] -> [SampleEvaluationParams(None, "the_output_is_world", brevity)
"""
Expand Down Expand Up @@ -241,7 +250,7 @@ async def run_test_suite_helper(
) -> Result[pd.DataFrame, str]:
async def _get_eval_params_list(
test_suite_spec: TestSuiteSpec,
) -> Result[Sequence[SampleEvaluationParams[TextInput, TextOutput]], str]:
) -> Result[DatasetEvaluationParams[TextInput, TextOutput], str]:
match test_suite_spec:
case TestSuiteWithInputsSpec(test_suite=test_suite, prompt_name=prompt_name, aiconfig=aiconfig):
return await user_test_suite_with_inputs_to_eval_params_list(test_suite, prompt_name, aiconfig)
Expand Down
Loading

0 comments on commit 35a77dd

Please sign in to comment.