From 630679f8e3ceffb24b80c024c33eb40b263ebf7e Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Mon, 30 Sep 2024 17:17:44 -0700 Subject: [PATCH] Collapse create summarized entities (#1237) * Collapse entity summarize * Semver --- .../patch-20240930235021909622.json | 4 + .../summarize/description_summarize.py | 30 +++++- .../v1/create_summarized_entities.py | 25 ++--- .../index/workflows/v1/subflows/__init__.py | 2 + .../v1/subflows/create_summarized_entities.py | 61 ++++++++++++ .../verbs/test_create_summarized_entities.py | 92 +++++++++++++++++++ 6 files changed, 192 insertions(+), 22 deletions(-) create mode 100644 .semversioner/next-release/patch-20240930235021909622.json create mode 100644 graphrag/index/workflows/v1/subflows/create_summarized_entities.py create mode 100644 tests/verbs/test_create_summarized_entities.py diff --git a/.semversioner/next-release/patch-20240930235021909622.json b/.semversioner/next-release/patch-20240930235021909622.json new file mode 100644 index 0000000000..682aff3385 --- /dev/null +++ b/.semversioner/next-release/patch-20240930235021909622.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Collapse entity summarize." +} diff --git a/graphrag/index/verbs/entities/summarize/description_summarize.py b/graphrag/index/verbs/entities/summarize/description_summarize.py index 5b7feb4184..40200b4cce 100644 --- a/graphrag/index/verbs/entities/summarize/description_summarize.py +++ b/graphrag/index/verbs/entities/summarize/description_summarize.py @@ -53,6 +53,29 @@ async def summarize_descriptions( strategy: dict[str, Any] | None = None, **kwargs, ) -> TableContainer: + """Summarize entity and relationship descriptions from an entity graph.""" + source = cast(pd.DataFrame, input.get_input()) + output = await summarize_descriptions_df( + source, + cache, + callbacks, + column=column, + to=to, + strategy=strategy, + **kwargs, + ) + return TableContainer(table=output) + + +async def summarize_descriptions_df( + input: pd.DataFrame, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + to: str, + strategy: dict[str, Any] | None = None, + **kwargs, +) -> pd.DataFrame: """ Summarize entity and relationship descriptions from an entity graph. @@ -111,7 +134,6 @@ async def summarize_descriptions( ``` """ log.debug("summarize_descriptions strategy=%s", strategy) - output = cast(pd.DataFrame, input.get_input()) strategy = strategy or {} strategy_exec = load_strategy( strategy.get("type", SummarizeStrategyType.graph_intelligence) @@ -181,7 +203,7 @@ async def do_summarize_descriptions( semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4)) results = [ - await get_resolved_entities(row, semaphore) for row in output.itertuples() + await get_resolved_entities(row, semaphore) for row in input.itertuples() ] to_result = [] @@ -191,8 +213,8 @@ async def do_summarize_descriptions( to_result.append(result.graph) else: to_result.append(None) - output[to] = to_result - return TableContainer(table=output) + input[to] = to_result + return input def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy: diff --git a/graphrag/index/workflows/v1/create_summarized_entities.py b/graphrag/index/workflows/v1/create_summarized_entities.py index 8f8d7f0042..d4d95786ab 100644 --- a/graphrag/index/workflows/v1/create_summarized_entities.py +++ b/graphrag/index/workflows/v1/create_summarized_entities.py @@ -3,8 +3,6 @@ """A module containing build_steps method definition.""" -from datashaper import AsyncType - from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_summarized_entities" @@ -20,28 +18,19 @@ def build_steps( * `workflow:create_base_text_units` """ summarize_descriptions_config = config.get("summarize_descriptions", {}) + strategy = summarize_descriptions_config.get("strategy", {}) + num_threads = strategy.get("num_threads", 4) + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False return [ { - "verb": "summarize_descriptions", + "verb": "create_summarized_entities", "args": { - **summarize_descriptions_config, - "column": "entity_graph", - "to": "entity_graph", - "async_mode": summarize_descriptions_config.get( - "async_mode", AsyncType.AsyncIO - ), + "strategy": strategy, + "num_threads": num_threads, + "graphml_snapshot_enabled": graphml_snapshot_enabled, }, "input": {"source": "workflow:create_base_extracted_entities"}, }, - { - "verb": "snapshot_rows", - "enabled": graphml_snapshot_enabled, - "args": { - "base_name": "summarized_graph", - "column": "entity_graph", - "formats": [{"format": "text", "extension": "graphml"}], - }, - }, ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index 8bc0e81809..95c4774d4b 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -16,6 +16,7 @@ create_final_relationships, ) from .create_final_text_units import create_final_text_units +from .create_summarized_entities import create_summarized_entities __all__ = [ "create_base_documents", @@ -29,4 +30,5 @@ "create_final_nodes", "create_final_relationships", "create_final_text_units", + "create_summarized_entities", ] diff --git a/graphrag/index/workflows/v1/subflows/create_summarized_entities.py b/graphrag/index/workflows/v1/subflows/create_summarized_entities.py new file mode 100644 index 0000000000..6ec042c418 --- /dev/null +++ b/graphrag/index/workflows/v1/subflows/create_summarized_entities.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final documents.""" + +from typing import Any, cast + +import pandas as pd +from datashaper import ( + Table, + VerbCallbacks, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.index.cache import PipelineCache +from graphrag.index.storage import PipelineStorage +from graphrag.index.verbs.entities.summarize.description_summarize import ( + summarize_descriptions_df, +) +from graphrag.index.verbs.snapshot_rows import snapshot_rows_df + + +@verb( + name="create_summarized_entities", + treats_input_tables_as_immutable=True, +) +async def create_summarized_entities( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + storage: PipelineStorage, + strategy: dict[str, Any] | None = None, + num_threads: int = 4, + graphml_snapshot_enabled: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final documents.""" + source = cast(pd.DataFrame, input.get_input()) + + summarized = await summarize_descriptions_df( + source, + cache, + callbacks, + column="entity_graph", + to="entity_graph", + strategy=strategy, + num_threads=num_threads, + ) + + if graphml_snapshot_enabled: + await snapshot_rows_df( + summarized, + column="entity_graph", + base_name="summarized_graph", + storage=storage, + formats=[{"format": "text", "extension": "graphml"}], + ) + + return create_verb_result(cast(Table, summarized)) diff --git a/tests/verbs/test_create_summarized_entities.py b/tests/verbs/test_create_summarized_entities.py new file mode 100644 index 0000000000..51fe57e2ee --- /dev/null +++ b/tests/verbs/test_create_summarized_entities.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +import networkx as nx + +from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage +from graphrag.index.workflows.v1.create_summarized_entities import ( + build_steps, + workflow_name, +) + +from .util import ( + get_config_for_workflow, + get_workflow_output, + load_expected, + load_input_tables, + remove_disabled_steps, +) + + +async def test_create_summarized_entities(): + input_tables = load_input_tables([ + "workflow:create_base_extracted_entities", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + del config["summarize_descriptions"]["strategy"]["llm"] + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + # the serialization of the graph may differ so we can't assert the dataframes directly + assert actual.shape == expected.shape, "Graph dataframe shapes differ" + + # let's parse a sample of the raw graphml + actual_graphml_0 = actual["entity_graph"][:1][0] + actual_graph_0 = nx.parse_graphml(actual_graphml_0) + + expected_graphml_0 = expected["entity_graph"][:1][0] + expected_graph_0 = nx.parse_graphml(expected_graphml_0) + + assert ( + actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes() + ), "Graphml node count differs" + assert ( + actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges() + ), "Graphml edge count differs" + + # ensure the mock summary was injected to the nodes + nodes = list(actual_graph_0.nodes(data=True)) + assert ( + nodes[0][1]["description"] + == "This is a MOCK response for the LLM. It is summarized!" + ) + + +async def test_create_summarized_entities_with_snapshots(): + input_tables = load_input_tables([ + "workflow:create_base_extracted_entities", + ]) + expected = load_expected(workflow_name) + + storage = MemoryPipelineStorage() + + config = get_config_for_workflow(workflow_name) + + del config["summarize_descriptions"]["strategy"]["llm"] + config["graphml_snapshot"] = True + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + storage=storage, + ) + + assert actual.shape == expected.shape, "Graph dataframe shapes differ" + + assert storage.keys() == [ + "summarized_graph.graphml", + ], "Graph snapshot keys differ"