Skip to content

Commit

Permalink
Merge pull request #66 from Health-Informatics-UoN/feature/LLMPipelin…
Browse files Browse the repository at this point in the history
…e-class

Feature/llm pipeline class
  • Loading branch information
Karthi-DStech authored Oct 24, 2024
2 parents 9e89a90 + 2b4d1ee commit 15d9a35
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 68 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ RAG/tmp.py
Carrot-Assistant/omop_tmp.py
RAG/.cache/
*.qdrant
/Carrot-Assistant/tests/log
74 changes: 40 additions & 34 deletions Carrot-Assistant/components/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from components.embeddings import Embeddings
from components.models import get_model
from components.prompt import Prompts
from tests.test_prompt_build import mock_rag_results


class llm_pipeline:
Expand Down Expand Up @@ -52,10 +51,12 @@ def get_simple_assistant(self) -> Pipeline:
self._logger.info(f"Pipeline initialized in {time.time()-start} seconds")
start = time.time()

pipeline.add_component("prompt", Prompts(
model_name=self._model_name,
eot_token=self._eot_token
).get_prompt())
pipeline.add_component(
"prompt",
Prompts(
model_name=self._model_name, eot_token=self._eot_token
).get_prompt(),
)
self._logger.info(f"Prompt added to pipeline in {time.time()-start} seconds")
start = time.time()

Expand All @@ -72,6 +73,7 @@ def get_simple_assistant(self) -> Pipeline:
self._logger.info(f"Pipeline connected in {time.time()-start} seconds")

return pipeline

