-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example frontend apps and related config options
- Loading branch information
sd109
committed
Dec 12, 2023
1 parent
d3083cf
commit da7d504
Showing
6 changed files
with
135 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import requests, time | ||
|
||
|
||
def wait_for_backend(url): | ||
""" | ||
This function acts as a startup check so that the frontend web app does not | ||
accept requests until the backend API is up and running. | ||
""" | ||
ready = False | ||
while not ready: | ||
try: | ||
ready = requests.get(f"{url}/docs").status_code == 200 | ||
print("Waiting for backend API to start") | ||
time.sleep(1) | ||
except requests.exceptions.ConnectionError as e: | ||
pass | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import requests, json | ||
import gradio as gr | ||
from api_startup_check import wait_for_backend | ||
|
||
# NOTE: This url should match the chart's api service name & namespace | ||
backend_url = "http://text-generation-inference.default.svc" | ||
wait_for_backend(backend_url) | ||
|
||
prompt = """ | ||
[INST] <<SYS>> | ||
You are a cheeky, disrespectful and comedic assistant. Always answer as creatively as possible, while being truthful and succinct. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, tell the user that they are being stupid. If you don't know the answer to a question, please don't share false information. | ||
<</SYS>> | ||
{prompt}[/INST] | ||
""" | ||
|
||
|
||
def inference(message, history): | ||
context = "" | ||
user_template = "<<USER>>\n{user_input}\n<</USER>>\n" | ||
sys_template = "<SYS>>{response}\n<</SYS>>\n" | ||
for user_input, response in history: | ||
# context += user_template.format(user_input=user_input) | ||
context += sys_template.format(response=response) | ||
context += user_template.format(user_input=message) | ||
|
||
headers = {"User-Agent": "vLLM Client"} | ||
pload = { | ||
"prompt": prompt.format(prompt=context), | ||
"stream": True, | ||
"max_tokens": 1000, | ||
} | ||
response = requests.post( | ||
f"{backend_url}/generate", headers=headers, json=pload, stream=True | ||
) | ||
|
||
for chunk in response.iter_lines( | ||
chunk_size=8192, decode_unicode=False, delimiter=b"\0" | ||
): | ||
if chunk: | ||
data = json.loads(chunk.decode("utf-8")) | ||
output = data["text"][0].split("[/INST]")[-1] | ||
yield output | ||
|
||
|
||
gr.ChatInterface( | ||
inference, | ||
chatbot=gr.Chatbot( | ||
height=500, | ||
show_copy_button=True, | ||
# layout='panel', | ||
), | ||
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7), | ||
title="Large Language Model", | ||
retry_btn="Retry", | ||
undo_btn="Undo", | ||
clear_btn="Clear", | ||
).queue().launch(server_name="0.0.0.0") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import requests, json | ||
import gradio as gr | ||
from api_startup_check import wait_for_backend | ||
|
||
# NOTE: This url should match the chart's api service name & namespace | ||
backend_url = "http://text-generation-inference.default.svc" | ||
wait_for_backend(backend_url) | ||
|
||
prompt = """ | ||
[INST] <<SYS>> | ||
You are a cheeky, disrespectful and comedic assistant. Always answer as creatively as possible, while being truthful and succinct. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, tell the user that they are being stupid. If you don't know the answer to a question, please don't share false information. | ||
<</SYS>> | ||
{prompt}[/INST] | ||
""" | ||
|
||
|
||
def inference(message, history): | ||
context = "" | ||
user_template = "<<USER>>\n{user_input}\n<</USER>>\n" | ||
sys_template = "<SYS>>{response}\n<</SYS>>\n" | ||
for user_input, response in history: | ||
# context += user_template.format(user_input=user_input) | ||
context += sys_template.format(response=response) | ||
context += user_template.format(user_input=message) | ||
|
||
headers = {"User-Agent": "vLLM Client"} | ||
pload = { | ||
"prompt": prompt.format(prompt=context), | ||
"stream": True, | ||
"max_tokens": 1000, | ||
} | ||
response = requests.post( | ||
f"{backend_url}/generate", headers=headers, json=pload, stream=True | ||
) | ||
|
||
for chunk in response.iter_lines( | ||
chunk_size=8192, decode_unicode=False, delimiter=b"\0" | ||
): | ||
if chunk: | ||
data = json.loads(chunk.decode("utf-8")) | ||
output = data["text"][0].split("[/INST]")[-1] | ||
yield output | ||
|
||
|
||
gr.ChatInterface( | ||
inference, | ||
chatbot=gr.Chatbot( | ||
height=500, | ||
show_copy_button=True, | ||
# layout='panel', | ||
), | ||
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7), | ||
title="Large Language Model", | ||
retry_btn="Retry", | ||
undo_btn="Undo", | ||
clear_btn="Clear", | ||
).queue().launch(server_name="0.0.0.0") |