-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Extract base docs and entity graph * Move extracted entities and text units * Move communities and community reports * Move covariates and final documents * Move entities, nodes, relationships * Move text_units and summarized entities * Assert all snapshot null cases * Remove disabled steps util * Remove incorrect use of input "others" * Convert text_embed_df to just return the embeddings, not update the df * Convert snapshot functions to noops * Semver * Remove lingering covariates_enabled param * Name consistency * Syntax cleanup
- Loading branch information
Showing
51 changed files
with
1,340 additions
and
824 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"type": "patch", | ||
"description": "Extract DataShaper-less flows." | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""Core workflows without DataShaper wrappings.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""Transform base documents by joining them with their text_units and adding optional attributes.""" | ||
|
||
import pandas as pd | ||
|
||
|
||
def create_base_documents( | ||
documents: pd.DataFrame, | ||
text_units: pd.DataFrame, | ||
document_attribute_columns: list[str] | None = None, | ||
) -> pd.DataFrame: | ||
"""Transform base documents by joining them with their text_units and adding optional attributes.""" | ||
exploded = ( | ||
text_units.explode("document_ids") | ||
.loc[:, ["id", "document_ids", "text"]] | ||
.rename( | ||
columns={ | ||
"document_ids": "chunk_doc_id", | ||
"id": "chunk_id", | ||
"text": "chunk_text", | ||
} | ||
) | ||
) | ||
|
||
joined = exploded.merge( | ||
documents, | ||
left_on="chunk_doc_id", | ||
right_on="id", | ||
how="inner", | ||
copy=False, | ||
) | ||
|
||
docs_with_text_units = joined.groupby("id", sort=False).agg( | ||
text_units=("chunk_id", list) | ||
) | ||
|
||
rejoined = docs_with_text_units.merge( | ||
documents, | ||
on="id", | ||
how="right", | ||
copy=False, | ||
).reset_index(drop=True) | ||
|
||
rejoined.rename(columns={"text": "raw_content"}, inplace=True) | ||
rejoined["id"] = rejoined["id"].astype(str) | ||
|
||
# Convert attribute columns to strings and collapse them into a JSON object | ||
if document_attribute_columns: | ||
# Convert all specified columns to string at once | ||
rejoined[document_attribute_columns] = rejoined[ | ||
document_attribute_columns | ||
].astype(str) | ||
|
||
# Collapse the document_attribute_columns into a single JSON object column | ||
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict( | ||
orient="records" | ||
) | ||
|
||
# Drop the original attribute columns after collapsing them | ||
rejoined.drop(columns=document_attribute_columns, inplace=True) | ||
|
||
return rejoined |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""All the steps to create the base entity graph.""" | ||
|
||
from typing import Any, cast | ||
|
||
import pandas as pd | ||
from datashaper import ( | ||
VerbCallbacks, | ||
) | ||
|
||
from graphrag.index.storage import PipelineStorage | ||
from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df | ||
from graphrag.index.verbs.graph.embed.embed_graph import embed_graph_df | ||
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df | ||
|
||
|
||
async def create_base_entity_graph( | ||
entities: pd.DataFrame, | ||
callbacks: VerbCallbacks, | ||
storage: PipelineStorage, | ||
clustering_config: dict[str, Any], | ||
embedding_config: dict[str, Any], | ||
graphml_snapshot_enabled: bool = False, | ||
embed_graph_enabled: bool = False, | ||
) -> pd.DataFrame: | ||
"""All the steps to create the base entity graph.""" | ||
clustering_strategy = clustering_config.get("strategy", {"type": "leiden"}) | ||
|
||
clustered = cluster_graph_df( | ||
entities, | ||
callbacks, | ||
column="entity_graph", | ||
strategy=clustering_strategy, | ||
to="clustered_graph", | ||
level_to="level", | ||
) | ||
|
||
if graphml_snapshot_enabled: | ||
await snapshot_rows_df( | ||
clustered, | ||
column="clustered_graph", | ||
base_name="clustered_graph", | ||
storage=storage, | ||
formats=[{"format": "text", "extension": "graphml"}], | ||
) | ||
|
||
embedding_strategy = embedding_config.get("strategy") | ||
if embed_graph_enabled and embedding_strategy: | ||
clustered = await embed_graph_df( | ||
clustered, | ||
callbacks, | ||
column="clustered_graph", | ||
strategy=embedding_strategy, | ||
to="embeddings", | ||
) | ||
|
||
# take second snapshot after embedding | ||
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph? | ||
if graphml_snapshot_enabled: | ||
await snapshot_rows_df( | ||
clustered, | ||
column="entity_graph", | ||
base_name="embedded_graph", | ||
storage=storage, | ||
formats=[{"format": "text", "extension": "graphml"}], | ||
) | ||
|
||
final_columns = ["level", "clustered_graph"] | ||
if embed_graph_enabled: | ||
final_columns.append("embeddings") | ||
|
||
return cast(pd.DataFrame, clustered[final_columns]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""All the steps to extract and format covariates.""" | ||
|
||
from typing import Any | ||
|
||
import pandas as pd | ||
from datashaper import ( | ||
AsyncType, | ||
VerbCallbacks, | ||
) | ||
|
||
from graphrag.index.cache import PipelineCache | ||
from graphrag.index.storage import PipelineStorage | ||
from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df | ||
from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df | ||
from graphrag.index.verbs.snapshot import snapshot_df | ||
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df | ||
|
||
|
||
async def create_base_extracted_entities( | ||
text_units: pd.DataFrame, | ||
cache: PipelineCache, | ||
callbacks: VerbCallbacks, | ||
storage: PipelineStorage, | ||
column: str, | ||
id_column: str, | ||
nodes: dict[str, Any], | ||
edges: dict[str, Any], | ||
strategy: dict[str, Any] | None, | ||
async_mode: AsyncType = AsyncType.AsyncIO, | ||
entity_types: list[str] | None = None, | ||
graphml_snapshot_enabled: bool = False, | ||
raw_entity_snapshot_enabled: bool = False, | ||
num_threads: int = 4, | ||
) -> pd.DataFrame: | ||
"""All the steps to extract and format covariates.""" | ||
entity_graph = await entity_extract_df( | ||
text_units, | ||
cache, | ||
callbacks, | ||
column=column, | ||
id_column=id_column, | ||
strategy=strategy, | ||
async_mode=async_mode, | ||
entity_types=entity_types, | ||
to="entities", | ||
graph_to="entity_graph", | ||
num_threads=num_threads, | ||
) | ||
|
||
if raw_entity_snapshot_enabled: | ||
await snapshot_df( | ||
entity_graph, | ||
name="raw_extracted_entities", | ||
storage=storage, | ||
formats=["json"], | ||
) | ||
|
||
merged_graph = merge_graphs_df( | ||
entity_graph, | ||
callbacks, | ||
column="entity_graph", | ||
to="entity_graph", | ||
nodes=nodes, | ||
edges=edges, | ||
) | ||
|
||
if graphml_snapshot_enabled: | ||
await snapshot_rows_df( | ||
merged_graph, | ||
base_name="merged_graph", | ||
column="entity_graph", | ||
storage=storage, | ||
formats=[{"format": "text", "extension": "graphml"}], | ||
) | ||
|
||
return merged_graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""All the steps to transform base text_units.""" | ||
|
||
from typing import Any, cast | ||
|
||
import pandas as pd | ||
from datashaper import VerbCallbacks | ||
|
||
from graphrag.index.verbs.genid import genid_df | ||
from graphrag.index.verbs.overrides.aggregate import aggregate_df | ||
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df | ||
|
||
|
||
def create_base_text_units( | ||
documents: pd.DataFrame, | ||
callbacks: VerbCallbacks, | ||
chunk_column_name: str, | ||
n_tokens_column_name: str, | ||
chunk_by_columns: list[str], | ||
strategy: dict[str, Any] | None = None, | ||
) -> pd.DataFrame: | ||
"""All the steps to transform base text_units.""" | ||
sort = documents.sort_values(by=["id"], ascending=[True]) | ||
|
||
sort["text_with_ids"] = list( | ||
zip(*[sort[col] for col in ["id", "text"]], strict=True) | ||
) | ||
|
||
aggregated = aggregate_df( | ||
sort, | ||
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None, | ||
aggregations=[ | ||
{ | ||
"column": "text_with_ids", | ||
"operation": "array_agg", | ||
"to": "texts", | ||
} | ||
], | ||
) | ||
|
||
chunked = chunk_df( | ||
aggregated, | ||
column="texts", | ||
to="chunks", | ||
callbacks=callbacks, | ||
strategy=strategy, | ||
) | ||
|
||
chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]]) | ||
chunked = chunked.explode("chunks") | ||
chunked.rename( | ||
columns={ | ||
"chunks": chunk_column_name, | ||
}, | ||
inplace=True, | ||
) | ||
|
||
chunked = genid_df( | ||
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name] | ||
) | ||
|
||
chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame( | ||
chunked[chunk_column_name].tolist(), index=chunked.index | ||
) | ||
chunked["id"] = chunked["chunk_id"] | ||
|
||
return cast( | ||
pd.DataFrame, chunked[chunked[chunk_column_name].notna()].reset_index(drop=True) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
"""All the steps to transform final communities.""" | ||
|
||
import pandas as pd | ||
from datashaper import ( | ||
VerbCallbacks, | ||
) | ||
|
||
from graphrag.index.verbs.graph.unpack import unpack_graph_df | ||
|
||
|
||
def create_final_communities( | ||
entity_graph: pd.DataFrame, | ||
callbacks: VerbCallbacks, | ||
) -> pd.DataFrame: | ||
"""All the steps to transform final communities.""" | ||
graph_nodes = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "nodes") | ||
graph_edges = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "edges") | ||
|
||
# Merge graph_nodes with graph_edges for both source and target matches | ||
source_clusters = graph_nodes.merge( | ||
graph_edges, left_on="label", right_on="source", how="inner" | ||
) | ||
|
||
target_clusters = graph_nodes.merge( | ||
graph_edges, left_on="label", right_on="target", how="inner" | ||
) | ||
|
||
# Concatenate the source and target clusters | ||
clusters = pd.concat([source_clusters, target_clusters], ignore_index=True) | ||
|
||
# Keep only rows where level_x == level_y | ||
combined_clusters = clusters[ | ||
clusters["level_x"] == clusters["level_y"] | ||
].reset_index(drop=True) | ||
|
||
cluster_relationships = ( | ||
combined_clusters.groupby(["cluster", "level_x"], sort=False) | ||
.agg( | ||
relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique") | ||
) | ||
.reset_index() | ||
) | ||
|
||
all_clusters = ( | ||
graph_nodes.groupby(["cluster", "level"], sort=False) | ||
.agg(id=("cluster", "first")) | ||
.reset_index() | ||
) | ||
|
||
joined = all_clusters.merge( | ||
cluster_relationships, | ||
left_on="id", | ||
right_on="cluster", | ||
how="inner", | ||
) | ||
|
||
filtered = joined[joined["level"] == joined["level_x"]].reset_index(drop=True) | ||
|
||
filtered["title"] = "Community " + filtered["id"].astype(str) | ||
|
||
return filtered.loc[ | ||
:, | ||
[ | ||
"id", | ||
"title", | ||
"level", | ||
"relationship_ids", | ||
"text_unit_ids", | ||
], | ||
] |
Oops, something went wrong.