Skip to content

Commit

Permalink
change mlops part to use llamaindex retriever query engine
Browse files Browse the repository at this point in the history
  • Loading branch information
aaalexlit committed May 12, 2024
1 parent b677f87 commit 527b7e5
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 178 deletions.
3 changes: 3 additions & 0 deletions dev.env
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ WANDB_API_KEY=..
ZILLIZ_CLOUD_URI=https://..
ZILLIZ_CLOUD_API_KEY=..

ZILLIZ_PUBLIC_ENDPOINT=https://..
ZILLIZ_API_KEY=..

LANGCHAIN_API_KEY=ls__..

COHERE_API_KEY=..
Expand Down
3 changes: 3 additions & 0 deletions slack_bot/dev.env
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ LANGCHAIN_API_KEY=ls__..
ZILLIZ_CLOUD_URI=https://..
ZILLIZ_CLOUD_API_KEY=..

ZILLIZ_PUBLIC_ENDPOINT=https://..
ZILLIZ_API_KEY=..

COHERE_API_KEY=..

# DEBUG log level
Expand Down
203 changes: 28 additions & 175 deletions slack_bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,18 @@
import sys
import uuid

import pinecone
from cohere import CohereAPIError
from langchain import callbacks
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Pinecone
from langchain_openai import ChatOpenAI
from langsmith import Client
from llama_index.callbacks.wandb import WandbCallbackHandler
from llama_index.core import ChatPromptTemplate
from llama_index.core import ServiceContext
from llama_index.core import VectorStoreIndex
from llama_index.core import get_response_synthesizer
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.postprocessor import TimeWeightedPostprocessor
from llama_index.core.query_engine import RouterQueryEngine, RetrieverQueryEngine
from llama_index.core.selectors import PydanticMultiSelector
from llama_index.core.tools import QueryEngineTool
from llama_index.core.vector_stores import ExactMatchFilter
from llama_index.core.vector_stores import MetadataFilters
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.vector_stores.milvus import MilvusVectorStore
from requests.exceptions import ChunkedEncodingError
Expand All @@ -43,32 +34,20 @@

DE_CHANNELS = ['C01FABYF2RG', 'C06CBSE16JC', 'C06BZJX8PSP']
ML_CHANNELS = ['C0288NJ5XSA', 'C05C3SGMLBB', 'C05DTQECY66']
MLOPS_CHANNELS = ['C02R98X7DS9', 'C06C1N46CQ1']
MLOPS_CHANNELS = ['C02R98X7DS9', 'C06C1N46CQ1', 'C0735558X52']

ALLOWED_CHANNELS = DE_CHANNELS + ML_CHANNELS + MLOPS_CHANNELS

PROJECT_NAME = 'datatalks-faq-slackbot'
ML_ZOOMCAMP_PROJECT_NAME = 'ml-zoomcamp-slack-bot'
DE_ZOOMCAMP_PROJECT_NAME = 'de-zoomcamp-slack-bot'

MLOPS_INDEX_NAME = 'mlops-faq-bot'
ML_FAQ_COLLECTION_NAME = 'mlzoomcamp_faq_git'
DE_FAQ_COLLECTION_NAME = 'dezoomcamp_faq_git'
ML_COLLECTION_NAME = 'mlzoomcamp_faq_git'
DE_COLLECTION_NAME = 'dezoomcamp_faq_git'
MLOPS_COLLECTION_NAME = 'mlopszoomcamp'

GPT_MODEL_NAME = 'gpt-3.5-turbo-0125'

ML_FAQ_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course FAQ document as well as "
"information about course syllabus and deadlines, schedule in other words. "
"Also, it contains midterm and capstone project evaluation criteria. "
"It is recommended to always check the FAQ first and then refer to the other sources.")
ML_GITHUB_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course GitHub repository that "
"contains information about the code, "
"as well as the formulation of homework assignments.")
ML_SLACK_TOOL_DESCRIPTION = ("Useful for retrieving specific context from the course "
"slack channel history especially the questions about homework "
"or when it's not likely to appear in the FAQ document. Also, Slack history can have "
"answers to any question, so it's always worth looking it up.")

