-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Collapse create summarized entities (#1237)
* Collapse entity summarize * Semver
- Loading branch information
Showing
6 changed files
with
192 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"type": "patch", | ||
"description": "Collapse entity summarize." | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
graphrag/index/workflows/v1/subflows/create_summarized_entities.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""All the steps to transform final documents.""" | ||
|
||
from typing import Any, cast | ||
|
||
import pandas as pd | ||
from datashaper import ( | ||
Table, | ||
VerbCallbacks, | ||
VerbInput, | ||
verb, | ||
) | ||
from datashaper.table_store.types import VerbResult, create_verb_result | ||
|
||
from graphrag.index.cache import PipelineCache | ||
from graphrag.index.storage import PipelineStorage | ||
from graphrag.index.verbs.entities.summarize.description_summarize import ( | ||
summarize_descriptions_df, | ||
) | ||
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df | ||
|
||
|
||
@verb( | ||
name="create_summarized_entities", | ||
treats_input_tables_as_immutable=True, | ||
) | ||
async def create_summarized_entities( | ||
input: VerbInput, | ||
cache: PipelineCache, | ||
callbacks: VerbCallbacks, | ||
storage: PipelineStorage, | ||
strategy: dict[str, Any] | None = None, | ||
num_threads: int = 4, | ||
graphml_snapshot_enabled: bool = False, | ||
**_kwargs: dict, | ||
) -> VerbResult: | ||
"""All the steps to transform final documents.""" | ||
source = cast(pd.DataFrame, input.get_input()) | ||
|
||
summarized = await summarize_descriptions_df( | ||
source, | ||
cache, | ||
callbacks, | ||
column="entity_graph", | ||
to="entity_graph", | ||
strategy=strategy, | ||
num_threads=num_threads, | ||
) | ||
|
||
if graphml_snapshot_enabled: | ||
await snapshot_rows_df( | ||
summarized, | ||
column="entity_graph", | ||
base_name="summarized_graph", | ||
storage=storage, | ||
formats=[{"format": "text", "extension": "graphml"}], | ||
) | ||
|
||
return create_verb_result(cast(Table, summarized)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
import networkx as nx | ||
|
||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage | ||
from graphrag.index.workflows.v1.create_summarized_entities import ( | ||
build_steps, | ||
workflow_name, | ||
) | ||
|
||
from .util import ( | ||
get_config_for_workflow, | ||
get_workflow_output, | ||
load_expected, | ||
load_input_tables, | ||
remove_disabled_steps, | ||
) | ||
|
||
|
||
async def test_create_summarized_entities(): | ||
input_tables = load_input_tables([ | ||
"workflow:create_base_extracted_entities", | ||
]) | ||
expected = load_expected(workflow_name) | ||
|
||
config = get_config_for_workflow(workflow_name) | ||
|
||
del config["summarize_descriptions"]["strategy"]["llm"] | ||
|
||
steps = remove_disabled_steps(build_steps(config)) | ||
|
||
actual = await get_workflow_output( | ||
input_tables, | ||
{ | ||
"steps": steps, | ||
}, | ||
) | ||
|
||
# the serialization of the graph may differ so we can't assert the dataframes directly | ||
assert actual.shape == expected.shape, "Graph dataframe shapes differ" | ||
|
||
# let's parse a sample of the raw graphml | ||
actual_graphml_0 = actual["entity_graph"][:1][0] | ||
actual_graph_0 = nx.parse_graphml(actual_graphml_0) | ||
|
||
expected_graphml_0 = expected["entity_graph"][:1][0] | ||
expected_graph_0 = nx.parse_graphml(expected_graphml_0) | ||
|
||
assert ( | ||
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes() | ||
), "Graphml node count differs" | ||
assert ( | ||
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges() | ||
), "Graphml edge count differs" | ||
|
||
# ensure the mock summary was injected to the nodes | ||
nodes = list(actual_graph_0.nodes(data=True)) | ||
assert ( | ||
nodes[0][1]["description"] | ||
== "This is a MOCK response for the LLM. It is summarized!" | ||
) | ||
|
||
|
||
async def test_create_summarized_entities_with_snapshots(): | ||
input_tables = load_input_tables([ | ||
"workflow:create_base_extracted_entities", | ||
]) | ||
expected = load_expected(workflow_name) | ||
|
||
storage = MemoryPipelineStorage() | ||
|
||
config = get_config_for_workflow(workflow_name) | ||
|
||
del config["summarize_descriptions"]["strategy"]["llm"] | ||
config["graphml_snapshot"] = True | ||
|
||
steps = remove_disabled_steps(build_steps(config)) | ||
|
||
actual = await get_workflow_output( | ||
input_tables, | ||
{ | ||
"steps": steps, | ||
}, | ||
storage=storage, | ||
) | ||
|
||
assert actual.shape == expected.shape, "Graph dataframe shapes differ" | ||
|
||
assert storage.keys() == [ | ||
"summarized_graph.graphml", | ||
], "Graph snapshot keys differ" |