Skip to content

Commit

Permalink
Fix/text unit code cleanup (#1040)
Browse files Browse the repository at this point in the history
* Optimized _build_text_unit_context function for improved time and space complexity

Refactored the _build_text_unit_context function to enhance its performance and efficiency. Key optimizations include:

1. Set for Text Unit IDs: Replaced list-based membership checks with a set (text_unit_ids_set) to achieve constant-time complexity for membership checks, reducing overall time complexity.
2. Direct Attribute Removal: Utilized pop with a default value (None) to directly remove attributes entity_order and num_relationships from text units, minimizing overhead and avoiding potential KeyError.
3. Default Dictionary for Entity Orders: Implemented defaultdict for managing entity orders, simplifying the ranking process and improving readability.

These improvements result in a more efficient function with better performance, especially when handling large datasets or numerous selected entities. The refactoring ensures that the core functionality remains unchanged while enhancing both time and space complexity.

* Format

* Ruff fixes

* semver

---------

Co-authored-by: arjun-234 <[email protected]>
Co-authored-by: Arjun D. <[email protected]>
  • Loading branch information
3 people authored Aug 27, 2024
1 parent 5d8e60c commit 22df2f8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 30 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240827212041426794.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Refactor text unit build at local search"
}
53 changes: 23 additions & 30 deletions graphrag/query/structured_search/local_search/mixed_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,42 +309,36 @@ def _build_text_unit_context(
context_name: str = "Sources",
) -> tuple[str, dict[str, pd.DataFrame]]:
"""Rank matching text units and add them to the context window until it hits the max_tokens limit."""
if len(selected_entities) == 0 or len(self.text_units) == 0:
if not selected_entities or not self.text_units:
return ("", {context_name.lower(): pd.DataFrame()})

selected_text_units = list[TextUnit]()
# for each matching text unit, rank first by the order of the entities that match it, then by the number of matching relationships
# that the text unit has with the matching entities
selected_text_units = []
text_unit_ids_set = set()

for index, entity in enumerate(selected_entities):
if entity.text_unit_ids:
for text_id in entity.text_unit_ids:
if (
text_id not in [unit.id for unit in selected_text_units]
and text_id in self.text_units
):
selected_unit = self.text_units[text_id]
num_relationships = count_relationships(
selected_unit, entity, self.relationships
)
if selected_unit.attributes is None:
selected_unit.attributes = {}
selected_unit.attributes["entity_order"] = index
selected_unit.attributes["num_relationships"] = (
num_relationships
)
selected_text_units.append(selected_unit)
for text_id in entity.text_unit_ids or []:
if text_id not in text_unit_ids_set and text_id in self.text_units:
text_unit_ids_set.add(text_id)
selected_unit = self.text_units[text_id]
num_relationships = count_relationships(
selected_unit, entity, self.relationships
)
if selected_unit.attributes is None:
selected_unit.attributes = {}
selected_unit.attributes["entity_order"] = index
selected_unit.attributes["num_relationships"] = num_relationships
selected_text_units.append(selected_unit)

# sort selected text units by ascending order of entity order and descending order of number of relationships
selected_text_units.sort(
key=lambda x: (
x.attributes["entity_order"], # type: ignore
-x.attributes["num_relationships"], # type: ignore
x.attributes["entity_order"],
-x.attributes["num_relationships"],
)
)

for unit in selected_text_units:
del unit.attributes["entity_order"] # type: ignore
del unit.attributes["num_relationships"] # type: ignore
unit.attributes.pop("entity_order", None)
unit.attributes.pop("num_relationships", None)

context_text, context_data = build_text_unit_context(
text_units=selected_text_units,
Expand All @@ -362,21 +356,20 @@ def _build_text_unit_context(
)
context_key = context_name.lower()
if context_key not in context_data:
candidate_context_data["in_context"] = False
context_data[context_key] = candidate_context_data
context_data[context_key]["in_context"] = False
else:
if (
"id" in candidate_context_data.columns
and "id" in context_data[context_key].columns
):
candidate_context_data["in_context"] = candidate_context_data[
"id"
].isin( # cspell:disable-line
context_data[context_key]["id"]
)
].isin(context_data[context_key]["id"])
context_data[context_key] = candidate_context_data
else:
context_data[context_key]["in_context"] = True

return (str(context_text), context_data)

def _build_local_context(
Expand Down

0 comments on commit 22df2f8

Please sign in to comment.