Skip to content

Commit

Permalink
Reorganize flows (#1240)
Browse files Browse the repository at this point in the history
* Extract base docs and entity graph

* Move extracted entities and text units

* Move communities and community reports

* Move covariates and final documents

* Move entities, nodes, relationships

* Move text_units and summarized entities

* Assert all snapshot null cases

* Remove disabled steps util

* Remove incorrect use of input "others"

* Convert text_embed_df to just return the embeddings, not update the df

* Convert snapshot functions to noops

* Semver

* Remove lingering covariates_enabled param

* Name consistency

* Syntax cleanup
  • Loading branch information
natoverse authored Oct 2, 2024
1 parent 9070ea5 commit f5c5876
Show file tree
Hide file tree
Showing 51 changed files with 1,340 additions and 824 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241001194728312270.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Extract DataShaper-less flows."
}
4 changes: 4 additions & 0 deletions graphrag/index/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Core workflows without DataShaper wrappings."""
64 changes: 64 additions & 0 deletions graphrag/index/flows/create_base_documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Transform base documents by joining them with their text_units and adding optional attributes."""

import pandas as pd


def create_base_documents(
documents: pd.DataFrame,
text_units: pd.DataFrame,
document_attribute_columns: list[str] | None = None,
) -> pd.DataFrame:
"""Transform base documents by joining them with their text_units and adding optional attributes."""
exploded = (
text_units.explode("document_ids")
.loc[:, ["id", "document_ids", "text"]]
.rename(
columns={
"document_ids": "chunk_doc_id",
"id": "chunk_id",
"text": "chunk_text",
}
)
)

joined = exploded.merge(
documents,
left_on="chunk_doc_id",
right_on="id",
how="inner",
copy=False,
)

docs_with_text_units = joined.groupby("id", sort=False).agg(
text_units=("chunk_id", list)
)

rejoined = docs_with_text_units.merge(
documents,
on="id",
how="right",
copy=False,
).reset_index(drop=True)

rejoined.rename(columns={"text": "raw_content"}, inplace=True)
rejoined["id"] = rejoined["id"].astype(str)

# Convert attribute columns to strings and collapse them into a JSON object
if document_attribute_columns:
# Convert all specified columns to string at once
rejoined[document_attribute_columns] = rejoined[
document_attribute_columns
].astype(str)

# Collapse the document_attribute_columns into a single JSON object column
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
orient="records"
)

# Drop the original attribute columns after collapsing them
rejoined.drop(columns=document_attribute_columns, inplace=True)

return rejoined
74 changes: 74 additions & 0 deletions graphrag/index/flows/create_base_entity_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to create the base entity graph."""

from typing import Any, cast

import pandas as pd
from datashaper import (
VerbCallbacks,
)

from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df
from graphrag.index.verbs.graph.embed.embed_graph import embed_graph_df
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df


async def create_base_entity_graph(
entities: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
clustering_config: dict[str, Any],
embedding_config: dict[str, Any],
graphml_snapshot_enabled: bool = False,
embed_graph_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
clustering_strategy = clustering_config.get("strategy", {"type": "leiden"})

clustered = cluster_graph_df(
entities,
callbacks,
column="entity_graph",
strategy=clustering_strategy,
to="clustered_graph",
level_to="level",
)

if graphml_snapshot_enabled:
await snapshot_rows_df(
clustered,
column="clustered_graph",
base_name="clustered_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

embedding_strategy = embedding_config.get("strategy")
if embed_graph_enabled and embedding_strategy:
clustered = await embed_graph_df(
clustered,
callbacks,
column="clustered_graph",
strategy=embedding_strategy,
to="embeddings",
)

# take second snapshot after embedding
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph?
if graphml_snapshot_enabled:
await snapshot_rows_df(
clustered,
column="entity_graph",
base_name="embedded_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

final_columns = ["level", "clustered_graph"]
if embed_graph_enabled:
final_columns.append("embeddings")

return cast(pd.DataFrame, clustered[final_columns])
79 changes: 79 additions & 0 deletions graphrag/index/flows/create_base_extracted_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to extract and format covariates."""

