From 7ab59c4683ae2b24ede8927db581f9c6860e0693 Mon Sep 17 00:00:00 2001 From: James Mitchell-White Date: Mon, 16 Sep 2024 12:28:32 +0100 Subject: [PATCH 1/4] eot_id only added if llama-3.1 Added the logic so that the end of turn token is only added at the end of the prompt if the model is llama-3.1. Other models don't like it --- Carrot-Assistant/components/__init__.py | 0 Carrot-Assistant/components/prompt.py | 14 +++++++++----- Carrot-Assistant/tests/test_prompt_build.py | 20 ++++++++++++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 Carrot-Assistant/components/__init__.py create mode 100644 Carrot-Assistant/tests/test_prompt_build.py diff --git a/Carrot-Assistant/components/__init__.py b/Carrot-Assistant/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Carrot-Assistant/components/prompt.py b/Carrot-Assistant/components/prompt.py index ff410d1..0c5d766 100644 --- a/Carrot-Assistant/components/prompt.py +++ b/Carrot-Assistant/components/prompt.py @@ -45,8 +45,7 @@ def __init__( Task: -Informal name: {{informal_name}}<|eot_id|> -Response: +Informal name: {{informal_name}} """, "top_n_RAG": """ You are an assistant that suggests formal RxNorm names for a medication. You will be given the name of a medication, along with some possibly related RxNorm terms. If you do not think these terms are related, ignore them when making your suggestion. @@ -71,8 +70,7 @@ def __init__( {% endfor %} Task: -Informal name: {{informal_name}}<|eot_id|> -Response: +Informal name: {{informal_name}} """, } @@ -88,7 +86,13 @@ def get_prompt(self) -> PromptBuilder: - If the _prompt_type of the object is "simple", returns a simple prompt for few-shot learning of formal drug names. """ try: - return PromptBuilder(self._prompt_templates[self._prompt_type]) + template = self._prompt_templates[self._prompt_type] + if "llama-3.1" in self._model_name: + template += """<|eot_id|> + Response:""" + else: + template += "\nResponse:" + return PromptBuilder(template) except KeyError: print(f"No prompt named {self._prompt_type}") diff --git a/Carrot-Assistant/tests/test_prompt_build.py b/Carrot-Assistant/tests/test_prompt_build.py new file mode 100644 index 0000000..fe70e1a --- /dev/null +++ b/Carrot-Assistant/tests/test_prompt_build.py @@ -0,0 +1,20 @@ +import pytest +from components.prompt import Prompts + +@pytest.fixture +def llama_3_simple_prompt_builder(): + return Prompts( + model_name="llama-3-8b", + prompt_type="simple", + ) + +@pytest.fixture +def llama_3_rag_prompt_builder(): + return Prompts( + model_name="llama-3-8b", + prompt_type="top_n_RAG", + ) + +@pytest.fixture +def llama_3_1_simple_prompt_builder(): + return Prompts From 659e77b9d7f610b8ee7c538ebab218cad01c51d7 Mon Sep 17 00:00:00 2001 From: James Mitchell-White Date: Mon, 16 Sep 2024 16:14:19 +0100 Subject: [PATCH 2/4] Update test_prompt_build.py Added test fixtures for prompt builder and a simple test --- Carrot-Assistant/tests/test_prompt_build.py | 24 ++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/Carrot-Assistant/tests/test_prompt_build.py b/Carrot-Assistant/tests/test_prompt_build.py index fe70e1a..980d786 100644 --- a/Carrot-Assistant/tests/test_prompt_build.py +++ b/Carrot-Assistant/tests/test_prompt_build.py @@ -6,15 +6,33 @@ def llama_3_simple_prompt_builder(): return Prompts( model_name="llama-3-8b", prompt_type="simple", - ) + ).get_prompt() @pytest.fixture def llama_3_rag_prompt_builder(): return Prompts( model_name="llama-3-8b", prompt_type="top_n_RAG", - ) + ).get_prompt() @pytest.fixture def llama_3_1_simple_prompt_builder(): - return Prompts + return Prompts( + model_name="llama-3.1-8b", + prompt_type="simple", + eot_token="<|eot_id|>" + ).get_prompt() + +@pytest.fixture +def llama_3_1_rag_prompt_builder(): + return Prompts( + model_name="llama-3.1-8b", + prompt_type="top_n_RAG", + eot_token="<|eot_id|>" + ).get_prompt() + +def test_simple_prompt_returned(llama_3_simple_prompt_builder): + assert "banana" in llama_3_simple_prompt_builder.run(informal_name="banana")["prompt"] + +def test_rag_prompt_returned(llama_3_rag_prompt_builder): + From ad32ad02acd43f76d681a56f41883e2b4bd79d1f Mon Sep 17 00:00:00 2001 From: James Mitchell-White Date: Tue, 17 Sep 2024 10:56:04 +0100 Subject: [PATCH 3/4] Added tests Added tests for the version of the prompt builder that adds an end of turn token --- Carrot-Assistant/components/prompt.py | 10 ++++----- Carrot-Assistant/tests/test_prompt_build.py | 23 +++++++++++++++++++-- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/Carrot-Assistant/components/prompt.py b/Carrot-Assistant/components/prompt.py index 0c5d766..1c0637a 100644 --- a/Carrot-Assistant/components/prompt.py +++ b/Carrot-Assistant/components/prompt.py @@ -9,7 +9,8 @@ class Prompts: def __init__( self, model_name: str, - prompt_type: str = "simple", + prompt_type: str="simple", + eot_token: str="" ) -> None: """ Initializes the Prompts class @@ -23,6 +24,7 @@ def __init__( """ self._model_name = model_name self._prompt_type = prompt_type + self._eot_token = eot_token # I hate how the triple-quoted strings look, but if you indent them they preserve the indentation. You can use textwrap.dedent to solve it, but that's not pleasing either. # modify this so it only adds the EOT token for llama 3.1 self._prompt_templates = { @@ -87,11 +89,7 @@ def get_prompt(self) -> PromptBuilder: """ try: template = self._prompt_templates[self._prompt_type] - if "llama-3.1" in self._model_name: - template += """<|eot_id|> - Response:""" - else: - template += "\nResponse:" + template += self._eot_token + "\nResponse:" return PromptBuilder(template) except KeyError: print(f"No prompt named {self._prompt_type}") diff --git a/Carrot-Assistant/tests/test_prompt_build.py b/Carrot-Assistant/tests/test_prompt_build.py index 980d786..a1c1952 100644 --- a/Carrot-Assistant/tests/test_prompt_build.py +++ b/Carrot-Assistant/tests/test_prompt_build.py @@ -31,8 +31,27 @@ def llama_3_1_rag_prompt_builder(): eot_token="<|eot_id|>" ).get_prompt() +@pytest.fixture +def mock_rag_results(): + return [ + {'content': 'apple'} + ] + def test_simple_prompt_returned(llama_3_simple_prompt_builder): assert "banana" in llama_3_simple_prompt_builder.run(informal_name="banana")["prompt"] -def test_rag_prompt_returned(llama_3_rag_prompt_builder): - +def test_rag_prompt_returned(llama_3_rag_prompt_builder, mock_rag_results): + result = llama_3_rag_prompt_builder.run(informal_name="banana", vec_results=mock_rag_results)["prompt"] + assert "banana" in result + assert "apple" in result + +def test_simple_prompt_with_eot(llama_3_1_simple_prompt_builder): + result = llama_3_1_simple_prompt_builder.run(informal_name="banana")["prompt"] + assert "banana" in result + assert "<|eot_id|>" in result + +def test_rag_prompt_with_eot(llama_3_1_rag_prompt_builder, mock_rag_results): + result = llama_3_1_rag_prompt_builder.run(informal_name="banana", vec_results=mock_rag_results)["prompt"] + assert "banana" in result + assert "apple" in result + assert "<|eot_id|>" in result From ada1778163d8addd72306c2230cf122f794a43f4 Mon Sep 17 00:00:00 2001 From: James Mitchell-White Date: Tue, 17 Sep 2024 11:02:26 +0100 Subject: [PATCH 4/4] Update pipeline.py Added the eot_token option to the pipeline. If we use llama-3.1 derivatives that don't have llama-3.1 in the name we will need a more robust solution, but this will do for now --- Carrot-Assistant/components/pipeline.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/Carrot-Assistant/components/pipeline.py b/Carrot-Assistant/components/pipeline.py index 6073e18..29e8249 100644 --- a/Carrot-Assistant/components/pipeline.py +++ b/Carrot-Assistant/components/pipeline.py @@ -9,6 +9,7 @@ from components.embeddings import Embeddings from components.models import get_model from components.prompt import Prompts +from tests.test_prompt_build import mock_rag_results class llm_pipeline: @@ -32,6 +33,10 @@ def __init__( self._opt = opt self._model_name = opt.llm_model self._logger = logger + if "llama-3.1" in opt.llm_model: + self._eot_token = "<|eot_id|>" + else: + self._eot_token = "" def get_simple_assistant(self) -> Pipeline: """ @@ -47,7 +52,10 @@ def get_simple_assistant(self) -> Pipeline: self._logger.info(f"Pipeline initialized in {time.time()-start} seconds") start = time.time() - pipeline.add_component("prompt", Prompts(self._model_name).get_prompt()) + pipeline.add_component("prompt", Prompts( + model_name=self._model_name, + eot_token=self._eot_token + ).get_prompt()) self._logger.info(f"Prompt added to pipeline in {time.time()-start} seconds") start = time.time() @@ -112,7 +120,11 @@ def get_rag_assistant(self) -> Pipeline: pipeline.add_component("query_embedder", vec_embedder) pipeline.add_component("retriever", vec_retriever) pipeline.add_component("router", router) - pipeline.add_component("prompt", Prompts(self._model_name, "top_n_RAG").get_prompt()) + pipeline.add_component("prompt", Prompts( + model_name=self._model_name, + prompt_type="top_n_RAG", + eot_token=self._eot_token + ).get_prompt()) pipeline.add_component("llm", llm) pipeline.connect("query_embedder.embedding", "retriever.query_embedding")