diff --git a/.semversioner/next-release/patch-20240923202146450500.json b/.semversioner/next-release/patch-20240923202146450500.json new file mode 100644 index 0000000000..2a1f5a3fc2 --- /dev/null +++ b/.semversioner/next-release/patch-20240923202146450500.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Merge text_embed into create-final-relationships subflow." +} diff --git a/graphrag/index/verbs/text/embed/text_embed.py b/graphrag/index/verbs/text/embed/text_embed.py index 76ac97d76f..d2aa1e8f35 100644 --- a/graphrag/index/verbs/text/embed/text_embed.py +++ b/graphrag/index/verbs/text/embed/text_embed.py @@ -79,6 +79,23 @@ async def text_embed( <...> ``` """ + input_df = cast(pd.DataFrame, input.get_input()) + result_df = await text_embed_df( + input_df, callbacks, cache, column, strategy, **kwargs + ) + return TableContainer(table=result_df) + + +# TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe +async def text_embed_df( + input: pd.DataFrame, + callbacks: VerbCallbacks, + cache: PipelineCache, + column: str, + strategy: dict, + **kwargs, +): + """Embed a piece of text into a vector space.""" vector_store_config = strategy.get("vector_store") if vector_store_config: @@ -113,28 +130,28 @@ async def text_embed( async def _text_embed_in_memory( - input: VerbInput, + input: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, column: str, strategy: dict, to: str, ): - output_df = cast(pd.DataFrame, input.get_input()) + output_df = input strategy_type = strategy["type"] strategy_exec = load_strategy(strategy_type) strategy_args = {**strategy} - input_table = input.get_input() + input_table = input texts: list[str] = input_table[column].to_numpy().tolist() result = await strategy_exec(texts, callbacks, cache, strategy_args) output_df[to] = result.embeddings - return TableContainer(table=output_df) + return output_df async def _text_embed_with_vector_store( - input: VerbInput, + input: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, column: str, @@ -144,7 +161,7 @@ async def _text_embed_with_vector_store( store_in_table: bool = False, to: str = "", ): - output_df = cast(pd.DataFrame, input.get_input()) + output_df = input strategy_type = strategy["type"] strategy_exec = load_strategy(strategy_type) strategy_args = {**strategy} @@ -179,10 +196,8 @@ async def _text_embed_with_vector_store( all_results = [] - while insert_batch_size * i < input.get_input().shape[0]: - batch = input.get_input().iloc[ - insert_batch_size * i : insert_batch_size * (i + 1) - ] + while insert_batch_size * i < input.shape[0]: + batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)] texts: list[str] = batch[column].to_numpy().tolist() titles: list[str] = batch[title_column].to_numpy().tolist() ids: list[str] = batch[id_column].to_numpy().tolist() @@ -218,7 +233,7 @@ async def _text_embed_with_vector_store( if store_in_table: output_df[to] = all_results - return TableContainer(table=output_df) + return output_df def _create_vector_store( diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index 8e6396b253..c2f8f25894 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -23,30 +23,15 @@ def build_steps( "relationship_description_embed", base_text_embed ) skip_description_embedding = config.get("skip_description_embedding", False) - return [ { - "id": "pre_embedding", - "verb": "create_final_relationships_pre_embedding", - "input": {"source": "workflow:create_base_entity_graph"}, - }, - { - "id": "description_embedding", - "verb": "text_embed", - "enabled": not skip_description_embedding, + "verb": "create_final_relationships", "args": { - "embedding_name": "relationship_description", - "column": "description", - "to": "description_embedding", - **relationship_description_embed_config, + "skip_embedding": skip_description_embedding, + "text_embed": relationship_description_embed_config, }, - }, - { - "verb": "create_final_relationships_post_embedding", "input": { - "source": "pre_embedding" - if skip_description_embedding - else "description_embedding", + "source": "workflow:create_base_entity_graph", "nodes": "workflow:create_final_nodes", }, }, diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index c72f00ea23..ead84cfdbc 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -7,11 +7,8 @@ from .create_base_text_units import create_base_text_units from .create_final_communities import create_final_communities from .create_final_nodes import create_final_nodes -from .create_final_relationships_post_embedding import ( - create_final_relationships_post_embedding, -) -from .create_final_relationships_pre_embedding import ( - create_final_relationships_pre_embedding, +from .create_final_relationships import ( + create_final_relationships, ) from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding @@ -20,7 +17,6 @@ "create_base_text_units", "create_final_communities", "create_final_nodes", - "create_final_relationships_post_embedding", - "create_final_relationships_pre_embedding", + "create_final_relationships", "create_final_text_units_pre_embedding", ] diff --git a/graphrag/index/workflows/v1/subflows/create_final_relationships_post_embedding.py b/graphrag/index/workflows/v1/subflows/create_final_relationships.py similarity index 59% rename from graphrag/index/workflows/v1/subflows/create_final_relationships_post_embedding.py rename to graphrag/index/workflows/v1/subflows/create_final_relationships.py index 0e29e70104..02fc7948f3 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_relationships_post_embedding.py +++ b/graphrag/index/workflows/v1/subflows/create_final_relationships.py @@ -1,37 +1,64 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""All the steps to transform final relationships after they are embedded.""" +"""All the steps to transform final relationships before they are embedded.""" from typing import Any, cast import pandas as pd from datashaper import ( Table, + VerbCallbacks, VerbInput, verb, ) from datashaper.table_store.types import VerbResult, create_verb_result +from graphrag.index.cache import PipelineCache from graphrag.index.utils.ds_util import get_required_input_table from graphrag.index.verbs.graph.compute_edge_combined_degree import ( compute_edge_combined_degree_df, ) +from graphrag.index.verbs.graph.unpack import unpack_graph_df +from graphrag.index.verbs.text.embed.text_embed import text_embed_df @verb( - name="create_final_relationships_post_embedding", + name="create_final_relationships", treats_input_tables_as_immutable=True, ) -def create_final_relationships_post_embedding( +async def create_final_relationships( input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + text_embed: dict, + skip_embedding: bool = False, **_kwargs: dict, ) -> VerbResult: - """All the steps to transform final relationships after they are embedded.""" + """All the steps to transform final relationships before they are embedded.""" table = cast(pd.DataFrame, input.get_input()) nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) - pruned_edges = table.drop(columns=["level"]) + graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges") + + graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True) + + filtered = cast( + pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True) + ) + + if not skip_embedding: + filtered = await text_embed_df( + filtered, + callbacks, + cache, + column="description", + strategy=text_embed["strategy"], + to="description_embedding", + embedding_name="relationship_description", + ) + + pruned_edges = filtered.drop(columns=["level"]) filtered_nodes = cast( pd.DataFrame, diff --git a/graphrag/index/workflows/v1/subflows/create_final_relationships_pre_embedding.py b/graphrag/index/workflows/v1/subflows/create_final_relationships_pre_embedding.py deleted file mode 100644 index bcc0f762b8..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_relationships_pre_embedding.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform final relationships before they are embedded.""" - -from typing import cast - -import pandas as pd -from datashaper import ( - Table, - VerbCallbacks, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.verbs.graph.unpack import unpack_graph_df - - -@verb( - name="create_final_relationships_pre_embedding", - treats_input_tables_as_immutable=True, -) -def create_final_relationships_pre_embedding( - input: VerbInput, - callbacks: VerbCallbacks, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final relationships before they are embedded.""" - table = cast(pd.DataFrame, input.get_input()) - - graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges") - - graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True) - - filtered = graph_edges[graph_edges["level"] == 0].reset_index(drop=True) - - return create_verb_result(cast(Table, filtered)) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 378217b7f9..635bf9e5cd 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -52,7 +52,7 @@ 1, 2000 ], - "subworkflows": 2, + "subworkflows": 1, "max_runtime": 100 }, "create_final_nodes": { diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 0987d642ce..cd66fde7eb 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -71,7 +71,7 @@ 1, 2000 ], - "subworkflows": 2, + "subworkflows": 1, "max_runtime": 100 }, "create_final_nodes": { diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index 6e86836874..87282cb8e9 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -37,3 +37,32 @@ async def test_create_final_relationships(): ) compare_outputs(actual, expected) + + +async def test_create_final_relationships_with_embeddings(): + input_tables = load_input_tables([ + "workflow:create_base_entity_graph", + "workflow:create_final_nodes", + ]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["skip_description_embedding"] = False + # default config has a detailed standard embed config + # just override the strategy to mock so the rest of the required parameters are in place + config["relationship_description_embed"]["strategy"]["type"] = "mock" + + steps = remove_disabled_steps(build_steps(config)) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + assert "description_embedding" in actual.columns + assert len(actual.columns) == len(expected.columns) + 1 + # the mock impl returns an array of 3 floats for each embedding + assert len(actual["description_embedding"][0]) == 3 diff --git a/tests/verbs/util.py b/tests/verbs/util.py index df2136e829..80779d31de 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -13,6 +13,7 @@ PipelineWorkflowStep, create_pipeline_config, ) +from graphrag.index.run.utils import _create_run_context def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]: @@ -61,7 +62,9 @@ async def get_workflow_output( input_tables=input_tables, ) - await workflow.run() + context = _create_run_context(None, None, None) + + await workflow.run(context=context) # if there's only one output, it is the default here, no name required return cast(pd.DataFrame, workflow.output())