diff --git a/slack_bot/main.py b/slack_bot/main.py index 830b2f2..a5768b8 100644 --- a/slack_bot/main.py +++ b/slack_bot/main.py @@ -14,7 +14,7 @@ TimeWeightedPostprocessor ) from llama_index.postprocessor import CohereRerank -from llama_index.query_engine import RouterQueryEngine +from llama_index.query_engine import RouterQueryEngine, RetrieverQueryEngine from llama_index.selectors.pydantic_selectors import PydanticMultiSelector from llama_index.tools import QueryEngineTool from llama_index.vector_stores import MilvusVectorStore, MetadataFilters @@ -46,10 +46,15 @@ ML_FAQ_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course FAQ document as well as " "information about course syllabus and deadlines, schedule in other words. " - "Also, it contains midterm and capstone project evaluation criteria") + "Also, it contains midterm and capstone project evaluation criteria. " + "It is recommended to always check the FAQ first and then refer to the other sources.") +ML_GITHUB_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course GitHub repository that " + "contains information about the code, " + "as well as the formulation of homework assignments.") ML_SLACK_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course " "slack channel history especially the questions about homework " - "or when it's not likely to appear in the FAQ document") + "or when it's not likely to appear in the FAQ document. Also, Slack history can have " + "answers to any question, so it's always worth looking it up.") # Event API & Web API SLACK_BOT_TOKEN = os.getenv('SLACK_BOT_TOKEN') @@ -114,10 +119,26 @@ def links_to_source_nodes(response): link_template.format(channel_id, thread_ts_str) res.add(link_template.format(channel_id, thread_ts_str)) elif 'source' in node.metadata: - res.add(f"<{node.metadata['source']}|{node.metadata['title']}> ") + title = node.metadata['title'] + if title == 'FAQ': + res.add(f"<{node.metadata['source']}|" + f" {title}-{node.text[:35]}...> ") + else: + res.add(f"<{node.metadata['source']}| {title}") + elif 'repo' in node.metadata: + repo = node.metadata['repo'] + owner = node.metadata['owner'] + branch = node.metadata['branch'] + file_path = node.metadata['file_path'] + link_to_file = build_repo_path(owner=owner, repo=repo, branch=branch, file_path=file_path) + res.add(f'<{link_to_file}| GitHub-{repo}-{file_path.split("/")[-1]}>') return '\n'.join(res) +def build_repo_path(owner: str, repo: str, branch: str, file_path: str): + return f'https://github.com/{owner}/{repo}/blob/{branch}/{file_path}' + + def remove_mentions(input_text): # Define a regular expression pattern to match the mention mention_pattern = r'<@U[0-9A-Z]+>' @@ -228,7 +249,7 @@ def init_llama_index_callback_manager(): return CallbackManager([wandb_callback, llama_debug]) -def get_ml_query_engine(): +def get_ml_router_query_engine(): callback_manager = init_llama_index_callback_manager() # Set llm temperature to 0.7 for generation service_context = ServiceContext.from_defaults(embed_model=embeddings, @@ -238,7 +259,16 @@ def get_ml_query_engine(): faq_tool = get_query_engine_tool_by_name(collection_name=ML_FAQ_COLLECTION_NAME, service_context=service_context, description=ML_FAQ_TOOL_DESCRIPTION, - route='faq') + route='faq', + ) + + github_tool = get_query_engine_tool_by_name(collection_name=ML_FAQ_COLLECTION_NAME, + service_context=service_context, + description=ML_GITHUB_TOOL_DESCRIPTION, + similarity_top_k=6, + rerank_top_n=3, + route='github', + ) slack_tool = get_query_engine_tool_by_name(collection_name=ML_SLACK_COLLECTION_NAME, service_context=service_context, @@ -246,7 +276,8 @@ def get_ml_query_engine(): similarity_top_k=20, rerank_top_n=3, rerank_by_time=True, - route='slack') + route='slack', + ) # Create the multi selector query engine # Set llm temperature to 0.4 for routing @@ -259,11 +290,51 @@ def get_ml_query_engine(): query_engine_tools=[ slack_tool, faq_tool, + github_tool ], service_context=router_service_context ) +def get_ml_retriever_query_engine(): + callback_manager = init_llama_index_callback_manager() + service_context = ServiceContext.from_defaults(embed_model=embeddings, + callback_manager=callback_manager, + llm=ChatOpenAI(model=GPT_MODEL_NAME, + temperature=0.7)) + if os.getenv('LOCAL_MILVUS', None): + localhost = os.getenv('LOCALHOST', 'localhost') + vector_store = MilvusVectorStore(collection_name=ML_FAQ_COLLECTION_NAME, + dim=embedding_dimension, + overwrite=False, + uri=f'http://{localhost}:19530') + else: + vector_store = MilvusVectorStore(collection_name=ML_FAQ_COLLECTION_NAME, + uri=os.getenv("ZILLIZ_CLOUD_URI"), + token=os.getenv("ZILLIZ_CLOUD_API_KEY"), + dim=embedding_dimension, + overwrite=False) + vector_store_index = VectorStoreIndex.from_vector_store(vector_store, + service_context=service_context) + + cohere_rerank = CohereRerank(api_key=os.getenv('COHERE_API_KEY'), top_n=4) + recency_postprocessor = get_time_weighted_postprocessor() + node_postprocessors = [recency_postprocessor, cohere_rerank] + + return RetrieverQueryEngine(vector_store_index.as_retriever(similarity_top_k=10), + node_postprocessors=node_postprocessors, + callback_manager=callback_manager) + + +def get_time_weighted_postprocessor(): + return TimeWeightedPostprocessor( + last_accessed_key='thread_ts', + time_decay=0.4, + time_access_refresh=False, + top_k=10, + ) + + if __name__ == "__main__": client = WebClient(SLACK_BOT_TOKEN) @@ -287,6 +358,6 @@ def get_ml_query_engine(): retriever=mlops_index.as_retriever() ) - ml_query_engine = get_ml_query_engine() + ml_query_engine = get_ml_retriever_query_engine() SocketModeHandler(app, SLACK_APP_TOKEN).start()