# Event API & Web API
SLACK_BOT_TOKEN = os.getenv('SLACK_BOT_TOKEN')
SLACK_APP_TOKEN = os.getenv('SLACK_APP_TOKEN')
Expand Down Expand Up @@ -199,7 +178,7 @@ def handle_message_events(body):
try:
with callbacks.collect_runs() as cb:
if channel_id in MLOPS_CHANNELS:
response = mlops_qa.run(question)
response = mlops_query_engine.query(question)
elif channel_id in ML_CHANNELS:
response = ml_query_engine.query(question)
else:
Expand Down Expand Up @@ -352,14 +331,12 @@ def get_greeting_message(channel_id):
"The answers might not be accurate since I'm " \
"just a human-friendly interface to the " \
"<https://docs.google.com/document/d/{link}| {name} Zoomcamp FAQ>" \
"{additional_message} and this course's <https://github.com/DataTalksClub/{repo}|GitHub repo>." \
", this Slack channel, and this course's <https://github.com/DataTalksClub/{repo}|GitHub repo>." \
"\nThanks for your request, I'm on it!"
additional_message = ", this Slack channel,"
if channel_id in MLOPS_CHANNELS:
name = 'MLOps'
link = '12TlBfhIiKtyBv8RnsoJR6F72bkPDGEvPOItJIxaEzE0/edit#heading=h.uwpp1jrsj0d'
repo = 'mlops-zoomcamp'
additional_message = ""
elif channel_id in ML_CHANNELS:
name = 'ML'
link = '1LpPanc33QJJ6BSsyxVg-pWNMplal84TdZtq10naIhD8/edit#heading=h.98qq6wfuzeck'
Expand All @@ -368,13 +345,7 @@ def get_greeting_message(channel_id):
name = 'DE'
link = '19bnYs80DwuUimHM65UV3sylsCn2j1vziPOwzBwQrebw/edit#heading=h.o29af0z8xx88'
repo = 'data-engineering-zoomcamp'
return message_template.format(name=name, link=link, repo=repo, additional_message=additional_message)


def log_langchain_to_wandb():
# Log everything to WANDB!!!
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
os.environ["WANDB_PROJECT"] = PROJECT_NAME
return message_template.format(name=name, link=link, repo=repo)


def log_to_langsmith():
Expand All @@ -383,124 +354,6 @@ def log_to_langsmith():
os.environ["LANGCHAIN_PROJECT"] = PROJECT_NAME


def setup_mlops_index():
logger.info('Initiating pinecone client...')
pinecone.init(
api_key=os.getenv('PINECONE_API_KEY'),
environment=os.getenv('PINECONE_ENV')
)
pinecone_index = Pinecone.from_existing_index(index_name=MLOPS_INDEX_NAME,
embedding=embeddings)
index = pinecone.GRPCIndex(MLOPS_INDEX_NAME)
logger.info(f"Mlops index stats: {index.describe_index_stats()}")
return pinecone_index


def get_query_engine_tool_by_name(collection_name: str,
service_context: ServiceContext,
description: str,
route: str = None,
similarity_top_k: int = 4,
rerank_top_n: int = 2,
rerank_by_time: bool = False):
if os.getenv('LOCAL_MILVUS', None):
localhost = os.getenv('LOCALHOST', 'localhost')
vector_store = MilvusVectorStore(collection_name=collection_name,
dim=embedding_dimension,
overwrite=False,
uri=f'http://{localhost}:19530')
else:
vector_store = MilvusVectorStore(collection_name=collection_name,
uri=os.getenv("ZILLIZ_CLOUD_URI"),
token=os.getenv("ZILLIZ_CLOUD_API_KEY"),
dim=embedding_dimension,
overwrite=False)
vector_store_index = VectorStoreIndex.from_vector_store(vector_store,
service_context=service_context)

cohere_rerank = CohereRerank(api_key=os.getenv('COHERE_API_KEY'), top_n=rerank_top_n)
node_postprocessors = [cohere_rerank]
if rerank_by_time:
key = 'thread_ts'
recency_postprocessor = TimeWeightedPostprocessor(
last_accessed_key=key,
time_decay=0.4,
time_access_refresh=False,
top_k=10
)
node_postprocessors.insert(0, recency_postprocessor)

if route:
filters = MetadataFilters(
filters=[ExactMatchFilter(key="route", value=route)]
)
else:
filters = None

return QueryEngineTool.from_defaults(
query_engine=vector_store_index.as_query_engine(
similarity_top_k=similarity_top_k,
node_postprocessors=node_postprocessors,
filters=filters,
),
description=description,
name=route
)


def init_llama_index_callback_manager(project_name):
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
wandb_callback = WandbCallbackHandler(run_args=dict(project=project_name))
return CallbackManager([wandb_callback, llama_debug])


def get_ml_router_query_engine():
callback_manager = init_llama_index_callback_manager(ML_ZOOMCAMP_PROJECT_NAME)
# Set llm temperature to 0.7 for generation
service_context = ServiceContext.from_defaults(embed_model=embeddings,
callback_manager=callback_manager,
llm=ChatOpenAI(model=GPT_MODEL_NAME,
temperature=0.7))
faq_tool = get_query_engine_tool_by_name(collection_name=ML_FAQ_COLLECTION_NAME,
service_context=service_context,
description=ML_FAQ_TOOL_DESCRIPTION,
route='faq',
)

github_tool = get_query_engine_tool_by_name(collection_name=ML_FAQ_COLLECTION_NAME,
service_context=service_context,
description=ML_GITHUB_TOOL_DESCRIPTION,
similarity_top_k=6,
rerank_top_n=3,
route='github',
)

slack_tool = get_query_engine_tool_by_name(collection_name=ML_FAQ_COLLECTION_NAME,
service_context=service_context,
description=ML_SLACK_TOOL_DESCRIPTION,
similarity_top_k=20,
rerank_top_n=3,
rerank_by_time=True,
route='slack',
)

