From ce71bcf7fbe9811058f6bbc1eb725c4a1d960e7e Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Wed, 25 Sep 2024 17:35:44 -0700 Subject: [PATCH] Collapse create final entities (#1220) * Collapse create_final_entities * Update smoke tests * Semver * Remove prints * Update embedding assertions --- .../patch-20240925235837094282.json | 4 + .../workflows/v1/create_final_entities.py | 104 +------------- .../index/workflows/v1/subflows/__init__.py | 2 + .../v1/subflows/create_final_entities.py | 105 ++++++++++++++ tests/fixtures/min-csv/config.json | 2 +- tests/fixtures/text/config.json | 2 +- tests/verbs/test_create_final_documents.py | 2 +- tests/verbs/test_create_final_entities.py | 133 ++++++++++++++++++ .../verbs/test_create_final_relationships.py | 2 +- tests/verbs/test_create_final_text_units.py | 2 +- tests/verbs/util.py | 6 +- 11 files changed, 258 insertions(+), 106 deletions(-) create mode 100644 .semversioner/next-release/patch-20240925235837094282.json create mode 100644 graphrag/index/workflows/v1/subflows/create_final_entities.py create mode 100644 tests/verbs/test_create_final_entities.py diff --git a/.semversioner/next-release/patch-20240925235837094282.json b/.semversioner/next-release/patch-20240925235837094282.json new file mode 100644 index 0000000000..885095516f --- /dev/null +++ b/.semversioner/next-release/patch-20240925235837094282.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Collapse create-final-entities." +} diff --git a/graphrag/index/workflows/v1/create_final_entities.py b/graphrag/index/workflows/v1/create_final_entities.py index 9d8b962b77..20ccef44a5 100644 --- a/graphrag/index/workflows/v1/create_final_entities.py +++ b/graphrag/index/workflows/v1/create_final_entities.py @@ -24,110 +24,16 @@ def build_steps( ) skip_name_embedding = config.get("skip_name_embedding", False) skip_description_embedding = config.get("skip_description_embedding", False) - is_using_vector_store = ( - entity_name_embed_config.get("strategy", {}).get("vector_store", None) - is not None - ) return [ { - "verb": "unpack_graph", + "verb": "create_final_entities", "args": { - "column": "clustered_graph", - "type": "nodes", + "skip_name_embedding": skip_name_embedding, + "skip_description_embedding": skip_description_embedding, + "name_text_embed": entity_name_embed_config, + "description_text_embed": entity_name_description_embed_config, }, "input": {"source": "workflow:create_base_entity_graph"}, }, - {"verb": "rename", "args": {"columns": {"label": "title"}}}, - { - "verb": "select", - "args": { - "columns": [ - "id", - "title", - "type", - "description", - "human_readable_id", - "graph_embedding", - "source_id", - ], - }, - }, - { - # create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities - # this dedupes the entities so that there is only one of each entity - "verb": "dedupe", - "args": {"columns": ["id"]}, - }, - {"verb": "rename", "args": {"columns": {"title": "name"}}}, - { - # ELIMINATE EMPTY NAMES - "verb": "filter", - "args": { - "column": "name", - "criteria": [ - { - "type": "value", - "operator": "is not empty", - } - ], - }, - }, - { - "verb": "text_split", - "args": {"separator": ",", "column": "source_id", "to": "text_unit_ids"}, - }, - {"verb": "drop", "args": {"columns": ["source_id"]}}, - { - "verb": "text_embed", - "enabled": not skip_name_embedding, - "args": { - "embedding_name": "entity_name", - "column": "name", - "to": "name_embedding", - **entity_name_embed_config, - }, - }, - { - "verb": "merge", - "enabled": not skip_description_embedding, - "args": { - "strategy": "concat", - "columns": ["name", "description"], - "to": "name_description", - "delimiter": ":", - "preserveSource": True, - }, - }, - { - "verb": "text_embed", - "enabled": not skip_description_embedding, - "args": { - "embedding_name": "entity_name_description", - "column": "name_description", - "to": "description_embedding", - **entity_name_description_embed_config, - }, - }, - { - "verb": "drop", - "enabled": not skip_description_embedding, - "args": { - "columns": ["name_description"], - }, - }, - { - # ELIMINATE EMPTY DESCRIPTION EMBEDDINGS - "verb": "filter", - "enabled": not skip_description_embedding and not is_using_vector_store, - "args": { - "column": "description_embedding", - "criteria": [ - { - "type": "value", - "operator": "is not empty", - } - ], - }, - }, ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index d8d201f389..86406de1d9 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -8,6 +8,7 @@ from .create_final_communities import create_final_communities from .create_final_covariates import create_final_covariates from .create_final_documents import create_final_documents +from .create_final_entities import create_final_entities from .create_final_nodes import create_final_nodes from .create_final_relationships import ( create_final_relationships, @@ -20,6 +21,7 @@ "create_final_communities", "create_final_covariates", "create_final_documents", + "create_final_entities", "create_final_nodes", "create_final_relationships", "create_final_text_units", diff --git a/graphrag/index/workflows/v1/subflows/create_final_entities.py b/graphrag/index/workflows/v1/subflows/create_final_entities.py new file mode 100644 index 0000000000..4b3489a154 --- /dev/null +++ b/graphrag/index/workflows/v1/subflows/create_final_entities.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform final entities.""" + +from typing import cast + +import pandas as pd +from datashaper import ( + Table, + VerbCallbacks, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.graph.unpack import unpack_graph_df +from graphrag.index.verbs.text.embed.text_embed import text_embed_df +from graphrag.index.verbs.text.split import text_split_df + + +@verb( + name="create_final_entities", + treats_input_tables_as_immutable=True, +) +async def create_final_entities( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + name_text_embed: dict, + description_text_embed: dict, + skip_name_embedding: bool = False, + skip_description_embedding: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final entities.""" + table = cast(pd.DataFrame, input.get_input()) + + nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes") + nodes.rename(columns={"label": "name"}, inplace=True) + + nodes = cast( + pd.DataFrame, + nodes[ + [ + "id", + "name", + "type", + "description", + "human_readable_id", + "graph_embedding", + "source_id", + ] + ], + ) + + # create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities + # this dedupes the entities so that there is only one of each entity + nodes.drop_duplicates(subset="id", inplace=True) + + # eliminate empty names + filtered = cast(pd.DataFrame, nodes[nodes["name"].notna()].reset_index(drop=True)) + + with_ids = text_split_df( + filtered, column="source_id", separator=",", to="text_unit_ids" + ) + with_ids.drop(columns=["source_id"], inplace=True) + + embedded = with_ids + + if not skip_name_embedding: + embedded = await text_embed_df( + embedded, + callbacks, + cache, + column="name", + strategy=name_text_embed["strategy"], + to="name_embedding", + embedding_name="entity_name", + ) + + if not skip_description_embedding: + # description embedding is a concat of the name + description, so we'll create a temporary column + embedded["name_description"] = embedded["name"] + ":" + embedded["description"] + embedded = await text_embed_df( + embedded, + callbacks, + cache, + column="name_description", + strategy=description_text_embed["strategy"], + to="description_embedding", + embedding_name="entity_name_description", + ) + embedded.drop(columns=["name_description"], inplace=True) + is_using_vector_store = ( + description_text_embed.get("strategy", {}).get("vector_store", None) + is not None + ) + if not is_using_vector_store: + embedded = embedded[embedded["description_embedding"].notna()].reset_index( + drop=True + ) + + return create_verb_result(cast(Table, embedded)) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 684062773e..09d6e418ff 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -44,7 +44,7 @@ "description", "graph_embedding" ], - "subworkflows": 11, + "subworkflows": 1, "max_runtime": 300 }, "create_final_relationships": { diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 02f803dea2..204937f9d4 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -61,7 +61,7 @@ "description", "graph_embedding" ], - "subworkflows": 11, + "subworkflows": 1, "max_runtime": 300 }, "create_final_relationships": { diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index 7092b96312..e70a578a98 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -63,4 +63,4 @@ async def test_create_final_documents_with_embeddings(): assert "raw_content_embedding" in actual.columns assert len(actual.columns) == len(expected.columns) + 1 # the mock impl returns an array of 3 floats for each embedding - assert len(actual["raw_content_embedding"][0]) == 3 + assert len(actual["raw_content_embedding"][:1][0]) == 3 diff --git a/tests/verbs/test_create_final_entities.py b/tests/verbs/test_create_final_entities.py new file mode 100644 index 0000000000..b32757aa77 --- /dev/null +++ b/tests/verbs/test_create_final_entities.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.index.workflows.v1.create_final_entities import ( + build_steps, + workflow_name, +) + +from .util import ( + compare_outputs, + get_config_for_workflow, + get_workflow_output, + load_expected, + load_input_tables, + remove_disabled_steps, +) + + +async def test_create_final_entities(): + input_tables = load_input_tables([ + "workflow:create_base_entity_graph", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["skip_name_embedding"] = True + config["skip_description_embedding"] = True + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + # ignore the description_embedding column, which is included in the expected output due to default config + compare_outputs( + actual, + expected, + columns=[ + "id", + "name", + "type", + "description", + "human_readable_id", + "graph_embedding", + "text_unit_ids", + ], + ) + assert len(actual.columns) == len(expected.columns) - 1 + + +async def test_create_final_entities_with_name_embeddings(): + input_tables = load_input_tables([ + "workflow:create_base_entity_graph", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["skip_name_embedding"] = False + config["skip_description_embedding"] = True + config["entity_name_embed"]["strategy"]["type"] = "mock" + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert "name_embedding" in actual.columns + assert len(actual.columns) == len(expected.columns) + # the mock impl returns an array of 3 floats for each embedding + assert len(actual["name_embedding"][:1][0]) == 3 + + +async def test_create_final_entities_with_description_embeddings(): + input_tables = load_input_tables([ + "workflow:create_base_entity_graph", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["skip_name_embedding"] = True + config["skip_description_embedding"] = False + config["entity_name_description_embed"]["strategy"]["type"] = "mock" + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert "description_embedding" in actual.columns + assert len(actual.columns) == len(expected.columns) + assert len(actual["description_embedding"][:1][0]) == 3 + + +async def test_create_final_entities_with_name_and_description_embeddings(): + input_tables = load_input_tables([ + "workflow:create_base_entity_graph", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["skip_name_embedding"] = False + config["skip_description_embedding"] = False + config["entity_name_description_embed"]["strategy"]["type"] = "mock" + config["entity_name_embed"]["strategy"]["type"] = "mock" + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert "description_embedding" in actual.columns + assert len(actual.columns) == len(expected.columns) + 1 + assert len(actual["description_embedding"][:1][0]) == 3 diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index 87282cb8e9..53a3755ca5 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -65,4 +65,4 @@ async def test_create_final_relationships_with_embeddings(): assert "description_embedding" in actual.columns assert len(actual.columns) == len(expected.columns) + 1 # the mock impl returns an array of 3 floats for each embedding - assert len(actual["description_embedding"][0]) == 3 + assert len(actual["description_embedding"][:1][0]) == 3 diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index a64c9ff5a5..161e3c1732 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -102,4 +102,4 @@ async def test_create_final_text_units_with_embeddings(): assert "text_embedding" in actual.columns assert len(actual.columns) == len(expected.columns) + 1 # the mock impl returns an array of 3 floats for each embedding - assert len(actual["text_embedding"][0]) == 3 + assert len(actual["text_embedding"][:1][0]) == 3 diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 884199284a..7dd948f8bf 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -88,12 +88,14 @@ def compare_outputs( assert column in actual.columns try: # dtypes can differ since the test data is read from parquet and our workflow runs in memory - assert_series_equal(actual[column], expected[column], check_dtype=False) + assert_series_equal( + actual[column], expected[column], check_dtype=False, check_index=False + ) except AssertionError: print("Expected:") print(expected[column]) print("Actual:") - print(actual[columns]) + print(actual[column]) raise