Skip to content

Commit

Permalink
[AIC-py][eval] allow arbitrary metric return type
Browse files Browse the repository at this point in the history
- allow different primitive types or arbitrary user type
  (must be a dataclass implementing `CustomMetricValue`)
- change existing metrics brevity and substring_match to their proper types
  (technically this is backwards-incompatible but makes more sense and should be
  a welcome change)
- add example metrics of types orderable string and arbitrary (non-orderable)
  container type, respectively. Both use nltk sentiment scores.
  • Loading branch information
jonathanlastmileai committed Dec 15, 2023
1 parent 5cce4eb commit 9d6c233
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 156 deletions.
3 changes: 2 additions & 1 deletion python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ prompt_toolkit
mock
pytest-asyncio
lastmile-utils==0.0.9
hypothesis==6.91.0
hypothesis==6.91.0
nltk
1 change: 1 addition & 0 deletions python/src/aiconfig/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = []
10 changes: 4 additions & 6 deletions python/src/aiconfig/eval/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@
TestSuiteWithInputsSettings,
)
"""
# pyright: reportWildcardImportFromLibrary=false
from ..lib import (
TestSuiteWithInputsSettings,
run_test_suite_with_inputs,
run_test_suite_outputs_only,
)
from .. import metrics

# pyright: reportWildcardImportFromLibrary=false
from ..lib import TestSuiteWithInputsSettings, run_test_suite_outputs_only, run_test_suite_with_inputs
from ..metrics import Metric, brevity, substring_match

__all__ = [
"Metric",
"metrics",
"brevity",
"substring_match",
"run_test_suite_with_inputs",
Expand Down
108 changes: 47 additions & 61 deletions python/src/aiconfig/eval/common.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,49 @@
import json
from typing import Any, Generic
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Protocol, TypeVar

import lastmile_utils.lib.core.api as cu
from pydantic import root_validator


T_InputDatum = TypeVar("T_InputDatum", contravariant=True)
T_OutputDatum = TypeVar("T_OutputDatum", contravariant=True)


@dataclass
class CustomMetricValue(ABC):
"""
Subclass this if you want your metric to return a type not included in MetricValue.
A subclass (an implemntation of CustomMetricValue) can either be ordered or unordered.
If ordered, it must implement the comparison operators <, <=, >, and >=.
See TextOverallPositiveSentiment for example.
See EvaluationMetricMetadata for more information about ordered metrics.
"""


MetricValue = int | float | str | bool | CustomMetricValue


class EvaluationFunction(Protocol, Generic[T_OutputDatum]):
@abstractmethod
def __call__(self, output_datum: T_OutputDatum) -> float:
def __call__(self, output_datum: T_OutputDatum) -> MetricValue:
pass


class EvaluationMetricMetadata(cu.Record, Generic[T_OutputDatum]):
"""A record to tie together metadata about an evaluation metric
to ensure that numbers are interpreted as intended.
IMPORTANT NOTE:
**This property is part of the contract of implementing a metric**:
`id` must uniquely identify the metric to disambiguate different
conceptual quantities in case they happen to share a name.
If you write a metric, to write an automated test for this property,
see `test_metric_library_id_properties()`.
Illustration of the property:
```
# Helper function
def extract_id(matcher, matcher_input):
return matcher(matcher_input).metric_metadata.id
# These two metrics share everything but the substring.
matcher1 = substring_match("hello")
matcher2 = substring_match("world")
# Run both on the same input.
the_input = "hello world"
# They have distinct IDs because of the two different substrings.
assert extract_id(matcher1, the_input) != extract_id(matcher2, the_input)
```
`id` must however be a _constant_ with respect to the metric input.
Illustration of the property:
```
the_matcher = substring_match("hello")
input1 = "hello world"
input2 = "the quick brown fox"
assert extract_id(the_matcher, input1) == extract_id(the_matcher, input2)
```
See `metrics.substring_match()` for an example of how to set an id.
Assumptions:
* Metric type is float
(bools and ints have to be represented as floats; tensors are not supported)
* Range is a single interval with one endpoint being the best possible value
and the opposite endpoint the worst possible value.
* The metric either gets better or worse along the entire range.
* If the best and worst values are not None, then the metric is assumed to be ordered.
In this case (if the metric is ordered) then the comparison operators <, <=, >, and >=
must be implemented (see CustomMetricValue).
If a metric is ordered, the domain is assumed to be a single closed interval or fully-ordered discrete set
with one endpoint being the best possible value and
the opposite endpoint the worst possible value.
* Furthermore, if a metric is ordered, it is implicitly associated with a monotonic function of "goodness".
That is, the metric either gets better along the entire domain, or worse along the entire domain.
There are two cases: higher-is-better and lower-is-better.
Examples:
- Accuracy (higher-is-better): range = 0 -> 1. Worst score is 0, best is 1.
Expand All @@ -82,24 +57,22 @@ def extract_id(matcher, matcher_input):
@property
def id(self) -> str:
return cu.hash_id(
f"{self.name}{self.description}{self.best_value}{self.worst_value}params={self._serialize_extra_metadata()}".encode(
"utf-8"
)
f"{self.name}{self.description}{self.best_value}{self.worst_value}params={self._serialize_extra_metadata()}".encode("utf-8")
)

