diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py index 13982eb61b..6118864dcb 100644 --- a/graphrag/index/flows/generate_text_embeddings.py +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -27,6 +27,7 @@ log = logging.getLogger(__name__) + async def generate_text_embeddings( final_documents: pd.DataFrame, final_relationships: pd.DataFrame, @@ -56,45 +57,44 @@ async def generate_text_embeddings( embedding_param_map = { document_raw_content_embedding: { "data": documents_embeddings, - "column_to_embed":"raw_content", + "column_to_embed": "raw_content", "filename": "create_final_documents_raw_content_embeddings", }, - relationship_description_embedding:{ + relationship_description_embedding: { "data": relationships_embeddings, - "column_to_embed":"description", + "column_to_embed": "description", "filename": "create_final_relationships_description_embeddings", }, - text_unit_text_embedding:{ + text_unit_text_embedding: { "data": text_units_embeddings, - "column_to_embed":"text", + "column_to_embed": "text", "filename": "create_final_text_units_text_embeddings", }, - entity_name_embedding:{ + entity_name_embedding: { "data": entities_embeddings, - "column_to_embed":"name", + "column_to_embed": "name", "filename": "create_final_entities_name_embeddings", }, - entity_description_embedding:{ + entity_description_embedding: { "data": entities_embeddings, - "column_to_embed":"name_description", + "column_to_embed": "name_description", "filename": "create_final_entities_description_embeddings", }, - community_title_embedding:{ + community_title_embedding: { "data": community_reports_embeddings, - "column_to_embed":"title", + "column_to_embed": "title", "filename": "create_final_community_reports_title_embeddings", }, - community_summary_embedding:{ + community_summary_embedding: { "data": community_reports_embeddings, - "column_to_embed":"summary", + "column_to_embed": "summary", "filename": "create_final_community_reports_summary_embeddings", }, - community_full_content_embedding:{ + community_full_content_embedding: { "data": community_reports_embeddings, - "column_to_embed":"full_content", + "column_to_embed": "full_content", "filename": "create_final_community_reports_full_content_embeddings", }, - } if base_text_embed: @@ -105,7 +105,7 @@ async def generate_text_embeddings( cache=cache, storage=storage, base_text_embed=base_text_embed, - **embedding_param_map[field] + **embedding_param_map[field], ) return pd.DataFrame() @@ -127,9 +127,9 @@ async def _run_and_snapshot_embeddings( "collection_name": filename, "store_in_table": True, } - + base_text_embed["strategy"]["vector_store"].update(new_vector_store) - + data["embedding"] = await embed_text( data, callbacks, diff --git a/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py b/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py index 9b8f13e227..7e77f84cb8 100644 --- a/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py @@ -33,7 +33,7 @@ async def generate_text_embeddings( cache: PipelineCache, storage: PipelineStorage, embedded_fields: set[str], - base_text_embed: dict | None = None, + base_text_embed: dict, **_kwargs: dict, ) -> VerbResult: """All the steps to generate embeddings.""" diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 9d8b24af6b..74b153c0f4 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -6,9 +6,9 @@ import json from typing import Any -import lancedb as lancedb import pyarrow as pa +import lancedb as lancedb from graphrag.model.types import TextEmbedder from .base import ( diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py new file mode 100644 index 0000000000..d6095881c7 --- /dev/null +++ b/tests/verbs/test_generate_text_embeddings.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from io import BytesIO + +import pandas as pd + +from graphrag.index.config.embeddings import ( + all_embeddings, +) +from graphrag.index.run.utils import create_run_context +from graphrag.index.workflows.v1.generate_text_embeddings import ( + build_steps, + workflow_name, +) + +from .util import ( + get_config_for_workflow, + get_workflow_output, + load_input_tables, +) + + +async def test_generate_text_embeddings(): + input_tables = load_input_tables( + inputs=[ + "workflow:create_final_documents", + "workflow:create_final_relationships", + "workflow:create_final_text_units", + "workflow:create_final_entities", + "workflow:create_final_community_reports", + ] + ) + context = create_run_context(None, None, None) + + config = get_config_for_workflow(workflow_name) + + config["text_embed"]["strategy"]["type"] = "mock" + config["embedded_fields"] = all_embeddings + config["text_embed"]["strategy"]["vector_store"] = { + "type": "lancedb", + "db_uri": "./lancedb", + "store_in_table": True, + } + + steps = build_steps(config) + + await get_workflow_output( + input_tables, + { + "steps": steps, + }, + context, + ) + + parquet_files = context.storage.keys() + assert "create_final_documents_raw_content_embeddings.parquet" in parquet_files + assert "create_final_relationships_description_embeddings.parquet" in parquet_files + assert "create_final_text_units_text_embeddings.parquet" in parquet_files + assert "create_final_entities_name_embeddings.parquet" in parquet_files + assert "create_final_entities_description_embeddings.parquet" in parquet_files + assert "create_final_community_reports_title_embeddings.parquet" in parquet_files + assert "create_final_community_reports_summary_embeddings.parquet" in parquet_files + assert ( + "create_final_community_reports_full_content_embeddings.parquet" + in parquet_files + ) + + create_final_documents_raw_content_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_documents_raw_content_embeddings.parquet", as_bytes=True + ) + ) + create_final_documents_raw_content_embeddings = pd.read_parquet( + create_final_documents_raw_content_embeddings_buffer + ) + assert len(create_final_documents_raw_content_embeddings.columns) == 2 + assert "id" in create_final_documents_raw_content_embeddings.columns + assert "embedding" in create_final_documents_raw_content_embeddings.columns + + create_final_relationships_description_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_relationships_description_embeddings.parquet", as_bytes=True + ) + ) + create_final_relationships_description_embeddings = pd.read_parquet( + create_final_relationships_description_embeddings_buffer + ) + assert len(create_final_relationships_description_embeddings.columns) == 2 + assert "id" in create_final_relationships_description_embeddings.columns + assert "embedding" in create_final_relationships_description_embeddings.columns + + create_final_text_units_text_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_text_units_text_embeddings.parquet", as_bytes=True + ) + ) + create_final_text_units_text_embeddings = pd.read_parquet( + create_final_text_units_text_embeddings_buffer + ) + assert len(create_final_text_units_text_embeddings.columns) == 2 + assert "id" in create_final_text_units_text_embeddings.columns + assert "embedding" in create_final_text_units_text_embeddings.columns + + create_final_entities_name_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_entities_name_embeddings.parquet", as_bytes=True + ) + ) + create_final_entities_name_embeddings = pd.read_parquet( + create_final_entities_name_embeddings_buffer + ) + assert len(create_final_entities_name_embeddings.columns) == 2 + assert "id" in create_final_entities_name_embeddings.columns + assert "embedding" in create_final_entities_name_embeddings.columns + + create_final_entities_description_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_entities_description_embeddings.parquet", as_bytes=True + ) + ) + create_final_entities_description_embeddings = pd.read_parquet( + create_final_entities_description_embeddings_buffer + ) + assert len(create_final_entities_description_embeddings.columns) == 2 + assert "id" in create_final_entities_description_embeddings.columns + assert "embedding" in create_final_entities_description_embeddings.columns + + create_final_community_reports_title_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_community_reports_title_embeddings.parquet", as_bytes=True + ) + ) + create_final_community_reports_title_embeddings = pd.read_parquet( + create_final_community_reports_title_embeddings_buffer + ) + assert len(create_final_community_reports_title_embeddings.columns) == 2 + assert "id" in create_final_community_reports_title_embeddings.columns + assert "embedding" in create_final_community_reports_title_embeddings.columns + + create_final_community_reports_summary_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_community_reports_summary_embeddings.parquet", as_bytes=True + ) + ) + create_final_community_reports_summary_embeddings = pd.read_parquet( + create_final_community_reports_summary_embeddings_buffer + ) + assert len(create_final_community_reports_summary_embeddings.columns) == 2 + assert "id" in create_final_community_reports_summary_embeddings.columns + assert "embedding" in create_final_community_reports_summary_embeddings.columns + + create_final_community_reports_full_content_embeddings_buffer = BytesIO( + await context.storage.get( + "create_final_community_reports_full_content_embeddings.parquet", + as_bytes=True, + ) + ) + create_final_community_reports_full_content_embeddings = pd.read_parquet( + create_final_community_reports_full_content_embeddings_buffer + ) + assert len(create_final_community_reports_full_content_embeddings.columns) == 2 + assert "id" in create_final_community_reports_full_content_embeddings.columns + assert "embedding" in create_final_community_reports_full_content_embeddings.columns