Skip to content

Commit

Permalink
Feature/optimize count relationships (#1312)
Browse files Browse the repository at this point in the history
* refactor build text unit context for better performance

* Further optimization and styling

* Remove TODO

---------

Co-authored-by: Brad Firesheets <[email protected]>
Co-authored-by: bfirems <[email protected]>
Co-authored-by: Josh Bradley <[email protected]>
  • Loading branch information
4 people authored Oct 23, 2024
1 parent 3df6f8c commit ac09e0a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 48 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241023002453006383.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Optimize text unit relationship count"
}
44 changes: 16 additions & 28 deletions graphrag/query/context_builder/source_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import tiktoken

from graphrag.model import Entity, Relationship, TextUnit
from graphrag.model import Relationship, TextUnit
from graphrag.query.llm.text_utils import num_tokens

"""
Expand Down Expand Up @@ -78,33 +78,21 @@ def build_text_unit_context(


def count_relationships(
text_unit: TextUnit, entity: Entity, relationships: dict[str, Relationship]
entity_relationships: list[Relationship], text_unit: TextUnit
) -> int:
"""Count the number of relationships of the selected entity that are associated with the text unit."""
matching_relationships = list[Relationship]()
if text_unit.relationship_ids is None:
entity_relationships = [
rel
for rel in relationships.values()
if rel.source == entity.title or rel.target == entity.title
]
entity_relationships = [
rel for rel in entity_relationships if rel.text_unit_ids
]
matching_relationships = [
rel
if not text_unit.relationship_ids:
# Use list comprehension to count relationships where the text_unit.id is in rel.text_unit_ids
return sum(
1
for rel in entity_relationships
if text_unit.id in rel.text_unit_ids # type: ignore
] # type: ignore
else:
text_unit_relationships = [
relationships[rel_id]
for rel_id in text_unit.relationship_ids
if rel_id in relationships
]
matching_relationships = [
rel
for rel in text_unit_relationships
if rel.source == entity.title or rel.target == entity.title
]
return len(matching_relationships)
if rel.text_unit_ids and text_unit.id in rel.text_unit_ids
)

# Use a set for faster lookups if entity_relationships is large
entity_relationship_ids = {rel.id for rel in entity_relationships}

# Count matching relationship ids efficiently
return sum(
1 for rel_id in text_unit.relationship_ids if rel_id in entity_relationship_ids
)
37 changes: 17 additions & 20 deletions graphrag/query/structured_search/local_search/mixed_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,13 @@ def build_context(
final_context.append(str(local_context))
final_context_data = {**final_context_data, **local_context_data}

# build text unit context
text_unit_tokens = max(int(max_tokens * text_unit_prop), 0)
text_unit_context, text_unit_context_data = self._build_text_unit_context(
selected_entities=selected_entities,
max_tokens=text_unit_tokens,
return_candidate_context=return_candidate_context,
)

if text_unit_context.strip() != "":
final_context.append(text_unit_context)
final_context_data = {**final_context_data, **text_unit_context_data}
Expand Down Expand Up @@ -312,34 +312,32 @@ def _build_text_unit_context(
"""Rank matching text units and add them to the context window until it hits the max_tokens limit."""
if not selected_entities or not self.text_units:
return ("", {context_name.lower(): pd.DataFrame()})

selected_text_units = []
text_unit_ids_set = set()

unit_info_list = []
relationship_values = list(self.relationships.values())

for index, entity in enumerate(selected_entities):
# get matching relationships
entity_relationships = [
rel
for rel in relationship_values
if rel.source == entity.title or rel.target == entity.title
]

for text_id in entity.text_unit_ids or []:
if text_id not in text_unit_ids_set and text_id in self.text_units:
text_unit_ids_set.add(text_id)
selected_unit = deepcopy(self.text_units[text_id])
num_relationships = count_relationships(
selected_unit, entity, self.relationships
entity_relationships, selected_unit
)
if selected_unit.attributes is None:
selected_unit.attributes = {}
selected_unit.attributes["entity_order"] = index
selected_unit.attributes["num_relationships"] = num_relationships
selected_text_units.append(selected_unit)

selected_text_units.sort(
key=lambda x: (
x.attributes["entity_order"],
-x.attributes["num_relationships"],
)
)
unit_info_list.append((selected_unit, index, num_relationships))

for unit in selected_text_units:
unit.attributes.pop("entity_order", None)
unit.attributes.pop("num_relationships", None)
# sort by entity_order and the number of relationships desc
unit_info_list.sort(key=lambda x: (x[1], -x[2]))

selected_text_units = [unit[0] for unit in unit_info_list]

context_text, context_data = build_text_unit_context(
text_units=selected_text_units,
Expand Down Expand Up @@ -484,7 +482,6 @@ def _build_local_context(
final_context_data[key] = candidate_df
else:
final_context_data[key]["in_context"] = True

else:
for key in final_context_data:
final_context_data[key]["in_context"] = True
Expand Down

0 comments on commit ac09e0a

Please sign in to comment.