From 75a24400db3ee4bb885ad323220cd21db4243884 Mon Sep 17 00:00:00 2001 From: Alex Litvinov Date: Fri, 5 Jan 2024 22:57:42 +0100 Subject: [PATCH] use custom prompt --- slack_bot/main.py | 79 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 15 deletions(-) diff --git a/slack_bot/main.py b/slack_bot/main.py index 1dcd284..930d6f3 100644 --- a/slack_bot/main.py +++ b/slack_bot/main.py @@ -8,11 +8,12 @@ from langchain.chat_models import ChatOpenAI from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Pinecone -from llama_index import ServiceContext, VectorStoreIndex +from llama_index import ServiceContext, VectorStoreIndex, get_response_synthesizer, ChatPromptTemplate from llama_index.callbacks import WandbCallbackHandler, CallbackManager, LlamaDebugHandler from llama_index.indices.postprocessor import ( TimeWeightedPostprocessor ) +from llama_index.llms import ChatMessage, MessageRole from llama_index.postprocessor import CohereRerank from llama_index.query_engine import RouterQueryEngine, RetrieverQueryEngine from llama_index.selectors.pydantic_selectors import PydanticMultiSelector @@ -102,7 +103,7 @@ def handle_message_events(body): client.chat_postMessage(channel=channel_id, thread_ts=event_ts, text=f"Here you go: \n{response} \n" - f"Sources:\n{sources}" + f"References:\n{sources}" ) except Exception as e: client.chat_postMessage(channel=channel_id, @@ -125,8 +126,9 @@ def links_to_source_nodes(response): elif 'source' in node.metadata: title = node.metadata['title'] if title == 'FAQ': + section_title = node.text.split('\n', 1)[0] res.add(f"<{node.metadata['source']}|" - f" {title}-{node.text[:35]}...> ") + f" {title}-{section_title}...> ") else: res.add(f"<{node.metadata['source']}| {title}") elif 'repo' in node.metadata: @@ -302,17 +304,54 @@ def get_ml_router_query_engine(): ) -def get_de_retriever_query_engine(): - callback_manager = init_llama_index_callback_manager(DE_ZOOMCAMP_PROJECT_NAME) - return get_retriever_query_engine(callback_manager, DE_FAQ_COLLECTION_NAME) - - -def get_ml_retriever_query_engine(): +def get_prompt_template(zoomcamp_name: str, cohort_year: int) -> ChatPromptTemplate: + system_prompt = ChatMessage( + content=( + f"""You are a helpful AI assistant for the {zoomcamp_name} ZoomCamp course at DataTalksClub, + and you can be found in the course's Slack channel. + As a trustworthy assistant, you must provide helpful answers to students' questions about the course, + and assist them in finding solutions when they encounter errors while following the course. + You must do it using only the excerpts from the course FAQ document, Slack threads, and GitHub repository + in the provided context, without relying on prior knowledge. + Current cohort is year {cohort_year} one. + + Here are your guidelines: + 1. Provide clear and concise explanations for your conclusions, including relevant code snippets + if the question pertains to code. Include citations only if strictly necessary. + 2. If the question already contains an answer and requests confirmation, avoid repetition. Instead, + disregard the answer and conduct your analysis based on the provided context. + 3. In your response, refrain from rephrasing the user's question or problem; simply provide an answer. + 4. Make sure that the code examples you provide are accurate and runnable. + 5. In cases where the provided context is insufficient and you are uncertain about the response, reply with: + 'I don't think I have an answer for this; you'll have to ask your fellows or instructors. + 6. All the hyperlinks need to be taken from the provided context, not from the prior knowledge + 7. The hyperlinks need to be formatted the following way: + Example of the correctly formatted link to github: + + """ + ), + role=MessageRole.SYSTEM, + ) + user_prompt = ChatMessage(content=("Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the context information and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: "), + role=MessageRole.USER, ) + return ChatPromptTemplate(message_templates=[ + system_prompt, + user_prompt, + ]) + + +def get_retriever_query_engine(collection_name: str, + zoomcamp_name: str, + cohort_year: int, ): callback_manager = init_llama_index_callback_manager(ML_ZOOMCAMP_PROJECT_NAME) - return get_retriever_query_engine(callback_manager, ML_FAQ_COLLECTION_NAME) - -def get_retriever_query_engine(callback_manager, collection_name): service_context = ServiceContext.from_defaults(embed_model=embeddings, callback_manager=callback_manager, llm=ChatOpenAI(model=GPT_MODEL_NAME, @@ -334,8 +373,15 @@ def get_retriever_query_engine(callback_manager, collection_name): cohere_rerank = CohereRerank(api_key=os.getenv('COHERE_API_KEY'), top_n=4) recency_postprocessor = get_time_weighted_postprocessor() node_postprocessors = [recency_postprocessor, cohere_rerank] + qa_prompt_template = get_prompt_template(zoomcamp_name=zoomcamp_name, + cohort_year=cohort_year, ) + response_synthesizer = get_response_synthesizer(service_context=service_context, + text_qa_template=qa_prompt_template, + verbose=True, + ) return RetrieverQueryEngine(vector_store_index.as_retriever(similarity_top_k=10), node_postprocessors=node_postprocessors, + response_synthesizer=response_synthesizer, callback_manager=callback_manager) @@ -370,8 +416,11 @@ def get_time_weighted_postprocessor(): retriever=mlops_index.as_retriever() ) - ml_query_engine = get_ml_retriever_query_engine() - - de_query_engine = get_de_retriever_query_engine() + ml_query_engine = get_retriever_query_engine(collection_name=ML_FAQ_COLLECTION_NAME, + zoomcamp_name='Machine Learning', + cohort_year=2023) + de_query_engine = get_retriever_query_engine(collection_name=DE_FAQ_COLLECTION_NAME, + zoomcamp_name='Data Engineering', + cohort_year=2024) SocketModeHandler(app, SLACK_APP_TOKEN).start()