Skip to content

Commit

Permalink
EdgeCraft RAG UI bug fix (opea-project#1189)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Yongbozzz and pre-commit-ci[bot] authored Dec 2, 2024
1 parent 0f8344e commit bb466b3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 48 deletions.
3 changes: 3 additions & 0 deletions EdgeCraftRAG/edgecraftrag/components/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def run(self, chat_request, retrieved_nodes, **kwargs):
repetition_penalty=chat_request.repetition_penalty,
)
self.llm().generate_kwargs = generate_kwargs
self.llm().max_new_tokens = chat_request.max_tokens
if chat_request.stream:

async def stream_generator():
Expand Down Expand Up @@ -99,8 +100,10 @@ def run_vllm(self, chat_request, retrieved_nodes, **kwargs):
max_tokens=chat_request.max_tokens,
model=model_name,
top_p=chat_request.top_p,
top_k=chat_request.top_k,
temperature=chat_request.temperature,
streaming=chat_request.stream,
repetition_penalty=chat_request.repetition_penalty,
)

if chat_request.stream:
Expand Down
2 changes: 1 addition & 1 deletion EdgeCraftRAG/ui/gradio/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ name: "default"

# Node parser
node_parser: "simple"
chunk_size: 192
chunk_size: 400
chunk_overlap: 48

# Indexer
Expand Down
2 changes: 1 addition & 1 deletion EdgeCraftRAG/ui/gradio/ecrag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_update_pipeline(
weight=llm_weights,
),
),
retriever=api_schema.RetrieverIn(retriever_type=retriever, retriever_topk=vector_search_top_k),
retriever=api_schema.RetrieverIn(retriever_type=retriever, retrieve_topk=vector_search_top_k),
postprocessor=[
api_schema.PostProcessorIn(
processor_type=postprocessor[0],
Expand Down
71 changes: 25 additions & 46 deletions EdgeCraftRAG/ui/gradio/ecragui.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,11 @@ async def bot(
top_k,
repetition_penalty,
max_tokens,
hide_full_prompt,
docs,
chunk_size,
chunk_overlap,
vector_search_top_k,
vector_search_top_n,
run_rerank,
search_method,
score_threshold,
vector_rerank_top_n,
):
"""Callback function for running chatbot on submit button click.
Expand All @@ -108,8 +104,21 @@ async def bot(
repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
conversation_id: unique conversation identifier.
"""
if history[-1][0] == "" or len(history[-1][0]) == 0:
yield history[:-1]
return

stream_opt = True
new_req = {"messages": history[-1][0], "stream": stream_opt, "max_tokens": max_tokens}
new_req = {
"messages": history[-1][0],
"stream": stream_opt,
"max_tokens": max_tokens,
"top_n": vector_rerank_top_n,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
server_addr = f"http://{MEGA_SERVICE_HOST_IP}:{MEGA_SERVICE_PORT}"

# Async for streaming response
Expand Down Expand Up @@ -362,7 +371,7 @@ def get_pipeline_df():
choices=avail_llm_inference_type, label="LLM Inference Type", value="local"
)

with gr.Accordion("LLM Configuration", open=True):
with gr.Accordion("LLM Configuration", open=True) as accordion:
u_llm_model_id = gr.Dropdown(
choices=avail_llms,
value=cfg.llm_model_id,
Expand Down Expand Up @@ -393,6 +402,12 @@ def get_pipeline_df():
# RAG Settings Events
# -------------------
# Event handlers
def update_visibility(selected_value): # Accept the event argument, even if not used
if selected_value == "vllm":
return gr.Accordion(visible=False)
else:
return gr.Accordion(visible=True)

def show_pipeline_detail(evt: gr.SelectData):
# get selected pipeline id
# Dataframe: {'headers': '', 'data': [[x00, x01], [x10, x11]}
Expand Down Expand Up @@ -470,6 +485,8 @@ def create_update_pipeline(
return res, get_pipeline_df()

# Events
u_llm_infertype.change(update_visibility, inputs=u_llm_infertype, outputs=accordion)

u_pipelines.select(
show_pipeline_detail,
inputs=None,
Expand Down Expand Up @@ -735,39 +752,9 @@ def delete_file():
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
retriever_argument = gr.Accordion("Retriever Configuration", open=True)
retriever_argument = gr.Accordion("Retriever Configuration", open=False)
with retriever_argument:
with gr.Row():
with gr.Row():
do_rerank = gr.Checkbox(
value=True,
label="Rerank searching result",
interactive=True,
)
hide_context = gr.Checkbox(
value=True,
label="Hide searching result in prompt",
interactive=True,
)
with gr.Row():
search_method = gr.Dropdown(
["similarity_score_threshold", "similarity", "mmr"],
value=cfg.search_method,
label="Searching Method",
info="Method used to search vector store",
multiselect=False,
interactive=True,
)
with gr.Row():
score_threshold = gr.Slider(
0.01,
0.99,
value=cfg.score_threshold,
step=0.01,
label="Similarity Threshold",
info="Only working for 'similarity score threshold' method",
interactive=True,
)
with gr.Row():
vector_rerank_top_n = gr.Slider(
1,
Expand Down Expand Up @@ -811,15 +798,11 @@ def delete_file():
top_k,
repetition_penalty,
u_max_tokens,
hide_context,
docs,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
do_rerank,
search_method,
score_threshold,
],
chatbot,
queue=True,
Expand All @@ -833,15 +816,11 @@ def delete_file():
top_k,
repetition_penalty,
u_max_tokens,
hide_context,
docs,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
do_rerank,
search_method,
score_threshold,
],
chatbot,
queue=True,
Expand Down

0 comments on commit bb466b3

Please sign in to comment.