Skip to content

Commit

Permalink
use custom prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
aaalexlit committed Jan 5, 2024
1 parent d799221 commit 75a2440
Showing 1 changed file with 64 additions and 15 deletions.
79 changes: 64 additions & 15 deletions slack_bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Pinecone
from llama_index import ServiceContext, VectorStoreIndex
from llama_index import ServiceContext, VectorStoreIndex, get_response_synthesizer, ChatPromptTemplate
from llama_index.callbacks import WandbCallbackHandler, CallbackManager, LlamaDebugHandler
from llama_index.indices.postprocessor import (
TimeWeightedPostprocessor
)
from llama_index.llms import ChatMessage, MessageRole
from llama_index.postprocessor import CohereRerank
from llama_index.query_engine import RouterQueryEngine, RetrieverQueryEngine
from llama_index.selectors.pydantic_selectors import PydanticMultiSelector
Expand Down Expand Up @@ -102,7 +103,7 @@ def handle_message_events(body):
client.chat_postMessage(channel=channel_id,
thread_ts=event_ts,
text=f"Here you go: \n{response} \n"
f"Sources:\n{sources}"
f"References:\n{sources}"
)
except Exception as e:
client.chat_postMessage(channel=channel_id,
Expand All @@ -125,8 +126,9 @@ def links_to_source_nodes(response):
elif 'source' in node.metadata:
title = node.metadata['title']
if title == 'FAQ':
section_title = node.text.split('\n', 1)[0]
res.add(f"<{node.metadata['source']}|"
f" {title}-{node.text[:35]}...> ")
f" {title}-{section_title}...> ")
else:
res.add(f"<{node.metadata['source']}| {title}")
elif 'repo' in node.metadata:
Expand Down Expand Up @@ -302,17 +304,54 @@ def get_ml_router_query_engine():
)


def get_de_retriever_query_engine():
callback_manager = init_llama_index_callback_manager(DE_ZOOMCAMP_PROJECT_NAME)
return get_retriever_query_engine(callback_manager, DE_FAQ_COLLECTION_NAME)


def get_ml_retriever_query_engine():
def get_prompt_template(zoomcamp_name: str, cohort_year: int) -> ChatPromptTemplate:
system_prompt = ChatMessage(
content=(
f"""You are a helpful AI assistant for the {zoomcamp_name} ZoomCamp course at DataTalksClub,
and you can be found in the course's Slack channel.
As a trustworthy assistant, you must provide helpful answers to students' questions about the course,
and assist them in finding solutions when they encounter errors while following the course.
You must do it using only the excerpts from the course FAQ document, Slack threads, and GitHub repository
in the provided context, without relying on prior knowledge.
Current cohort is year {cohort_year} one.
Here are your guidelines:
1. Provide clear and concise explanations for your conclusions, including relevant code snippets
if the question pertains to code. Include citations only if strictly necessary.
2. If the question already contains an answer and requests confirmation, avoid repetition. Instead,
disregard the answer and conduct your analysis based on the provided context.
3. In your response, refrain from rephrasing the user's question or problem; simply provide an answer.
4. Make sure that the code examples you provide are accurate and runnable.
5. In cases where the provided context is insufficient and you are uncertain about the response, reply with:
'I don't think I have an answer for this; you'll have to ask your fellows or instructors.
6. All the hyperlinks need to be taken from the provided context, not from the prior knowledge
7. The hyperlinks need to be formatted the following way: <hyperlink|displayed text>
Example of the correctly formatted link to github:
<https://github.com/DataTalksClub/data-engineering-zoomcamp|DE zoomcamp GitHub repo>
"""
),
role=MessageRole.SYSTEM,
)
user_prompt = ChatMessage(content=("Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the query.\n"
"Query: {query_str}\n"
"Answer: "),
role=MessageRole.USER, )
return ChatPromptTemplate(message_templates=[
system_prompt,
user_prompt,
])


def get_retriever_query_engine(collection_name: str,
zoomcamp_name: str,
cohort_year: int, ):
callback_manager = init_llama_index_callback_manager(ML_ZOOMCAMP_PROJECT_NAME)
return get_retriever_query_engine(callback_manager, ML_FAQ_COLLECTION_NAME)


def get_retriever_query_engine(callback_manager, collection_name):
service_context = ServiceContext.from_defaults(embed_model=embeddings,
callback_manager=callback_manager,
llm=ChatOpenAI(model=GPT_MODEL_NAME,
Expand All @@ -334,8 +373,15 @@ def get_retriever_query_engine(callback_manager, collection_name):
cohere_rerank = CohereRerank(api_key=os.getenv('COHERE_API_KEY'), top_n=4)
recency_postprocessor = get_time_weighted_postprocessor()
node_postprocessors = [recency_postprocessor, cohere_rerank]
qa_prompt_template = get_prompt_template(zoomcamp_name=zoomcamp_name,
cohort_year=cohort_year, )
response_synthesizer = get_response_synthesizer(service_context=service_context,
text_qa_template=qa_prompt_template,
verbose=True,
)
return RetrieverQueryEngine(vector_store_index.as_retriever(similarity_top_k=10),
node_postprocessors=node_postprocessors,
response_synthesizer=response_synthesizer,
callback_manager=callback_manager)


Expand Down Expand Up @@ -370,8 +416,11 @@ def get_time_weighted_postprocessor():
retriever=mlops_index.as_retriever()
)

ml_query_engine = get_ml_retriever_query_engine()

de_query_engine = get_de_retriever_query_engine()
ml_query_engine = get_retriever_query_engine(collection_name=ML_FAQ_COLLECTION_NAME,
zoomcamp_name='Machine Learning',
cohort_year=2023)

de_query_engine = get_retriever_query_engine(collection_name=DE_FAQ_COLLECTION_NAME,
zoomcamp_name='Data Engineering',
cohort_year=2024)
SocketModeHandler(app, SLACK_APP_TOKEN).start()

0 comments on commit 75a2440

Please sign in to comment.