def _serialize_extra_metadata(self) -> str:
return json.dumps(self.extra_metadata, sort_keys=True)

name: str
description: str
best_value: float
worst_value: float
best_value: MetricValue | None = None
worst_value: MetricValue | None = None
# e.g. {"substring": "hello", "case_sensitive": False}
extra_metadata: dict[str, Any] = {}


class SampleMetricValue(cu.Record, Generic[T_OutputDatum]):
value: float
value: MetricValue
metric_metadata: EvaluationMetricMetadata[T_OutputDatum]

@root_validator(pre=True)
Expand All @@ -109,9 +82,22 @@ def check_value_range(cls, values: dict[str, Any]) -> dict[str, Any]:
values["metric_metadata"].best_value,
)
value = values["value"]
if worst_value == best_value:
if worst_value is None and best_value is None:
# fine
return values
elif worst_value is None or best_value is None:
raise ValueError(
f"""
[{values["metric_metadata"].name}]
{values["metric_metadata"].description}
You must define both worst_value and best_value, or neither.
You defined worst_value = {worst_value} and best_value = {best_value}.
"""
)
elif worst_value == best_value:
raise ValueError("best_value and worst_value cannot be equal")
if worst_value < best_value and not worst_value <= value <= best_value:
elif worst_value < best_value and not worst_value <= value <= best_value:
raise ValueError(
f"""
[{values["metric_metadata"].name}]
Expand All @@ -122,7 +108,7 @@ def check_value_range(cls, values: dict[str, Any]) -> dict[str, Any]:
but got value outside that range.
"""
)
if worst_value > best_value and not worst_value >= value >= best_value:
elif worst_value > best_value and not worst_value >= value >= best_value:
raise ValueError(
f"""
[{values["metric_metadata"].name}]
Expand All @@ -133,5 +119,5 @@ def check_value_range(cls, values: dict[str, Any]) -> dict[str, Any]:
but got value outside that range.
"""
)

return values
else:
return values
2 changes: 1 addition & 1 deletion python/src/aiconfig/eval/examples/travel/travel_eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
46 changes: 11 additions & 35 deletions python/src/aiconfig/eval/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,9 @@
import lastmile_utils.lib.core.api as cu
import pandas as pd
from aiconfig.Config import AIConfigRuntime
from result import Ok, Result

from aiconfig.eval.common import (
SampleMetricValue,
T_InputDatum,
T_OutputDatum,
)
from aiconfig.eval.common import MetricValue, SampleMetricValue, T_InputDatum, T_OutputDatum
from aiconfig.eval.metrics import Metric
from result import Ok, Result

logging.basicConfig(format=cu.LOGGER_FMT)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,12 +87,8 @@ def __str__(self) -> str:


