+
+
+
+
+
+
+ Edge Craft RAG based Q&A Chatbot
+ Powered by Intel NEXC Edge AI solutions
+
+
+
+
+
+ """
+ )
+ _ = gr.Textbox(
+ label="System Status",
+ value=get_system_status,
+ max_lines=1,
+ every=1,
+ info="",
+ elem_id="white_border",
+ )
+
+ def get_pipeline_df():
+ global pipeline_df
+ pipeline_df = cli.get_current_pipelines()
+ return pipeline_df
+
+ # -------------------
+ # RAG Settings Layout
+ # -------------------
+ with gr.Tab("RAG Settings"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ u_pipelines = gr.Dataframe(
+ headers=["ID", "Name"],
+ column_widths=[70, 30],
+ value=get_pipeline_df,
+ label="Pipelines",
+ show_label=True,
+ interactive=False,
+ every=5,
+ )
+
+ u_rag_pipeline_status = gr.Textbox(label="Status", value="", interactive=False)
+
+ with gr.Column(scale=3):
+ with gr.Accordion("Pipeline Configuration"):
+ with gr.Row():
+ rag_create_pipeline = gr.Button("Create Pipeline")
+ rag_activate_pipeline = gr.Button("Activate Pipeline")
+ rag_remove_pipeline = gr.Button("Remove Pipeline")
+
+ with gr.Column(variant="panel"):
+ u_pipeline_name = gr.Textbox(
+ label="Name",
+ value=cfg.name,
+ interactive=True,
+ )
+ u_active = gr.Checkbox(
+ value=True,
+ label="Activated",
+ interactive=True,
+ )
+
+ with gr.Column(variant="panel"):
+ with gr.Accordion("Node Parser"):
+ u_node_parser = gr.Dropdown(
+ choices=avail_node_parsers,
+ label="Node Parser",
+ value=cfg.node_parser,
+ info="Select a parser to split documents.",
+ multiselect=False,
+ interactive=True,
+ )
+ u_chunk_size = gr.Slider(
+ label="Chunk size",
+ value=cfg.chunk_size,
+ minimum=100,
+ maximum=2000,
+ step=50,
+ interactive=True,
+ info="Size of sentence chunk",
+ )
+
+ u_chunk_overlap = gr.Slider(
+ label="Chunk overlap",
+ value=cfg.chunk_overlap,
+ minimum=0,
+ maximum=400,
+ step=1,
+ interactive=True,
+ info=("Overlap between 2 chunks"),
+ )
+
+ with gr.Column(variant="panel"):
+ with gr.Accordion("Indexer"):
+ u_indexer = gr.Dropdown(
+ choices=avail_indexers,
+ label="Indexer",
+ value=cfg.indexer,
+ info="Select an indexer for indexing content of the documents.",
+ multiselect=False,
+ interactive=True,
+ )
+
+ with gr.Accordion("Embedding Model Configuration"):
+ u_embed_model_id = gr.Dropdown(
+ choices=avail_embed_models,
+ value=cfg.embedding_model_id,
+ label="Embedding Model",
+ # info="Select a Embedding Model",
+ multiselect=False,
+ allow_custom_value=True,
+ )
+
+ u_embed_device = gr.Dropdown(
+ choices=avail_devices,
+ value=cfg.embedding_device,
+ label="Embedding run device",
+ # info="Run embedding model on which device?",
+ multiselect=False,
+ )
+
+ with gr.Column(variant="panel"):
+ with gr.Accordion("Retriever"):
+ u_retriever = gr.Dropdown(
+ choices=avail_retrievers,
+ value=cfg.retriever,
+ label="Retriever",
+ info="Select a retriever for retrieving context.",
+ multiselect=False,
+ interactive=True,
+ )
+ u_vector_search_top_k = gr.Slider(
+ 1,
+ 50,
+ value=cfg.k_retrieval,
+ step=1,
+ label="Search top k",
+ info="Number of searching results, must >= Rerank top n",
+ interactive=True,
+ )
+
+ with gr.Column(variant="panel"):
+ with gr.Accordion("Postprocessor"):
+ u_postprocessor = gr.Dropdown(
+ choices=avail_postprocessors,
+ value=cfg.postprocessor,
+ label="Postprocessor",
+ info="Select postprocessors for post-processing of the context.",
+ multiselect=True,
+ interactive=True,
+ )
+
+ with gr.Accordion("Rerank Model Configuration", open=True):
+ u_rerank_model_id = gr.Dropdown(
+ choices=avail_rerank_models,
+ value=cfg.rerank_model_id,
+ label="Rerank Model",
+ # info="Select a Rerank Model",
+ multiselect=False,
+ allow_custom_value=True,
+ )
+
+ u_rerank_device = gr.Dropdown(
+ choices=avail_devices,
+ value=cfg.rerank_device,
+ label="Rerank run device",
+ # info="Run rerank model on which device?",
+ multiselect=False,
+ )
+
+ with gr.Column(variant="panel"):
+ with gr.Accordion("Generator"):
+ u_generator = gr.Dropdown(
+ choices=avail_generators,
+ value=cfg.generator,
+ label="Generator",
+ info="Select a generator for AI inference.",
+ multiselect=False,
+ interactive=True,
+ )
+
+ with gr.Accordion("LLM Configuration", open=True):
+ u_llm_model_id = gr.Dropdown(
+ choices=avail_llms,
+ value=cfg.llm_model_id,
+ label="Large Language Model",
+ # info="Select a Large Language Model",
+ multiselect=False,
+ allow_custom_value=True,
+ )
+
+ u_llm_device = gr.Dropdown(
+ choices=avail_devices,
+ value=cfg.llm_device,
+ label="LLM run device",
+ # info="Run LLM on which device?",
+ multiselect=False,
+ )
+
+ u_llm_weights = gr.Radio(
+ avail_weights_compression,
+ label="Weights",
+ info="weights compression",
+ )
+
+ # -------------------
+ # RAG Settings Events
+ # -------------------
+ # Event handlers
+ def show_pipeline_detail(evt: gr.SelectData):
+ # get selected pipeline id
+ # Dataframe: {'headers': '', 'data': [[x00, x01], [x10, x11]}
+ # SelectData.index: [i, j]
+ print(u_pipelines.value["data"])
+ print(evt.index)
+ # always use pipeline id for indexing
+ selected_id = pipeline_df[evt.index[0]][0]
+ pl = cli.get_pipeline(selected_id)
+ # TODO: change to json fomart
+ # pl["postprocessor"][0]["processor_type"]
+ # pl["postprocessor"]["model"]["model_id"], pl["postprocessor"]["model"]["device"]
+ return (
+ pl["name"],
+ pl["status"]["active"],
+ pl["node_parser"]["parser_type"],
+ pl["node_parser"]["chunk_size"],
+ pl["node_parser"]["chunk_overlap"],
+ pl["indexer"]["indexer_type"],
+ pl["retriever"]["retriever_type"],
+ pl["retriever"]["retrieve_topk"],
+ pl["generator"]["generator_type"],
+ pl["generator"]["model"]["model_id"],
+ pl["generator"]["model"]["device"],
+ "",
+ pl["indexer"]["model"]["model_id"],
+ pl["indexer"]["model"]["device"],
+ )
+
+ def modify_create_pipeline_button():
+ return "Create Pipeline"
+
+ def modify_update_pipeline_button():
+ return "Update Pipeline"
+
+ def create_update_pipeline(
+ name,
+ active,
+ node_parser,
+ chunk_size,
+ chunk_overlap,
+ indexer,
+ retriever,
+ vector_search_top_k,
+ postprocessor,
+ generator,
+ llm_id,
+ llm_device,
+ llm_weights,
+ embedding_id,
+ embedding_device,
+ rerank_id,
+ rerank_device,
+ ):
+ res = cli.create_update_pipeline(
+ name,
+ active,
+ node_parser,
+ chunk_size,
+ chunk_overlap,
+ indexer,
+ retriever,
+ vector_search_top_k,
+ postprocessor,
+ generator,
+ llm_id,
+ llm_device,
+ llm_weights,
+ embedding_id,
+ embedding_device,
+ rerank_id,
+ rerank_device,
+ )
+ return res, get_pipeline_df()
+
+ # Events
+ u_pipelines.select(
+ show_pipeline_detail,
+ inputs=None,
+ outputs=[
+ u_pipeline_name,
+ u_active,
+ # node parser
+ u_node_parser,
+ u_chunk_size,
+ u_chunk_overlap,
+ # indexer
+ u_indexer,
+ # retriever
+ u_retriever,
+ u_vector_search_top_k,
+ # postprocessor
+ # u_postprocessor,
+ # generator
+ u_generator,
+ # models
+ u_llm_model_id,
+ u_llm_device,
+ u_llm_weights,
+ u_embed_model_id,
+ u_embed_device,
+ # u_rerank_model_id,
+ # u_rerank_device
+ ],
+ )
+
+ u_pipeline_name.input(modify_create_pipeline_button, inputs=None, outputs=rag_create_pipeline)
+
+ # Create pipeline button will change to update pipeline button if any
+ # of the listed fields changed
+ gr.on(
+ triggers=[
+ u_active.input,
+ # node parser
+ u_node_parser.input,
+ u_chunk_size.input,
+ u_chunk_overlap.input,
+ # indexer
+ u_indexer.input,
+ # retriever
+ u_retriever.input,
+ u_vector_search_top_k.input,
+ # postprocessor
+ u_postprocessor.input,
+ # generator
+ u_generator.input,
+ # models
+ u_llm_model_id.input,
+ u_llm_device.input,
+ u_llm_weights.input,
+ u_embed_model_id.input,
+ u_embed_device.input,
+ u_rerank_model_id.input,
+ u_rerank_device.input,
+ ],
+ fn=modify_update_pipeline_button,
+ inputs=None,
+ outputs=rag_create_pipeline,
+ )
+
+ rag_create_pipeline.click(
+ create_update_pipeline,
+ inputs=[
+ u_pipeline_name,
+ u_active,
+ u_node_parser,
+ u_chunk_size,
+ u_chunk_overlap,
+ u_indexer,
+ u_retriever,
+ u_vector_search_top_k,
+ u_postprocessor,
+ u_generator,
+ u_llm_model_id,
+ u_llm_device,
+ u_llm_weights,
+ u_embed_model_id,
+ u_embed_device,
+ u_rerank_model_id,
+ u_rerank_device,
+ ],
+ outputs=[u_rag_pipeline_status, u_pipelines],
+ queue=False,
+ )
+
+ rag_activate_pipeline.click(
+ cli.activate_pipeline,
+ inputs=[u_pipeline_name],
+ outputs=[u_rag_pipeline_status, u_active],
+ queue=False,
+ )
+
+ # --------------
+ # Chatbot Layout
+ # --------------
+ def get_files():
+ return cli.get_files()
+
+ def create_vectordb(docs, spliter, vector_db):
+ res = cli.create_vectordb(docs, spliter, vector_db)
+ return gr.update(value=get_files()), res
+
+ global u_files_selected_row
+ u_files_selected_row = None
+
+ def select_file(data, evt: gr.SelectData):
+ if not evt.selected or len(evt.index) == 0:
+ return "No file selected"
+ global u_files_selected_row
+ row_index = evt.index[0]
+ u_files_selected_row = data.iloc[row_index]
+ file_name, file_id = u_files_selected_row
+ return f"File Name: {file_name}\nFile ID: {file_id}"
+
+ def deselect_file():
+ global u_files_selected_row
+ u_files_selected_row = None
+ return gr.update(value=get_files()), "Selection cleared"
+
+ def delete_file():
+ global u_files_selected_row
+ if u_files_selected_row is None:
+ res = "Please select a file first."
+ else:
+ file_name, file_id = u_files_selected_row
+ u_files_selected_row = None
+ res = cli.delete_file(file_id)
+ return gr.update(value=get_files()), res
+
+ with gr.Tab("Chatbot"):
+ with gr.Row():
+ with gr.Column(scale=1):
+ docs = gr.File(
+ label="Step 1: Load text files",
+ file_count="multiple",
+ file_types=[
+ ".csv",
+ ".doc",
+ ".docx",
+ ".enex",
+ ".epub",
+ ".html",
+ ".md",
+ ".odt",
+ ".pdf",
+ ".ppt",
+ ".pptx",
+ ".txt",
+ ],
+ )
+ retriever_argument = gr.Accordion("Vector Store Configuration", open=False)
+ with retriever_argument:
+ spliter = gr.Dropdown(
+ ["Character", "RecursiveCharacter", "Markdown", "Chinese"],
+ value=cfg.splitter_name,
+ label="Text Spliter",
+ info="Method used to split the documents",
+ multiselect=False,
+ )
+
+ vector_db = gr.Dropdown(
+ ["FAISS", "Chroma"],
+ value=cfg.vector_db,
+ label="Vector Stores",
+ info="Stores embedded data and performs vector search.",
+ multiselect=False,
+ )
+ load_docs = gr.Button("Upload files")
+
+ u_files_status = gr.Textbox(label="File Processing Status", value="", interactive=False)
+ u_files = gr.Dataframe(
+ headers=["Loaded File Name", "File ID"],
+ value=get_files,
+ label="Loaded Files",
+ show_label=False,
+ interactive=False,
+ every=5,
+ )
+
+ with gr.Accordion("Delete File", open=False):
+ selected_files = gr.Textbox(label="Click file to select", value="", interactive=False)
+ with gr.Row():
+ with gr.Column():
+ delete_button = gr.Button("Delete Selected File")
+ with gr.Column():
+ deselect_button = gr.Button("Clear Selection")
+
+ do_rag = gr.Checkbox(
+ value=True,
+ label="RAG is ON",
+ interactive=True,
+ info="Whether to do RAG for generation",
+ )
+ with gr.Accordion("Generation Configuration", open=False):
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ temperature = gr.Slider(
+ label="Temperature",
+ value=0.1,
+ minimum=0.0,
+ maximum=1.0,
+ step=0.1,
+ interactive=True,
+ info="Higher values produce more diverse outputs",
+ )
+ with gr.Column():
+ with gr.Row():
+ top_p = gr.Slider(
+ label="Top-p (nucleus sampling)",
+ value=1.0,
+ minimum=0.0,
+ maximum=1,
+ step=0.01,
+ interactive=True,
+ info=(
+ "Sample from the smallest possible set of tokens whose cumulative probability "
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
+ ),
+ )
+ with gr.Column():
+ with gr.Row():
+ top_k = gr.Slider(
+ label="Top-k",
+ value=50,
+ minimum=0.0,
+ maximum=200,
+ step=1,
+ interactive=True,
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
+ )
+ with gr.Column():
+ with gr.Row():
+ repetition_penalty = gr.Slider(
+ label="Repetition Penalty",
+ value=1.1,
+ minimum=1.0,
+ maximum=2.0,
+ step=0.1,
+ interactive=True,
+ info="Penalize repetition — 1.0 to disable.",
+ )
+ with gr.Column(scale=4):
+ chatbot = gr.Chatbot(
+ height=600,
+ label="Step 2: Input Query",
+ show_copy_button=True,
+ )
+ with gr.Row():
+ with gr.Column():
+ msg = gr.Textbox(
+ label="QA Message Box",
+ placeholder="Chat Message Box",
+ show_label=False,
+ container=False,
+ )
+ with gr.Column():
+ with gr.Row():
+ submit = gr.Button("Submit")
+ stop = gr.Button("Stop")
+ clear = gr.Button("Clear")
+ retriever_argument = gr.Accordion("Retriever Configuration", open=True)
+ with retriever_argument:
+ with gr.Row():
+ with gr.Row():
+ do_rerank = gr.Checkbox(
+ value=True,
+ label="Rerank searching result",
+ interactive=True,
+ )
+ hide_context = gr.Checkbox(
+ value=True,
+ label="Hide searching result in prompt",
+ interactive=True,
+ )
+ with gr.Row():
+ search_method = gr.Dropdown(
+ ["similarity_score_threshold", "similarity", "mmr"],
+ value=cfg.search_method,
+ label="Searching Method",
+ info="Method used to search vector store",
+ multiselect=False,
+ interactive=True,
+ )
+ with gr.Row():
+ score_threshold = gr.Slider(
+ 0.01,
+ 0.99,
+ value=cfg.score_threshold,
+ step=0.01,
+ label="Similarity Threshold",
+ info="Only working for 'similarity score threshold' method",
+ interactive=True,
+ )
+ with gr.Row():
+ vector_rerank_top_n = gr.Slider(
+ 1,
+ 10,
+ value=cfg.k_rerank,
+ step=1,
+ label="Rerank top n",
+ info="Number of rerank results",
+ interactive=True,
+ )
+ load_docs.click(
+ create_vectordb,
+ inputs=[
+ docs,
+ spliter,
+ vector_db,
+ ],
+ outputs=[u_files, u_files_status],
+ queue=True,
+ )
+ # TODO: Need to de-select the dataframe,
+ # otherwise every time the dataframe is updated, a select event is triggered
+ u_files.select(select_file, inputs=[u_files], outputs=selected_files, queue=True)
+
+ delete_button.click(
+ delete_file,
+ outputs=[u_files, u_files_status],
+ queue=True,
+ )
+ deselect_button.click(
+ deselect_file,
+ outputs=[u_files, selected_files],
+ queue=True,
+ )
+
+ submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
+ bot,
+ [
+ chatbot,
+ temperature,
+ top_p,
+ top_k,
+ repetition_penalty,
+ hide_context,
+ do_rag,
+ docs,
+ spliter,
+ vector_db,
+ u_chunk_size,
+ u_chunk_overlap,
+ u_vector_search_top_k,
+ vector_rerank_top_n,
+ do_rerank,
+ search_method,
+ score_threshold,
+ ],
+ chatbot,
+ queue=True,
+ )
+ submit_click_event = submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
+ bot,
+ [
+ chatbot,
+ temperature,
+ top_p,
+ top_k,
+ repetition_penalty,
+ hide_context,
+ do_rag,
+ docs,
+ spliter,
+ vector_db,
+ u_chunk_size,
+ u_chunk_overlap,
+ u_vector_search_top_k,
+ vector_rerank_top_n,
+ do_rerank,
+ search_method,
+ score_threshold,
+ ],
+ chatbot,
+ queue=True,
+ )
+ # stop.click(
+ # fn=request_cancel,
+ # inputs=None,
+ # outputs=None,
+ # cancels=[submit_event, submit_click_event],
+ # queue=False,
+ # )
+ clear.click(lambda: None, None, chatbot, queue=False)
+ return demo
+
+
+def main():
+ # Create the parser
+ parser = argparse.ArgumentParser(description="Load Embedding and LLM Models with OpenVino.")
+ # Add the arguments
+ parser.add_argument("--prompt_template", type=str, required=False, help="User specific template")
+ # parser.add_argument("--server_name", type=str, default="0.0.0.0")
+ # parser.add_argument("--server_port", type=int, default=8082)
+ parser.add_argument("--config", type=str, default="./default.yaml", help="configuration file path")
+ parser.add_argument("--share", action="store_true", help="share model")
+ parser.add_argument("--debug", action="store_true", help="enable debugging")
+
+ # Execute the parse_args() method to collect command line arguments
+ args = parser.parse_args()
+ logger.info(args)
+ cfg = OmegaConf.load(args.config)
+ init_cfg_(cfg)
+ logger.info(cfg)
+
+ demo = build_demo(cfg, args)
+ # if you are launching remotely, specify server_name and server_port
+ # demo.launch(server_name='your server name', server_port='server port in int')
+ # if you have any issue to launch on your platform, you can pass share=True to launch method:
+ # demo.launch(share=True)
+ # it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/
+ # demo.launch(share=True)
+ demo.queue().launch(
+ server_name=UI_SERVICE_HOST_IP, server_port=UI_SERVICE_PORT, share=args.share, allowed_paths=["."]
+ )
+
+ # %%
+ # please run this cell for stopping gradio interface
+ demo.close()
+
+
+def init_cfg_(cfg):
+ if "name" not in cfg:
+ cfg.name = "default"
+ if "embedding_device" not in cfg:
+ cfg.embedding_device = "CPU"
+ if "rerank_device" not in cfg:
+ cfg.rerank_device = "CPU"
+ if "llm_device" not in cfg:
+ cfg.llm_device = "CPU"
+ if "model_language" not in cfg:
+ cfg.model_language = "Chinese"
+ if "vector_db" not in cfg:
+ cfg.vector_db = "FAISS"
+ if "splitter_name" not in cfg:
+ cfg.splitter_name = "RecursiveCharacter" # or "Chinese"
+ if "search_method" not in cfg:
+ cfg.search_method = "similarity"
+ if "score_threshold" not in cfg:
+ cfg.score_threshold = 0.5
+
+
+if __name__ == "__main__":
+ main()
diff --git a/EdgeCraftRAG/ui/gradio/platform_config.py b/EdgeCraftRAG/ui/gradio/platform_config.py
new file mode 100644
index 000000000..852409c1c
--- /dev/null
+++ b/EdgeCraftRAG/ui/gradio/platform_config.py
@@ -0,0 +1,114 @@
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import sys
+from enum import Enum
+
+import openvino.runtime as ov
+from config import SUPPORTED_EMBEDDING_MODELS, SUPPORTED_LLM_MODELS, SUPPORTED_RERANK_MODELS
+
+sys.path.append("..")
+from edgecraftrag.base import GeneratorType, IndexerType, NodeParserType, PostProcessorType, RetrieverType
+
+
+def _get_llm_model_ids(supported_models, model_language=None):
+ if model_language is None:
+ model_ids = [model_id for model_id, _ in supported_models.items()]
+ return model_ids
+
+ if model_language not in supported_models:
+ print("Invalid model language! Please choose from the available options.")
+ return None
+
+ # Create a list of model IDs based on the selected language
+ llm_model_ids = [
+ model_id
+ for model_id, model_config in supported_models[model_language].items()
+ if model_config.get("rag_prompt_template") or model_config.get("normalize_embeddings")
+ ]
+
+ return llm_model_ids
+
+
+def _list_subdirectories(parent_directory):
+ """List all subdirectories under the given parent directory using os.listdir.
+
+ Parameters:
+ parent_directory (str): The path to the parent directory from which to list subdirectories.
+
+ Returns:
+ list: A list of subdirectory names found in the parent directory.
+ """
+ # Get a list of all entries in the parent directory
+ entries = os.listdir(parent_directory)
+
+ # Filter out the entries to only keep directories
+ subdirectories = [entry for entry in entries if os.path.isdir(os.path.join(parent_directory, entry))]
+
+ return sorted(subdirectories)
+
+
+def _get_available_models(model_ids, local_dirs):
+ """Filters and sorts model IDs based on their presence in the local directories.
+
+ Parameters:
+ model_ids (list): A list of model IDs to check.
+ local_dirs (list): A list of local directory names to check against.
+
+ Returns:
+ list: A sorted list of available model IDs.
+ """
+ # Filter model_ids for those that are present in local directories
+ return sorted([model_id for model_id in model_ids if model_id in local_dirs])
+
+
+def get_local_available_models(model_type: str, local_path: str = "./"):
+ local_dirs = _list_subdirectories(local_path)
+ if model_type == "llm":
+ model_ids = _get_llm_model_ids(SUPPORTED_LLM_MODELS, "Chinese")
+ elif model_type == "embed":
+ model_ids = _get_llm_model_ids(SUPPORTED_EMBEDDING_MODELS, "Chinese")
+ elif model_type == "rerank":
+ model_ids = _get_llm_model_ids(SUPPORTED_RERANK_MODELS)
+ else:
+ print("Unknown model type")
+ avail_models = _get_available_models(model_ids, local_dirs)
+ return avail_models
+
+
+def get_available_devices():
+ core = ov.Core()
+ avail_devices = core.available_devices + ["AUTO"]
+ if "NPU" in avail_devices:
+ avail_devices.remove("NPU")
+ return avail_devices
+
+
+def get_available_weights():
+ avail_weights_compression = ["FP16", "INT8", "INT4"]
+ return avail_weights_compression
+
+
+def get_enum_values(c: Enum):
+ return [v.value for k, v in vars(c).items() if not callable(v) and not k.startswith("__") and not k.startswith("_")]
+
+
+def get_available_node_parsers():
+ return get_enum_values(NodeParserType)
+
+
+def get_available_indexers():
+ return get_enum_values(IndexerType)
+
+
+def get_available_retrievers():
+ return get_enum_values(RetrieverType)
+
+
+def get_available_postprocessors():
+ return get_enum_values(PostProcessorType)
+
+
+def get_available_generators():
+ return get_enum_values(GeneratorType)
diff --git a/FaqGen/docker_compose/intel/cpu/xeon/README.md b/FaqGen/docker_compose/intel/cpu/xeon/README.md
index 04fea0f85..c512621b0 100644
--- a/FaqGen/docker_compose/intel/cpu/xeon/README.md
+++ b/FaqGen/docker_compose/intel/cpu/xeon/README.md
@@ -114,9 +114,11 @@ docker compose up -d
3. MegaService
```bash
- curl http://${host_ip}:8888/v1/faqgen -H "Content-Type: application/json" -d '{
- "messages": "Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."
- }'
+ curl http://${host_ip}:8888/v1/faqgen \
+ -H "Content-Type: multipart/form-data" \
+ -F "messages=Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5." \
+ -F "max_tokens=32" \
+ -F "stream=false"
```
Following the validation of all aforementioned microservices, we are now prepared to construct a mega-service.
diff --git a/FaqGen/docker_compose/intel/hpu/gaudi/README.md b/FaqGen/docker_compose/intel/hpu/gaudi/README.md
index acdded9c2..548a94e16 100644
--- a/FaqGen/docker_compose/intel/hpu/gaudi/README.md
+++ b/FaqGen/docker_compose/intel/hpu/gaudi/README.md
@@ -28,7 +28,7 @@ To construct the Mega Service, we utilize the [GenAIComps](https://github.com/op
```bash
git clone https://github.com/opea-project/GenAIExamples
-cd GenAIExamples/FaqGen/docker/
+cd GenAIExamples/FaqGen/
docker build --no-cache -t opea/faqgen:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f Dockerfile .
```
@@ -37,7 +37,7 @@ docker build --no-cache -t opea/faqgen:latest --build-arg https_proxy=$https_pro
Construct the frontend Docker image using the command below:
```bash
-cd GenAIExamples/FaqGen/
+cd GenAIExamples/FaqGen/ui
docker build -t opea/faqgen-ui:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f ./docker/Dockerfile .
```
@@ -115,9 +115,11 @@ docker compose up -d
3. MegaService
```bash
- curl http://${host_ip}:8888/v1/faqgen -H "Content-Type: application/json" -d '{
- "messages": "Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."
- }'
+ curl http://${host_ip}:8888/v1/faqgen \
+ -H "Content-Type: multipart/form-data" \
+ -F "messages=Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5." \
+ -F "max_tokens=32" \
+ -F "stream=false"
```
## 🚀 Launch the UI
diff --git a/FaqGen/tests/test_compose_on_gaudi.sh b/FaqGen/tests/test_compose_on_gaudi.sh
index a58339780..161c1e2a7 100644
--- a/FaqGen/tests/test_compose_on_gaudi.sh
+++ b/FaqGen/tests/test_compose_on_gaudi.sh
@@ -101,13 +101,30 @@ function validate_microservices() {
}
function validate_megaservice() {
- # Curl the Mega Service
- validate_services \
- "${ip_address}:8888/v1/faqgen" \
- "Text Embeddings Inference" \
- "mega-faqgen" \
- "faqgen-gaudi-backend-server" \
- '{"messages": "Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."}'
+ local SERVICE_NAME="mega-faqgen"
+ local DOCKER_NAME="faqgen-gaudi-backend-server"
+ local EXPECTED_RESULT="Embeddings"
+ local INPUT_DATA="messages=Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."
+ local URL="${ip_address}:8888/v1/faqgen"
+ local HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -F "$INPUT_DATA" -H 'Content-Type: multipart/form-data' "$URL")
+ if [ "$HTTP_STATUS" -eq 200 ]; then
+ echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..."
+
+ local CONTENT=$(curl -s -X POST -F "$INPUT_DATA" -H 'Content-Type: multipart/form-data' "$URL" | tee ${LOG_PATH}/${SERVICE_NAME}.log)
+
+ if echo "$CONTENT" | grep -q "$EXPECTED_RESULT"; then
+ echo "[ $SERVICE_NAME ] Content is as expected."
+ else
+ echo "[ $SERVICE_NAME ] Content does not match the expected result: $CONTENT"
+ docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log
+ exit 1
+ fi
+ else
+ echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS"
+ docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log
+ exit 1
+ fi
+ sleep 1s
}
function validate_frontend() {
@@ -152,7 +169,7 @@ function main() {
validate_microservices
validate_megaservice
- validate_frontend
+ # validate_frontend
stop_docker
echo y | docker system prune
diff --git a/FaqGen/tests/test_compose_on_xeon.sh b/FaqGen/tests/test_compose_on_xeon.sh
index c6265e02d..e9ed4bf1e 100755
--- a/FaqGen/tests/test_compose_on_xeon.sh
+++ b/FaqGen/tests/test_compose_on_xeon.sh
@@ -101,13 +101,30 @@ function validate_microservices() {
}
function validate_megaservice() {
- # Curl the Mega Service
- validate_services \
- "${ip_address}:8888/v1/faqgen" \
- "Text Embeddings Inference" \
- "mega-faqgen" \
- "faqgen-xeon-backend-server" \
- '{"messages": "Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."}'
+ local SERVICE_NAME="mega-faqgen"
+ local DOCKER_NAME="faqgen-xeon-backend-server"
+ local EXPECTED_RESULT="Embeddings"
+ local INPUT_DATA="messages=Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."
+ local URL="${ip_address}:8888/v1/faqgen"
+ local HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -F "$INPUT_DATA" -H 'Content-Type: multipart/form-data' "$URL")
+ if [ "$HTTP_STATUS" -eq 200 ]; then
+ echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..."
+
+ local CONTENT=$(curl -s -X POST -F "$INPUT_DATA" -H 'Content-Type: multipart/form-data' "$URL" | tee ${LOG_PATH}/${SERVICE_NAME}.log)
+
+ if echo "$CONTENT" | grep -q "$EXPECTED_RESULT"; then
+ echo "[ $SERVICE_NAME ] Content is as expected."
+ else
+ echo "[ $SERVICE_NAME ] Content does not match the expected result: $CONTENT"
+ docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log
+ exit 1
+ fi
+ else
+ echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS"
+ docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log
+ exit 1
+ fi
+ sleep 1s
}
function validate_frontend() {
@@ -152,7 +169,7 @@ function main() {
validate_microservices
validate_megaservice
- validate_frontend
+ # validate_frontend
stop_docker
echo y | docker system prune
diff --git a/FaqGen/ui/svelte/.env b/FaqGen/ui/svelte/.env
index bfdca1c9a..4d0880c76 100644
--- a/FaqGen/ui/svelte/.env
+++ b/FaqGen/ui/svelte/.env
@@ -1 +1 @@
-DOC_BASE_URL = 'http://backend_address:8888/v1/faqgen'
+FAQ_BASE_URL = 'http://backend_address:8888/v1/faqgen'
diff --git a/FaqGen/ui/svelte/src/lib/doc.svelte b/FaqGen/ui/svelte/src/lib/doc.svelte
index bae896ba3..f9ea33584 100644
--- a/FaqGen/ui/svelte/src/lib/doc.svelte
+++ b/FaqGen/ui/svelte/src/lib/doc.svelte
@@ -38,8 +38,8 @@
} else {
currentIdx = index;
if (
- (currentIdx === 1 && message !== "") ||
- (currentIdx === 2 && $kb_id !== "")
+ (currentIdx === 2 && message !== "") ||
+ (currentIdx === 1 && $kb_id !== "")
) {
formModal = true;
} else {
@@ -49,10 +49,10 @@
}
function panelExchange() {
- if (currentIdx === 2) {
+ if (currentIdx === 1) {
kb_id.set("");
dispatch("clearMsg", { status: true });
- } else if (currentIdx === 1) {
+ } else if (currentIdx === 2) {
message = "";
dispatch("clearMsg", { status: true });
}
@@ -152,7 +152,7 @@
type="submit"
data-testid="sum-click"
class="xl:my-12 inline-flex items-center px-5 py-2.5 text-sm font-medium text-center text-white bg-blue-700 mt-2 focus:ring-4 focus:ring-blue-200 dark:focus:ring-blue-900 hover:bg-blue-800"
- on:click={() => generateFaq()}
+ on:click={() => generateFaq()}
>
Generate FAQs
@@ -165,11 +165,11 @@
/>
{#if currentIdx === 1}