diff --git a/comps/dataprep/neo4j/llama_index/config.py b/comps/dataprep/neo4j/llama_index/config.py index 3037b8f9f..42c1abe0d 100644 --- a/comps/dataprep/neo4j/llama_index/config.py +++ b/comps/dataprep/neo4j/llama_index/config.py @@ -16,3 +16,7 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") OPENAI_LLM_MODEL = os.getenv("OPENAI_LLM_MODEL", "gpt-4o") + +LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "meta-llama/Meta-Llama-3.1-70B-Instruct") +MAX_INPUT_LEN = os.getenv("MAX_INPUT_LEN", "8192") +MAX_OUTPUT_TOKENS = os.getenv("MAX_OUTPUT_TOKENS", "1024") diff --git a/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py b/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py index a7ece023f..c1732b732 100644 --- a/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py +++ b/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py @@ -16,6 +16,9 @@ import openai import requests from config import ( + LLM_MODEL_ID, + MAX_INPUT_LEN, + MAX_OUTPUT_TOKENS, NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME, @@ -37,6 +40,7 @@ from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore from llama_index.llms.openai import OpenAI +from llama_index.llms.openai_like import OpenAILike from llama_index.llms.text_generation_inference import TextGenerationInference from neo4j import GraphDatabase from openai import Client @@ -54,6 +58,9 @@ nest_asyncio.apply() +import time +import traceback + from llama_index.core.async_utils import run_jobs from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.graph_stores.types import KG_NODES_KEY, KG_RELATIONS_KEY, EntityNode, Relation @@ -71,19 +78,17 @@ class GraphRAGStore(Neo4jPropertyGraphStore): max_cluster_size = 100 def __init__(self, username: str, password: str, url: str, llm: LLM): - super().__init__(username=username, password=password, url=url) + super().__init__(username=username, password=password, url=url, refresh_schema=False) self.llm = llm self.driver = GraphDatabase.driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) - def generate_community_summary(self, text): + async def generate_community_summary(self, text): """Generate summary for a given text using an LLM.""" - # Get model information from the TGI endpoint - model_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id") - max_input_length = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "max_input_length") + model_name = LLM_MODEL_ID + max_input_length = int(MAX_INPUT_LEN) if not model_name or not max_input_length: raise ValueError(f"Could not retrieve model information from TGI endpoint: {TGI_LLM_ENDPOINT}") - # Get the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [ @@ -105,14 +110,14 @@ def generate_community_summary(self, text): trimmed_messages = trim_messages_to_token_limit(tokenizer, messages, max_input_length) if OPENAI_API_KEY: - response = OpenAI().chat(messages) + response = OpenAI().achat(messages) else: - response = self.llm.chat(trimmed_messages) + response = self.llm.achat(trimmed_messages) clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip() return clean_response - def build_communities(self): + async def build_communities(self): """Builds communities from the graph and summarizes them.""" nx_graph = self._create_nx_graph() community_hierarchical_clusters = hierarchical_leiden(nx_graph, max_cluster_size=self.max_cluster_size) @@ -124,7 +129,7 @@ def build_communities(self): # self._print_cluster_info(self.entity_info, community_info) self.save_entity_info(self.entity_info) # entity_from_db = self.read_entity_info() # to verify if the data is stored in db - self._summarize_communities(community_info) + await self._summarize_communities(community_info) # sum = self.read_all_community_summaries() # to verify summaries are stored in db def _create_nx_graph(self): @@ -232,29 +237,38 @@ def read_entity_info(self) -> dict: entity_info[record["entity_id"]] = [int(cluster_id) for cluster_id in record["cluster_ids"]] return entity_info - def _summarize_communities(self, community_info): + async def _summarize_communities(self, community_info, num_workers=5): """Generate and store summaries for each community.""" + # Run tasks concurrently with a limited number of workers + tasks = [] for community_id, details in community_info.items(): logger.info(f"Summarizing community {community_id}") details_text = "\n".join(details) + "." # Ensure it ends with a period - self.community_summary[community_id] = self.generate_community_summary(details_text) + tasks.append(self._process_community(community_id, details_text)) + await run_jobs( + tasks, + workers=num_workers, + show_progress=True, + desc="Summarize communities", + ) - # To store summaries in neo4j - summary = self.generate_community_summary(details_text) - self.store_community_summary_in_neo4j(community_id, summary) - # self.community_summary[ - # community_id - # ] = self.store_community_summary_in_neo4j(community_id, summary) + async def _process_community(self, community_id, details_text): + """Process a single community and store the summary.""" + summary = await self.generate_community_summary(details_text) + self.community_summary[community_id] = summary + self.store_community_summary_in_neo4j(community_id, summary) def store_community_summary_in_neo4j(self, community_id, summary): """Store the community summary in Neo4j.""" + logger.info(f"Community_id: {community_id} type: {type(community_id)}") with self.driver.session() as session: session.run( """ - MERGE (c:Cluster {id: $community_id}) - SET c.summary = $summary + MATCH (c:Cluster {id: $community_id, name: $community_name}) + SET c.summary = $summary """, - community_id=int(community_id), + community_id=str(community_id), + community_name=str(community_id), summary=summary, ) @@ -347,7 +361,7 @@ def __call__(self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: async def _aextract(self, node: BaseNode) -> BaseNode: """Extract triples from a node.""" assert hasattr(node, "text") - + starttime = time.time() text = node.get_content(metadata_mode="llm") try: llm_response = await self.llm.apredict( @@ -359,7 +373,9 @@ async def _aextract(self, node: BaseNode) -> BaseNode: except ValueError: entities = [] entities_relationship = [] + logger.info(f"Time taken to LLM and parse: {time.time() - starttime}") + starttime = time.time() existing_nodes = node.metadata.pop(KG_NODES_KEY, []) existing_relations = node.metadata.pop(KG_RELATIONS_KEY, []) entity_metadata = node.metadata.copy() @@ -383,6 +399,7 @@ async def _aextract(self, node: BaseNode) -> BaseNode: node.metadata[KG_NODES_KEY] = existing_nodes node.metadata[KG_RELATIONS_KEY] = existing_relations + logger.info(f"Time taken to process entities and relations: {time.time() - starttime}") logger.info(f"number of extracted nodes {len(existing_nodes), existing_nodes}") logger.info(f"number of extracted relations {len(existing_relations), existing_relations}") return node @@ -469,19 +486,24 @@ def trim_messages_to_token_limit(tokenizer, messages, max_tokens): """Trim the messages to fit within the token limit.""" total_tokens = 0 trimmed_messages = [] + buffer = 100 + effective_max_tokens = max_tokens - buffer for message in messages: tokens = tokenizer.tokenize(message.content) - total_tokens += len(tokens) - if total_tokens > max_tokens: + message_token_count = len(tokens) + if total_tokens + message_token_count > effective_max_tokens: # Trim the message to fit within the remaining token limit - logger.info(f"Trimming messages: {total_tokens} > {max_tokens}") - remaining_tokens = max_tokens - (total_tokens - len(tokens)) + logger.info(f"Trimming messages: {total_tokens + message_token_count} > {effective_max_tokens}") + logger.info(f"message_token_count: {message_token_count}") + remaining_tokens = effective_max_tokens - total_tokens + logger.info(f"remaining_tokens: {remaining_tokens}") tokens = tokens[:remaining_tokens] message.content = tokenizer.convert_tokens_to_string(tokens) trimmed_messages.append(message) break else: + total_tokens += message_token_count trimmed_messages.append(message) return trimmed_messages @@ -493,9 +515,66 @@ def trim_messages_to_token_limit(tokenizer, messages, max_tokens): upload_folder = "./uploaded_files/" client = OpenAI() +# Global variables to store the initialized objects +llm = None +embed_model = None +graph_store = None +kg_extractor = None +initialized = False + + +def initialize_graph_store_and_models(): + global llm, embed_model, graph_store, kg_extractor, initialized + + if OPENAI_API_KEY: + logger.info("OpenAI API Key is set. Verifying its validity...") + openai.api_key = OPENAI_API_KEY + try: + llm = OpenAI(temperature=0, model=OPENAI_LLM_MODEL) + embed_model = OpenAIEmbedding(model=OPENAI_EMBEDDING_MODEL, embed_batch_size=100) + logger.info("OpenAI API Key is valid.") + except openai.AuthenticationError: + logger.info("OpenAI API Key is invalid.") + except Exception as e: + logger.info(f"An error occurred while verifying the API Key: {e}") + else: + logger.info("NO OpenAI API Key. TGI/VLLM/TEI endpoints will be used.") + # works with TGI and VLLM endpoints + llm = OpenAILike( + model=LLM_MODEL_ID, + api_base=TGI_LLM_ENDPOINT + "/v1", + api_key="fake", + temperature=0.7, + max_tokens=MAX_OUTPUT_TOKENS, # 1512 + timeout=1200, # timeout in seconds) + ) + emb_name = get_attribute_from_tgi_endpoint(TEI_EMBEDDING_ENDPOINT, "model_id") + embed_model = TextEmbeddingsInference( + base_url=TEI_EMBEDDING_ENDPOINT, + model_name=emb_name, + timeout=1200, # timeout in seconds + embed_batch_size=10, # batch size for embedding + ) + Settings.embed_model = embed_model + Settings.llm = llm + kg_extractor = GraphRAGExtractor( + llm=llm, + extract_prompt=KG_TRIPLET_EXTRACT_TMPL, + max_paths_per_chunk=2, + parse_fn=parse_fn, + ) + graph_store = GraphRAGStore(username=NEO4J_USERNAME, password=NEO4J_PASSWORD, url=NEO4J_URL, llm=llm) + initialized = True + def ingest_data_to_neo4j(doc_path: DocPath): """Ingest document to Neo4J.""" + global initialized + if not initialized: + starttime = time.time() + initialize_graph_store_and_models() + logger.info(f"Time taken to initialize: {time.time() - starttime}") + path = doc_path.path if logflag: logger.info(f"Parsing document {path}.") @@ -539,44 +618,7 @@ def ingest_data_to_neo4j(doc_path: DocPath): if logflag: logger.info(f"Done preprocessing. Created {len(nodes)} chunks of the original file.") - if OPENAI_API_KEY: - logger.info("OpenAI API Key is set. Verifying its validity...") - openai.api_key = OPENAI_API_KEY - try: - llm = OpenAI(temperature=0, model=OPENAI_LLM_MODEL) - embed_model = OpenAIEmbedding(model=OPENAI_EMBEDDING_MODEL, embed_batch_size=100) - logger.info("OpenAI API Key is valid.") - except openai.AuthenticationError: - logger.info("OpenAI API Key is invalid.") - except Exception as e: - logger.info(f"An error occurred while verifying the API Key: {e}") - else: - logger.info("NO OpenAI API Key. TGI/TEI endpoints will be used.") - llm_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id") - llm = TextGenerationInference( - model_url=TGI_LLM_ENDPOINT, - model_name=llm_name, - temperature=0.7, - max_tokens=1512, - timeout=600, # timeout in seconds - ) - emb_name = get_attribute_from_tgi_endpoint(TEI_EMBEDDING_ENDPOINT, "model_id") - embed_model = TextEmbeddingsInference( - base_url=TEI_EMBEDDING_ENDPOINT, - model_name=emb_name, - timeout=600, # timeout in seconds - embed_batch_size=10, # batch size for embedding - ) - Settings.embed_model = embed_model - Settings.llm = llm - kg_extractor = GraphRAGExtractor( - llm=llm, - extract_prompt=KG_TRIPLET_EXTRACT_TMPL, - max_paths_per_chunk=2, - parse_fn=parse_fn, - ) - graph_store = GraphRAGStore(username=NEO4J_USERNAME, password=NEO4J_PASSWORD, url=NEO4J_URL, llm=llm) - + starttime = time.time() # nodes are the chunked docs to insert index = PropertyGraphIndex( nodes=nodes, @@ -588,7 +630,8 @@ def ingest_data_to_neo4j(doc_path: DocPath): ) if logflag: logger.info("The graph is built.") - logger.info(f"Total number of triplets {len(index.property_graph_store.get_triplets())}") + logger.info(f"Time taken to update PropertyGraphIndex: {time.time() - starttime}") + # logger.info(f"Total number of triplets {len(index.property_graph_store.get_triplets())}") if logflag: logger.info("Done building communities.") @@ -596,13 +639,15 @@ def ingest_data_to_neo4j(doc_path: DocPath): return index -def build_communities(index: PropertyGraphIndex): +async def build_communities(index: PropertyGraphIndex): try: - index.property_graph_store.build_communities() + await index.property_graph_store.build_communities() if logflag: logger.info("Done building communities.") except Exception as e: logger.error(f"Error building communities: {e}") + error_trace = traceback.format_exc() + logger.error(f"Error building communities: {e}\n{error_trace}") return True @@ -621,42 +666,30 @@ async def ingest_documents( chunk_overlap: int = Form(100), process_table: bool = Form(False), table_strategy: str = Form("fast"), + skip_ingestion: bool = Form(False), ): if logflag: logger.info(f"files:{files}") logger.info(f"link_list:{link_list}") - - if files: - if not isinstance(files, list): - files = [files] - uploaded_files = [] - for file in files: - encode_file = encode_filename(file.filename) - save_path = upload_folder + encode_file - await save_content_to_local_disk(save_path, file) - index = ingest_data_to_neo4j( - DocPath( - path=save_path, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - process_table=process_table, - table_strategy=table_strategy, - ) - ) - uploaded_files.append(save_path) - if logflag: - logger.info(f"Successfully saved file {save_path}") - - if link_list: - link_list = json.loads(link_list) # Parse JSON string to list - if not isinstance(link_list, list): - raise HTTPException(status_code=400, detail="link_list should be a list.") - for link in link_list: - encoded_link = encode_filename(link) - save_path = upload_folder + encoded_link + ".txt" - content = parse_html_new([link], chunk_size=chunk_size, chunk_overlap=chunk_overlap) - try: - await save_content_to_local_disk(save_path, content) + logger.info(f"skip_ingestion:{skip_ingestion}") + + if skip_ingestion: + initialize_graph_store_and_models() + index = PropertyGraphIndex.from_existing( + property_graph_store=graph_store, + embed_model=embed_model or Settings.embed_model, + embed_kg_nodes=True, + ) + else: + if files: + if not isinstance(files, list): + files = [files] + uploaded_files = [] + for file in files: + encode_file = encode_filename(file.filename) + save_path = upload_folder + encode_file + await save_content_to_local_disk(save_path, file) + starttime = time.time() index = ingest_data_to_neo4j( DocPath( path=save_path, @@ -666,14 +699,38 @@ async def ingest_documents( table_strategy=table_strategy, ) ) - except json.JSONDecodeError: - raise HTTPException(status_code=500, detail="Fail to ingest data") + logger.info(f"Time taken to ingest file:{encode_file} {time.time() - starttime}") + uploaded_files.append(save_path) + if logflag: + logger.info(f"Successfully saved file {save_path}") + + if link_list: + link_list = json.loads(link_list) # Parse JSON string to list + if not isinstance(link_list, list): + raise HTTPException(status_code=400, detail="link_list should be a list.") + for link in link_list: + encoded_link = encode_filename(link) + save_path = upload_folder + encoded_link + ".txt" + content = parse_html_new([link], chunk_size=chunk_size, chunk_overlap=chunk_overlap) + try: + await save_content_to_local_disk(save_path, content) + index = ingest_data_to_neo4j( + DocPath( + path=save_path, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + process_table=process_table, + table_strategy=table_strategy, + ) + ) + except json.JSONDecodeError: + raise HTTPException(status_code=500, detail="Fail to ingest data") - if logflag: - logger.info(f"Successfully saved link {link}") + if logflag: + logger.info(f"Successfully saved link {link}") - if files or link_list: - build_communities(index) + if files or link_list or skip_ingestion: + await build_communities(index) result = {"status": 200, "message": "Data preparation succeeded"} if logflag: logger.info(result) diff --git a/comps/dataprep/neo4j/llama_index/requirements.txt b/comps/dataprep/neo4j/llama_index/requirements.txt index c183ecf3d..5f0631bc2 100644 --- a/comps/dataprep/neo4j/llama_index/requirements.txt +++ b/comps/dataprep/neo4j/llama_index/requirements.txt @@ -15,9 +15,10 @@ langchain_community llama-index llama-index-core llama-index-embeddings-text-embeddings-inference +llama-index-graph-stores-neo4j llama-index-llms-openai +llama-index-llms-openai-like llama-index-llms-text-generation-inference -llama_index_graph_stores_neo4j==0.3.3 markdown neo4j numpy diff --git a/comps/retrievers/neo4j/llama_index/config.py b/comps/retrievers/neo4j/llama_index/config.py index 3037b8f9f..42c1abe0d 100644 --- a/comps/retrievers/neo4j/llama_index/config.py +++ b/comps/retrievers/neo4j/llama_index/config.py @@ -16,3 +16,7 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") OPENAI_LLM_MODEL = os.getenv("OPENAI_LLM_MODEL", "gpt-4o") + +LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "meta-llama/Meta-Llama-3.1-70B-Instruct") +MAX_INPUT_LEN = os.getenv("MAX_INPUT_LEN", "8192") +MAX_OUTPUT_TOKENS = os.getenv("MAX_OUTPUT_TOKENS", "1024") diff --git a/comps/retrievers/neo4j/llama_index/requirements.txt b/comps/retrievers/neo4j/llama_index/requirements.txt index c91f71ba7..0b65432a6 100644 --- a/comps/retrievers/neo4j/llama_index/requirements.txt +++ b/comps/retrievers/neo4j/llama_index/requirements.txt @@ -12,9 +12,10 @@ langchain-community llama-index-core llama-index-embeddings-openai llama-index-embeddings-text-embeddings-inference +llama-index-graph-stores-neo4j llama-index-llms-openai +llama-index-llms-openai-like llama-index-llms-text-generation-inference -llama_index_graph_stores_neo4j==0.3.3 neo4j numpy opencv-python diff --git a/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py b/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py index 830dc2775..e7d0cd650 100644 --- a/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py +++ b/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py @@ -9,6 +9,8 @@ import openai from config import ( + LLM_MODEL_ID, + MAX_OUTPUT_TOKENS, NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME, @@ -25,6 +27,7 @@ from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference from llama_index.llms.openai import OpenAI +from llama_index.llms.openai_like import OpenAILike from llama_index.llms.text_generation_inference import TextGenerationInference from neo4j import GraphDatabase from pydantic import BaseModel, PrivateAttr @@ -67,21 +70,21 @@ def __init__(self, graph_store: GraphRAGStore, llm: LLM, index: PropertyGraphInd self._llm = llm self._similarity_top_k = similarity_top_k - def custom_query(self, query_str: str) -> RetrievalResponseData: + def custom_query(self, query_str: str, batch_size: int = 16) -> RetrievalResponseData: """Process all community summaries to generate answers to a specific query.""" entities = self.get_entities(query_str, self._similarity_top_k) - entity_info = self._graph_store.read_entity_info() - community_ids = self.retrieve_entity_communities(entity_info, entities) community_summaries = self.retrieve_community_summaries_cypher(entities) community_ids = list(community_summaries.keys()) if logflag: logger.info(f"Community ids: {community_ids}") - # community_summaries of relevant communities - community_answers = [ - self.generate_answer_from_summary(community_summary, query_str) - for id, community_summary in community_summaries.items() - ] + # Process community summaries in batches + community_answers = [] + for i in range(0, len(community_ids), batch_size): + batch_ids = community_ids[i : i + batch_size] + batch_summaries = {community_id: community_summaries[community_id] for community_id in batch_ids} + batch_answers = self.generate_batch_answers_from_summaries(batch_summaries, query_str) + community_answers.extend(batch_answers) # Convert answers to RetrievalResponseData objects response_data = [RetrievalResponseData(text=answer, metadata={}) for answer in community_answers] return response_data @@ -104,7 +107,7 @@ def get_entities(self, query_str, similarity_top_k): entities = set() pattern = r"(\w+(?:\s+\w+)*)\s*->\s*(\w+(?:\s+\w+)*)\s*->\s*(\w+(?:\s+\w+)*)" if logflag: - logger.info(f" len of triplets {len(self._index.property_graph_store.get_triplets())}") + # logger.info(f" len of triplets {len(self._index.property_graph_store.get_triplets())}") logger.info(f"number of nodes retrieved {len(nodes_retrieved), nodes_retrieved}") for node in nodes_retrieved: matches = re.findall(pattern, node.text, re.DOTALL) @@ -179,26 +182,53 @@ def generate_answer_from_summary(self, community_summary, query): cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip() return cleaned_response - -@register_microservice( - name="opea_service@retriever_community_answers_neo4j", - service_type=ServiceType.RETRIEVER, - endpoint="/v1/retrieval", - host="0.0.0.0", - port=6009, -) -@register_statistics(names=["opea_service@retriever_community_answers_neo4j"]) -async def retrieve(input: Union[ChatCompletionRequest]) -> Union[ChatCompletionRequest]: - if logflag: - logger.info(input) - start = time.time() - - if isinstance(input.messages, str): - query = input.messages - else: - query = input.messages[0]["content"] - logger.info(f"Query received in retriever: {query}") - + def generate_batch_answers_from_summaries(self, batch_summaries, query): + """Generate answers from a batch of community summaries based on a given query using LLM.""" + batch_prompts = [] + for community_id, summary in batch_summaries.items(): + prompt = ( + f"Given the community summary: {summary}, " f"how would you answer the following query? Query: {query}" + ) + messages = [ + ChatMessage(role="system", content=prompt), + ChatMessage( + role="user", + content="I need an answer based on the above information.", + ), + ] + batch_prompts.append((community_id, messages)) + + # Generate answers for the batch + answers = self.generate_batch_responses(batch_prompts) + return answers + + def generate_batch_responses(self, batch_prompts): + """Generate responses for a batch of prompts using LLM.""" + responses = {} + messages = [messages for _, messages in batch_prompts] + + # Generate responses for the batch + if OPENAI_API_KEY: + batch_responses = [OpenAI().chat(message) for message in messages] + else: + batch_responses = [self._llm.chat(message) for message in messages] + + for (community_id, _), response in zip(batch_prompts, batch_responses): + cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip() + responses[community_id] = cleaned_response + + return [responses[community_id] for community_id, _ in batch_prompts] + + +# Global variables to store the graph_store and index +graph_store = None +query_engine = None +index = None +initialized = False + + +async def initialize_graph_store_and_index(): + global graph_store, index, initialized, query_engine if OPENAI_API_KEY: logger.info("OpenAI API Key is set. Verifying its validity...") openai.api_key = OPENAI_API_KEY @@ -211,31 +241,41 @@ async def retrieve(input: Union[ChatCompletionRequest]) -> Union[ChatCompletionR except Exception as e: logger.info(f"An error occurred while verifying the API Key: {e}") else: - logger.info("No OpenAI API KEY provided. Will use TGI and TEI endpoints") - llm_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id") - llm = TextGenerationInference( - model_url=TGI_LLM_ENDPOINT, - model_name=llm_name, + logger.info("No OpenAI API KEY provided. Will use TGI/VLLM and TEI endpoints") + # llm_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id") + # works w VLLM too + llm = OpenAILike( + model=LLM_MODEL_ID, + api_base=TGI_LLM_ENDPOINT + "/v1", + api_key="fake", + timeout=600, temperature=0.7, - max_tokens=1512, # 512otherwise too shor + max_tokens=MAX_OUTPUT_TOKENS, ) emb_name = get_attribute_from_tgi_endpoint(TEI_EMBEDDING_ENDPOINT, "model_id") embed_model = TextEmbeddingsInference( base_url=TEI_EMBEDDING_ENDPOINT, model_name=emb_name, - timeout=60, # timeout in seconds + timeout=1200, # timeout in seconds embed_batch_size=10, # batch size for embedding ) Settings.embed_model = embed_model Settings.llm = llm + + logger.info("Creating graph store from existing...") + starttime = time.time() # pre-existiing graph store (created with data_prep/llama-index/extract_graph_neo4j.py) graph_store = GraphRAGStore(username=NEO4J_USERNAME, password=NEO4J_PASSWORD, url=NEO4J_URL, llm=llm) + logger.info(f"Time to create graph store: {time.time() - starttime:.2f} seconds") + logger.info("Creating index from existing...") + starttime = time.time() index = PropertyGraphIndex.from_existing( property_graph_store=graph_store, embed_model=embed_model or Settings.embed_model, embed_kg_nodes=True, ) + logger.info(f"Time to create index: {time.time() - starttime:.2f} seconds") query_engine = GraphRAGQueryEngine( graph_store=index.property_graph_store, @@ -243,6 +283,29 @@ async def retrieve(input: Union[ChatCompletionRequest]) -> Union[ChatCompletionR index=index, similarity_top_k=3, ) + initialized = True + + +@register_microservice( + name="opea_service@retriever_community_answers_neo4j", + service_type=ServiceType.RETRIEVER, + endpoint="/v1/retrieval", + host="0.0.0.0", + port=6009, +) +@register_statistics(names=["opea_service@retriever_community_answers_neo4j"]) +async def retrieve(input: Union[ChatCompletionRequest]) -> Union[ChatCompletionRequest]: + if logflag: + logger.info(input) + start = time.time() + + if isinstance(input.messages, str): + query = input.messages + else: + query = input.messages[0]["content"] + logger.info(f"Query received in retriever: {query}") + if not initialized: + await initialize_graph_store_and_index() # these are the answers from the community summaries answers_by_community = query_engine.query(query)