Skip to content

Commit

Permalink
Merge branch 'main' into refactor/apirouter
Browse files Browse the repository at this point in the history
  • Loading branch information
kuraisle authored Sep 5, 2024
2 parents 35925a2 + fa68a12 commit acf8757
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 33 deletions.
Binary file removed .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.pyc
/*/__pycache__/
.env
.DS_Store
.vscode/settings.json
AI Assistant/tmp.ipynb
AI Assistant/.vscode/settings.json
Expand Down
Binary file removed Carrot-Assistant/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions Carrot-Assistant/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
prefix="/pipeline",
)


if __name__ == "__main__":
import uvicorn

Expand Down
47 changes: 47 additions & 0 deletions Carrot-Assistant/components/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,23 @@
from omop.omop_models import Concept

class EmbeddingModelName(str, Enum):
"""
This enumerates the embedding models we have the download details for
"""
BGESMALL = "BGESMALL"
MINILM = "MINILM"

class EmbeddingModelInfo(BaseModel):
"""
A simple class to hold the information for embeddings models
"""
path: str
dimensions: int

class EmbeddingModel(BaseModel):
"""
A class to match the name of an embeddings model with the details required to download and use it.
"""
name: EmbeddingModelName
info: EmbeddingModelInfo

Expand All @@ -32,6 +41,20 @@ class EmbeddingModel(BaseModel):
}

def get_embedding_model(name: EmbeddingModelName) -> EmbeddingModel:
"""
Collects the details of an embedding model when given its name
Parameters
----------
name: EmbeddingModelName
The name of an embedding model we have the details for
Returns
-------
EmbeddingModel
An EmbeddingModel object containing the name and the details used
"""
return EmbeddingModel(name=name, info=EMBEDDING_MODELS[name])

class Embeddings:
Expand Down Expand Up @@ -130,6 +153,30 @@ def _load_embeddings(self):
recreate_index=False # We're loading existing embeddings, don't recreate
)

def get_embedder(self) -> FastembedTextEmbedder:
"""
Get an embedder for queries in LLM pipelines
Returns
_______
FastembedTextEmbedder
"""
query_embedder = FastembedTextEmbedder(model=self.model.info.path, parallel=0)
query_embedder.warm_up()
return query_embedder


def get_retriever(self) -> QdrantEmbeddingRetriever:
"""
Get a retriever for LLM pipelines
Returns
-------
QdrantEmbeddingRetriever
"""
print(self.search_kwargs)
return QdrantEmbeddingRetriever(document_store=self.embeddings_store, **self.search_kwargs)

def search(self, query: List[str]) -> List[List[Dict[str,Any]]]:
"""
Search the attached vector database with a list of informal medications
Expand Down
11 changes: 4 additions & 7 deletions Carrot-Assistant/components/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from haystack.components.generators import OpenAIGenerator
from haystack_integrations.components.generators.llama_cpp import LlamaCppGenerator
import torch

local_models = {
Expand Down Expand Up @@ -27,7 +28,7 @@

def get_model(
model_name: str, temperature: float = 0.7, logger: logging.Logger | None = None
) -> object:
) -> OpenAIGenerator | LlamaCppGenerator:
"""
Get an interface for interacting with an LLM
Expand All @@ -50,17 +51,13 @@ def get_model(
"""
if "gpt" in model_name.lower():
logger.info(f"Loading {model_name} model")
from haystack.components.generators import OpenAIGenerator

llm = OpenAIGenerator(
model=model_name, generation_kwargs={"temperature": temperature}
)

else:
logger.info(f"Loading {model_name} model")
from haystack_integrations.components.generators.llama_cpp import (
LlamaCppGenerator,
)
from huggingface_hub import hf_hub_download

device = -1 if torch.cuda.is_available() else 0
Expand All @@ -78,7 +75,7 @@ def get_model(
except:
print(f"Error loading {model_name}")
finally:
logger.info(f"Loading llama-3.1-8b")
logger.info("Loading llama-3.1-8b")
llm = LlamaCppGenerator(
model=hf_hub_download(**local_models[model_name]),
n_ctx=0,
Expand Down
59 changes: 59 additions & 0 deletions Carrot-Assistant/components/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import argparse
import logging
import time
from typing import List, Dict

from haystack import Pipeline
from haystack.components.routers import ConditionalRouter

from components.embeddings import Embeddings
from components.models import get_model
from components.prompt import Prompts

Expand Down Expand Up @@ -59,6 +62,62 @@ def get_simple_assistant(self) -> Pipeline:

pipeline.connect("prompt.prompt", "llm.prompt")
self._logger.info(f"Pipeline connected in {time.time()-start} seconds")

return pipeline
def get_rag_assistant(self) -> Pipeline:
"""
Get an assistant that uses vector search to populate a prompt for an LLM
Returns
-------
Pipeline
The pipeline for the assistant
"""
start = time.time()
pipeline = Pipeline()
self._logger.info(f"Pipeline initialized in {time.time()-start} seconds")
start = time.time()


vec_search = Embeddings(
embeddings_path=self._opt.embeddings_path,
force_rebuild=self._opt.force_rebuild,
embed_vocab=self._opt.embed_vocab,
model_name=self._opt.embedding_model,
search_kwargs=self._opt.embedding_search_kwargs
)

vec_embedder = vec_search.get_embedder()
vec_retriever = vec_search.get_retriever()
router = ConditionalRouter(routes=[
{
"condition": "{{vec_results[0].score > 0.95}}",
"output": "{{vec_results}}",
"output_name": "exact_match",
"output_type": List[Dict],
},
{
"condition": "{{vec_results[0].score <=0.95}}",
"output": "{{vec_results}}",
"output_name": "no_exact_match",
"output_type": List[Dict]
}
])
llm = get_model(
model_name=self._model_name,
temperature=self._opt.temperature,
logger=self._logger,
)

pipeline.add_component("query_embedder", vec_embedder)
pipeline.add_component("retriever", vec_retriever)
pipeline.add_component("router", router)
pipeline.add_component("prompt", Prompts(self._model_name, "top_n_RAG").get_prompt())
pipeline.add_component("llm", llm)

pipeline.connect("query_embedder.embedding", "retriever.query_embedding")
pipeline.connect("retriever.documents", "router.vec_results")
pipeline.connect("router.no_exact_match", "prompt.vec_results")
pipeline.connect("prompt.prompt", "llm.prompt")

return pipeline
73 changes: 47 additions & 26 deletions Carrot-Assistant/components/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Prompts:
def __init__(
self,
model_name: str,
prompt_type: str | None = "simple",
prompt_type: str = "simple",
) -> None:
"""
Initializes the Prompts class
Expand All @@ -23,34 +23,35 @@ def __init__(
"""
self._model_name = model_name
self._prompt_type = prompt_type
# I hate how the triple-quoted strings look, but if you indent them they preserve the indentation. You can use textwrap.dedent to solve it, but that's not pleasing either.
# modify this so it only adds the EOT token for llama 3.1
self._prompt_templates = {
"simple": """
You will be given the informal name of a medication. Respond only with the formal name of that medication, without any extra explanation.
def get_prompt(self) -> PromptBuilder | None:
"""
Get the prompt based on the prompt_type supplied to the object.
Examples:
Returns
-------
PromptBuilder
The prompt for the model
Informal name: Tylenol
Response: Acetaminophen
- If the _prompt_type of the object is "simple", returns a simple prompt for few-shot learning of formal drug names.
"""
if self._prompt_type == "simple":
return self._simple_prompt()
Informal name: Advil
Response: Ibuprofen
def _simple_prompt(self) -> PromptBuilder:
"""
Get a simple prompt
Informal name: Motrin
Response: Ibuprofen
Returns
-------
PromptBuilder
The simple prompt
"""
prompt_template = """
You will be given the informal name of a medication. Respond only with the formal name of that medication, without any extra explanation.
Informal name: Aleve
Response: Naproxen
Examples:
Task:
Informal name: {{informal_name}}<|eot_id|>
Response:
""",
"top_n_RAG": """
You are an assistant that suggests formal RxNorm names for a medication. You will be given the name of a medication, along with some possibly related RxNorm terms. If you do not think these terms are related, ignore them when making your suggestion.
Respond only with the formal name of the medication, without any extra explanation.
Informal name: Tylenol
Response: Acetaminophen
Expand All @@ -64,10 +65,30 @@ def _simple_prompt(self) -> PromptBuilder:
Informal name: Aleve
Response: Naproxen
Task:
Possible related terms:
{% for result in vec_results %}
{{result.content}}
{% endfor %}
Task:
Informal name: {{informal_name}}<|eot_id|>
Response:
"""
""",
}

return PromptBuilder(template=prompt_template)
def get_prompt(self) -> PromptBuilder:
"""
Get the prompt based on the prompt_type supplied to the object.
Returns
-------
PromptBuilder
The prompt for the model
- If the _prompt_type of the object is "simple", returns a simple prompt for few-shot learning of formal drug names.
"""
try:
return PromptBuilder(self._prompt_templates[self._prompt_type])
except KeyError:
print(f"No prompt named {self._prompt_type}")

Binary file removed llettuce-docs/.DS_Store
Binary file not shown.
Binary file removed llettuce-docs/_build/.DS_Store
Binary file not shown.

0 comments on commit acf8757

Please sign in to comment.