def evaluate(
evaluation_params_list: Sequence[
SampleEvaluationParams[T_InputDatum, T_OutputDatum]
],
) -> Result[
DatasetEvaluationResult[T_InputDatum, T_OutputDatum], str
]: # pyright: ignore[fixme, reportInvalidTypeVarUse]
evaluation_params_list: Sequence[SampleEvaluationParams[T_InputDatum, T_OutputDatum]],
) -> Result[DatasetEvaluationResult[T_InputDatum, T_OutputDatum], str]: # pyright: ignore[fixme, reportInvalidTypeVarUse]
results: Sequence[SampleEvaluationResult[T_InputDatum, T_OutputDatum]] = []

for eval_params in evaluation_params_list:
Expand All @@ -111,9 +102,7 @@ def evaluate(
result = SampleEvaluationResult(
input_datum=eval_params.input_sample,
output_datum=sample,
metric_value=SampleMetricValue(
value=res_, metric_metadata=metric.metric_metadata
),
metric_value=SampleMetricValue(value=res_, metric_metadata=metric.metric_metadata),
)
results.append(result)

Expand All @@ -123,7 +112,7 @@ def evaluate(
def eval_res_to_df(
eval_res: DatasetEvaluationResult[T_InputDatum, T_OutputDatum],
) -> pd.DataFrame:
records: list[dict[str, None | str | float | T_InputDatum | T_OutputDatum]] = []
records: list[dict[str, None | MetricValue | T_InputDatum | T_OutputDatum]] = []
for sample_res in eval_res:
records.append(
dict(
Expand Down Expand Up @@ -167,9 +156,7 @@ async def user_test_suite_with_inputs_to_eval_params_list(
all_inputs = list(input_to_metrics_mapping.keys())

async def _run(input_datum: str) -> Result[TextOutput, str]:
return (await run_aiconfig_helper(aiconfig, prompt_name, input_datum)).map(
TextOutput
)
return (await run_aiconfig_helper(aiconfig, prompt_name, input_datum)).map(TextOutput)

# TODO: fix the race condition and then use gather
# https://github.com/lastmile-ai/aiconfig/issues/434
Expand Down Expand Up @@ -208,17 +195,10 @@ def user_test_suite_outputs_only_to_eval_params_list(
"""
Example: [("the_output_is_world", brevity)] -> [SampleEvaluationParams(None, "the_output_is_world", brevity)
"""
return [
SampleEvaluationParams(
input_sample=None, output_sample=TextOutput(output_datum), metric=metric
)
for output_datum, metric in test_suite
]
return [SampleEvaluationParams(input_sample=None, output_sample=TextOutput(output_datum), metric=metric) for output_datum, metric in test_suite]


async def run_aiconfig_helper(
runtime: AIConfigRuntime, prompt_name: str, question: str
) -> Result[str, str]:
async def run_aiconfig_helper(runtime: AIConfigRuntime, prompt_name: str, question: str) -> Result[str, str]:
params = {
"the_query": question,
}
Expand Down Expand Up @@ -252,12 +232,8 @@ async def _get_eval_params_list(
test_suite_spec: TestSuiteSpec,
) -> Result[Sequence[SampleEvaluationParams[TextInput, TextOutput]], str]:
match test_suite_spec:
case TestSuiteWithInputsSpec(
test_suite=test_suite, prompt_name=prompt_name, aiconfig=aiconfig
):
return await user_test_suite_with_inputs_to_eval_params_list(
test_suite, prompt_name, aiconfig
)
case TestSuiteWithInputsSpec(test_suite=test_suite, prompt_name=prompt_name, aiconfig=aiconfig):
return await user_test_suite_with_inputs_to_eval_params_list(test_suite, prompt_name, aiconfig)
case TestSuiteOutputsOnlySpec(test_suite=test_suite):
return Ok(user_test_suite_outputs_only_to_eval_params_list(test_suite))

Expand Down
Loading

0 comments on commit 9d6c233

Please sign in to comment.