from typing import Any

import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)

from graphrag.index.cache import PipelineCache
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df
from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df
from graphrag.index.verbs.snapshot import snapshot_df
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df


async def create_base_extracted_entities(
text_units: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
storage: PipelineStorage,
column: str,
id_column: str,
nodes: dict[str, Any],
edges: dict[str, Any],
strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
num_threads: int = 4,
) -> pd.DataFrame:
"""All the steps to extract and format covariates."""
entity_graph = await entity_extract_df(
text_units,
cache,
callbacks,
column=column,
id_column=id_column,
strategy=strategy,
async_mode=async_mode,
entity_types=entity_types,
to="entities",
graph_to="entity_graph",
num_threads=num_threads,
)

if raw_entity_snapshot_enabled:
await snapshot_df(
entity_graph,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)

merged_graph = merge_graphs_df(
entity_graph,
callbacks,
column="entity_graph",
to="entity_graph",
nodes=nodes,
edges=edges,
)

if graphml_snapshot_enabled:
await snapshot_rows_df(
merged_graph,
base_name="merged_graph",
column="entity_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

return merged_graph
71 changes: 71 additions & 0 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform base text_units."""

from typing import Any, cast

import pandas as pd
from datashaper import VerbCallbacks

from graphrag.index.verbs.genid import genid_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df


def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])

sort["text_with_ids"] = list(
zip(*[sort[col] for col in ["id", "text"]], strict=True)
)

aggregated = aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
aggregations=[
{
"column": "text_with_ids",
"operation": "array_agg",
"to": "texts",
}
],
)

chunked = chunk_df(
aggregated,
column="texts",
to="chunks",
callbacks=callbacks,
strategy=strategy,
)

chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
chunked = chunked.explode("chunks")
chunked.rename(
columns={
"chunks": chunk_column_name,
},
inplace=True,
)

chunked = genid_df(
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
)

chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
chunked[chunk_column_name].tolist(), index=chunked.index
)
chunked["id"] = chunked["chunk_id"]

return cast(
pd.DataFrame, chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
)
73 changes: 73 additions & 0 deletions graphrag/index/flows/create_final_communities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform final communities."""

import pandas as pd
from datashaper import (
VerbCallbacks,
)

from graphrag.index.verbs.graph.unpack import unpack_graph_df


def create_final_communities(
entity_graph: pd.DataFrame,
callbacks: VerbCallbacks,
) -> pd.DataFrame:
"""All the steps to transform final communities."""
graph_nodes = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "nodes")
graph_edges = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "edges")

# Merge graph_nodes with graph_edges for both source and target matches
source_clusters = graph_nodes.merge(
graph_edges, left_on="label", right_on="source", how="inner"
)

target_clusters = graph_nodes.merge(
graph_edges, left_on="label", right_on="target", how="inner"
)

# Concatenate the source and target clusters
clusters = pd.concat([source_clusters, target_clusters], ignore_index=True)

# Keep only rows where level_x == level_y
combined_clusters = clusters[
clusters["level_x"] == clusters["level_y"]
].reset_index(drop=True)

cluster_relationships = (
combined_clusters.groupby(["cluster", "level_x"], sort=False)
.agg(
relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique")
)
.reset_index()
)

all_clusters = (
graph_nodes.groupby(["cluster", "level"], sort=False)
.agg(id=("cluster", "first"))
.reset_index()
)

joined = all_clusters.merge(
cluster_relationships,
left_on="id",
right_on="cluster",
how="inner",
)

filtered = joined[joined["level"] == joined["level_x"]].reset_index(drop=True)

filtered["title"] = "Community " + filtered["id"].astype(str)

return filtered.loc[
:,
[
"id",
"title",
"level",
"relationship_ids",
"text_unit_ids",
],
]
Loading

0 comments on commit f5c5876

Please sign in to comment.