Skip to content

Commit

Permalink
Perf optimizations in map_query_to_entities() (#1276)
Browse files Browse the repository at this point in the history
* Address perf issue in map_query_to_entities()

* Add semver

---------

Co-authored-by: Matthieu Maitre <[email protected]>
Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 1f70d42 commit 6aae386
Show file tree
Hide file tree
Showing 9 changed files with 388 additions and 13 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241014040518441266.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Perf optimizations in map_query_to_entities()"
}
19 changes: 13 additions & 6 deletions graphrag/query/context_builder/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from graphrag.model import Entity, Relationship
from graphrag.query.input.retrieval.entities import (
get_entity_by_id,
get_entity_by_key,
get_entity_by_name,
)
Expand Down Expand Up @@ -36,7 +37,7 @@ def map_query_to_entities(
query: str,
text_embedding_vectorstore: BaseVectorStore,
text_embedder: BaseTextEmbedding,
all_entities: list[Entity],
all_entities_dict: dict[str, Entity],
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
include_entity_names: list[str] | None = None,
exclude_entity_names: list[str] | None = None,
Expand All @@ -48,6 +49,7 @@ def map_query_to_entities(
include_entity_names = []
if exclude_entity_names is None:
exclude_entity_names = []
all_entities = list(all_entities_dict.values())
matched_entities = []
if query != "":
# get entities with highest semantic similarity to query
Expand All @@ -58,11 +60,16 @@ def map_query_to_entities(
k=k * oversample_scaler,
)
for result in search_results:
matched = get_entity_by_key(
entities=all_entities,
key=embedding_vectorstore_key,
value=result.document.id,
)
if embedding_vectorstore_key == EntityVectorStoreKey.ID and isinstance(
result.document.id, str
):
matched = get_entity_by_id(all_entities_dict, result.document.id)
else:
matched = get_entity_by_key(
entities=all_entities,
key=embedding_vectorstore_key,
value=result.document.id,
)
if matched:
matched_entities.append(matched)
else:
Expand Down
21 changes: 15 additions & 6 deletions graphrag/query/input/retrieval/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,26 @@
from graphrag.model import Entity


def get_entity_by_id(entities: dict[str, Entity], value: str) -> Entity | None:
"""Get entity by id."""
entity = entities.get(value)
if entity is None and is_valid_uuid(value):
entity = entities.get(value.replace("-", ""))
return entity


def get_entity_by_key(
entities: Iterable[Entity], key: str, value: str | int
) -> Entity | None:
"""Get entity by key."""
for entity in entities:
if isinstance(value, str) and is_valid_uuid(value):
if getattr(entity, key) == value or getattr(entity, key) == value.replace(
"-", ""
):
if isinstance(value, str) and is_valid_uuid(value):
value_no_dashes = value.replace("-", "")
for entity in entities:
entity_value = getattr(entity, key)
if entity_value in (value, value_no_dashes):
return entity
else:
else:
for entity in entities:
if getattr(entity, key) == value:
return entity
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def build_context(
query=query,
text_embedding_vectorstore=self.entity_text_embeddings,
text_embedder=self.text_embedder,
all_entities=list(self.entities.values()),
all_entities_dict=self.entities,
embedding_vectorstore_key=self.embedding_vectorstore_key,
include_entity_names=include_entity_names,
exclude_entity_names=exclude_entity_names,
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/query/context_builder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
182 changes: 182 additions & 0 deletions tests/unit/query/context_builder/test_entity_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from typing import Any

from graphrag.model import Entity
from graphrag.model.types import TextEmbedder
from graphrag.query.context_builder.entity_extraction import (
EntityVectorStoreKey,
map_query_to_entities,
)
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.vector_stores import (
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
)


class MockBaseVectorStore(BaseVectorStore):
def __init__(self, documents: list[VectorStoreDocument]) -> None:
super().__init__("mock")
self.documents = documents

def connect(self, **kwargs: Any) -> None:
raise NotImplementedError

def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
raise NotImplementedError

def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
return [
VectorStoreSearchResult(document=document, score=1)
for document in self.documents[:k]
]

def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
return sorted(
[
VectorStoreSearchResult(
document=document, score=abs(len(text) - len(document.text or ""))
)
for document in self.documents
],
key=lambda x: x.score,
)[:k]

def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
return [document for document in self.documents if document.id in include_ids]


class MockBaseTextEmbedding(BaseTextEmbedding):
def embed(self, text: str, **kwargs: Any) -> list[float]:
return [len(text)]

async def aembed(self, text: str, **kwargs: Any) -> list[float]:
return [len(text)]


def test_map_query_to_entities():
entities = [
Entity(
id="2da37c7a-50a8-44d4-aa2c-fd401e19976c",
short_id="sid1",
title="t1",
rank=2,
),
Entity(
id="c4f93564-4507-4ee4-b102-98add401a965",
short_id="sid2",
title="t22",
rank=4,
),
Entity(
id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
short_id="sid3",
title="t333",
rank=1,
),
Entity(
id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83",
short_id="sid4",
title="t4444",
rank=3,
),
]

assert map_query_to_entities(
query="t22",
text_embedding_vectorstore=MockBaseVectorStore([
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.ID,
k=1,
oversample_scaler=1,
) == [
Entity(
id="c4f93564-4507-4ee4-b102-98add401a965",
short_id="sid2",
title="t22",
rank=4,
)
]

assert map_query_to_entities(
query="t22",
text_embedding_vectorstore=MockBaseVectorStore([
VectorStoreDocument(id=entity.title, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
k=1,
oversample_scaler=1,
) == [
Entity(
id="c4f93564-4507-4ee4-b102-98add401a965",
short_id="sid2",
title="t22",
rank=4,
)
]

assert map_query_to_entities(
query="",
text_embedding_vectorstore=MockBaseVectorStore([
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.ID,
k=2,
) == [
Entity(
id="c4f93564-4507-4ee4-b102-98add401a965",
short_id="sid2",
title="t22",
rank=4,
),
Entity(
id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83",
short_id="sid4",
title="t4444",
rank=3,
),
]

assert map_query_to_entities(
query="",
text_embedding_vectorstore=MockBaseVectorStore([
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
k=2,
) == [
Entity(
id="c4f93564-4507-4ee4-b102-98add401a965",
short_id="sid2",
title="t22",
rank=4,
),
Entity(
id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83",
short_id="sid4",
title="t4444",
rank=3,
),
]
2 changes: 2 additions & 0 deletions tests/unit/query/input/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
2 changes: 2 additions & 0 deletions tests/unit/query/input/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
Loading

0 comments on commit 6aae386

Please sign in to comment.