Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 16, 2024
1 parent 078e04f commit bc64208
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 24 deletions.
26 changes: 13 additions & 13 deletions comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import openai
import requests
from config import (
LLM_MODEL_ID,
MAX_INPUT_LEN,
MAX_OUTPUT_TOKENS,
NEO4J_PASSWORD,
NEO4J_URL,
NEO4J_USERNAME,
Expand All @@ -25,9 +28,6 @@
TEI_EMBEDDING_ENDPOINT,
TGI_LLM_ENDPOINT,
host_ip,
LLM_MODEL_ID,
MAX_INPUT_LEN,
MAX_OUTPUT_TOKENS,
)
from fastapi import File, Form, HTTPException, UploadFile
from graspologic.partition import hierarchical_leiden
Expand All @@ -40,8 +40,8 @@
from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.llms.openai import OpenAI
from llama_index.llms.text_generation_inference import TextGenerationInference
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.text_generation_inference import TextGenerationInference
from neo4j import GraphDatabase
from openai import Client
from transformers import AutoTokenizer
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, username: str, password: str, url: str, llm: LLM):
async def generate_community_summary(self, text):
"""Generate summary for a given text using an LLM."""
model_name = LLM_MODEL_ID
max_input_length=int(MAX_INPUT_LEN)
max_input_length = int(MAX_INPUT_LEN)
if not model_name or not max_input_length:
raise ValueError(f"Could not retrieve model information from TGI endpoint: {TGI_LLM_ENDPOINT}")

Expand Down Expand Up @@ -116,7 +116,6 @@ async def generate_community_summary(self, text):

clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
return clean_response


async def build_communities(self):
"""Builds communities from the graph and summarizes them."""
Expand Down Expand Up @@ -241,7 +240,7 @@ def read_entity_info(self) -> dict:
async def _summarize_communities(self, community_info, num_workers=5):
"""Generate and store summaries for each community."""
# Run tasks concurrently with a limited number of workers
tasks=[]
tasks = []
for community_id, details in community_info.items():
logger.info(f"Summarizing community {community_id}")
details_text = "\n".join(details) + "." # Ensure it ends with a period
Expand All @@ -252,12 +251,13 @@ async def _summarize_communities(self, community_info, num_workers=5):
show_progress=True,
desc="Summarize communities",
)

async def _process_community(self, community_id, details_text):
"""Process a single community and store the summary."""
summary = await self.generate_community_summary(details_text)
self.community_summary[community_id] = summary
self.store_community_summary_in_neo4j(community_id, summary)

def store_community_summary_in_neo4j(self, community_id, summary):
"""Store the community summary in Neo4j."""
logger.info(f"Community_id: {community_id} type: {type(community_id)}")
Expand Down Expand Up @@ -539,13 +539,13 @@ def initialize_graph_store_and_models():
logger.info(f"An error occurred while verifying the API Key: {e}")
else:
logger.info("NO OpenAI API Key. TGI/VLLM/TEI endpoints will be used.")
#works with TGI and VLLM endpoints
# works with TGI and VLLM endpoints
llm = OpenAILike(
model=LLM_MODEL_ID,
api_base=TGI_LLM_ENDPOINT+"/v1",
model=LLM_MODEL_ID,
api_base=TGI_LLM_ENDPOINT + "/v1",
api_key="fake",
temperature=0.7,
max_tokens=MAX_OUTPUT_TOKENS, # 1512
max_tokens=MAX_OUTPUT_TOKENS, # 1512
timeout=1200, # timeout in seconds)
)
emb_name = get_attribute_from_tgi_endpoint(TEI_EMBEDDING_ENDPOINT, "model_id")
Expand Down Expand Up @@ -730,7 +730,7 @@ async def ingest_documents(
logger.info(f"Successfully saved link {link}")

if files or link_list or skip_ingestion:
await build_communities(index)
await build_communities(index)
result = {"status": 200, "message": "Data preparation succeeded"}
if logflag:
logger.info(result)
Expand Down
2 changes: 1 addition & 1 deletion comps/retrievers/neo4j/llama_index/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "meta-llama/Meta-Llama-3.1-70B-Instruct")
MAX_INPUT_LEN = os.getenv("MAX_INPUT_LEN", "8192")
MAX_OUTPUT_TOKENS = os.getenv("MAX_OUTPUT_TOKENS", "1024")
MAX_OUTPUT_TOKENS = os.getenv("MAX_OUTPUT_TOKENS", "1024")
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import openai
from config import (
LLM_MODEL_ID,
MAX_OUTPUT_TOKENS,
NEO4J_PASSWORD,
NEO4J_URL,
NEO4J_USERNAME,
Expand All @@ -17,8 +19,6 @@
OPENAI_LLM_MODEL,
TEI_EMBEDDING_ENDPOINT,
TGI_LLM_ENDPOINT,
LLM_MODEL_ID,
MAX_OUTPUT_TOKENS,
)
from llama_index.core import PropertyGraphIndex, Settings
from llama_index.core.indices.property_graph.sub_retrievers.vector import VectorContextRetriever
Expand Down Expand Up @@ -81,7 +81,7 @@ def custom_query(self, query_str: str, batch_size: int = 16) -> RetrievalRespons
# Process community summaries in batches
community_answers = []
for i in range(0, len(community_ids), batch_size):
batch_ids = community_ids[i:i + batch_size]
batch_ids = community_ids[i : i + batch_size]
batch_summaries = {community_id: community_summaries[community_id] for community_id in batch_ids}
batch_answers = self.generate_batch_answers_from_summaries(batch_summaries, query_str)
community_answers.extend(batch_answers)
Expand Down Expand Up @@ -181,14 +181,13 @@ def generate_answer_from_summary(self, community_summary, query):
response = self._llm.chat(messages)
cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
return cleaned_response

def generate_batch_answers_from_summaries(self, batch_summaries, query):
"""Generate answers from a batch of community summaries based on a given query using LLM."""
batch_prompts = []
for community_id, summary in batch_summaries.items():
prompt = (
f"Given the community summary: {summary}, "
f"how would you answer the following query? Query: {query}"
f"Given the community summary: {summary}, " f"how would you answer the following query? Query: {query}"
)
messages = [
ChatMessage(role="system", content=prompt),
Expand All @@ -202,7 +201,7 @@ def generate_batch_answers_from_summaries(self, batch_summaries, query):
# Generate answers for the batch
answers = self.generate_batch_responses(batch_prompts)
return answers

def generate_batch_responses(self, batch_prompts):
"""Generate responses for a batch of prompts using LLM."""
responses = {}
Expand Down Expand Up @@ -243,11 +242,11 @@ async def initialize_graph_store_and_index():
logger.info(f"An error occurred while verifying the API Key: {e}")
else:
logger.info("No OpenAI API KEY provided. Will use TGI/VLLM and TEI endpoints")
#llm_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id")
#works w VLLM too
# llm_name = get_attribute_from_tgi_endpoint(TGI_LLM_ENDPOINT, "model_id")
# works w VLLM too
llm = OpenAILike(
model=LLM_MODEL_ID,
api_base=TGI_LLM_ENDPOINT+"/v1",
api_base=TGI_LLM_ENDPOINT + "/v1",
api_key="fake",
timeout=600,
temperature=0.7,
Expand Down

0 comments on commit bc64208

Please sign in to comment.