Skip to content

Commit

Permalink
Collapse create summarized entities (#1237)
Browse files Browse the repository at this point in the history
* Collapse entity summarize

* Semver
  • Loading branch information
natoverse authored Oct 1, 2024
1 parent 5220bb7 commit 630679f
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 22 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240930235021909622.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse entity summarize."
}
30 changes: 26 additions & 4 deletions graphrag/index/verbs/entities/summarize/description_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ async def summarize_descriptions(
strategy: dict[str, Any] | None = None,
**kwargs,
) -> TableContainer:
"""Summarize entity and relationship descriptions from an entity graph."""
source = cast(pd.DataFrame, input.get_input())
output = await summarize_descriptions_df(
source,
cache,
callbacks,
column=column,
to=to,
strategy=strategy,
**kwargs,
)
return TableContainer(table=output)


async def summarize_descriptions_df(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
to: str,
strategy: dict[str, Any] | None = None,
**kwargs,
) -> pd.DataFrame:
"""
Summarize entity and relationship descriptions from an entity graph.
Expand Down Expand Up @@ -111,7 +134,6 @@ async def summarize_descriptions(
```
"""
log.debug("summarize_descriptions strategy=%s", strategy)
output = cast(pd.DataFrame, input.get_input())
strategy = strategy or {}
strategy_exec = load_strategy(
strategy.get("type", SummarizeStrategyType.graph_intelligence)
Expand Down Expand Up @@ -181,7 +203,7 @@ async def do_summarize_descriptions(
semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4))

results = [
await get_resolved_entities(row, semaphore) for row in output.itertuples()
await get_resolved_entities(row, semaphore) for row in input.itertuples()
]

to_result = []
Expand All @@ -191,8 +213,8 @@ async def do_summarize_descriptions(
to_result.append(result.graph)
else:
to_result.append(None)
output[to] = to_result
return TableContainer(table=output)
input[to] = to_result
return input


def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy:
Expand Down
25 changes: 7 additions & 18 deletions graphrag/index/workflows/v1/create_summarized_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""A module containing build_steps method definition."""

from datashaper import AsyncType

from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep

workflow_name = "create_summarized_entities"
Expand All @@ -20,28 +18,19 @@ def build_steps(
* `workflow:create_base_text_units`
"""
summarize_descriptions_config = config.get("summarize_descriptions", {})
strategy = summarize_descriptions_config.get("strategy", {})
num_threads = strategy.get("num_threads", 4)

graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False

return [
{
"verb": "summarize_descriptions",
"verb": "create_summarized_entities",
"args": {
**summarize_descriptions_config,
"column": "entity_graph",
"to": "entity_graph",
"async_mode": summarize_descriptions_config.get(
"async_mode", AsyncType.AsyncIO
),
"strategy": strategy,
"num_threads": num_threads,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
},
"input": {"source": "workflow:create_base_extracted_entities"},
},
{
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"base_name": "summarized_graph",
"column": "entity_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
create_final_relationships,
)
from .create_final_text_units import create_final_text_units
from .create_summarized_entities import create_summarized_entities

__all__ = [
"create_base_documents",
Expand All @@ -29,4 +30,5 @@
"create_final_nodes",
"create_final_relationships",
"create_final_text_units",
"create_summarized_entities",
]
61 changes: 61 additions & 0 deletions graphrag/index/workflows/v1/subflows/create_summarized_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform final documents."""

from typing import Any, cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.cache import PipelineCache
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.entities.summarize.description_summarize import (
summarize_descriptions_df,
)
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df


@verb(
name="create_summarized_entities",
treats_input_tables_as_immutable=True,
)
async def create_summarized_entities(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
storage: PipelineStorage,
strategy: dict[str, Any] | None = None,
num_threads: int = 4,
graphml_snapshot_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final documents."""
source = cast(pd.DataFrame, input.get_input())

summarized = await summarize_descriptions_df(
source,
cache,
callbacks,
column="entity_graph",
to="entity_graph",
strategy=strategy,
num_threads=num_threads,
)

if graphml_snapshot_enabled:
await snapshot_rows_df(
summarized,
column="entity_graph",
base_name="summarized_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

return create_verb_result(cast(Table, summarized))
92 changes: 92 additions & 0 deletions tests/verbs/test_create_summarized_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import networkx as nx

from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.workflows.v1.create_summarized_entities import (
build_steps,
workflow_name,
)

from .util import (
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)


async def test_create_summarized_entities():
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)

del config["summarize_descriptions"]["strategy"]["llm"]

steps = remove_disabled_steps(build_steps(config))

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

# the serialization of the graph may differ so we can't assert the dataframes directly
assert actual.shape == expected.shape, "Graph dataframe shapes differ"

# let's parse a sample of the raw graphml
actual_graphml_0 = actual["entity_graph"][:1][0]
actual_graph_0 = nx.parse_graphml(actual_graphml_0)

expected_graphml_0 = expected["entity_graph"][:1][0]
expected_graph_0 = nx.parse_graphml(expected_graphml_0)

assert (
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
), "Graphml node count differs"
assert (
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
), "Graphml edge count differs"

# ensure the mock summary was injected to the nodes
nodes = list(actual_graph_0.nodes(data=True))
assert (
nodes[0][1]["description"]
== "This is a MOCK response for the LLM. It is summarized!"
)


async def test_create_summarized_entities_with_snapshots():
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)

storage = MemoryPipelineStorage()

config = get_config_for_workflow(workflow_name)

del config["summarize_descriptions"]["strategy"]["llm"]
config["graphml_snapshot"] = True

steps = remove_disabled_steps(build_steps(config))

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)

assert actual.shape == expected.shape, "Graph dataframe shapes differ"

assert storage.keys() == [
"summarized_graph.graphml",
], "Graph snapshot keys differ"

0 comments on commit 630679f

Please sign in to comment.