From bc64208abfa6964c8dfd203f1323ba22fb3f6c05 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 06:13:22 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../neo4j/llama_index/extract_graph_neo4j.py | 26 +++++++++---------- comps/retrievers/neo4j/llama_index/config.py | 2 +- .../retriever_community_answers_neo4j.py | 19 +++++++------- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py b/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py index 9c166ea94..c1732b732 100644 --- a/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py +++ b/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py @@ -16,6 +16,9 @@ import openai import requests from config import ( + LLM_MODEL_ID, + MAX_INPUT_LEN, + MAX_OUTPUT_TOKENS, NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME, @@ -25,9 +28,6 @@ TEI_EMBEDDING_ENDPOINT, TGI_LLM_ENDPOINT, host_ip, - LLM_MODEL_ID, - MAX_INPUT_LEN, - MAX_OUTPUT_TOKENS, ) from fastapi import File, Form, HTTPException, UploadFile from graspologic.partition import hierarchical_leiden @@ -40,8 +40,8 @@ from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore from llama_index.llms.openai import OpenAI -from llama_index.llms.text_generation_inference import TextGenerationInference from llama_index.llms.openai_like import OpenAILike +from llama_index.llms.text_generation_inference import TextGenerationInference from neo4j import GraphDatabase from openai import Client from transformers import AutoTokenizer @@ -85,7 +85,7 @@ def __init__(self, username: str, password: str, url: str, llm: LLM): async def generate_community_summary(self, text): """Generate summary for a given text using an LLM.""" model_name = LLM_MODEL_ID - max_input_length=int(MAX_INPUT_LEN) + max_input_length = int(MAX_INPUT_LEN) if not model_name or not max_input_length: raise ValueError(f"Could not retrieve model information from TGI endpoint: {TGI_LLM_ENDPOINT}") @@ -116,7 +116,6 @@ async def generate_community_summary(self, text): clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip() return clean_response - async def build_communities(self): """Builds communities from the graph and summarizes them.""" @@ -241,7 +240,7 @@ def read_entity_info(self) -> dict: async def _summarize_communities(self, community_info, num_workers=5): """Generate and store summaries for each community.""" # Run tasks concurrently with a limited number of workers - tasks=[] + tasks = [] for community_id, details in community_info.items(): logger.info(f"Summarizing community {community_id}") details_text = "\n".join(details) + "." # Ensure it ends with a period @@ -252,12 +251,13 @@ async def _summarize_communities(self, community_info, num_workers=5): show_progress=True, desc="Summarize communities", ) + async def _process_community(self, community_id, details_text): """Process a single community and store the summary.""" summary = await self.generate_community_summary(details_text) self.community_summary[community_id] = summary self.store_community_summary_in_neo4j(community_id, summary) - + def store_community_summary_in_neo4j(self, community_id, summary): """Store the community summary in Neo4j.""" logger.info(f"Community_id: {community_id} type: {type(community_id)}") @@ -539,13 +539,13 @@ def initialize_graph_store_and_models(): logger.info(f"An error occurred while verifying the API Key: {e}") else: logger.info("NO OpenAI API Key. TGI/VLLM/TEI endpoints will be used.") - #works with TGI and VLLM endpoints + # works with TGI and VLLM endpoints llm = OpenAILike( - model=LLM_MODEL_ID, - api_base=TGI_LLM_ENDPOINT+"/v1", + model=LLM_MODEL_ID, + api_base=TGI_LLM_ENDPOINT + "/v1", api_key="fake", temperature=0.7, - max_tokens=MAX_OUTPUT_TOKENS, # 1512 + max_tokens=MAX_OUTPUT_TOKENS, # 1512 timeout=1200, # timeout in seconds) ) emb_name = get_attribute_from_tgi_endpoint(TEI_EMBEDDING_ENDPOINT, "model_id") @@ -730,7 +730,7 @@ async def ingest_documents( logger.info(f"Successfully saved link {link}") if files or link_list or skip_ingestion: - await build_communities(index) + await build_communities(index) result = {"status": 200, "message": "Data preparation succeeded"} if logflag: logger.info(result) diff --git a/comps/retrievers/neo4j/llama_index/config.py b/comps/retrievers/neo4j/llama_index/config.py index 33c5be7fd..42c1abe0d 100644 --- a/comps/retrievers/neo4j/llama_index/config.py +++ b/comps/retrievers/neo4j/llama_index/config.py @@ -19,4 +19,4 @@ LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "meta-llama/Meta-Llama-3.1-70B-Instruct") MAX_INPUT_LEN = os.getenv("MAX_INPUT_LEN", "8192") -MAX_OUTPUT_TOKENS = os.getenv("MAX_OUTPUT_TOKENS", "1024") \ No newline at end of file +MAX_OUTPUT_TOKENS = os.getenv("MAX_OUTPUT_TOKENS", "1024") diff --git a/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py b/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py index 5b56b23eb..e7d0cd650 100644 --- a/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py +++ b/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py @@ -9,6 +9,8 @@ import openai from config import ( + LLM_MODEL_ID, + MAX_OUTPUT_TOKENS, NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME, @@ -17,8 +19,6 @@ OPENAI_LLM_MODEL, TEI_EMBEDDING_ENDPOINT, TGI_LLM_ENDPOINT, - LLM_MODEL_ID, - MAX_OUTPUT_TOKENS, ) from llama_index.core import PropertyGraphIndex, Settings from llama_index.core.indices.property_graph.sub_retrievers.vector import VectorContextRetriever @@ -81,7 +81,7 @@ def custom_query(self, query_str: str, batch_size: int = 16) -> RetrievalRespons # Process community summaries in batches community_answers = [] for i in range(0, len(community_ids), batch_size): - batch_ids = community_ids[i:i + batch_size] + batch_ids = community_ids[i : i + batch_size] batch_summaries = {community_id: community_summaries[community_id] for community_id in batch_ids} batch_answers = self.generate_batch_answers_from_summaries(batch_summaries, query_str) community_answers.extend(batch_answers) @@ -181,14 +181,13 @@ def generate_answer_from_summary(self, community_summary, query): response = self._llm.chat(messages) cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip() return cleaned_response - + def generate_batch_answers_from_summaries(self, batch_summaries, query): """Generate answers from a batch of community summaries based on a given query using LLM.""" batch_prompts = [] for community_id, summary in batch_summaries.items(): prompt = ( - f"Given the community summary: {summary}, " - f"how would you answer the following query? Query: {query}" + f"Given the community summary: {summary}, " f"how would you answer the following query? Query: {query}" ) messages = [ ChatMessage(role="system", content=prompt), @@ -202,7 +201,7 @@ def generate_batch_answers_from_summaries(self, batch_summaries, query): # Generate answers for the batch answers = self.generate_batch_responses(batch_prompts) return answers - + def generate_batch_responses(self, batch_prompts): """Generate responses for a batch of prompts using LLM.""" responses = {} @@ -243,11 +242,11 @@ async def initialize_graph_store_and_index(): logger.info(f"An error occurred while verifying the API Key: {e}") else: logger.info("No OpenAI API KEY provided. Will use TGI/VLLM and TEI endpoints") - #llm_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id") - #works w VLLM too + # llm_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id") + # works w VLLM too llm = OpenAILike( model=LLM_MODEL_ID, - api_base=TGI_LLM_ENDPOINT+"/v1", + api_base=TGI_LLM_ENDPOINT + "/v1", api_key="fake", timeout=600, temperature=0.7,