diff --git a/src/intelligence_layer/evaluation/evaluation/elo_evaluator.py b/src/intelligence_layer/evaluation/evaluation/elo_evaluator.py index 3c3347bcc..8f0bcfe90 100644 --- a/src/intelligence_layer/evaluation/evaluation/elo_evaluator.py +++ b/src/intelligence_layer/evaluation/evaluation/elo_evaluator.py @@ -1,13 +1,18 @@ from abc import abstractmethod +from itertools import combinations +import random from typing import Sequence +from aleph_alpha_client import Client from pydantic import BaseModel -from intelligence_layer.core import Input, NoOpTracer, Output, Tracer +from intelligence_layer.core import Input, NoOpTracer, Output, Tracer, Task, Language from intelligence_layer.evaluation import EvaluationLogic from intelligence_layer.evaluation.aggregation.elo import MatchOutcome -from intelligence_layer.evaluation.dataset.domain import Example +from intelligence_layer.evaluation.dataset.domain import Example, ExpectedOutput +from intelligence_layer.evaluation.evaluation.elo_graders.llama_grader import LlamaGradingInput from intelligence_layer.evaluation.run.domain import SuccessfulExampleOutput +from intelligence_layer.examples.qa.single_chunk_qa import QA_INSTRUCTIONS class Match(BaseModel): @@ -16,50 +21,77 @@ class Match(BaseModel): outcome: MatchOutcome -class EloEvaluationLogic(EvaluationLogic[Input, Output, str, list[Match]]): - # def __init__( - # self, - # # client: Client, - # tracer: Tracer = NoOpTracer(), - # ): - # self._tracer = tracer - # # self._grader = Grader( ## TODO - # # LlamaControlModel(name="llama-2-70b-chat", client=client) - # # ) +class Matches(BaseModel): + matches: Sequence[Match] - @abstractmethod - def do_evaluate( + +class EloEvaluationLogic(EvaluationLogic[Input, Output, str, Matches]): + """Evaluation logic for a pair-wise ELO comparison. + + Args: + grader: The :class:`Task` that perform the grading, i.e. the actual comparison of two run outputs. + tracer: :class:`Tracer` for tracking and debugging + + """ + def __init__( self, - example: Example[Input, str], - *output: SuccessfulExampleOutput[Output], - ) -> list[Match]: - # pairs = combinations(output, 2) - # return Matches( - # matches=[ - # self._run_grader(first, second, example) - # for [first, second] in pairs - # if self._high_priority_runs is None ##TODO: Adapts to iterative elo class - # or len(self._high_priority_runs) == 0 - # or first.run_id in self._high_priority_runs - # or second.run_id in self._high_priority_runs - # ] - # ) - pass + grader: Task[Input, MatchOutcome], + tracer: Tracer = NoOpTracer(), + ): + self.tracer = tracer + self.grader = grader + + def do_evaluate( + self, + example: Example[str, str], + *output: SuccessfulExampleOutput[str], + ) -> Matches: + pairs = combinations(output, 2) + return Matches(matches=[ + Match( + outcome=self._run_grader(first, second, example), + player_a=first.run_id, + player_b=second.run_id, + ) + for [first, second] in pairs + ]) + @abstractmethod def _run_grader( self, first: SuccessfulExampleOutput[Output], second: SuccessfulExampleOutput[Output], example: Example[Input, str], ) -> Match: - pass - # if random.choice([True, False]): - # first, second = second, first - # - # - # - # return Match( - # outcome='str', - # player_a=first.run_id, - # player_b=second.run_id, - # ) + """Compare two run outputs to each other and return a :class:`Match` that contains the result of the comparison. + + Args: + first: :class:`SuccessfulExampleOutput`, the first example of the comparison. + second: :class:`SuccessfulExampleOutput`, the second example of the comparison. + example: :class:`Example`, the Example `first` and `second` are based on + + Returns: :class:`Match` that contains the result of the comparison + """ + if random.choice([True, False]): + first, second = second, first + + qa_instruction = QA_INSTRUCTIONS[Language("en")].unformatted_instruction.format(question=example.input.question) + no_answer = "There is no answer." + grading_input = LlamaGradingInput( + instruction=f"{example.input.chunk} {qa_instruction}", + first_completion=( + first.output.answer if first.output.answer is not None else no_answer + ), + second_completion=( + second.output.answer if second.output.answer is not None else no_answer + ), + ) + + grading_output = self.grader.run(grading_input, self.tracer) + + return Match( + outcome=grading_output, + player_a=first.run_id, + player_b=second.run_id, + ) + diff --git a/src/intelligence_layer/evaluation/evaluation/elo_graders/llama_grader.py b/src/intelligence_layer/evaluation/evaluation/elo_graders/llama_grader.py new file mode 100644 index 000000000..f201efe0d --- /dev/null +++ b/src/intelligence_layer/evaluation/evaluation/elo_graders/llama_grader.py @@ -0,0 +1,104 @@ +import math +from typing import Mapping, Sequence + +from aleph_alpha_client import Prompt +from intelligence_layer.core import ( + CompleteInput, + CompleteOutput, ControlModel, +) +from intelligence_layer.core import Task +from intelligence_layer.core import TaskSpan +from pydantic import BaseModel +from intelligence_layer.evaluation import MatchOutcome + +from intelligence_layer.core import ( + Llama3InstructModel, +) + + +class LlamaGradingInput(BaseModel): + instruction: str + first_completion: str + second_completion: str + + +class LlamaGrader(Task[LlamaGradingInput, MatchOutcome]): + INPUT_TEMPLATE = """ +Your task is to compare two answers to an instruction on one metric. + +Please make sure you read and understand these instruction carefully. Please keep this document open while reviewing, and refer to it as needed. + +The Instruction for the answers was:{instruction} + +Evaluation Procedure: +1. Read both answers carefully and identify the main facts and details they present. +2. Check if the answers contain any factual errors that are not supported by the instruction. +3. Evaluate which answer is more correct. + +Answer A:{first_completion} + +Answer B:{second_completion} + +Which answer is more correct given the Instruction and Evaluation Procedure, Answer A or Answer B? + +Response: Answer """ + VALUES = [ + " A", + " B", + ] # The space before the A and B is important due to tokenization + + def __init__(self): + super().__init__() + self._model = Llama3InstructModel + + def do_run(self, input: LlamaGradingInput, task_span: TaskSpan) -> MatchOutcome: + text = self.INPUT_TEMPLATE.format( + instruction=input.instruction, + first_completion=input.first_completion, + second_completion=input.second_completion, + ) + + complete_input = self._create_complete_input(Prompt.from_text(text)) + complete_output = self._model.complete_task().run(complete_input, task_span) + + return self._calculate_winners(complete_output) + + def _create_complete_input(self, prompt: Prompt) -> CompleteInput: + return CompleteInput( + prompt=prompt, + maximum_tokens=1, + log_probs=3, + disable_optimizations=True, + ) + + def _calculate_winners(self, complete_output: CompleteOutput) -> MatchOutcome: + default_log_prob = float("-inf") + + def get_normalized_prob( + log_prob_list: Sequence[Mapping[str, float | None]] | None + ) -> float: + assert log_prob_list is not None + log_probs = log_prob_list[0] + values = [ + math.exp(log_probs.get(str(key), default_log_prob) or default_log_prob) + for key in self.VALUES + ] + if all(v == 0 for v in values): + raise ValueError( + f"LLM evaluation response does not contain logprobs for the required tokens for the values: {self.VALUES}" + ) + normalized_A_prob = values[0] / sum(values) + return normalized_A_prob + + def categorize_value(value: float) -> MatchOutcome: + if value > 0.7: + return MatchOutcome.A_WINS + elif 0.3 > value: + return MatchOutcome.B_WINS + else: + return MatchOutcome.DRAW + + normalized_probability = get_normalized_prob( + complete_output.completions[0].log_probs + ) + return categorize_value(normalized_probability) diff --git a/tests/evaluation/test_elo_evaluator.py b/tests/evaluation/test_elo_evaluator.py index e9f57c9f1..14ae10725 100644 --- a/tests/evaluation/test_elo_evaluator.py +++ b/tests/evaluation/test_elo_evaluator.py @@ -1,13 +1,23 @@ from itertools import combinations +from typing import Sequence, Tuple +from aleph_alpha_client import Client + +from intelligence_layer.core import Task, utc_now, Language +from intelligence_layer.evaluation import InMemoryEvaluationRepository, Evaluator from intelligence_layer.evaluation.aggregation.elo import MatchOutcome from intelligence_layer.evaluation.dataset.domain import Example +from intelligence_layer.evaluation.dataset.in_memory_dataset_repository import InMemoryDatasetRepository from intelligence_layer.evaluation.evaluation.elo_evaluator import ( EloEvaluationLogic, - Match, + Match, Matches, ) -from intelligence_layer.evaluation.run.domain import SuccessfulExampleOutput +from intelligence_layer.evaluation.run.domain import SuccessfulExampleOutput, ExampleOutput, RunOverview +from intelligence_layer.evaluation.run.in_memory_run_repository import InMemoryRunRepository +from pytest import fixture + +from intelligence_layer.examples import SingleChunkQaOutput, SingleChunkQaInput def choose_winner(first: SuccessfulExampleOutput[str], second: SuccessfulExampleOutput[str] @@ -23,20 +33,113 @@ def choose_winner(first: SuccessfulExampleOutput[str], second: SuccessfulExample class LexicographicELoComparisonEvaluationLogic( EloEvaluationLogic[str, str] ): - def do_evaluate( + def _run_grader( self, example: Example[str, str], *output: SuccessfulExampleOutput[str], - ) -> list[Match]: + ) -> Matches: pairs = combinations(output, 2) - return [ + return Matches(matches=[ Match( outcome=choose_winner(first, second), player_a=first.run_id, player_b=second.run_id, ) for [first, second] in pairs - ] + ]) + + +@fixture +def in_memory_dataset_repository() -> InMemoryDatasetRepository: + return InMemoryDatasetRepository() + + +@fixture +def in_memory_run_repository() -> InMemoryRunRepository: + return InMemoryRunRepository() + + +@fixture +def in_memory_evaluation_repository() -> InMemoryEvaluationRepository: + return InMemoryEvaluationRepository() + + +@fixture +def elo_evaluation_logic() -> EloEvaluationLogic: + return LexicographicELoComparisonEvaluationLogic(grader=Task[None, None]) + + +@fixture +def elo_evaluator( + in_memory_dataset_repository: InMemoryDatasetRepository, + in_memory_run_repository: InMemoryRunRepository, + in_memory_evaluation_repository: InMemoryEvaluationRepository, + elo_evaluation_logic: EloEvaluationLogic, +) -> Evaluator: + return Evaluator( + in_memory_dataset_repository, + in_memory_run_repository, + in_memory_evaluation_repository, + "Testing", + elo_evaluation_logic, + ) + + +@fixture +def qa_outputs() -> Sequence[SingleChunkQaOutput]: + return [ + SingleChunkQaOutput(answer=answer, highlights=[]) + for answer in [ + "Surface micromachining builds microstructures.", + "Surface micromachining builds microstructures. This is done by deposition and etching structural layers over a substrate.", + "Surface micromachining builds microstructures by deposition and etching structural layers over a substrate. This is different from Bulk micromachining, in which a silicon substrate wafer is selectively etched to produce structures.", + ] + ] + + +@fixture +def qa_setup( + in_memory_dataset_repository: InMemoryDatasetRepository, + in_memory_run_repository: InMemoryRunRepository, + qa_outputs: Sequence[SingleChunkQaOutput], +) -> Tuple[Sequence[str], str]: + + qa_input_text = """Surface micromachining builds microstructures by deposition and etching structural layers over a substrate.[1] This is different from Bulk micromachining, in which a silicon substrate wafer is selectively etched to produce structures.""" + qa_input = SingleChunkQaInput( + chunk=qa_input_text, question="What is micromachining?", language=Language("en") + ) + expected_output = "Surface micromachining builds microstructures by deposition and etching structural layers over a substrate." + + example_id = "some-example-id" + dataset_id = in_memory_dataset_repository.create_dataset( + examples=[ + Example(input=qa_input, expected_output=expected_output, id=example_id) + ], + dataset_name="some-example-dataset-name", + ).id + + run_ids = [f"some-run-id-{i}" for i in range(len(qa_outputs))] + for i, output in enumerate(qa_outputs): + in_memory_run_repository.store_example_output( + example_output=ExampleOutput( + run_id=run_ids[i], + example_id=example_id, + output=output, + ) + ) + in_memory_run_repository.store_run_overview( + RunOverview( + dataset_id=dataset_id, + id=run_ids[i], + start=utc_now(), + end=utc_now(), + failed_example_count=0, + successful_example_count=len(qa_outputs), + description="runner", + ) + ) + + return run_ids, dataset_id def test_choose_winner_should_return_contestant_with_lower_run_id(): @@ -63,13 +166,36 @@ def test_do_evaluate_should_build_correct_matches(): contestant_c = SuccessfulExampleOutput[str](run_id="c", example_id="_", output="_") contestants = [contestant_a, contestant_b, contestant_c] - evaluation_logic = LexicographicELoComparisonEvaluationLogic() + evaluation_logic = LexicographicELoComparisonEvaluationLogic(grader=Task[None, None]) - matches = evaluation_logic.do_evaluate(example, *contestants) + matches = evaluation_logic._run_grader(example, *contestants).matches for match in matches: assert isinstance(match, Match) if match.player_a < match.player_b: assert match.outcome == MatchOutcome.A_WINS elif match.player_a > match.player_b: - assert match.outcome == MatchOutcome.B_WINS \ No newline at end of file + assert match.outcome == MatchOutcome.B_WINS + + +def test_full_elo_eval_run(qa_setup: Tuple[Sequence[str], str], # TODO: Better name + elo_evaluator: Evaluator, + in_memory_dataset_repository: InMemoryDatasetRepository, + in_memory_run_repository: InMemoryRunRepository, + in_memory_evaluation_repository: InMemoryEvaluationRepository) -> None: + run_ids, _ = qa_setup + + evaluation_overview = elo_evaluator.evaluate_runs(run_ids[0], run_ids[1]) + + new_elo_qa_evaluator = Evaluator( + in_memory_dataset_repository, + in_memory_run_repository, + in_memory_evaluation_repository, + "Testing", + evaluation_logic=EloEvaluationLogic(grader=Task[str,str]) + ) + + new_evaluation_overview = new_elo_qa_evaluator.evaluate_runs(*run_ids) + + # TODO check if above code runs and add assertions +