-
Notifications
You must be signed in to change notification settings - Fork 0
/
repl.py
122 lines (83 loc) · 3.19 KB
/
repl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import models
import postgres
from copy import deepcopy
from typing import List, Tuple
def title_to_url(title: str):
return f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}"
def print_chunk_info(chunk, rank):
print(f" {chunk.title} ({rank}) {title_to_url(chunk.title)}")
print(f" {chunk.description}")
def print_results(chunks_with_distances):
if len(chunks_with_distances) == 0:
print("No results found.")
return
pages = {}
chunks_with_distances = deepcopy(chunks_with_distances)
# merge multiple chunks of the same page, keeping the highest rank
last_rank = 0
for chunk, rank in chunks_with_distances:
assert last_rank < rank, "distances are not sorted."
last_rank = rank
existing_chunk = pages.get(chunk.pageId, None)
if not existing_chunk:
pages[chunk.pageId] = (chunk, rank)
print("\nResults:\n")
# sort pages by rank and print
for p in sorted(pages.values(), key=lambda x: x[1]):
print_chunk_info(p[0], p[1])
print()
def print_llm_stream(stream):
chars = 0
print()
for c in stream:
chunk_of_text: str = c["message"]["content"]
if "\n" in chunk_of_text:
chars = 0
if chars > 72 and " " in chunk_of_text:
split = chunk_of_text.index(" ")
chunk_of_text = f"{chunk_of_text[:split]}\n{chunk_of_text[split:].lstrip()}"
chars = 0
chars += len(chunk_of_text)
print(chunk_of_text, end="", flush=True)
print()
def get_sys_prompt():
return """
You are a powerful AI trained to help people. You are augmented by a context
with a number of documents, and your job is to use and consume the documents to
best help the user. When you answer the user's query, you cite documents from
the context by referencing document URLs. You should answer in full sentences,
using proper grammar and spelling.
"""
def get_user_prompt(query, chunks_with_distances: List[Tuple[postgres.Chunk, float]]):
context = ""
index = 0
for c, _ in chunks_with_distances:
context += f"Document url:{title_to_url(c.title)}\n{c.title}: {c.description}\n{c.text}\n\n"
index += 1
return f"""
Respond to a query using the following context. Base your answer on documents
from this context only and cite the documents you used by using the URL like
this:
According to https://en.wikipedia.org/wiki/Cat the domestic cat is a small carnivorous mammal.
Context:
{context}
End of context.
Respond to the following query: {query}.
"""
def rag(query, number_of_documents=5):
emb = models.embedding_string(query, models.EmbeddingPrefix.QUERY)
chunks_with_distances = postgres.get_similar_chunks_with_distance(
emb, number_of_documents
)
print_results(chunks_with_distances)
sys_prompt = get_sys_prompt()
user_prompt = get_user_prompt(query, chunks_with_distances)
# print(sys_prompt)
# print(user_prompt)
stream = models.chat(input=user_prompt, system=[sys_prompt], stream=True)
print_llm_stream(stream)
if __name__ == "__main__":
while True:
print(78 * ".")
query = input("\nQuery >> ")
rag(query, number_of_documents=1)