Skip to content

Commit

Permalink
Merge pull request #47 from Health-Informatics-UoN/bugfix/eot_token
Browse files Browse the repository at this point in the history
Bugfix/eot token
  • Loading branch information
Karthi-DStech authored Sep 18, 2024
2 parents cc9d167 + ada1778 commit 289c57b
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
Empty file.
16 changes: 14 additions & 2 deletions Carrot-Assistant/components/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from components.embeddings import Embeddings
from components.models import get_model
from components.prompt import Prompts
from tests.test_prompt_build import mock_rag_results


class llm_pipeline:
Expand All @@ -32,6 +33,10 @@ def __init__(
self._opt = opt
self._model_name = opt.llm_model
self._logger = logger
if "llama-3.1" in opt.llm_model:
self._eot_token = "<|eot_id|>"
else:
self._eot_token = ""

def get_simple_assistant(self) -> Pipeline:
"""
Expand All @@ -47,7 +52,10 @@ def get_simple_assistant(self) -> Pipeline:
self._logger.info(f"Pipeline initialized in {time.time()-start} seconds")
start = time.time()

pipeline.add_component("prompt", Prompts(self._model_name).get_prompt())
pipeline.add_component("prompt", Prompts(
model_name=self._model_name,
eot_token=self._eot_token
).get_prompt())
self._logger.info(f"Prompt added to pipeline in {time.time()-start} seconds")
start = time.time()

Expand Down Expand Up @@ -112,7 +120,11 @@ def get_rag_assistant(self) -> Pipeline:
pipeline.add_component("query_embedder", vec_embedder)
pipeline.add_component("retriever", vec_retriever)
pipeline.add_component("router", router)
pipeline.add_component("prompt", Prompts(self._model_name, "top_n_RAG").get_prompt())
pipeline.add_component("prompt", Prompts(
model_name=self._model_name,
prompt_type="top_n_RAG",
eot_token=self._eot_token
).get_prompt())
pipeline.add_component("llm", llm)

pipeline.connect("query_embedder.embedding", "retriever.query_embedding")
Expand Down
14 changes: 8 additions & 6 deletions Carrot-Assistant/components/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class Prompts:
def __init__(
self,
model_name: str,
prompt_type: str = "simple",
prompt_type: str="simple",
eot_token: str=""
) -> None:
"""
Initializes the Prompts class
Expand All @@ -23,6 +24,7 @@ def __init__(
"""
self._model_name = model_name
self._prompt_type = prompt_type
self._eot_token = eot_token
# I hate how the triple-quoted strings look, but if you indent them they preserve the indentation. You can use textwrap.dedent to solve it, but that's not pleasing either.
# modify this so it only adds the EOT token for llama 3.1
self._prompt_templates = {
Expand All @@ -45,8 +47,7 @@ def __init__(
Task:
Informal name: {{informal_name}}<|eot_id|>
Response:
Informal name: {{informal_name}}
""",
"top_n_RAG": """
You are an assistant that suggests formal RxNorm names for a medication. You will be given the name of a medication, along with some possibly related RxNorm terms. If you do not think these terms are related, ignore them when making your suggestion.
Expand All @@ -71,8 +72,7 @@ def __init__(
{% endfor %}
Task:
Informal name: {{informal_name}}<|eot_id|>
Response:
Informal name: {{informal_name}}
""",
}

Expand All @@ -88,7 +88,9 @@ def get_prompt(self) -> PromptBuilder:
- If the _prompt_type of the object is "simple", returns a simple prompt for few-shot learning of formal drug names.
"""
try:
return PromptBuilder(self._prompt_templates[self._prompt_type])
template = self._prompt_templates[self._prompt_type]
template += self._eot_token + "\nResponse:"
return PromptBuilder(template)
except KeyError:
print(f"No prompt named {self._prompt_type}")

57 changes: 57 additions & 0 deletions Carrot-Assistant/tests/test_prompt_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from components.prompt import Prompts

@pytest.fixture
def llama_3_simple_prompt_builder():
return Prompts(
model_name="llama-3-8b",
prompt_type="simple",
).get_prompt()

@pytest.fixture
def llama_3_rag_prompt_builder():
return Prompts(
model_name="llama-3-8b",
prompt_type="top_n_RAG",
).get_prompt()

@pytest.fixture
def llama_3_1_simple_prompt_builder():
return Prompts(
model_name="llama-3.1-8b",
prompt_type="simple",
eot_token="<|eot_id|>"
).get_prompt()

@pytest.fixture
def llama_3_1_rag_prompt_builder():
return Prompts(
model_name="llama-3.1-8b",
prompt_type="top_n_RAG",
eot_token="<|eot_id|>"
).get_prompt()

@pytest.fixture
def mock_rag_results():
return [
{'content': 'apple'}
]

def test_simple_prompt_returned(llama_3_simple_prompt_builder):
assert "banana" in llama_3_simple_prompt_builder.run(informal_name="banana")["prompt"]

def test_rag_prompt_returned(llama_3_rag_prompt_builder, mock_rag_results):
result = llama_3_rag_prompt_builder.run(informal_name="banana", vec_results=mock_rag_results)["prompt"]
assert "banana" in result
assert "apple" in result

def test_simple_prompt_with_eot(llama_3_1_simple_prompt_builder):
result = llama_3_1_simple_prompt_builder.run(informal_name="banana")["prompt"]
assert "banana" in result
assert "<|eot_id|>" in result

def test_rag_prompt_with_eot(llama_3_1_rag_prompt_builder, mock_rag_results):
result = llama_3_1_rag_prompt_builder.run(informal_name="banana", vec_results=mock_rag_results)["prompt"]
assert "banana" in result
assert "apple" in result
assert "<|eot_id|>" in result

0 comments on commit 289c57b

Please sign in to comment.