def get_rag_assistant(self) -> Pipeline:
"""
Get an assistant that uses vector search to populate a prompt for an LLM
Expand All @@ -85,46 +87,50 @@ def get_rag_assistant(self) -> Pipeline:
pipeline = Pipeline()
self._logger.info(f"Pipeline initialized in {time.time()-start} seconds")
start = time.time()



vec_search = Embeddings(
embeddings_path=self._opt.embeddings_path,
force_rebuild=self._opt.force_rebuild,
embed_vocab=self._opt.embed_vocab,
model_name=self._opt.embedding_model,
search_kwargs=self._opt.embedding_search_kwargs
)
embeddings_path=self._opt.embeddings_path,
force_rebuild=self._opt.force_rebuild,
embed_vocab=self._opt.embed_vocab,
model_name=self._opt.embedding_model,
search_kwargs=self._opt.embedding_search_kwargs,
)

vec_embedder = vec_search.get_embedder()
vec_retriever = vec_search.get_retriever()
router = ConditionalRouter(routes=[
{
"condition": "{{vec_results[0].score > 0.95}}",
"output": "{{vec_results}}",
"output_name": "exact_match",
"output_type": List[Dict],
},
{
"condition": "{{vec_results[0].score <=0.95}}",
"output": "{{vec_results}}",
"output_name": "no_exact_match",
"output_type": List[Dict]
}
])
router = ConditionalRouter(
routes=[
{
"condition": "{{vec_results[0].score > 0.95}}",
"output": "{{vec_results}}",
"output_name": "exact_match",
"output_type": List[Dict],
},
{
"condition": "{{vec_results[0].score <=0.95}}",
"output": "{{vec_results}}",
"output_name": "no_exact_match",
"output_type": List[Dict],
},
]
)
llm = get_model(
model_name=self._model_name,
temperature=self._opt.temperature,
logger=self._logger,
)

pipeline.add_component("query_embedder", vec_embedder)
pipeline.add_component("retriever", vec_retriever)
pipeline.add_component("router", router)
pipeline.add_component("prompt", Prompts(
model_name=self._model_name,
prompt_type="top_n_RAG",
eot_token=self._eot_token
).get_prompt())
pipeline.add_component(
"prompt",
Prompts(
model_name=self._model_name,
prompt_type="top_n_RAG",
eot_token=self._eot_token,
).get_prompt(),
)
pipeline.add_component("llm", llm)

pipeline.connect("query_embedder.embedding", "retriever.query_embedding")
Expand Down
80 changes: 53 additions & 27 deletions Carrot-Assistant/evaluation/evaltypes.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,119 @@
from abc import ABC, abstractmethod
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Any


class EvaluationFramework:
def __init__(self, results_file='results.json'):
def __init__(self, results_file="results.json"):
self.results_file = results_file

def run_evaluations(self):
# Run some tests
self._save_evaluations

def _save_evaluations(self):
# Append to 'results.json'
# Append to 'results.json'
pass


class Metric(ABC):
"""Base class for all metrics."""

@abstractmethod
def calculate(self, *args, **kwargs) -> float:
"""
Calculate the metric value.
"""
pass


class TestPipeline(ABC):
"""
Base class for Pipeline runs
"""
@abstractmethod
def run(self, *args, **kwargs):
"""
Run the pipeline
"""
pass

M = TypeVar('M', bound=Metric)
"""
Base class for Pipeline runs
"""

@abstractmethod
def run(self, *args, **kwargs) -> Any:
"""
Run the pipeline
"""
...


M = TypeVar("M", bound=Metric)


class PipelineTest(Generic[M]):
"""
Base class for Pipeline tests
"""

def __init__(self, name: str, pipeline: TestPipeline, metrics: list[M]):
self.name = name
self.pipeline = pipeline
self.metrics = metrics

@abstractmethod
def run_pipeline(self, *args, **kwargs):
pass

@abstractmethod
def evaluate(self, *args, **kwargs) -> dict[str, float]:
pass
pass


class SingleResultMetric(Metric):
"""Metric for evaluating pipelines that return a single result."""


class InformationRetrievalMetric(Metric):
"""Metric for evaluating information retrieval pipelines."""

pass


class SingleResultPipeline(TestPipeline):
"""
Base class for pipelines returning a single result
"""
"""
Base class for pipelines returning a single result
"""


class SingleResultPipelineTest(PipelineTest[SingleResultMetric]):
def __init__(self, name: str, pipeline: SingleResultPipeline, metrics: list[SingleResultMetric]):
def __init__(
self,
name: str,
pipeline: SingleResultPipeline,
metrics: list[SingleResultMetric],
):
super().__init__(name, pipeline, metrics)

def run_pipeline(self, input_data):
"""
Run the pipeline with the given input data.
Args:
input_data: The input data for the pipeline.
Returns:
The result of running the pipeline on the input data.
"""
return self.pipeline.run(input_data)

def evaluate(self, input_data, expected_output):
"""
Evaluate the pipeline by running it on the input data and comparing the result
to the expected output using all metrics.
Args:
input_data: The input data for the pipeline.
expected_output: The expected output to compare against.
Returns:
A dictionary mapping metric names to their calculated values.
"""
pipeline_output = self.run_pipeline(input_data)
return {metric.__class__.__name__: metric.calculate(pipeline_output, expected_output)
for metric in self.metrics}
return {
metric.__class__.__name__: metric.calculate(
pipeline_output, expected_output
)
for metric in self.metrics
}
47 changes: 47 additions & 0 deletions Carrot-Assistant/evaluation/pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict
from evaluation.evaltypes import SingleResultPipeline
from options.pipeline_options import LLMModel
from components.models import local_models
from jinja2 import Template
from llama_cpp import Llama
from huggingface_hub import hf_hub_download


class LLMPipeline(SingleResultPipeline):
"""
This class runs a simple LLM-only pipeline on provided input
"""

def __init__(self, llm: LLMModel, prompt_template: Template) -> None:
"""
Initialises the LLMPipeline class
Parameters
----------
llm: LLMModel
One of the model options in the LLMModel enum
prompt_template: Template
A jinja2 template for a prompt
"""
self.llm = llm
self.prompt_template = prompt_template
self._model = Llama(hf_hub_download(**local_models[self.llm.value]))

def run(self, input: Dict[str, str]) -> str:
"""
Runs the LLMPipeline on a given input
Parameters
----------
input: Dict[str, str]
The input is rendered into a prompt string by the .render method of the prompt template, so needs to be a dictionary of the template's parameters
Returns
-------
str
The output of running the prompt through the given model
"""
prompt = self.prompt_template.render(input)
return self._model.create_chat_completion(
messages=[{"role": "user", "content": prompt}]
)["choices"][0]["message"]["content"]
Loading

0 comments on commit 15d9a35

Please sign in to comment.