# Create the multi selector query engine
# Set llm temperature to 0.4 for routing
router_service_context = ServiceContext.from_defaults(embed_model=embeddings,
callback_manager=callback_manager,
llm=ChatOpenAI(model=GPT_MODEL_NAME,
temperature=0.4))
return RouterQueryEngine(
selector=PydanticMultiSelector.from_defaults(verbose=True),
query_engine_tools=[
slack_tool,
faq_tool,
github_tool
],
service_context=router_service_context
)


def get_prompt_template(zoomcamp_name: str, cohort_year: int, course_start_date: str) -> ChatPromptTemplate:
system_prompt = ChatMessage(
content=(
Expand Down Expand Up @@ -537,8 +390,6 @@ def get_prompt_template(zoomcamp_name: str, cohort_year: int, course_start_date:
"---------------------\n"
"{context_str}\n"
"---------------------\n"
# "Given the information above and not prior knowledge, "
# "answer the question.\n"
"Question: {query_str}\n"
"Answer: "),
role=MessageRole.USER, )
Expand All @@ -556,10 +407,7 @@ def get_retriever_query_engine(collection_name: str,
zoomcamp_name: str,
cohort_year: int,
course_start_date: str):
callback_manager = init_llama_index_callback_manager(ML_ZOOMCAMP_PROJECT_NAME)

service_context = ServiceContext.from_defaults(embed_model=embeddings,
callback_manager=callback_manager,
llm=ChatOpenAI(model=GPT_MODEL_NAME,
temperature=0.7))
if os.getenv('LOCAL_MILVUS', None):
Expand All @@ -569,11 +417,18 @@ def get_retriever_query_engine(collection_name: str,
overwrite=False,
uri=f'http://{localhost}:19530')
else:
vector_store = MilvusVectorStore(collection_name=collection_name,
uri=os.getenv("ZILLIZ_CLOUD_URI"),
token=os.getenv("ZILLIZ_CLOUD_API_KEY"),
dim=embedding_dimension,
overwrite=False)
if collection_name == MLOPS_COLLECTION_NAME:
vector_store = MilvusVectorStore(collection_name=collection_name,
uri=os.getenv("ZILLIZ_PUBLIC_ENDPOINT"),
token=os.getenv("ZILLIZ_API_KEY"),
dim=embedding_dimension,
overwrite=False)
else:
vector_store = MilvusVectorStore(collection_name=collection_name,
uri=os.getenv("ZILLIZ_CLOUD_URI"),
token=os.getenv("ZILLIZ_CLOUD_API_KEY"),
dim=embedding_dimension,
overwrite=False)
vector_store_index = VectorStoreIndex.from_vector_store(vector_store,
service_context=service_context)
cohere_rerank = CohereRerank(api_key=os.getenv('COHERE_API_KEY'), top_n=4)
Expand All @@ -589,7 +444,7 @@ def get_retriever_query_engine(collection_name: str,
return RetrieverQueryEngine(vector_store_index.as_retriever(similarity_top_k=15),
node_postprocessors=node_postprocessors,
response_synthesizer=response_synthesizer,
callback_manager=callback_manager)
)


def get_time_weighted_postprocessor():
Expand All @@ -616,20 +471,18 @@ def get_time_weighted_postprocessor():

log_to_langsmith()

mlops_index = setup_mlops_index()

mlops_qa = RetrievalQA.from_chain_type(
llm=ChatOpenAI(model_name=GPT_MODEL_NAME),
retriever=mlops_index.as_retriever()
)

ml_query_engine = get_retriever_query_engine(collection_name=ML_FAQ_COLLECTION_NAME,
ml_query_engine = get_retriever_query_engine(collection_name=ML_COLLECTION_NAME,
zoomcamp_name='Machine Learning',
cohort_year=2023,
course_start_date='11 September 2023')

de_query_engine = get_retriever_query_engine(collection_name=DE_FAQ_COLLECTION_NAME,
de_query_engine = get_retriever_query_engine(collection_name=DE_COLLECTION_NAME,
zoomcamp_name='Data Engineering',
cohort_year=2024,
course_start_date='15 January 2024')

mlops_query_engine = get_retriever_query_engine(collection_name=MLOPS_COLLECTION_NAME,
zoomcamp_name='MLOps',
cohort_year=2024,
course_start_date='13 May 2024')
SocketModeHandler(app, SLACK_APP_TOKEN).start()
3 changes: 0 additions & 3 deletions slack_bot/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ slack-bolt==1.18.1
slack-sdk==3.27.1
langchain==0.1.10
sentence-transformers==2.5.1
wandb==0.16.3
cohere==4.51
pymilvus==2.3.6
langchain-openai==0.0.8
llama-index-core==0.10.15
llama-index-readers-web==0.1.6
llama-index-callbacks-wandb==0.1.2
llama-index-readers-github==0.1.7
llama-index-vector-stores-milvus==0.1.4
llama-index-embeddings-langchain==0.1.2
llama-index-postprocessor-cohere-rerank==0.1.2
llama-index-llms-langchain==0.1.3
pinecone-client[grpc]==2.2.4

0 comments on commit 527b7e5

Please sign in to comment.