-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathstreamlit_app.py
73 lines (57 loc) · 2.6 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
import os
import sys
from typing import Any, Dict, Generator, List, Union
import openai
import streamlit as st
from llama_index import StorageContext, load_index_from_storage
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
ResponseType = Union[Generator[Any, None, None], Any, List, Dict]
openai.api_key = "YOUR_OPENAI_API_KEY"
@st.cache_resource(show_spinner=False) # type: ignore[misc]
def load_index() -> Any:
"""Load the index from the storage directory."""
print("Loading index...")
base_dir = os.path.dirname(os.path.abspath(__file__))
dir_path = os.path.join(base_dir, ".kb")
# rebuild storage context
storage_context = StorageContext.from_defaults(persist_dir=dir_path)
# load index
index = load_index_from_storage(storage_context)
query_engine = index.as_query_engine()
print("Done.")
return query_engine
def main() -> None:
"""Run the chatbot."""
if "query_engine" not in st.session_state:
st.session_state.query_engine = load_index()
st.title("Chat with BlogAI Assistant!!")
st.write("All about Snowpark for Data Engineering Quickstarts from quickstarts.snowflake.com. Ask away your questions!")
if "messages" not in st.session_state:
system_prompt = (
"Your purpose is to answer questions about specific documents only. "
"Please answer the user's questions based on what you know about the document. "
"If the question is outside scope of the document, please politely decline. "
"If you don't know the answer, say `I don't know`. "
)
st.session_state.messages = [{"role": "system", "content": system_prompt}]
for message in st.session_state.messages:
if message["role"] not in ["user", "assistant"]:
continue
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
print("Querying query engine API...")
response = st.session_state.query_engine.query(prompt)
full_response = f"{response}"
print(full_response)
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
if __name__ == "__main__":
main()