From f5c5876ddef0ce084c7d10ca0475af706f8ee0de Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Wed, 2 Oct 2024 08:57:08 -0700 Subject: [PATCH] Reorganize flows (#1240) * Extract base docs and entity graph * Move extracted entities and text units * Move communities and community reports * Move covariates and final documents * Move entities, nodes, relationships * Move text_units and summarized entities * Assert all snapshot null cases * Remove disabled steps util * Remove incorrect use of input "others" * Convert text_embed_df to just return the embeddings, not update the df * Convert snapshot functions to noops * Semver * Remove lingering covariates_enabled param * Name consistency * Syntax cleanup --- .../patch-20241001194728312270.json | 4 + graphrag/index/flows/__init__.py | 4 + graphrag/index/flows/create_base_documents.py | 64 +++++++ .../index/flows/create_base_entity_graph.py | 74 ++++++++ .../flows/create_base_extracted_entities.py | 79 +++++++++ .../index/flows/create_base_text_units.py | 71 ++++++++ .../index/flows/create_final_communities.py | 73 ++++++++ .../flows/create_final_community_reports.py | 166 ++++++++++++++++++ .../index/flows/create_final_covariates.py | 69 ++++++++ .../index/flows/create_final_documents.py | 33 ++++ graphrag/index/flows/create_final_entities.py | 80 +++++++++ graphrag/index/flows/create_final_nodes.py | 72 ++++++++ .../index/flows/create_final_relationships.py | 69 ++++++++ .../index/flows/create_final_text_units.py | 119 +++++++++++++ .../index/flows/create_summarized_entities.py | 50 ++++++ graphrag/index/operations/__init__.py | 4 + .../extract_covariates/extract_covariates.py | 4 +- .../entities/extraction/entity_extract.py | 4 +- graphrag/index/verbs/snapshot.py | 8 +- graphrag/index/verbs/snapshot_rows.py | 8 +- graphrag/index/verbs/text/embed/text_embed.py | 37 ++-- graphrag/index/workflows/load.py | 7 - .../workflows/v1/create_base_documents.py | 2 +- .../v1/create_final_community_reports.py | 13 +- .../workflows/v1/create_final_documents.py | 5 +- .../workflows/v1/create_final_entities.py | 8 +- .../v1/create_final_relationships.py | 5 +- .../workflows/v1/create_final_text_units.py | 23 ++- .../v1/subflows/create_base_documents.py | 58 +----- .../v1/subflows/create_base_entity_graph.py | 59 ++----- .../create_base_extracted_entities.py | 63 +++---- .../v1/subflows/create_base_text_units.py | 63 ++----- .../v1/subflows/create_final_communities.py | 63 +------ .../create_final_community_reports.py | 163 +++-------------- .../v1/subflows/create_final_covariates.py | 40 +---- .../v1/subflows/create_final_documents.py | 26 ++- .../v1/subflows/create_final_entities.py | 75 ++------ .../v1/subflows/create_final_nodes.py | 62 ++----- .../v1/subflows/create_final_relationships.py | 78 ++------ .../v1/subflows/create_final_text_units.py | 119 +++---------- .../v1/subflows/create_summarized_entities.py | 32 ++-- tests/verbs/test_create_base_entity_graph.py | 12 +- .../test_create_base_extracted_entities.py | 5 + .../test_create_final_community_reports.py | 5 +- tests/verbs/test_create_final_documents.py | 5 +- tests/verbs/test_create_final_entities.py | 9 +- tests/verbs/test_create_final_nodes.py | 41 ++++- .../verbs/test_create_final_relationships.py | 5 +- tests/verbs/test_create_final_text_units.py | 7 +- .../verbs/test_create_summarized_entities.py | 10 +- tests/verbs/util.py | 9 +- 51 files changed, 1340 insertions(+), 824 deletions(-) create mode 100644 .semversioner/next-release/patch-20241001194728312270.json create mode 100644 graphrag/index/flows/__init__.py create mode 100644 graphrag/index/flows/create_base_documents.py create mode 100644 graphrag/index/flows/create_base_entity_graph.py create mode 100644 graphrag/index/flows/create_base_extracted_entities.py create mode 100644 graphrag/index/flows/create_base_text_units.py create mode 100644 graphrag/index/flows/create_final_communities.py create mode 100644 graphrag/index/flows/create_final_community_reports.py create mode 100644 graphrag/index/flows/create_final_covariates.py create mode 100644 graphrag/index/flows/create_final_documents.py create mode 100644 graphrag/index/flows/create_final_entities.py create mode 100644 graphrag/index/flows/create_final_nodes.py create mode 100644 graphrag/index/flows/create_final_relationships.py create mode 100644 graphrag/index/flows/create_final_text_units.py create mode 100644 graphrag/index/flows/create_summarized_entities.py create mode 100644 graphrag/index/operations/__init__.py diff --git a/.semversioner/next-release/patch-20241001194728312270.json b/.semversioner/next-release/patch-20241001194728312270.json new file mode 100644 index 0000000000..5501149eb4 --- /dev/null +++ b/.semversioner/next-release/patch-20241001194728312270.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Extract DataShaper-less flows." +} diff --git a/graphrag/index/flows/__init__.py b/graphrag/index/flows/__init__.py new file mode 100644 index 0000000000..b09c865054 --- /dev/null +++ b/graphrag/index/flows/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Core workflows without DataShaper wrappings.""" diff --git a/graphrag/index/flows/create_base_documents.py b/graphrag/index/flows/create_base_documents.py new file mode 100644 index 0000000000..3f1ba29e36 --- /dev/null +++ b/graphrag/index/flows/create_base_documents.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Transform base documents by joining them with their text_units and adding optional attributes.""" + +import pandas as pd + + +def create_base_documents( + documents: pd.DataFrame, + text_units: pd.DataFrame, + document_attribute_columns: list[str] | None = None, +) -> pd.DataFrame: + """Transform base documents by joining them with their text_units and adding optional attributes.""" + exploded = ( + text_units.explode("document_ids") + .loc[:, ["id", "document_ids", "text"]] + .rename( + columns={ + "document_ids": "chunk_doc_id", + "id": "chunk_id", + "text": "chunk_text", + } + ) + ) + + joined = exploded.merge( + documents, + left_on="chunk_doc_id", + right_on="id", + how="inner", + copy=False, + ) + + docs_with_text_units = joined.groupby("id", sort=False).agg( + text_units=("chunk_id", list) + ) + + rejoined = docs_with_text_units.merge( + documents, + on="id", + how="right", + copy=False, + ).reset_index(drop=True) + + rejoined.rename(columns={"text": "raw_content"}, inplace=True) + rejoined["id"] = rejoined["id"].astype(str) + + # Convert attribute columns to strings and collapse them into a JSON object + if document_attribute_columns: + # Convert all specified columns to string at once + rejoined[document_attribute_columns] = rejoined[ + document_attribute_columns + ].astype(str) + + # Collapse the document_attribute_columns into a single JSON object column + rejoined["attributes"] = rejoined[document_attribute_columns].to_dict( + orient="records" + ) + + # Drop the original attribute columns after collapsing them + rejoined.drop(columns=document_attribute_columns, inplace=True) + + return rejoined diff --git a/graphrag/index/flows/create_base_entity_graph.py b/graphrag/index/flows/create_base_entity_graph.py new file mode 100644 index 0000000000..052023191e --- /dev/null +++ b/graphrag/index/flows/create_base_entity_graph.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to create the base entity graph.""" + +from typing import Any, cast + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.storage import PipelineStorage +from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df +from graphrag.index.verbs.graph.embed.embed_graph import embed_graph_df +from graphrag.index.verbs.snapshot_rows import snapshot_rows_df + + +async def create_base_entity_graph( + entities: pd.DataFrame, + callbacks: VerbCallbacks, + storage: PipelineStorage, + clustering_config: dict[str, Any], + embedding_config: dict[str, Any], + graphml_snapshot_enabled: bool = False, + embed_graph_enabled: bool = False, +) -> pd.DataFrame: + """All the steps to create the base entity graph.""" + clustering_strategy = clustering_config.get("strategy", {"type": "leiden"}) + + clustered = cluster_graph_df( + entities, + callbacks, + column="entity_graph", + strategy=clustering_strategy, + to="clustered_graph", + level_to="level", + ) + + if graphml_snapshot_enabled: + await snapshot_rows_df( + clustered, + column="clustered_graph", + base_name="clustered_graph", + storage=storage, + formats=[{"format": "text", "extension": "graphml"}], + ) + + embedding_strategy = embedding_config.get("strategy") + if embed_graph_enabled and embedding_strategy: + clustered = await embed_graph_df( + clustered, + callbacks, + column="clustered_graph", + strategy=embedding_strategy, + to="embeddings", + ) + + # take second snapshot after embedding + # todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph? + if graphml_snapshot_enabled: + await snapshot_rows_df( + clustered, + column="entity_graph", + base_name="embedded_graph", + storage=storage, + formats=[{"format": "text", "extension": "graphml"}], + ) + + final_columns = ["level", "clustered_graph"] + if embed_graph_enabled: + final_columns.append("embeddings") + + return cast(pd.DataFrame, clustered[final_columns]) diff --git a/graphrag/index/flows/create_base_extracted_entities.py b/graphrag/index/flows/create_base_extracted_entities.py new file mode 100644 index 0000000000..b538f18fb9 --- /dev/null +++ b/graphrag/index/flows/create_base_extracted_entities.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to extract and format covariates.""" + +from typing import Any + +import pandas as pd +from datashaper import ( + AsyncType, + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.storage import PipelineStorage +from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df +from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df +from graphrag.index.verbs.snapshot import snapshot_df +from graphrag.index.verbs.snapshot_rows import snapshot_rows_df + + +async def create_base_extracted_entities( + text_units: pd.DataFrame, + cache: PipelineCache, + callbacks: VerbCallbacks, + storage: PipelineStorage, + column: str, + id_column: str, + nodes: dict[str, Any], + edges: dict[str, Any], + strategy: dict[str, Any] | None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + graphml_snapshot_enabled: bool = False, + raw_entity_snapshot_enabled: bool = False, + num_threads: int = 4, +) -> pd.DataFrame: + """All the steps to extract and format covariates.""" + entity_graph = await entity_extract_df( + text_units, + cache, + callbacks, + column=column, + id_column=id_column, + strategy=strategy, + async_mode=async_mode, + entity_types=entity_types, + to="entities", + graph_to="entity_graph", + num_threads=num_threads, + ) + + if raw_entity_snapshot_enabled: + await snapshot_df( + entity_graph, + name="raw_extracted_entities", + storage=storage, + formats=["json"], + ) + + merged_graph = merge_graphs_df( + entity_graph, + callbacks, + column="entity_graph", + to="entity_graph", + nodes=nodes, + edges=edges, + ) + + if graphml_snapshot_enabled: + await snapshot_rows_df( + merged_graph, + base_name="merged_graph", + column="entity_graph", + storage=storage, + formats=[{"format": "text", "extension": "graphml"}], + ) + + return merged_graph diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py new file mode 100644 index 0000000000..091b922105 --- /dev/null +++ b/graphrag/index/flows/create_base_text_units.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform base text_units.""" + +from typing import Any, cast + +import pandas as pd +from datashaper import VerbCallbacks + +from graphrag.index.verbs.genid import genid_df +from graphrag.index.verbs.overrides.aggregate import aggregate_df +from graphrag.index.verbs.text.chunk.text_chunk import chunk_df + + +def create_base_text_units( + documents: pd.DataFrame, + callbacks: VerbCallbacks, + chunk_column_name: str, + n_tokens_column_name: str, + chunk_by_columns: list[str], + strategy: dict[str, Any] | None = None, +) -> pd.DataFrame: + """All the steps to transform base text_units.""" + sort = documents.sort_values(by=["id"], ascending=[True]) + + sort["text_with_ids"] = list( + zip(*[sort[col] for col in ["id", "text"]], strict=True) + ) + + aggregated = aggregate_df( + sort, + groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None, + aggregations=[ + { + "column": "text_with_ids", + "operation": "array_agg", + "to": "texts", + } + ], + ) + + chunked = chunk_df( + aggregated, + column="texts", + to="chunks", + callbacks=callbacks, + strategy=strategy, + ) + + chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]]) + chunked = chunked.explode("chunks") + chunked.rename( + columns={ + "chunks": chunk_column_name, + }, + inplace=True, + ) + + chunked = genid_df( + chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name] + ) + + chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame( + chunked[chunk_column_name].tolist(), index=chunked.index + ) + chunked["id"] = chunked["chunk_id"] + + return cast( + pd.DataFrame, chunked[chunked[chunk_column_name].notna()].reset_index(drop=True) + ) diff --git a/graphrag/index/flows/create_final_communities.py b/graphrag/index/flows/create_final_communities.py new file mode 100644 index 0000000000..5165412133 --- /dev/null +++ b/graphrag/index/flows/create_final_communities.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final communities.""" + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.verbs.graph.unpack import unpack_graph_df + + +def create_final_communities( + entity_graph: pd.DataFrame, + callbacks: VerbCallbacks, +) -> pd.DataFrame: + """All the steps to transform final communities.""" + graph_nodes = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "nodes") + graph_edges = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "edges") + + # Merge graph_nodes with graph_edges for both source and target matches + source_clusters = graph_nodes.merge( + graph_edges, left_on="label", right_on="source", how="inner" + ) + + target_clusters = graph_nodes.merge( + graph_edges, left_on="label", right_on="target", how="inner" + ) + + # Concatenate the source and target clusters + clusters = pd.concat([source_clusters, target_clusters], ignore_index=True) + + # Keep only rows where level_x == level_y + combined_clusters = clusters[ + clusters["level_x"] == clusters["level_y"] + ].reset_index(drop=True) + + cluster_relationships = ( + combined_clusters.groupby(["cluster", "level_x"], sort=False) + .agg( + relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique") + ) + .reset_index() + ) + + all_clusters = ( + graph_nodes.groupby(["cluster", "level"], sort=False) + .agg(id=("cluster", "first")) + .reset_index() + ) + + joined = all_clusters.merge( + cluster_relationships, + left_on="id", + right_on="cluster", + how="inner", + ) + + filtered = joined[joined["level"] == joined["level_x"]].reset_index(drop=True) + + filtered["title"] = "Community " + filtered["id"].astype(str) + + return filtered.loc[ + :, + [ + "id", + "title", + "level", + "relationship_ids", + "text_unit_ids", + ], + ] diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py new file mode 100644 index 0000000000..911ae60ac8 --- /dev/null +++ b/graphrag/index/flows/create_final_community_reports.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform community reports.""" + +from uuid import uuid4 + +import pandas as pd +from datashaper import ( + AsyncType, + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.community_reports.schemas import ( + CLAIM_DESCRIPTION, + CLAIM_DETAILS, + CLAIM_ID, + CLAIM_STATUS, + CLAIM_SUBJECT, + CLAIM_TYPE, + EDGE_DEGREE, + EDGE_DESCRIPTION, + EDGE_DETAILS, + EDGE_ID, + EDGE_SOURCE, + EDGE_TARGET, + NODE_DEGREE, + NODE_DESCRIPTION, + NODE_DETAILS, + NODE_ID, + NODE_NAME, +) +from graphrag.index.verbs.graph.report.create_community_reports import ( + create_community_reports_df, +) +from graphrag.index.verbs.graph.report.prepare_community_reports import ( + prepare_community_reports_df, +) +from graphrag.index.verbs.graph.report.restore_community_hierarchy import ( + restore_community_hierarchy_df, +) +from graphrag.index.verbs.text.embed.text_embed import text_embed_df + + +async def create_final_community_reports( + nodes_input: pd.DataFrame, + edges_input: pd.DataFrame, + claims_input: pd.DataFrame | None, + callbacks: VerbCallbacks, + cache: PipelineCache, + strategy: dict, + async_mode: AsyncType = AsyncType.AsyncIO, + num_threads: int = 4, + full_content_text_embed: dict | None = None, + summary_text_embed: dict | None = None, + title_text_embed: dict | None = None, +) -> pd.DataFrame: + """All the steps to transform community reports.""" + nodes = _prep_nodes(nodes_input) + edges = _prep_edges(edges_input) + + claims = None + if claims_input is not None: + claims = _prep_claims(claims_input) + + community_hierarchy = restore_community_hierarchy_df(nodes) + + local_contexts = prepare_community_reports_df( + nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000) + ) + + community_reports = await create_community_reports_df( + local_contexts, + nodes, + community_hierarchy, + callbacks, + cache, + strategy, + async_mode=async_mode, + num_threads=num_threads, + ) + + community_reports["id"] = community_reports["community"].apply( + lambda _x: str(uuid4()) + ) + + # Embed full content if not skipped + if full_content_text_embed: + community_reports["full_content_embedding"] = await text_embed_df( + community_reports, + callbacks, + cache, + column="full_content", + strategy=full_content_text_embed["strategy"], + embedding_name="community_report_full_content", + ) + + # Embed summary if not skipped + if summary_text_embed: + community_reports["summary_embedding"] = await text_embed_df( + community_reports, + callbacks, + cache, + column="summary", + strategy=summary_text_embed["strategy"], + embedding_name="community_report_summary", + ) + + # Embed title if not skipped + if title_text_embed: + community_reports["title_embedding"] = await text_embed_df( + community_reports, + callbacks, + cache, + column="title", + strategy=title_text_embed["strategy"], + embedding_name="community_report_title", + ) + + return community_reports + + +def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame: + input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) + # merge values of four columns into a map column + input[NODE_DETAILS] = input.apply( + lambda x: { + NODE_ID: x[NODE_ID], + NODE_NAME: x[NODE_NAME], + NODE_DESCRIPTION: x[NODE_DESCRIPTION], + NODE_DEGREE: x[NODE_DEGREE], + }, + axis=1, + ) + return input + + +def _prep_edges(input: pd.DataFrame) -> pd.DataFrame: + input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) + input[EDGE_DETAILS] = input.apply( + lambda x: { + EDGE_ID: x[EDGE_ID], + EDGE_SOURCE: x[EDGE_SOURCE], + EDGE_TARGET: x[EDGE_TARGET], + EDGE_DESCRIPTION: x[EDGE_DESCRIPTION], + EDGE_DEGREE: x[EDGE_DEGREE], + }, + axis=1, + ) + return input + + +def _prep_claims(input: pd.DataFrame) -> pd.DataFrame: + input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) + input[CLAIM_DETAILS] = input.apply( + lambda x: { + CLAIM_ID: x[CLAIM_ID], + CLAIM_SUBJECT: x[CLAIM_SUBJECT], + CLAIM_TYPE: x[CLAIM_TYPE], + CLAIM_STATUS: x[CLAIM_STATUS], + CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION], + }, + axis=1, + ) + return input diff --git a/graphrag/index/flows/create_final_covariates.py b/graphrag/index/flows/create_final_covariates.py new file mode 100644 index 0000000000..98b352e48f --- /dev/null +++ b/graphrag/index/flows/create_final_covariates.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to extract and format covariates.""" + +from typing import Any, cast +from uuid import uuid4 + +import pandas as pd +from datashaper import ( + AsyncType, + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.covariates.extract_covariates.extract_covariates import ( + extract_covariates_df, +) + + +async def create_final_covariates( + text_units: pd.DataFrame, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + covariate_type: str, + strategy: dict[str, Any] | None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + num_threads: int = 4, +) -> pd.DataFrame: + """All the steps to extract and format covariates.""" + covariates = await extract_covariates_df( + text_units, + cache, + callbacks, + column, + covariate_type, + strategy, + async_mode, + entity_types, + num_threads, + ) + + covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4())) + covariates["human_readable_id"] = (covariates.index + 1).astype(str) + covariates.rename(columns={"chunk_id": "text_unit_id"}, inplace=True) + + return cast( + pd.DataFrame, + covariates[ + [ + "id", + "human_readable_id", + "covariate_type", + "type", + "description", + "subject_id", + "object_id", + "status", + "start_date", + "end_date", + "source_text", + "text_unit_id", + "document_ids", + "n_tokens", + ] + ], + ) diff --git a/graphrag/index/flows/create_final_documents.py b/graphrag/index/flows/create_final_documents.py new file mode 100644 index 0000000000..a4456dddae --- /dev/null +++ b/graphrag/index/flows/create_final_documents.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final documents.""" + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.text.embed.text_embed import text_embed_df + + +async def create_final_documents( + documents: pd.DataFrame, + callbacks: VerbCallbacks, + cache: PipelineCache, + text_embed: dict | None = None, +) -> pd.DataFrame: + """All the steps to transform final documents.""" + documents.rename(columns={"text_units": "text_unit_ids"}, inplace=True) + + if text_embed: + documents["raw_content_embedding"] = await text_embed_df( + documents, + callbacks, + cache, + column="raw_content", + strategy=text_embed["strategy"], + ) + + return documents diff --git a/graphrag/index/flows/create_final_entities.py b/graphrag/index/flows/create_final_entities.py new file mode 100644 index 0000000000..a9348e33da --- /dev/null +++ b/graphrag/index/flows/create_final_entities.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final entities.""" + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.graph.unpack import unpack_graph_df +from graphrag.index.verbs.text.embed.text_embed import text_embed_df +from graphrag.index.verbs.text.split import text_split_df + + +async def create_final_entities( + entity_graph: pd.DataFrame, + callbacks: VerbCallbacks, + cache: PipelineCache, + name_text_embed: dict, + description_text_embed: dict, +) -> pd.DataFrame: + """All the steps to transform final entities.""" + # Process nodes + nodes = ( + unpack_graph_df(entity_graph, callbacks, "clustered_graph", "nodes") + .rename(columns={"label": "name"}) + .loc[ + :, + [ + "id", + "name", + "type", + "description", + "human_readable_id", + "graph_embedding", + "source_id", + ], + ] + .drop_duplicates(subset="id") + ) + + nodes = nodes.loc[nodes["name"].notna()] + + # Split 'source_id' column into 'text_unit_ids' + nodes = text_split_df( + nodes, column="source_id", separator=",", to="text_unit_ids" + ).drop(columns=["source_id"]) + + # Embed name if not skipped + if name_text_embed: + nodes["name_embedding"] = await text_embed_df( + nodes, + callbacks, + cache, + column="name", + strategy=name_text_embed["strategy"], + embedding_name="entity_name", + ) + + # Embed description if not skipped + if description_text_embed: + # Concatenate 'name' and 'description' and embed + nodes["name_description"] = nodes["name"] + ":" + nodes["description"] + nodes["description_embedding"] = await text_embed_df( + nodes, + callbacks, + cache, + column="name_description", + strategy=description_text_embed["strategy"], + embedding_name="entity_name_description", + ) + + # Drop rows with NaN 'description_embedding' if not using vector store + if not description_text_embed.get("strategy", {}).get("vector_store"): + nodes = nodes.loc[nodes["description_embedding"].notna()] + nodes.drop(columns="name_description", inplace=True) + + return nodes diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py new file mode 100644 index 0000000000..4597a6f074 --- /dev/null +++ b/graphrag/index/flows/create_final_nodes.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final nodes.""" + +from typing import Any, cast + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.storage import PipelineStorage +from graphrag.index.verbs.graph.layout.layout_graph import layout_graph_df +from graphrag.index.verbs.graph.unpack import unpack_graph_df +from graphrag.index.verbs.snapshot import snapshot_df + + +async def create_final_nodes( + entity_graph: pd.DataFrame, + callbacks: VerbCallbacks, + storage: PipelineStorage, + strategy: dict[str, Any], + level_for_node_positions: int, + snapshot_top_level_nodes: bool = False, +) -> pd.DataFrame: + """All the steps to transform final nodes.""" + laid_out_entity_graph = cast( + pd.DataFrame, + layout_graph_df( + entity_graph, + callbacks, + strategy, + embeddings_column="embeddings", + graph_column="clustered_graph", + to="node_positions", + graph_to="positioned_graph", + ), + ) + + nodes = cast( + pd.DataFrame, + unpack_graph_df( + laid_out_entity_graph, callbacks, column="positioned_graph", type="nodes" + ), + ) + + nodes_without_positions = nodes.drop(columns=["x", "y"]) + + nodes = nodes[nodes["level"] == level_for_node_positions].reset_index(drop=True) + nodes = cast(pd.DataFrame, nodes[["id", "x", "y"]]) + + if snapshot_top_level_nodes: + await snapshot_df( + nodes, + name="top_level_nodes", + storage=storage, + formats=["json"], + ) + + nodes.rename(columns={"id": "top_level_node_id"}, inplace=True) + nodes["top_level_node_id"] = nodes["top_level_node_id"].astype(str) + + joined = nodes_without_positions.merge( + nodes, + left_on="id", + right_on="top_level_node_id", + how="inner", + ) + joined.rename(columns={"label": "title", "cluster": "community"}, inplace=True) + + return joined diff --git a/graphrag/index/flows/create_final_relationships.py b/graphrag/index/flows/create_final_relationships.py new file mode 100644 index 0000000000..a94ec6f3dd --- /dev/null +++ b/graphrag/index/flows/create_final_relationships.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final relationships.""" + +from typing import cast + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.graph.compute_edge_combined_degree import ( + compute_edge_combined_degree_df, +) +from graphrag.index.verbs.graph.unpack import unpack_graph_df +from graphrag.index.verbs.text.embed.text_embed import text_embed_df + + +async def create_final_relationships( + entity_graph: pd.DataFrame, + nodes: pd.DataFrame, + callbacks: VerbCallbacks, + cache: PipelineCache, + text_embed: dict | None = None, +) -> pd.DataFrame: + """All the steps to transform final relationships.""" + graph_edges = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "edges") + + graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True) + + filtered = cast( + pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True) + ) + + if text_embed: + filtered["description_embedding"] = await text_embed_df( + filtered, + callbacks, + cache, + column="description", + strategy=text_embed["strategy"], + embedding_name="relationship_description", + ) + + pruned_edges = filtered.drop(columns=["level"]) + + filtered_nodes = nodes[nodes["level"] == 0].reset_index(drop=True) + filtered_nodes = cast(pd.DataFrame, filtered_nodes[["title", "degree"]]) + + edge_combined_degree = compute_edge_combined_degree_df( + pruned_edges, + filtered_nodes, + to="rank", + node_name_column="title", + node_degree_column="degree", + edge_source_column="source", + edge_target_column="target", + ) + + edge_combined_degree["human_readable_id"] = edge_combined_degree[ + "human_readable_id" + ].astype(str) + edge_combined_degree["text_unit_ids"] = edge_combined_degree[ + "text_unit_ids" + ].str.split(",") + + return edge_combined_degree diff --git a/graphrag/index/flows/create_final_text_units.py b/graphrag/index/flows/create_final_text_units.py new file mode 100644 index 0000000000..93e5340a94 --- /dev/null +++ b/graphrag/index/flows/create_final_text_units.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform the text units.""" + +from typing import cast + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.text.embed.text_embed import text_embed_df + + +async def create_final_text_units( + text_units: pd.DataFrame, + final_entities: pd.DataFrame, + final_relationships: pd.DataFrame, + final_covariates: pd.DataFrame | None, + callbacks: VerbCallbacks, + cache: PipelineCache, + text_embed: dict | None = None, +) -> pd.DataFrame: + """All the steps to transform the text units.""" + selected = text_units.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename( + columns={"chunk": "text"} + ) + + entity_join = _entities(final_entities) + relationship_join = _relationships(final_relationships) + + entity_joined = _join(selected, entity_join) + relationship_joined = _join(entity_joined, relationship_join) + final_joined = relationship_joined + + if final_covariates is not None: + covariate_join = _covariates(final_covariates) + final_joined = _join(relationship_joined, covariate_join) + + aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index() + + is_using_vector_store = False + if text_embed: + aggregated["text_embedding"] = await text_embed_df( + aggregated, + callbacks, + cache, + column="text", + strategy=text_embed["strategy"], + ) + is_using_vector_store = ( + text_embed.get("strategy", {}).get("vector_store", None) is not None + ) + + return cast( + pd.DataFrame, + aggregated[ + [ + "id", + "text", + *( + [] + if (not text_embed or is_using_vector_store) + else ["text_embedding"] + ), + "n_tokens", + "document_ids", + "entity_ids", + "relationship_ids", + *([] if final_covariates is None else ["covariate_ids"]), + ] + ], + ) + + +def _entities(df: pd.DataFrame) -> pd.DataFrame: + selected = df.loc[:, ["id", "text_unit_ids"]] + unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True) + + return ( + unrolled.groupby("text_unit_ids", sort=False) + .agg(entity_ids=("id", "unique")) + .reset_index() + .rename(columns={"text_unit_ids": "id"}) + ) + + +def _relationships(df: pd.DataFrame) -> pd.DataFrame: + selected = df.loc[:, ["id", "text_unit_ids"]] + unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True) + + return ( + unrolled.groupby("text_unit_ids", sort=False) + .agg(relationship_ids=("id", "unique")) + .reset_index() + .rename(columns={"text_unit_ids": "id"}) + ) + + +def _covariates(df: pd.DataFrame) -> pd.DataFrame: + selected = df.loc[:, ["id", "text_unit_id"]] + + return ( + selected.groupby("text_unit_id", sort=False) + .agg(covariate_ids=("id", "unique")) + .reset_index() + .rename(columns={"text_unit_id": "id"}) + ) + + +def _join(left, right): + return left.merge( + right, + on="id", + how="left", + suffixes=["_1", "_2"], + ) diff --git a/graphrag/index/flows/create_summarized_entities.py b/graphrag/index/flows/create_summarized_entities.py new file mode 100644 index 0000000000..dc5c6d25e9 --- /dev/null +++ b/graphrag/index/flows/create_summarized_entities.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to summarize entities.""" + +from typing import Any + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.storage import PipelineStorage +from graphrag.index.verbs.entities.summarize.description_summarize import ( + summarize_descriptions_df, +) +from graphrag.index.verbs.snapshot_rows import snapshot_rows_df + + +async def create_summarized_entities( + entities: pd.DataFrame, + cache: PipelineCache, + callbacks: VerbCallbacks, + storage: PipelineStorage, + strategy: dict[str, Any] | None = None, + num_threads: int = 4, + graphml_snapshot_enabled: bool = False, +) -> pd.DataFrame: + """All the steps to summarize entities.""" + summarized = await summarize_descriptions_df( + entities, + cache, + callbacks, + column="entity_graph", + to="entity_graph", + strategy=strategy, + num_threads=num_threads, + ) + + if graphml_snapshot_enabled: + await snapshot_rows_df( + summarized, + column="entity_graph", + base_name="summarized_graph", + storage=storage, + formats=[{"format": "text", "extension": "graphml"}], + ) + + return summarized diff --git a/graphrag/index/operations/__init__.py b/graphrag/index/operations/__init__.py new file mode 100644 index 0000000000..c5a0f18b95 --- /dev/null +++ b/graphrag/index/operations/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Reusable data frame operations.""" diff --git a/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py b/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py index 59d2486b17..92785efe50 100644 --- a/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py +++ b/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py @@ -74,7 +74,7 @@ async def extract_covariates_df( strategy: dict[str, Any] | None, async_mode: AsyncType = AsyncType.AsyncIO, entity_types: list[str] | None = None, - **kwargs, + num_threads: int = 4, ): """Extract claims from a piece of text.""" log.debug("extract_covariates strategy=%s", strategy) @@ -104,7 +104,7 @@ async def run_strategy(row): run_strategy, callbacks, scheduling_type=async_mode, - num_threads=kwargs.get("num_threads", 4), + num_threads=num_threads, ) return pd.DataFrame([item for row in results for item in row or []]) diff --git a/graphrag/index/verbs/entities/extraction/entity_extract.py b/graphrag/index/verbs/entities/extraction/entity_extract.py index 2487c900be..e5c8eff2c0 100644 --- a/graphrag/index/verbs/entities/extraction/entity_extract.py +++ b/graphrag/index/verbs/entities/extraction/entity_extract.py @@ -84,7 +84,7 @@ async def entity_extract_df( graph_to: str | None = None, async_mode: AsyncType = AsyncType.AsyncIO, entity_types=DEFAULT_ENTITY_TYPES, - **kwargs, + num_threads: int = 4, ) -> pd.DataFrame: """ Extract entities from a piece of text. @@ -194,7 +194,7 @@ async def run_strategy(row): run_strategy, callbacks, scheduling_type=async_mode, - num_threads=kwargs.get("num_threads", 4), + num_threads=num_threads, ) to_result = [] diff --git a/graphrag/index/verbs/snapshot.py b/graphrag/index/verbs/snapshot.py index 3789a4c998..032e19512e 100644 --- a/graphrag/index/verbs/snapshot.py +++ b/graphrag/index/verbs/snapshot.py @@ -20,11 +20,11 @@ async def snapshot( **_kwargs: dict, ) -> TableContainer: """Take a entire snapshot of the tabular data.""" - data = cast(pd.DataFrame, input.get_input()) + source = cast(pd.DataFrame, input.get_input()) - await snapshot_df(data, name, formats, storage) + await snapshot_df(source, name, formats, storage) - return TableContainer(table=data) + return TableContainer(table=source) async def snapshot_df( @@ -32,7 +32,7 @@ async def snapshot_df( name: str, formats: list[str], storage: PipelineStorage, -): +) -> None: """Take a entire snapshot of the tabular data.""" for fmt in formats: if fmt == "parquet": diff --git a/graphrag/index/verbs/snapshot_rows.py b/graphrag/index/verbs/snapshot_rows.py index 5aa85a52a6..e4b567aaec 100644 --- a/graphrag/index/verbs/snapshot_rows.py +++ b/graphrag/index/verbs/snapshot_rows.py @@ -33,7 +33,7 @@ async def snapshot_rows( ) -> TableContainer: """Take a by-row snapshot of the tabular data.""" source = cast(pd.DataFrame, input.get_input()) - output = await snapshot_rows_df( + await snapshot_rows_df( source, column, base_name, @@ -41,10 +41,9 @@ async def snapshot_rows( formats, row_name_column, ) - return TableContainer(table=output) + return TableContainer(table=source) -# todo: once this is out of "verb land", it does not need to return the input async def snapshot_rows_df( input: pd.DataFrame, column: str | None, @@ -52,7 +51,7 @@ async def snapshot_rows_df( storage: PipelineStorage, formats: list[str | dict[str, Any]], row_name_column: str | None = None, -) -> pd.DataFrame: +) -> None: """Take a by-row snapshot of the tabular data.""" parsed_formats = _parse_formats(formats) num_rows = len(input) @@ -82,7 +81,6 @@ def get_row_name(row: Any, row_idx: Any): msg = "column must be specified for text format" raise ValueError(msg) await storage.set(f"{row_name}.{extension}", str(row[column])) - return input def _parse_formats(formats: list[str | dict[str, Any]]) -> list[FormatSpecifier]: diff --git a/graphrag/index/verbs/text/embed/text_embed.py b/graphrag/index/verbs/text/embed/text_embed.py index d2aa1e8f35..bf431249ff 100644 --- a/graphrag/index/verbs/text/embed/text_embed.py +++ b/graphrag/index/verbs/text/embed/text_embed.py @@ -80,10 +80,11 @@ async def text_embed( ``` """ input_df = cast(pd.DataFrame, input.get_input()) - result_df = await text_embed_df( + to = kwargs.get("to", f"{column}_embedding") + input_df[to] = await text_embed_df( input_df, callbacks, cache, column, strategy, **kwargs ) - return TableContainer(table=result_df) + return TableContainer(table=input_df) # TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe @@ -116,7 +117,6 @@ async def text_embed_df( vector_store, vector_store_workflow_config, vector_store_config.get("store_in_table", False), - kwargs.get("to", f"{column}_embedding"), ) return await _text_embed_in_memory( @@ -125,7 +125,6 @@ async def text_embed_df( cache, column, strategy, - kwargs.get("to", f"{column}_embedding"), ) @@ -135,19 +134,15 @@ async def _text_embed_in_memory( cache: PipelineCache, column: str, strategy: dict, - to: str, ): - output_df = input strategy_type = strategy["type"] strategy_exec = load_strategy(strategy_type) strategy_args = {**strategy} - input_table = input - texts: list[str] = input_table[column].to_numpy().tolist() + texts: list[str] = input[column].to_numpy().tolist() result = await strategy_exec(texts, callbacks, cache, strategy_args) - output_df[to] = result.embeddings - return output_df + return result.embeddings async def _text_embed_with_vector_store( @@ -159,9 +154,7 @@ async def _text_embed_with_vector_store( vector_store: BaseVectorStore, vector_store_config: dict, store_in_table: bool = False, - to: str = "", ): - output_df = input strategy_type = strategy["type"] strategy_exec = load_strategy(strategy_type) strategy_args = {**strategy} @@ -174,18 +167,20 @@ async def _text_embed_with_vector_store( id_column: str = vector_store_config.get("id_column", "id") overwrite: bool = vector_store_config.get("overwrite", True) - if column not in output_df.columns: - msg = f"Column {column} not found in input dataframe with columns {output_df.columns}" + if column not in input.columns: + msg = ( + f"Column {column} not found in input dataframe with columns {input.columns}" + ) raise ValueError(msg) - if title_column not in output_df.columns: - msg = f"Column {title_column} not found in input dataframe with columns {output_df.columns}" + if title_column not in input.columns: + msg = f"Column {title_column} not found in input dataframe with columns {input.columns}" raise ValueError(msg) - if id_column not in output_df.columns: - msg = f"Column {id_column} not found in input dataframe with columns {output_df.columns}" + if id_column not in input.columns: + msg = f"Column {id_column} not found in input dataframe with columns {input.columns}" raise ValueError(msg) total_rows = 0 - for row in output_df[column]: + for row in input[column]: if isinstance(row, list): total_rows += len(row) else: @@ -231,9 +226,9 @@ async def _text_embed_with_vector_store( i += 1 if store_in_table: - output_df[to] = all_results + return all_results - return output_df + return None def _create_vector_store( diff --git a/graphrag/index/workflows/load.py b/graphrag/index/workflows/load.py index 4dd6f9bfd0..a9f65b86d1 100644 --- a/graphrag/index/workflows/load.py +++ b/graphrag/index/workflows/load.py @@ -132,7 +132,6 @@ def create_workflow( **(additional_workflows or {}), } steps = steps or _get_steps_for_workflow(name, config, additional_workflows) - steps = _remove_disabled_steps(steps) return Workflow( verbs=additional_verbs or {}, schema={ @@ -163,9 +162,3 @@ def _get_steps_for_workflow( raise UnknownWorkflowError(name) return workflows[name](config or {}) - - -def _remove_disabled_steps( - steps: list[PipelineWorkflowStep], -) -> list[PipelineWorkflowStep]: - return [step for step in steps if step.get("enabled", True)] diff --git a/graphrag/index/workflows/v1/create_base_documents.py b/graphrag/index/workflows/v1/create_base_documents.py index 9fd08a7ed6..1186b1d475 100644 --- a/graphrag/index/workflows/v1/create_base_documents.py +++ b/graphrag/index/workflows/v1/create_base_documents.py @@ -28,7 +28,7 @@ def build_steps( }, "input": { "source": DEFAULT_INPUT_NAME, - "others": ["workflow:create_final_text_units"], + "text_units": "workflow:create_final_text_units", }, }, ] diff --git a/graphrag/index/workflows/v1/create_final_community_reports.py b/graphrag/index/workflows/v1/create_final_community_reports.py index ec79cd1a17..5e933c8ea7 100644 --- a/graphrag/index/workflows/v1/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/create_final_community_reports.py @@ -43,13 +43,18 @@ def build_steps( { "verb": "create_final_community_reports", "args": { - "covariates_enabled": covariates_enabled, "skip_full_content_embedding": skip_full_content_embedding, "skip_summary_embedding": skip_summary_embedding, "skip_title_embedding": skip_title_embedding, - "full_content_text_embed": community_report_full_content_embed_config, - "summary_text_embed": community_report_summary_embed_config, - "title_text_embed": community_report_title_embed_config, + "full_content_text_embed": community_report_full_content_embed_config + if not skip_full_content_embedding + else None, + "summary_text_embed": community_report_summary_embed_config + if not skip_summary_embedding + else None, + "title_text_embed": community_report_title_embed_config + if not skip_title_embedding + else None, **create_community_reports_config, }, "input": input, diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py index 6bc94aeaa1..7f39cfcfb8 100644 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ b/graphrag/index/workflows/v1/create_final_documents.py @@ -26,8 +26,9 @@ def build_steps( { "verb": "create_final_documents", "args": { - "skip_embedding": skip_raw_content_embedding, - "text_embed": document_raw_content_embed_config, + "text_embed": document_raw_content_embed_config + if not skip_raw_content_embedding + else None, }, "input": {"source": "workflow:create_base_documents"}, }, diff --git a/graphrag/index/workflows/v1/create_final_entities.py b/graphrag/index/workflows/v1/create_final_entities.py index 20ccef44a5..d7393f4ab4 100644 --- a/graphrag/index/workflows/v1/create_final_entities.py +++ b/graphrag/index/workflows/v1/create_final_entities.py @@ -31,8 +31,12 @@ def build_steps( "args": { "skip_name_embedding": skip_name_embedding, "skip_description_embedding": skip_description_embedding, - "name_text_embed": entity_name_embed_config, - "description_text_embed": entity_name_description_embed_config, + "name_text_embed": entity_name_embed_config + if not skip_name_embedding + else None, + "description_text_embed": entity_name_description_embed_config + if not skip_description_embedding + else None, }, "input": {"source": "workflow:create_base_entity_graph"}, }, diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index c2f8f25894..086822a822 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -27,8 +27,9 @@ def build_steps( { "verb": "create_final_relationships", "args": { - "skip_embedding": skip_description_embedding, - "text_embed": relationship_description_embed_config, + "text_embed": relationship_description_embed_config + if not skip_description_embedding + else None, }, "input": { "source": "workflow:create_base_entity_graph", diff --git a/graphrag/index/workflows/v1/create_final_text_units.py b/graphrag/index/workflows/v1/create_final_text_units.py index 1117a36364..15a8dfb115 100644 --- a/graphrag/index/workflows/v1/create_final_text_units.py +++ b/graphrag/index/workflows/v1/create_final_text_units.py @@ -24,24 +24,23 @@ def build_steps( skip_text_unit_embedding = config.get("skip_text_unit_embedding", False) covariates_enabled = config.get("covariates_enabled", False) - others = [ - "workflow:create_final_entities", - "workflow:create_final_relationships", - ] + input = { + "source": "workflow:create_base_text_units", + "entities": "workflow:create_final_entities", + "relationships": "workflow:create_final_relationships", + } + if covariates_enabled: - others.append("workflow:create_final_covariates") + input["covariates"] = "workflow:create_final_covariates" return [ { "verb": "create_final_text_units", "args": { - "skip_embedding": skip_text_unit_embedding, - "text_embed": text_unit_text_embed_config, - "covariates_enabled": covariates_enabled, - }, - "input": { - "source": "workflow:create_base_text_units", - "others": others, + "text_embed": text_unit_text_embed_config + if not skip_text_unit_embedding + else None, }, + "input": input, }, ] diff --git a/graphrag/index/workflows/v1/subflows/create_base_documents.py b/graphrag/index/workflows/v1/subflows/create_base_documents.py index 718b978405..6e682e6733 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_documents.py +++ b/graphrag/index/workflows/v1/subflows/create_base_documents.py @@ -13,6 +13,11 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result +from graphrag.index.flows.create_base_documents import ( + create_base_documents as create_base_documents_flow, +) +from graphrag.index.utils.ds_util import get_required_input_table + @verb(name="create_base_documents", treats_input_tables_as_immutable=True) def create_base_documents( @@ -22,60 +27,13 @@ def create_base_documents( ) -> VerbResult: """All the steps to transform base documents.""" source = cast(pd.DataFrame, input.get_input()) - text_units = cast(pd.DataFrame, input.get_others()[0]) - - text_units = ( - text_units.explode("document_ids") - .loc[:, ["id", "document_ids", "text"]] - .rename( - columns={ - "document_ids": "chunk_doc_id", - "id": "chunk_id", - "text": "chunk_text", - } - ) - ) - - joined = text_units.merge( - source, - left_on="chunk_doc_id", - right_on="id", - how="inner", - copy=False, - ) - - docs_with_text_units = joined.groupby("id", sort=False).agg( - text_units=("chunk_id", list) - ) - - rejoined = docs_with_text_units.merge( - source, - on="id", - how="right", - copy=False, - ).reset_index(drop=True) - - rejoined.rename(columns={"text": "raw_content"}, inplace=True) - rejoined["id"] = rejoined["id"].astype(str) - - # Convert attribute columns to strings and collapse them into a JSON object - if document_attribute_columns: - # Convert all specified columns to string at once - rejoined[document_attribute_columns] = rejoined[ - document_attribute_columns - ].astype(str) - - # Collapse the document_attribute_columns into a single JSON object column - rejoined["attributes"] = rejoined[document_attribute_columns].to_dict( - orient="records" - ) + text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table) - # Drop the original attribute columns after collapsing them - rejoined.drop(columns=document_attribute_columns, inplace=True) + output = create_base_documents_flow(source, text_units, document_attribute_columns) return create_verb_result( cast( Table, - rejoined, + output, ) ) diff --git a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py index e34ec0f40b..6b981790ba 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""All the steps to transform final documents.""" +"""All the steps to create the base entity graph.""" from typing import Any, cast @@ -14,10 +14,10 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result +from graphrag.index.flows.create_base_entity_graph import ( + create_base_entity_graph as create_base_entity_graph_flow, +) from graphrag.index.storage import PipelineStorage -from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df -from graphrag.index.verbs.graph.embed.embed_graph import embed_graph_df -from graphrag.index.verbs.snapshot_rows import snapshot_rows_df @verb( @@ -34,52 +34,17 @@ async def create_base_entity_graph( embed_graph_enabled: bool = False, **_kwargs: dict, ) -> VerbResult: - """All the steps to transform final documents.""" + """All the steps to create the base entity graph.""" source = cast(pd.DataFrame, input.get_input()) - clustering_strategy = clustering_config.get("strategy", {"type": "leiden"}) - - clustered = cluster_graph_df( + output = await create_base_entity_graph_flow( source, callbacks, - column="entity_graph", - strategy=clustering_strategy, - to="clustered_graph", - level_to="level", + storage, + clustering_config, + embedding_config, + graphml_snapshot_enabled, + embed_graph_enabled, ) - if graphml_snapshot_enabled: - await snapshot_rows_df( - clustered, - column="clustered_graph", - base_name="clustered_graph", - storage=storage, - formats=[{"format": "text", "extension": "graphml"}], - ) - - embedding_strategy = embedding_config.get("strategy") - if embed_graph_enabled and embedding_strategy: - clustered = await embed_graph_df( - clustered, - callbacks, - column="clustered_graph", - strategy=embedding_strategy, - to="embeddings", - ) - - # take second snapshot after embedding - # todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph? - if graphml_snapshot_enabled: - await snapshot_rows_df( - clustered, - column="entity_graph", - base_name="embedded_graph", - storage=storage, - formats=[{"format": "text", "extension": "graphml"}], - ) - - final_columns = ["level", "clustered_graph"] - if embed_graph_enabled: - final_columns.append("embeddings") - - return create_verb_result(cast(Table, clustered[final_columns])) + return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py b/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py index 464b5a11d1..846467f31d 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py +++ b/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""All the steps to extract and format covariates.""" +"""All the steps to extract and format base entities.""" from typing import Any, cast @@ -16,11 +16,10 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.cache import PipelineCache +from graphrag.index.flows.create_base_extracted_entities import ( + create_base_extracted_entities as create_base_extracted_entities_flow, +) from graphrag.index.storage import PipelineStorage -from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df -from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df -from graphrag.index.verbs.snapshot import snapshot_df -from graphrag.index.verbs.snapshot_rows import snapshot_rows_df @verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True) @@ -36,51 +35,29 @@ async def create_base_extracted_entities( strategy: dict[str, Any] | None, async_mode: AsyncType = AsyncType.AsyncIO, entity_types: list[str] | None = None, + num_threads: int = 4, graphml_snapshot_enabled: bool = False, raw_entity_snapshot_enabled: bool = False, - **kwargs: dict, + **_kwargs: dict, ) -> VerbResult: - """All the steps to extract and format covariates.""" + """All the steps to extract and format base entities.""" source = cast(pd.DataFrame, input.get_input()) - entity_graph = await entity_extract_df( + output = await create_base_extracted_entities_flow( source, cache, callbacks, - column=column, - id_column=id_column, - strategy=strategy, - async_mode=async_mode, - entity_types=entity_types, - to="entities", - graph_to="entity_graph", - **kwargs, - ) - - if raw_entity_snapshot_enabled: - await snapshot_df( - entity_graph, - name="raw_extracted_entities", - storage=storage, - formats=["json"], - ) - - merged_graph = merge_graphs_df( - entity_graph, - callbacks, - column="entity_graph", - to="entity_graph", - nodes=nodes, - edges=edges, + storage, + column, + id_column, + nodes, + edges, + strategy, + async_mode, + entity_types, + graphml_snapshot_enabled, + raw_entity_snapshot_enabled, + num_threads=num_threads, ) - if graphml_snapshot_enabled: - await snapshot_rows_df( - merged_graph, - base_name="merged_graph", - column="entity_graph", - storage=storage, - formats=[{"format": "text", "extension": "graphml"}], - ) - - return create_verb_result(cast(Table, merged_graph)) + return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_base_text_units.py b/graphrag/index/workflows/v1/subflows/create_base_text_units.py index 344e4cafea..370abc590f 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_base_text_units.py @@ -14,9 +14,9 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.verbs.genid import genid_df -from graphrag.index.verbs.overrides.aggregate import aggregate_df -from graphrag.index.verbs.text.chunk.text_chunk import chunk_df +from graphrag.index.flows.create_base_text_units import ( + create_base_text_units as create_base_text_units_flow, +) @verb(name="create_base_text_units", treats_input_tables_as_immutable=True) @@ -30,57 +30,20 @@ def create_base_text_units( **_kwargs: dict, ) -> VerbResult: """All the steps to transform base text_units.""" - table = cast(pd.DataFrame, input.get_input()) - - sort = table.sort_values(by=["id"], ascending=[True]) - - sort["text_with_ids"] = list( - zip(*[sort[col] for col in ["id", "text"]], strict=True) - ) - - aggregated = aggregate_df( - sort, - groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None, - aggregations=[ - { - "column": "text_with_ids", - "operation": "array_agg", - "to": "texts", - } - ], - ) - - chunked = chunk_df( - aggregated, - column="texts", - to="chunks", - callbacks=callbacks, - strategy=strategy, + source = cast(pd.DataFrame, input.get_input()) + + output = create_base_text_units_flow( + source, + callbacks, + chunk_column_name, + n_tokens_column_name, + chunk_by_columns, + strategy, ) - chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]]) - chunked = chunked.explode("chunks") - chunked.rename( - columns={ - "chunks": chunk_column_name, - }, - inplace=True, - ) - - chunked = genid_df( - chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name] - ) - - chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame( - chunked[chunk_column_name].tolist(), index=chunked.index - ) - chunked["id"] = chunked["chunk_id"] - - filtered = chunked[chunked[chunk_column_name].notna()].reset_index(drop=True) - return create_verb_result( cast( Table, - filtered, + output, ) ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_communities.py b/graphrag/index/workflows/v1/subflows/create_final_communities.py index 2cbe8f6cae..6224925a07 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_communities.py +++ b/graphrag/index/workflows/v1/subflows/create_final_communities.py @@ -14,7 +14,9 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.verbs.graph.unpack import unpack_graph_df +from graphrag.index.flows.create_final_communities import ( + create_final_communities as create_final_communities_flow, +) @verb(name="create_final_communities", treats_input_tables_as_immutable=True) @@ -24,65 +26,16 @@ def create_final_communities( **_kwargs: dict, ) -> VerbResult: """All the steps to transform final communities.""" - table = cast(pd.DataFrame, input.get_input()) - - graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes") - graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges") - - # Merge graph_nodes with graph_edges for both source and target matches - source_clusters = graph_nodes.merge( - graph_edges, left_on="label", right_on="source", how="inner" - ) + source = cast(pd.DataFrame, input.get_input()) - target_clusters = graph_nodes.merge( - graph_edges, left_on="label", right_on="target", how="inner" + output = create_final_communities_flow( + source, + callbacks, ) - # Concatenate the source and target clusters - clusters = pd.concat([source_clusters, target_clusters], ignore_index=True) - - # Keep only rows where level_x == level_y - combined_clusters = clusters[ - clusters["level_x"] == clusters["level_y"] - ].reset_index(drop=True) - - cluster_relationships = ( - combined_clusters.groupby(["cluster", "level_x"], sort=False) - .agg( - relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique") - ) - .reset_index() - ) - - all_clusters = ( - graph_nodes.groupby(["cluster", "level"], sort=False) - .agg(id=("cluster", "first")) - .reset_index() - ) - - joined = all_clusters.merge( - cluster_relationships, - left_on="id", - right_on="cluster", - how="inner", - ) - - filtered = joined[joined["level"] == joined["level_x"]].reset_index(drop=True) - - filtered["title"] = "Community " + filtered["id"].astype(str) - return create_verb_result( cast( Table, - filtered.loc[ - :, - [ - "id", - "title", - "level", - "relationship_ids", - "text_unit_ids", - ], - ], + output, ) ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py index 58881e2147..480744129b 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py @@ -4,7 +4,6 @@ """All the steps to transform community reports.""" from typing import cast -from uuid import uuid4 import pandas as pd from datashaper import ( @@ -17,36 +16,10 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.cache import PipelineCache -from graphrag.index.graph.extractors.community_reports.schemas import ( - CLAIM_DESCRIPTION, - CLAIM_DETAILS, - CLAIM_ID, - CLAIM_STATUS, - CLAIM_SUBJECT, - CLAIM_TYPE, - EDGE_DEGREE, - EDGE_DESCRIPTION, - EDGE_DETAILS, - EDGE_ID, - EDGE_SOURCE, - EDGE_TARGET, - NODE_DEGREE, - NODE_DESCRIPTION, - NODE_DETAILS, - NODE_ID, - NODE_NAME, +from graphrag.index.flows.create_final_community_reports import ( + create_final_community_reports as create_final_community_reports_flow, ) -from graphrag.index.utils.ds_util import get_required_input_table -from graphrag.index.verbs.graph.report.create_community_reports import ( - create_community_reports_df, -) -from graphrag.index.verbs.graph.report.prepare_community_reports import ( - prepare_community_reports_df, -) -from graphrag.index.verbs.graph.report.restore_community_hierarchy import ( - restore_community_hierarchy_df, -) -from graphrag.index.verbs.text.embed.text_embed import text_embed_df +from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table @verb(name="create_final_community_reports", treats_input_tables_as_immutable=True) @@ -55,134 +28,38 @@ async def create_final_community_reports( callbacks: VerbCallbacks, cache: PipelineCache, strategy: dict, - full_content_text_embed: dict, - summary_text_embed: dict, - title_text_embed: dict, async_mode: AsyncType = AsyncType.AsyncIO, num_threads: int = 4, - skip_full_content_embedding: bool = False, - skip_summary_embedding: bool = False, - skip_title_embedding: bool = False, - covariates_enabled: bool = False, + full_content_text_embed: dict | None = None, + summary_text_embed: dict | None = None, + title_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform community reports.""" - nodes = _prep_nodes(cast(pd.DataFrame, input.get_input())) - edges = _prep_edges( - cast(pd.DataFrame, get_required_input_table(input, "relationships").table) - ) + nodes = cast(pd.DataFrame, input.get_input()) + edges = cast(pd.DataFrame, get_required_input_table(input, "relationships").table) - claims = None - if covariates_enabled: - claims = _prep_claims( - cast(pd.DataFrame, get_required_input_table(input, "covariates").table) - ) - - community_hierarchy = restore_community_hierarchy_df(nodes) + claims = get_named_input_table(input, "covariates") + if claims: + claims = cast(pd.DataFrame, claims.table) - local_contexts = prepare_community_reports_df( - nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000) - ) - - community_reports = await create_community_reports_df( - local_contexts, + output = await create_final_community_reports_flow( nodes, - community_hierarchy, + edges, + claims, callbacks, cache, strategy, - async_mode=async_mode, - num_threads=num_threads, - ) - - community_reports["id"] = community_reports["community"].apply( - lambda _x: str(uuid4()) + async_mode, + num_threads, + full_content_text_embed, + summary_text_embed, + title_text_embed, ) - # Embed full content if not skipped - if not skip_full_content_embedding: - community_reports = await text_embed_df( - community_reports, - callbacks, - cache, - column="full_content", - strategy=full_content_text_embed["strategy"], - to="full_content_embedding", - embedding_name="community_report_full_content", - ) - - # Embed summary if not skipped - if not skip_summary_embedding: - community_reports = await text_embed_df( - community_reports, - callbacks, - cache, - column="summary", - strategy=summary_text_embed["strategy"], - to="summary_embedding", - embedding_name="community_report_summary", - ) - - # Embed title if not skipped - if not skip_title_embedding: - community_reports = await text_embed_df( - community_reports, - callbacks, - cache, - column="title", - strategy=title_text_embed["strategy"], - to="title_embedding", - embedding_name="community_report_title", - ) - return create_verb_result( cast( Table, - community_reports, + output, ) ) - - -def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame: - input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) - # merge values of four columns into a map column - input[NODE_DETAILS] = input.apply( - lambda x: { - NODE_ID: x[NODE_ID], - NODE_NAME: x[NODE_NAME], - NODE_DESCRIPTION: x[NODE_DESCRIPTION], - NODE_DEGREE: x[NODE_DEGREE], - }, - axis=1, - ) - return input - - -def _prep_edges(input: pd.DataFrame) -> pd.DataFrame: - input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) - input[EDGE_DETAILS] = input.apply( - lambda x: { - EDGE_ID: x[EDGE_ID], - EDGE_SOURCE: x[EDGE_SOURCE], - EDGE_TARGET: x[EDGE_TARGET], - EDGE_DESCRIPTION: x[EDGE_DESCRIPTION], - EDGE_DEGREE: x[EDGE_DEGREE], - }, - axis=1, - ) - return input - - -def _prep_claims(input: pd.DataFrame) -> pd.DataFrame: - input = input.fillna(value={NODE_DESCRIPTION: "No Description"}) - input[CLAIM_DETAILS] = input.apply( - lambda x: { - CLAIM_ID: x[CLAIM_ID], - CLAIM_SUBJECT: x[CLAIM_SUBJECT], - CLAIM_TYPE: x[CLAIM_TYPE], - CLAIM_STATUS: x[CLAIM_STATUS], - CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION], - }, - axis=1, - ) - return input diff --git a/graphrag/index/workflows/v1/subflows/create_final_covariates.py b/graphrag/index/workflows/v1/subflows/create_final_covariates.py index 028a880b3e..8a6c3c9040 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_covariates.py +++ b/graphrag/index/workflows/v1/subflows/create_final_covariates.py @@ -4,7 +4,6 @@ """All the steps to extract and format covariates.""" from typing import Any, cast -from uuid import uuid4 import pandas as pd from datashaper import ( @@ -17,8 +16,8 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.cache import PipelineCache -from graphrag.index.verbs.covariates.extract_covariates.extract_covariates import ( - extract_covariates_df, +from graphrag.index.flows.create_final_covariates import ( + create_final_covariates as create_final_covariates_flow, ) @@ -32,12 +31,13 @@ async def create_final_covariates( strategy: dict[str, Any] | None, async_mode: AsyncType = AsyncType.AsyncIO, entity_types: list[str] | None = None, - **kwargs: dict, + num_threads: int = 4, + **_kwargs: dict, ) -> VerbResult: """All the steps to extract and format covariates.""" source = cast(pd.DataFrame, input.get_input()) - covariates = await extract_covariates_df( + output = await create_final_covariates_flow( source, cache, callbacks, @@ -46,33 +46,7 @@ async def create_final_covariates( strategy, async_mode, entity_types, - **kwargs, + num_threads, ) - covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4())) - covariates["human_readable_id"] = (covariates.index + 1).astype(str) - covariates.rename(columns={"chunk_id": "text_unit_id"}, inplace=True) - - return create_verb_result( - cast( - Table, - covariates[ - [ - "id", - "human_readable_id", - "covariate_type", - "type", - "description", - "subject_id", - "object_id", - "status", - "start_date", - "end_date", - "source_text", - "text_unit_id", - "document_ids", - "n_tokens", - ] - ], - ) - ) + return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_documents.py b/graphrag/index/workflows/v1/subflows/create_final_documents.py index cb8accaa27..4512ee17b5 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_documents.py +++ b/graphrag/index/workflows/v1/subflows/create_final_documents.py @@ -15,7 +15,9 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.cache import PipelineCache -from graphrag.index.verbs.text.embed.text_embed import text_embed_df +from graphrag.index.flows.create_final_documents import ( + create_final_documents as create_final_documents_flow, +) @verb( @@ -26,23 +28,17 @@ async def create_final_documents( input: VerbInput, callbacks: VerbCallbacks, cache: PipelineCache, - text_embed: dict, - skip_embedding: bool = False, + text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final documents.""" source = cast(pd.DataFrame, input.get_input()) - source.rename(columns={"text_units": "text_unit_ids"}, inplace=True) - - if not skip_embedding: - source = await text_embed_df( - source, - callbacks, - cache, - column="raw_content", - strategy=text_embed["strategy"], - to="raw_content_embedding", - ) + output = await create_final_documents_flow( + source, + callbacks, + cache, + text_embed, + ) - return create_verb_result(cast(Table, source)) + return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_entities.py b/graphrag/index/workflows/v1/subflows/create_final_entities.py index 1a285d366b..54a10ebfe7 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_entities.py +++ b/graphrag/index/workflows/v1/subflows/create_final_entities.py @@ -15,9 +15,9 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.cache import PipelineCache -from graphrag.index.verbs.graph.unpack import unpack_graph_df -from graphrag.index.verbs.text.embed.text_embed import text_embed_df -from graphrag.index.verbs.text.split import text_split_df +from graphrag.index.flows.create_final_entities import ( + create_final_entities as create_final_entities_flow, +) @verb( @@ -30,68 +30,17 @@ async def create_final_entities( cache: PipelineCache, name_text_embed: dict, description_text_embed: dict, - skip_name_embedding: bool = False, - skip_description_embedding: bool = False, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final entities.""" - table = cast(pd.DataFrame, input.get_input()) - - # Process nodes - nodes = ( - unpack_graph_df(table, callbacks, "clustered_graph", "nodes") - .rename(columns={"label": "name"}) - .loc[ - :, - [ - "id", - "name", - "type", - "description", - "human_readable_id", - "graph_embedding", - "source_id", - ], - ] - .drop_duplicates(subset="id") + source = cast(pd.DataFrame, input.get_input()) + + output = await create_final_entities_flow( + source, + callbacks, + cache, + name_text_embed, + description_text_embed, ) - nodes = nodes.loc[nodes["name"].notna()] - - # Split 'source_id' column into 'text_unit_ids' - nodes = text_split_df( - nodes, column="source_id", separator=",", to="text_unit_ids" - ).drop(columns=["source_id"]) - - # Embed name if not skipped - if not skip_name_embedding: - nodes = await text_embed_df( - nodes, - callbacks, - cache, - column="name", - strategy=name_text_embed["strategy"], - to="name_embedding", - embedding_name="entity_name", - ) - - # Embed description if not skipped - if not skip_description_embedding: - # Concatenate 'name' and 'description' and embed - nodes = await text_embed_df( - nodes.assign(name_description=nodes["name"] + ":" + nodes["description"]), - callbacks, - cache, - column="name_description", - strategy=description_text_embed["strategy"], - to="description_embedding", - embedding_name="entity_name_description", - ) - - # Drop rows with NaN 'description_embedding' if not using vector store - if not description_text_embed.get("strategy", {}).get("vector_store"): - nodes = nodes.loc[nodes["description_embedding"].notna()] - nodes.drop(columns="name_description", inplace=True) - - # Return final result - return create_verb_result(cast(Table, nodes)) + return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_nodes.py b/graphrag/index/workflows/v1/subflows/create_final_nodes.py index 4bdbb6f154..9a1754cb83 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_nodes.py +++ b/graphrag/index/workflows/v1/subflows/create_final_nodes.py @@ -14,65 +14,37 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.verbs.graph.layout.layout_graph import layout_graph_df -from graphrag.index.verbs.graph.unpack import unpack_graph_df +from graphrag.index.flows.create_final_nodes import ( + create_final_nodes as create_final_nodes_flow, +) +from graphrag.index.storage import PipelineStorage @verb(name="create_final_nodes", treats_input_tables_as_immutable=True) -def create_final_nodes( +async def create_final_nodes( input: VerbInput, callbacks: VerbCallbacks, + storage: PipelineStorage, strategy: dict[str, Any], level_for_node_positions: int, + snapshot_top_level_nodes: bool = False, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final nodes.""" - table = cast(pd.DataFrame, input.get_input()) - - laid_out_entity_graph = cast( - pd.DataFrame, - layout_graph_df( - table, - callbacks, - strategy, - embeddings_column="embeddings", - graph_column="clustered_graph", - to="node_positions", - graph_to="positioned_graph", - ), - ) - - nodes = cast( - pd.DataFrame, - unpack_graph_df( - laid_out_entity_graph, callbacks, column="positioned_graph", type="nodes" - ), - ) - - nodes_without_positions = nodes.drop(columns=["x", "y"]) - - nodes = nodes[nodes["level"] == level_for_node_positions].reset_index(drop=True) - nodes = cast(pd.DataFrame, nodes[["id", "x", "y"]]) - - # TODO: original workflow saved an optional snapshot of top level nodes - # Combining the verbs loses the `storage` injection, so it would fail - # verb arg: snapshot_top_level_nodes: bool, - # (name: "top_level_nodes", formats: ["json"]) - - nodes.rename(columns={"id": "top_level_node_id"}, inplace=True) - nodes["top_level_node_id"] = nodes["top_level_node_id"].astype(str) - - joined = nodes_without_positions.merge( - nodes, - left_on="id", - right_on="top_level_node_id", - how="inner", + source = cast(pd.DataFrame, input.get_input()) + + output = await create_final_nodes_flow( + source, + callbacks, + storage, + strategy, + level_for_node_positions, + snapshot_top_level_nodes, ) - joined.rename(columns={"label": "title", "cluster": "community"}, inplace=True) return create_verb_result( cast( Table, - joined, + output, ) ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_relationships.py b/graphrag/index/workflows/v1/subflows/create_final_relationships.py index d7d7fa9565..fd21a4f5d6 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_relationships.py +++ b/graphrag/index/workflows/v1/subflows/create_final_relationships.py @@ -3,7 +3,7 @@ """All the steps to transform final relationships.""" -from typing import Any, cast +from typing import cast import pandas as pd from datashaper import ( @@ -15,12 +15,10 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.cache import PipelineCache -from graphrag.index.utils.ds_util import get_required_input_table -from graphrag.index.verbs.graph.compute_edge_combined_degree import ( - compute_edge_combined_degree_df, +from graphrag.index.flows.create_final_relationships import ( + create_final_relationships as create_final_relationships_flow, ) -from graphrag.index.verbs.graph.unpack import unpack_graph_df -from graphrag.index.verbs.text.embed.text_embed import text_embed_df +from graphrag.index.utils.ds_util import get_required_input_table @verb( @@ -31,69 +29,19 @@ async def create_final_relationships( input: VerbInput, callbacks: VerbCallbacks, cache: PipelineCache, - text_embed: dict, - skip_embedding: bool = False, + text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final relationships.""" - table = cast(pd.DataFrame, input.get_input()) + source = cast(pd.DataFrame, input.get_input()) nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) - graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges") - - graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True) - - filtered = cast( - pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True) - ) - - if not skip_embedding: - filtered = await text_embed_df( - filtered, - callbacks, - cache, - column="description", - strategy=text_embed["strategy"], - to="description_embedding", - embedding_name="relationship_description", - ) - - pruned_edges = filtered.drop(columns=["level"]) - - filtered_nodes = cast( - pd.DataFrame, - nodes[nodes["level"] == 0].reset_index(drop=True)[["title", "degree"]], + output = await create_final_relationships_flow( + source, + nodes, + callbacks, + cache, + text_embed, ) - edge_combined_degree = compute_edge_combined_degree_df( - pruned_edges, - filtered_nodes, - to="rank", - node_name_column="title", - node_degree_column="degree", - edge_source_column="source", - edge_target_column="target", - ) - - edge_combined_degree["human_readable_id"] = edge_combined_degree[ - "human_readable_id" - ].astype(str) - edge_combined_degree["text_unit_ids"] = _to_array( - edge_combined_degree["text_unit_ids"], "," - ) - - return create_verb_result(cast(Table, edge_combined_degree)) - - -# from datashaper, we should be able to inline this -def _to_array(column, delimiter: str): - def convert_value(value: Any) -> list: - if pd.isna(value): - return [] - if isinstance(value, list): - return value - if isinstance(value, str): - return value.split(delimiter) - return [value] - - return column.apply(convert_value) + return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units.py b/graphrag/index/workflows/v1/subflows/create_final_text_units.py index f9b6d29f92..2747256219 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_final_text_units.py @@ -16,7 +16,10 @@ ) from graphrag.index.cache import PipelineCache -from graphrag.index.verbs.text.embed.text_embed import text_embed_df +from graphrag.index.flows.create_final_text_units import ( + create_final_text_units as create_final_text_units_flow, +) +from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table @verb(name="create_final_text_units", treats_input_tables_as_immutable=True) @@ -24,104 +27,30 @@ async def create_final_text_units( input: VerbInput, callbacks: VerbCallbacks, cache: PipelineCache, - text_embed: dict, - skip_embedding: bool = False, - covariates_enabled: bool = False, + text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform the text units.""" - table = cast(pd.DataFrame, input.get_input()) - others = input.get_others() - - selected = table.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename( - columns={"chunk": "text"} - ) - - final_entities = cast(pd.DataFrame, others[0]) - final_relationships = cast(pd.DataFrame, others[1]) - entity_join = _entities(final_entities) - relationship_join = _relationships(final_relationships) - - entity_joined = _join(selected, entity_join) - relationship_joined = _join(entity_joined, relationship_join) - final_joined = relationship_joined - - if covariates_enabled: - final_covariates = cast(pd.DataFrame, others[2]) - covariate_join = _covariates(final_covariates) - final_joined = _join(relationship_joined, covariate_join) - - aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index() - - if not skip_embedding: - aggregated = await text_embed_df( - aggregated, - callbacks, - cache, - column="text", - strategy=text_embed["strategy"], - to="text_embedding", - ) - - is_using_vector_store = ( - text_embed.get("strategy", {}).get("vector_store", None) is not None - ) - - final = aggregated[ - [ - "id", - "text", - *([] if (skip_embedding or is_using_vector_store) else ["text_embedding"]), - "n_tokens", - "document_ids", - "entity_ids", - "relationship_ids", - *([] if not covariates_enabled else ["covariate_ids"]), - ] - ] - return create_verb_result(cast(Table, final)) - - -def _entities(df: pd.DataFrame) -> pd.DataFrame: - selected = df.loc[:, ["id", "text_unit_ids"]] - unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True) - - return ( - unrolled.groupby("text_unit_ids", sort=False) - .agg(entity_ids=("id", "unique")) - .reset_index() - .rename(columns={"text_unit_ids": "id"}) + source = cast(pd.DataFrame, input.get_input()) + final_entities = cast( + pd.DataFrame, get_required_input_table(input, "entities").table ) - - -def _relationships(df: pd.DataFrame) -> pd.DataFrame: - selected = df.loc[:, ["id", "text_unit_ids"]] - unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True) - - return ( - unrolled.groupby("text_unit_ids", sort=False) - .agg(relationship_ids=("id", "unique")) - .reset_index() - .rename(columns={"text_unit_ids": "id"}) + final_relationships = cast( + pd.DataFrame, get_required_input_table(input, "relationships").table ) - - -def _covariates(df: pd.DataFrame) -> pd.DataFrame: - selected = df.loc[:, ["id", "text_unit_id"]] - - return ( - selected.groupby("text_unit_id", sort=False) - .agg(covariate_ids=("id", "unique")) - .reset_index() - .rename(columns={"text_unit_id": "id"}) + final_covariates = get_named_input_table(input, "covariates") + + if final_covariates: + final_covariates = cast(pd.DataFrame, final_covariates.table) + + output = await create_final_text_units_flow( + source, + final_entities, + final_relationships, + final_covariates, + callbacks, + cache, + text_embed, ) - -def _join(left, right): - return left.merge( - right, - left_on="id", - right_on="id", - how="left", - suffixes=["_1", "_2"], - ) + return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_summarized_entities.py b/graphrag/index/workflows/v1/subflows/create_summarized_entities.py index 6ec042c418..2d5c917d5d 100644 --- a/graphrag/index/workflows/v1/subflows/create_summarized_entities.py +++ b/graphrag/index/workflows/v1/subflows/create_summarized_entities.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""All the steps to transform final documents.""" +"""All the steps to summarize entities.""" from typing import Any, cast @@ -15,11 +15,10 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.cache import PipelineCache -from graphrag.index.storage import PipelineStorage -from graphrag.index.verbs.entities.summarize.description_summarize import ( - summarize_descriptions_df, +from graphrag.index.flows.create_summarized_entities import ( + create_summarized_entities as create_summarized_entities_flow, ) -from graphrag.index.verbs.snapshot_rows import snapshot_rows_df +from graphrag.index.storage import PipelineStorage @verb( @@ -36,26 +35,17 @@ async def create_summarized_entities( graphml_snapshot_enabled: bool = False, **_kwargs: dict, ) -> VerbResult: - """All the steps to transform final documents.""" + """All the steps to summarize entities.""" source = cast(pd.DataFrame, input.get_input()) - summarized = await summarize_descriptions_df( + output = await create_summarized_entities_flow( source, cache, callbacks, - column="entity_graph", - to="entity_graph", - strategy=strategy, - num_threads=num_threads, + storage, + strategy, + num_threads, + graphml_snapshot_enabled, ) - if graphml_snapshot_enabled: - await snapshot_rows_df( - summarized, - column="entity_graph", - base_name="summarized_graph", - storage=storage, - formats=[{"format": "text", "extension": "graphml"}], - ) - - return create_verb_result(cast(Table, summarized)) + return create_verb_result(cast(Table, output)) diff --git a/tests/verbs/test_create_base_entity_graph.py b/tests/verbs/test_create_base_entity_graph.py index 082ca7280c..0294ff45d9 100644 --- a/tests/verbs/test_create_base_entity_graph.py +++ b/tests/verbs/test_create_base_entity_graph.py @@ -14,7 +14,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -24,15 +23,18 @@ async def test_create_base_entity_graph(): ]) expected = load_expected(workflow_name) + storage = MemoryPipelineStorage() + config = get_config_for_workflow(workflow_name) - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, { "steps": steps, }, + storage=storage, ) # the serialization of the graph may differ so we can't assert the dataframes directly @@ -52,6 +54,8 @@ async def test_create_base_entity_graph(): actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges() ), "Graphml edge count differs" + assert len(storage.keys()) == 0, "Storage should be empty" + async def test_create_base_entity_graph_with_embeddings(): input_tables = load_input_tables([ @@ -63,7 +67,7 @@ async def test_create_base_entity_graph_with_embeddings(): config["embed_graph_enabled"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -90,7 +94,7 @@ async def test_create_base_entity_graph_with_snapshots(): config["graphml_snapshot"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, diff --git a/tests/verbs/test_create_base_extracted_entities.py b/tests/verbs/test_create_base_extracted_entities.py index 6483471b86..029126c1cb 100644 --- a/tests/verbs/test_create_base_extracted_entities.py +++ b/tests/verbs/test_create_base_extracted_entities.py @@ -21,6 +21,8 @@ async def test_create_base_extracted_entities(): input_tables = load_input_tables(["workflow:create_base_text_units"]) expected = load_expected(workflow_name) + storage = MemoryPipelineStorage() + config = get_config_for_workflow(workflow_name) del config["entity_extract"]["strategy"]["llm"] @@ -32,6 +34,7 @@ async def test_create_base_extracted_entities(): { "steps": steps, }, + storage=storage, ) # let's parse a sample of the raw graphml @@ -43,6 +46,8 @@ async def test_create_base_extracted_entities(): assert actual.columns == expected.columns + assert len(storage.keys()) == 0, "Storage should be empty" + async def test_create_base_extracted_entities_with_snapshots(): input_tables = load_input_tables(["workflow:create_base_text_units"]) diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py index 2a5004cbe0..b4f5ecf7fc 100644 --- a/tests/verbs/test_create_final_community_reports.py +++ b/tests/verbs/test_create_final_community_reports.py @@ -12,7 +12,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -29,7 +28,7 @@ async def test_create_final_community_reports(): # deleting the llm config results in a default mock injection in run_graph_intelligence del config["create_community_reports"]["strategy"]["llm"] - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -68,7 +67,7 @@ async def test_create_final_community_reports_with_embeddings(): config["skip_title_embedding"] = False config["community_report_title_embed"]["strategy"]["type"] = "mock" - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index e70a578a98..c25b69ad52 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -12,7 +12,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -26,7 +25,7 @@ async def test_create_final_documents(): config["skip_raw_content_embedding"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -51,7 +50,7 @@ async def test_create_final_documents_with_embeddings(): # just override the strategy to mock so the rest of the required parameters are in place config["document_raw_content_embed"]["strategy"]["type"] = "mock" - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, diff --git a/tests/verbs/test_create_final_entities.py b/tests/verbs/test_create_final_entities.py index b32757aa77..d63776aef4 100644 --- a/tests/verbs/test_create_final_entities.py +++ b/tests/verbs/test_create_final_entities.py @@ -12,7 +12,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -27,7 +26,7 @@ async def test_create_final_entities(): config["skip_name_embedding"] = True config["skip_description_embedding"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -65,7 +64,7 @@ async def test_create_final_entities_with_name_embeddings(): config["skip_description_embedding"] = True config["entity_name_embed"]["strategy"]["type"] = "mock" - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -92,7 +91,7 @@ async def test_create_final_entities_with_description_embeddings(): config["skip_description_embedding"] = False config["entity_name_description_embed"]["strategy"]["type"] = "mock" - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -119,7 +118,7 @@ async def test_create_final_entities_with_name_and_description_embeddings(): config["entity_name_description_embed"]["strategy"]["type"] = "mock" config["entity_name_embed"]["strategy"]["type"] = "mock" - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, diff --git a/tests/verbs/test_create_final_nodes.py b/tests/verbs/test_create_final_nodes.py index 7f0e20428a..cd8d4dd99a 100644 --- a/tests/verbs/test_create_final_nodes.py +++ b/tests/verbs/test_create_final_nodes.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage from graphrag.index.workflows.v1.create_final_nodes import ( build_steps, workflow_name, @@ -12,7 +13,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -22,19 +22,56 @@ async def test_create_final_nodes(): ]) expected = load_expected(workflow_name) + storage = MemoryPipelineStorage() + config = get_config_for_workflow(workflow_name) # default config turns UMAP off, which translates into false for layout # we don't have graph embeddings in the test data, so this will fail if True config["layout_graph_enabled"] = False - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, { "steps": steps, }, + storage=storage, ) compare_outputs(actual, expected) + + assert len(storage.keys()) == 0, "Storage should be empty" + + +async def test_create_final_nodes_with_snapshot(): + input_tables = load_input_tables([ + "workflow:create_base_entity_graph", + ]) + expected = load_expected(workflow_name) + + storage = MemoryPipelineStorage() + + config = get_config_for_workflow(workflow_name) + + # default config turns UMAP off, which translates into false for layout + # we don't have graph embeddings in the test data, so this will fail if True + config["layout_graph_enabled"] = False + config["snapshot_top_level_nodes"] = True + + steps = build_steps(config) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + storage=storage, + ) + + assert actual.shape == expected.shape, "Graph dataframe shapes differ" + + assert storage.keys() == [ + "top_level_nodes.json", + ], "Graph snapshot keys differ" diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index 53a3755ca5..d3ac5c807a 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -12,7 +12,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -27,7 +26,7 @@ async def test_create_final_relationships(): config["skip_description_embedding"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -53,7 +52,7 @@ async def test_create_final_relationships_with_embeddings(): # just override the strategy to mock so the rest of the required parameters are in place config["relationship_description_embed"]["strategy"]["type"] = "mock" - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 161e3c1732..4a3bee0a94 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -12,7 +12,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -30,7 +29,7 @@ async def test_create_final_text_units(): config["covariates_enabled"] = True config["skip_text_unit_embedding"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -56,7 +55,7 @@ async def test_create_final_text_units_no_covariates(): config["covariates_enabled"] = False config["skip_text_unit_embedding"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, @@ -90,7 +89,7 @@ async def test_create_final_text_units_with_embeddings(): # just override the strategy to mock so the rest of the required parameters are in place config["text_unit_text_embed"]["strategy"]["type"] = "mock" - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, diff --git a/tests/verbs/test_create_summarized_entities.py b/tests/verbs/test_create_summarized_entities.py index 51fe57e2ee..7d9ac9b55b 100644 --- a/tests/verbs/test_create_summarized_entities.py +++ b/tests/verbs/test_create_summarized_entities.py @@ -14,7 +14,6 @@ get_workflow_output, load_expected, load_input_tables, - remove_disabled_steps, ) @@ -24,17 +23,20 @@ async def test_create_summarized_entities(): ]) expected = load_expected(workflow_name) + storage = MemoryPipelineStorage() + config = get_config_for_workflow(workflow_name) del config["summarize_descriptions"]["strategy"]["llm"] - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, { "steps": steps, }, + storage=storage, ) # the serialization of the graph may differ so we can't assert the dataframes directly @@ -61,6 +63,8 @@ async def test_create_summarized_entities(): == "This is a MOCK response for the LLM. It is summarized!" ) + assert len(storage.keys()) == 0, "Storage should be empty" + async def test_create_summarized_entities_with_snapshots(): input_tables = load_input_tables([ @@ -75,7 +79,7 @@ async def test_create_summarized_entities_with_snapshots(): del config["summarize_descriptions"]["strategy"]["llm"] config["graphml_snapshot"] = True - steps = remove_disabled_steps(build_steps(config)) + steps = build_steps(config) actual = await get_workflow_output( input_tables, diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 68dec6ba50..4dff53996a 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -10,7 +10,6 @@ from graphrag.config import create_graphrag_config from graphrag.index import ( PipelineWorkflowConfig, - PipelineWorkflowStep, create_pipeline_config, ) from graphrag.index.run.utils import _create_run_context @@ -87,7 +86,7 @@ def compare_outputs( assert len(actual) == len( expected - ), f"Expected: {len(expected)}, Actual: {len(actual)}" + ), f"Expected: {len(expected)} rows, Actual: {len(actual)} rows" for column in cols: assert column in actual.columns @@ -102,9 +101,3 @@ def compare_outputs( print("Actual:") print(actual[column]) raise - - -def remove_disabled_steps( - steps: list[PipelineWorkflowStep], -) -> list[PipelineWorkflowStep]: - return [step for step in steps if step.get("enabled", True)]