Skip to content

Commit

Permalink
Added type test
Browse files Browse the repository at this point in the history
minor change to type signature of TestPipeline.run abstract method

Added a test that the LLMPipeline.run should return a string
  • Loading branch information
kuraisle committed Oct 18, 2024
1 parent 7907b79 commit 13269d6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
5 changes: 3 additions & 2 deletions Carrot-Assistant/evaluation/evaltypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Any


class EvaluationFramework:
Expand Down Expand Up @@ -32,7 +32,7 @@ class TestPipeline(ABC):
"""

@abstractmethod
def run(self, *args, **kwargs):
def run(self, *args, **kwargs) -> Any:
"""
Run the pipeline
"""
Expand All @@ -48,6 +48,7 @@ class PipelineTest(Generic[M]):
"""

def __init__(self, name: str, pipeline: TestPipeline, metrics: list[M]):
self.name = name
self.pipeline = pipeline
self.metrics = metrics

Expand Down
57 changes: 50 additions & 7 deletions Carrot-Assistant/tests/test_evals.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
import pytest
from jinja2 import Environment, Template

from evaluation.evaltypes import SingleResultPipeline, SingleResultPipelineTest
from evaluation.metrics import ExactMatchMetric
from evaluation.pipelines import LLMPipeline

from options.pipeline_options import LLMModel


class IdentityPipeline(SingleResultPipeline):
def run(self, input_data):
return input_data


class ExactMatchTest(SingleResultPipelineTest):
def __init__(self, name: str, pipeline: SingleResultPipeline):
super().__init__(name, pipeline, [ExactMatchMetric()])

def run_pipeline(self, input_data):
return self.pipeline.run(input_data)



class TestExactMatch:
@pytest.fixture
def identity_pipeline(self):
return IdentityPipeline()

@pytest.fixture
def exact_match_test(self, identity_pipeline):
return SingleResultPipelineTest("Exact Match Test", identity_pipeline, [ExactMatchMetric()])
return SingleResultPipelineTest(
"Exact Match Test", identity_pipeline, [ExactMatchMetric()]
)

@pytest.fixture
def all_match_dataset(self):
Expand All @@ -32,12 +42,20 @@ def no_match_dataset(self):

@pytest.fixture
def half_match_dataset(self):
return [("input1", "input1"), ("input2", "output2"), ("input3", "input3"), ("input4", "output4")]
return [
("input1", "input1"),
("input2", "output2"),
("input3", "input3"),
("input4", "output4"),
]

def run_test(self, test, dataset):
results = [test.evaluate(input_data, expected_output) for input_data, expected_output in dataset]
exact_match_results = [result['ExactMatchMetric'] for result in results]
return sum(exact_match_results) / len(exact_match_results)
results = [
test.evaluate(input_data, expected_output)
for input_data, expected_output in dataset
]
exact_match_results = [result["ExactMatchMetric"] for result in results]
return sum(exact_match_results) / len(exact_match_results)

def test_all_match(self, exact_match_test, all_match_dataset):
assert self.run_test(exact_match_test, all_match_dataset) == 1.0
Expand All @@ -47,3 +65,28 @@ def test_no_match(self, exact_match_test, no_match_dataset):

def test_half_match(self, exact_match_test, half_match_dataset):
assert self.run_test(exact_match_test, half_match_dataset) == 0.5


# LLM pipeline tests


class TestBasicLLM:
@pytest.fixture
def llm_prompt(self):
env = Environment()
template = env.from_string(
"""
You are a parrot that repeats whatever is said to you, with no explanation. You will be given a sentence as input, repeat it.
Sentence: {{input_sentence}}
"""
)
return template

@pytest.fixture
def llm_pipeline(self, llm_prompt):
return LLMPipeline(LLMModel["llama-3.1-8b"], llm_prompt)

def test_returns_string(self, llm_pipeline):
model_output = llm_pipeline.run({"input_sentence": "Polly wants a cracker"})
assert isinstance(model_output, str)

0 comments on commit 13269d6

Please sign in to comment.