Skip to content

Commit

Permalink
Collapse graph documents workflows (#1284)
Browse files Browse the repository at this point in the history
* Copy base documents logic into final documents

* Delete create_base_documents

* Combine graph creation under create_base_entity_graph

* Delete collapsed workflows

* Migrate most graph internals to nx.Graph

* Fix None edge case

* Semver

* Remove comment typo

* Fix smoke tests
  • Loading branch information
natoverse authored Oct 15, 2024
1 parent 137a5cd commit ce5b120
Show file tree
Hide file tree
Showing 41 changed files with 447 additions and 1,099 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241014220522452574.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse intermediate workflow outputs."
}
27 changes: 3 additions & 24 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@
PipelineWorkflowReference,
)
from graphrag.index.workflows.default_workflows import (
create_base_documents,
create_base_entity_graph,
create_base_extracted_entities,
create_base_text_units,
create_final_communities,
create_final_community_reports,
Expand All @@ -61,7 +59,6 @@
create_final_nodes,
create_final_relationships,
create_final_text_units,
create_summarized_entities,
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,17 +170,12 @@ def _document_workflows(
)
return [
PipelineWorkflowReference(
name=create_base_documents,
name=create_final_documents,
config={
"document_attribute_columns": list(
{*(settings.input.document_attribute_columns)}
- builtin_document_attributes
)
},
),
PipelineWorkflowReference(
name=create_final_documents,
config={
),
"document_raw_content_embed": _get_embedding_settings(
settings.embeddings,
"document_raw_content",
Expand Down Expand Up @@ -267,10 +259,9 @@ def _graph_workflows(
)
return [
PipelineWorkflowReference(
name=create_base_extracted_entities,
name=create_base_entity_graph,
config={
"graphml_snapshot": settings.snapshots.graphml,
"raw_entity_snapshot": settings.snapshots.raw_entities,
"entity_extract": {
**settings.entity_extraction.parallelization.model_dump(),
"async_mode": settings.entity_extraction.async_mode,
Expand All @@ -279,25 +270,13 @@ def _graph_workflows(
),
"entity_types": settings.entity_extraction.entity_types,
},
},
),
PipelineWorkflowReference(
name=create_summarized_entities,
config={
"graphml_snapshot": settings.snapshots.graphml,
"summarize_descriptions": {
**settings.summarize_descriptions.parallelization.model_dump(),
"async_mode": settings.summarize_descriptions.async_mode,
"strategy": settings.summarize_descriptions.resolved_strategy(
settings.root_dir,
),
},
},
),
PipelineWorkflowReference(
name=create_base_entity_graph,
config={
"graphml_snapshot": settings.snapshots.graphml,
"embed_graph_enabled": settings.embed_graph.enabled,
"cluster_graph": {
"strategy": settings.cluster_graph.resolved_strategy()
Expand Down
64 changes: 0 additions & 64 deletions graphrag/index/flows/create_base_documents.py

This file was deleted.

97 changes: 81 additions & 16 deletions graphrag/index/flows/create_base_entity_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,83 @@

import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)

from graphrag.index.cache import PipelineCache
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.embed_graph import embed_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.merge_graphs import merge_graphs
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.storage import PipelineStorage


async def create_base_entity_graph(
entities: pd.DataFrame,
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_column: str,
id_column: str,
clustering_strategy: dict[str, Any],
embedding_strategy: dict[str, Any] | None,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
node_merge_config: dict[str, Any] | None = None,
edge_merge_config: dict[str, Any] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
embedding_strategy: dict[str, Any] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entities, entity_graphs = await extract_entities(
text_units,
callbacks,
cache,
text_column=text_column,
id_column=id_column,
strategy=extraction_strategy,
async_mode=extraction_async_mode,
entity_types=entity_types,
to="entities",
num_threads=extraction_num_threads,
)

merged_graph = merge_graphs(
entity_graphs,
callbacks,
node_operations=node_merge_config,
edge_operations=edge_merge_config,
)

summarized = await summarize_descriptions(
merged_graph,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)

clustered = cluster_graph(
entities,
summarized,
callbacks,
column="entity_graph",
strategy=clustering_strategy,
to="clustered_graph",
level_to="level",
)

if graphml_snapshot_enabled:
await snapshot_rows(
clustered,
column="clustered_graph",
base_name="clustered_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

if embedding_strategy:
clustered["embeddings"] = await embed_graph(
clustered,
Expand All @@ -51,16 +92,40 @@ async def create_base_entity_graph(
strategy=embedding_strategy,
)

# take second snapshot after embedding
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph?
if raw_entity_snapshot_enabled:
await snapshot(
entities,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)

if graphml_snapshot_enabled:
await snapshot_graphml(
merged_graph,
name="merged_graph",
storage=storage,
)
await snapshot_graphml(
summarized,
name="summarized_graph",
storage=storage,
)
await snapshot_rows(
clustered,
column="entity_graph",
base_name="embedded_graph",
column="clustered_graph",
base_name="clustered_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
if embedding_strategy:
await snapshot_rows(
clustered,
column="entity_graph",
base_name="embedded_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

final_columns = ["level", "clustered_graph"]
if embedding_strategy:
Expand Down
79 changes: 0 additions & 79 deletions graphrag/index/flows/create_base_extracted_entities.py

This file was deleted.

Loading

0 comments on commit ce5b120

Please sign in to comment.