Skip to content

Commit

Permalink
Update pipelines.py
Browse files Browse the repository at this point in the history
Documentation
  • Loading branch information
kuraisle committed Oct 18, 2024
1 parent a5b4a31 commit 2b4d1ee
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions Carrot-Assistant/evaluation/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,39 @@


class LLMPipeline(SingleResultPipeline):
"""
This class runs a simple LLM-only pipeline on provided input
"""

def __init__(self, llm: LLMModel, prompt_template: Template) -> None:
"""
Initialises the LLMPipeline class
Parameters
----------
llm: LLMModel
One of the model options in the LLMModel enum
prompt_template: Template
A jinja2 template for a prompt
"""
self.llm = llm
self.prompt_template = prompt_template
self._model = Llama(hf_hub_download(**local_models[self.llm.value]))

def run(self, input: Dict[str, str]) -> str:
"""
Runs the LLMPipeline on a given input
Parameters
----------
input: Dict[str, str]
The input is rendered into a prompt string by the .render method of the prompt template, so needs to be a dictionary of the template's parameters
Returns
-------
str
The output of running the prompt through the given model
"""
prompt = self.prompt_template.render(input)
return self._model.create_chat_completion(
messages=[{"role": "user", "content": prompt}]
Expand Down

0 comments on commit 2b4d1ee

Please sign in to comment.