-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: WIP add first part of more comprehensive elo test
TASK: IL-394
- Loading branch information
1 parent
16dde84
commit 8c289e5
Showing
3 changed files
with
311 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
src/intelligence_layer/evaluation/evaluation/elo_graders/llama_grader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import math | ||
from typing import Mapping, Sequence | ||
|
||
from aleph_alpha_client import Prompt | ||
from intelligence_layer.core import ( | ||
CompleteInput, | ||
CompleteOutput, ControlModel, | ||
) | ||
from intelligence_layer.core import Task | ||
from intelligence_layer.core import TaskSpan | ||
from pydantic import BaseModel | ||
from intelligence_layer.evaluation import MatchOutcome | ||
|
||
from intelligence_layer.core import ( | ||
Llama3InstructModel, | ||
) | ||
|
||
|
||
class LlamaGradingInput(BaseModel): | ||
instruction: str | ||
first_completion: str | ||
second_completion: str | ||
|
||
|
||
class LlamaGrader(Task[LlamaGradingInput, MatchOutcome]): | ||
INPUT_TEMPLATE = """ | ||
Your task is to compare two answers to an instruction on one metric. | ||
Please make sure you read and understand these instruction carefully. Please keep this document open while reviewing, and refer to it as needed. | ||
The Instruction for the answers was:{instruction} | ||
Evaluation Procedure: | ||
1. Read both answers carefully and identify the main facts and details they present. | ||
2. Check if the answers contain any factual errors that are not supported by the instruction. | ||
3. Evaluate which answer is more correct. | ||
Answer A:{first_completion} | ||
Answer B:{second_completion} | ||
Which answer is more correct given the Instruction and Evaluation Procedure, Answer A or Answer B? | ||
Response: Answer """ | ||
VALUES = [ | ||
" A", | ||
" B", | ||
] # The space before the A and B is important due to tokenization | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self._model = Llama3InstructModel | ||
|
||
def do_run(self, input: LlamaGradingInput, task_span: TaskSpan) -> MatchOutcome: | ||
text = self.INPUT_TEMPLATE.format( | ||
instruction=input.instruction, | ||
first_completion=input.first_completion, | ||
second_completion=input.second_completion, | ||
) | ||
|
||
complete_input = self._create_complete_input(Prompt.from_text(text)) | ||
complete_output = self._model.complete_task().run(complete_input, task_span) | ||
|
||
return self._calculate_winners(complete_output) | ||
|
||
def _create_complete_input(self, prompt: Prompt) -> CompleteInput: | ||
return CompleteInput( | ||
prompt=prompt, | ||
maximum_tokens=1, | ||
log_probs=3, | ||
disable_optimizations=True, | ||
) | ||
|
||
def _calculate_winners(self, complete_output: CompleteOutput) -> MatchOutcome: | ||
default_log_prob = float("-inf") | ||
|
||
def get_normalized_prob( | ||
log_prob_list: Sequence[Mapping[str, float | None]] | None | ||
) -> float: | ||
assert log_prob_list is not None | ||
log_probs = log_prob_list[0] | ||
values = [ | ||
math.exp(log_probs.get(str(key), default_log_prob) or default_log_prob) | ||
for key in self.VALUES | ||
] | ||
if all(v == 0 for v in values): | ||
raise ValueError( | ||
f"LLM evaluation response does not contain logprobs for the required tokens for the values: {self.VALUES}" | ||
) | ||
normalized_A_prob = values[0] / sum(values) | ||
return normalized_A_prob | ||
|
||
def categorize_value(value: float) -> MatchOutcome: | ||
if value > 0.7: | ||
return MatchOutcome.A_WINS | ||
elif 0.3 > value: | ||
return MatchOutcome.B_WINS | ||
else: | ||
return MatchOutcome.DRAW | ||
|
||
normalized_probability = get_normalized_prob( | ||
complete_output.completions[0].log_probs | ||
) | ||
return categorize_value(normalized_probability) |
Oops, something went wrong.