Skip to content

Commit

Permalink
switch to retriever query engine
Browse files Browse the repository at this point in the history
add github source processing
  • Loading branch information
aaalexlit committed Dec 28, 2023
1 parent f3fe5db commit d6515e5
Showing 1 changed file with 79 additions and 8 deletions.
87 changes: 79 additions & 8 deletions slack_bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TimeWeightedPostprocessor
)
from llama_index.postprocessor import CohereRerank
from llama_index.query_engine import RouterQueryEngine
from llama_index.query_engine import RouterQueryEngine, RetrieverQueryEngine
from llama_index.selectors.pydantic_selectors import PydanticMultiSelector
from llama_index.tools import QueryEngineTool
from llama_index.vector_stores import MilvusVectorStore, MetadataFilters
Expand Down Expand Up @@ -46,10 +46,15 @@

ML_FAQ_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course FAQ document as well as "
"information about course syllabus and deadlines, schedule in other words. "
"Also, it contains midterm and capstone project evaluation criteria")
"Also, it contains midterm and capstone project evaluation criteria. "
"It is recommended to always check the FAQ first and then refer to the other sources.")
ML_GITHUB_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course GitHub repository that "
"contains information about the code, "
"as well as the formulation of homework assignments.")
ML_SLACK_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course "
"slack channel history especially the questions about homework "
"or when it's not likely to appear in the FAQ document")
"or when it's not likely to appear in the FAQ document. Also, Slack history can have "
"answers to any question, so it's always worth looking it up.")

# Event API & Web API
SLACK_BOT_TOKEN = os.getenv('SLACK_BOT_TOKEN')
Expand Down Expand Up @@ -114,10 +119,26 @@ def links_to_source_nodes(response):
link_template.format(channel_id, thread_ts_str)
res.add(link_template.format(channel_id, thread_ts_str))
elif 'source' in node.metadata:
res.add(f"<{node.metadata['source']}|{node.metadata['title']}> ")
title = node.metadata['title']
if title == 'FAQ':
res.add(f"<{node.metadata['source']}|"
f" {title}-{node.text[:35]}...> ")
else:
res.add(f"<{node.metadata['source']}| {title}")
elif 'repo' in node.metadata:
repo = node.metadata['repo']
owner = node.metadata['owner']
branch = node.metadata['branch']
file_path = node.metadata['file_path']
link_to_file = build_repo_path(owner=owner, repo=repo, branch=branch, file_path=file_path)
res.add(f'<{link_to_file}| GitHub-{repo}-{file_path.split("/")[-1]}>')
return '\n'.join(res)


def build_repo_path(owner: str, repo: str, branch: str, file_path: str):
return f'https://github.com/{owner}/{repo}/blob/{branch}/{file_path}'


def remove_mentions(input_text):
# Define a regular expression pattern to match the mention
mention_pattern = r'<@U[0-9A-Z]+>'
Expand Down Expand Up @@ -228,7 +249,7 @@ def init_llama_index_callback_manager():
return CallbackManager([wandb_callback, llama_debug])


def get_ml_query_engine():
def get_ml_router_query_engine():
callback_manager = init_llama_index_callback_manager()
# Set llm temperature to 0.7 for generation
service_context = ServiceContext.from_defaults(embed_model=embeddings,
Expand All @@ -238,15 +259,25 @@ def get_ml_query_engine():
faq_tool = get_query_engine_tool_by_name(collection_name=ML_FAQ_COLLECTION_NAME,
service_context=service_context,
description=ML_FAQ_TOOL_DESCRIPTION,
route='faq')
route='faq',
)

github_tool = get_query_engine_tool_by_name(collection_name=ML_FAQ_COLLECTION_NAME,
service_context=service_context,
description=ML_GITHUB_TOOL_DESCRIPTION,
similarity_top_k=6,
rerank_top_n=3,
route='github',
)

slack_tool = get_query_engine_tool_by_name(collection_name=ML_SLACK_COLLECTION_NAME,
service_context=service_context,
description=ML_SLACK_TOOL_DESCRIPTION,
similarity_top_k=20,
rerank_top_n=3,
rerank_by_time=True,
route='slack')
route='slack',
)

# Create the multi selector query engine
# Set llm temperature to 0.4 for routing
Expand All @@ -259,11 +290,51 @@ def get_ml_query_engine():
query_engine_tools=[
slack_tool,
faq_tool,
github_tool
],
service_context=router_service_context
)


def get_ml_retriever_query_engine():
callback_manager = init_llama_index_callback_manager()
service_context = ServiceContext.from_defaults(embed_model=embeddings,
callback_manager=callback_manager,
llm=ChatOpenAI(model=GPT_MODEL_NAME,
temperature=0.7))
if os.getenv('LOCAL_MILVUS', None):
localhost = os.getenv('LOCALHOST', 'localhost')
vector_store = MilvusVectorStore(collection_name=ML_FAQ_COLLECTION_NAME,
dim=embedding_dimension,
overwrite=False,
uri=f'http://{localhost}:19530')
else:
vector_store = MilvusVectorStore(collection_name=ML_FAQ_COLLECTION_NAME,
uri=os.getenv("ZILLIZ_CLOUD_URI"),
token=os.getenv("ZILLIZ_CLOUD_API_KEY"),
dim=embedding_dimension,
overwrite=False)
vector_store_index = VectorStoreIndex.from_vector_store(vector_store,
service_context=service_context)

cohere_rerank = CohereRerank(api_key=os.getenv('COHERE_API_KEY'), top_n=4)
recency_postprocessor = get_time_weighted_postprocessor()
node_postprocessors = [recency_postprocessor, cohere_rerank]

return RetrieverQueryEngine(vector_store_index.as_retriever(similarity_top_k=10),
node_postprocessors=node_postprocessors,
callback_manager=callback_manager)


def get_time_weighted_postprocessor():
return TimeWeightedPostprocessor(
last_accessed_key='thread_ts',
time_decay=0.4,
time_access_refresh=False,
top_k=10,
)


if __name__ == "__main__":
client = WebClient(SLACK_BOT_TOKEN)

Expand All @@ -287,6 +358,6 @@ def get_ml_query_engine():
retriever=mlops_index.as_retriever()
)

ml_query_engine = get_ml_query_engine()
ml_query_engine = get_ml_retriever_query_engine()

SocketModeHandler(app, SLACK_APP_TOKEN).start()

0 comments on commit d6515e5

Please sign in to comment.