diff --git a/.semversioner/next-release/minor-20241018204541069382.json b/.semversioner/next-release/minor-20241018204541069382.json new file mode 100644 index 0000000000..d1b09fd986 --- /dev/null +++ b/.semversioner/next-release/minor-20241018204541069382.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "embeddings moved to a different workflow" +} diff --git a/docs/config/json_yaml.md b/docs/config/json_yaml.md index bc2c5646dc..07aa184c84 100644 --- a/docs/config/json_yaml.md +++ b/docs/config/json_yaml.md @@ -85,8 +85,8 @@ This is the base LLM configuration section. Other steps may override this config - `async_mode` (see Async Mode top-level config) - `batch_size` **int** - The maximum batch size to use. - `batch_max_tokens` **int** - The maximum batch # of tokens. -- `target` **required|all** - Determines which set of embeddings to emit. -- `skip` **list[str]** - Which embeddings to skip. +- `target` **required|all|none** - Determines which set of embeddings to emit. +- `skip` **list[str]** - Which embeddings to skip. Only useful if target=all to customize the list. - `vector_store` **dict** - The vector store to use. Configured for lancedb by default. - `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb` - `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb` @@ -94,7 +94,7 @@ This is the base LLM configuration section. Other steps may override this config - `api_key` **str** (optional - only for AI Search) - The AI Search api key to use. - `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used. - `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True` - - `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings` + - `container_name` **str** - The name of a vector container. This stores all indexes (tables) for a given dataset ingest. Default=`default` - `strategy` **dict** - Fully override the text-embedding strategy. ## chunks diff --git a/docs/examples_notebooks/local_search.ipynb b/docs/examples_notebooks/local_search.ipynb index 0c692f02a9..af4dccbbd1 100644 --- a/docs/examples_notebooks/local_search.ipynb +++ b/docs/examples_notebooks/local_search.ipynb @@ -108,7 +108,7 @@ "# load description embeddings to an in-memory lancedb vectorstore\n", "# to connect to a remote db, specify url and port values.\n", "description_embedding_store = LanceDBVectorStore(\n", - " collection_name=\"entity_description_embeddings\",\n", + " collection_name=\"entity.description\",\n", ")\n", "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", "entity_description_embeddings = store_entity_semantic_embeddings(\n", diff --git a/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb b/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb index bb2e93d8f0..3874b172b1 100644 --- a/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb +++ b/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb @@ -299,7 +299,7 @@ "entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n", "\n", "description_embedding_store = LanceDBVectorStore(\n", - " collection_name=\"entity_description_embeddings\",\n", + " collection_name=\"entity.description\",\n", ")\n", "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", "entity_description_embeddings = store_entity_semantic_embeddings(\n", diff --git a/graphrag/api/index.py b/graphrag/api/index.py index 666f86425e..e5924dfd9b 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -59,13 +59,7 @@ async def build_index( msg = "Cannot resume and update a run at the same time." raise ValueError(msg) - # TODO: must update filepath of lancedb (if used) until the new config engine has been implemented - # TODO: remove the type ignore annotations below once the new config engine has been refactored - vector_store_type = config.embeddings.vector_store["type"] # type: ignore - if vector_store_type == VectorStoreType.LanceDB: - db_uri = config.embeddings.vector_store["db_uri"] # type: ignore - lancedb_dir = Path(config.root_dir).resolve() / db_uri - config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore + config = _patch_vector_config(config) pipeline_config = create_pipeline_config(config) pipeline_cache = ( @@ -90,3 +84,22 @@ async def build_index( progress_reporter.success(output.workflow) progress_reporter.info(str(output.result)) return outputs + + +def _patch_vector_config(config: GraphRagConfig): + """Back-compat patch to ensure a default vector store configuration.""" + if not config.embeddings.vector_store: + config.embeddings.vector_store = { + "type": "lancedb", + "db_uri": "output/lancedb", + "container_name": "default", + "overwrite": True, + } + # TODO: must update filepath of lancedb (if used) until the new config engine has been implemented + # TODO: remove the type ignore annotations below once the new config engine has been refactored + vector_store_type = config.embeddings.vector_store["type"] # type: ignore + if vector_store_type == VectorStoreType.LanceDB: + db_uri = config.embeddings.vector_store["db_uri"] # type: ignore + lancedb_dir = Path(config.root_dir).resolve() / db_uri + config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore + return config diff --git a/graphrag/api/query.py b/graphrag/api/query.py index d6d3bde385..d643e9433b 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -182,56 +182,22 @@ async def local_search( ------ TODO: Document any exceptions to expect. """ - #################################### BEGIN PATCH #################################### - # TODO: remove the following patch that checks for a vector_store prior to v1 release - # TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present - # Only applicable in situations involving a local vector_store (lancedb). The general idea: - # if vector_store not in config: - # 1. assume user is running local if vector_store is not in config - # 2. insert default vector_store in config - # 3 .create lancedb vector_store instance - # 4. upload vector embeddings from the input dataframes to the vector_store - backwards_compatible = False - if not config.embeddings.vector_store: - backwards_compatible = True - from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings - from graphrag.vector_stores.lancedb import LanceDBVectorStore - - config.embeddings.vector_store = { - "type": "lancedb", - "db_uri": f"{Path(config.storage.base_dir)}/lancedb", - "collection_name": "entity_description_embeddings", - "overwrite": True, - } - _entities = read_indexer_entities(nodes, entities, community_level) - description_embedding_store = LanceDBVectorStore( - db_uri=config.embeddings.vector_store["db_uri"], - collection_name=config.embeddings.vector_store["collection_name"], - overwrite=config.embeddings.vector_store["overwrite"], - ) - description_embedding_store.connect( - db_uri=config.embeddings.vector_store["db_uri"] - ) - # dump embeddings from the entities list to the description_embedding_store - store_entity_semantic_embeddings( - entities=_entities, vectorstore=description_embedding_store - ) - #################################### END PATCH #################################### + config = _patch_vector_store(config, nodes, entities, community_level) # TODO: update filepath of lancedb (if used) until the new config engine has been implemented # TODO: remove the type ignore annotations below once the new config engine has been refactored vector_store_type = config.embeddings.vector_store.get("type") # type: ignore vector_store_args = config.embeddings.vector_store - if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible: + if vector_store_type == VectorStoreType.LanceDB: db_uri = config.embeddings.vector_store["db_uri"] # type: ignore lancedb_dir = Path(config.root_dir).resolve() / db_uri vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore - if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release - description_embedding_store = _get_embedding_description_store( - config_args=vector_store_args, # type: ignore - ) + + description_embedding_store = _get_embedding_description_store( + config_args=vector_store_args, # type: ignore + ) _entities = read_indexer_entities(nodes, entities, community_level) _covariates = read_indexer_covariates(covariates) if covariates is not None else [] @@ -289,56 +255,22 @@ async def local_search_streaming( ------ TODO: Document any exceptions to expect. """ - #################################### BEGIN PATCH #################################### - # TODO: remove the following patch that checks for a vector_store prior to v1 release - # TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present - # Only applicable in situations involving a local vector_store (lancedb). The general idea: - # if vector_store not in config: - # 1. assume user is running local if vector_store is not in config - # 2. insert default vector_store in config - # 3 .create lancedb vector_store instance - # 4. upload vector embeddings from the input dataframes to the vector_store - backwards_compatible = False - if not config.embeddings.vector_store: - backwards_compatible = True - from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings - from graphrag.vector_stores.lancedb import LanceDBVectorStore - - config.embeddings.vector_store = { - "type": "lancedb", - "db_uri": f"{Path(config.storage.base_dir)}/lancedb", - "collection_name": "entity_description_embeddings", - "overwrite": True, - } - _entities = read_indexer_entities(nodes, entities, community_level) - description_embedding_store = LanceDBVectorStore( - db_uri=config.embeddings.vector_store["db_uri"], - collection_name=config.embeddings.vector_store["collection_name"], - overwrite=config.embeddings.vector_store["overwrite"], - ) - description_embedding_store.connect( - db_uri=config.embeddings.vector_store["db_uri"] - ) - # dump embeddings from the entities list to the description_embedding_store - store_entity_semantic_embeddings( - entities=_entities, vectorstore=description_embedding_store - ) - #################################### END PATCH #################################### + config = _patch_vector_store(config, nodes, entities, community_level) # TODO: must update filepath of lancedb (if used) until the new config engine has been implemented # TODO: remove the type ignore annotations below once the new config engine has been refactored vector_store_type = config.embeddings.vector_store.get("type") # type: ignore vector_store_args = config.embeddings.vector_store - if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible: + if vector_store_type == VectorStoreType.LanceDB: db_uri = config.embeddings.vector_store["db_uri"] # type: ignore lancedb_dir = Path(config.root_dir).resolve() / db_uri vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore - if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release - description_embedding_store = _get_embedding_description_store( - config_args=vector_store_args, # type: ignore - ) + + description_embedding_store = _get_embedding_description_store( + conf_args=vector_store_args, # type: ignore + ) _entities = read_indexer_entities(nodes, entities, community_level) _covariates = read_indexer_covariates(covariates) if covariates is not None else [] @@ -368,13 +300,55 @@ async def local_search_streaming( yield stream_chunk +def _patch_vector_store( + config: GraphRagConfig, + nodes: pd.DataFrame, + entities: pd.DataFrame, + community_level: int, +) -> GraphRagConfig: + # TODO: remove the following patch that checks for a vector_store prior to v1 release + # TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present + # Only applicable in situations involving a local vector_store (lancedb). The general idea: + # if vector_store not in config: + # 1. assume user is running local if vector_store is not in config + # 2. insert default vector_store in config + # 3 .create lancedb vector_store instance + # 4. upload vector embeddings from the input dataframes to the vector_store + if not config.embeddings.vector_store: + from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings + from graphrag.vector_stores.lancedb import LanceDBVectorStore + + config.embeddings.vector_store = { + "type": "lancedb", + "db_uri": f"{Path(config.storage.base_dir)}/lancedb", + "container_name": "default", + "overwrite": True, + } + description_embedding_store = LanceDBVectorStore( + db_uri=config.embeddings.vector_store["db_uri"], + collection_name="default-entity-description", + overwrite=config.embeddings.vector_store["overwrite"], + ) + description_embedding_store.connect( + db_uri=config.embeddings.vector_store["db_uri"] + ) + # dump embeddings from the entities list to the description_embedding_store + _entities = read_indexer_entities(nodes, entities, community_level) + store_entity_semantic_embeddings( + entities=_entities, vectorstore=description_embedding_store + ) + return config + + def _get_embedding_description_store( config_args: dict, ): """Get the embedding description store.""" vector_store_type = config_args["type"] + collection_name = f"{config_args['container_name']}-entity-description" description_embedding_store = VectorStoreFactory.get_vector_store( - vector_store_type=vector_store_type, kwargs=config_args + vector_store_type=vector_store_type, + kwargs={**config_args, "collection_name": collection_name}, ) description_embedding_store.connect(**config_args) return description_embedding_store diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index d2f5f9e67f..70a9d4955e 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -115,6 +115,7 @@ def run_local_search( config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir resolve_paths(config) + # TODO remove optional create_final_entities_description_embeddings.parquet to delete backwards compatibility dataframe_dict = _resolve_parquet_files( root_dir=root_dir, config=config, @@ -125,7 +126,9 @@ def run_local_search( "create_final_relationships.parquet", "create_final_entities.parquet", ], - optional_list=["create_final_covariates.parquet"], + optional_list=[ + "create_final_covariates.parquet", + ], ) final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] final_community_reports: pd.DataFrame = dataframe_dict[ diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index baf5810961..61fd7ce6f2 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -414,6 +414,7 @@ def hydrate_parallelization_params( raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES, top_level_nodes=reader.bool("top_level_nodes") or defs.SNAPSHOTS_TOP_LEVEL_NODES, + embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS, ) with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")): umap_model = UmapConfig( diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index d4ad18db0f..2da4949783 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -82,6 +82,7 @@ SNAPSHOTS_GRAPHML = False SNAPSHOTS_RAW_ENTITIES = False SNAPSHOTS_TOP_LEVEL_NODES = False +SNAPSHOTS_EMBEDDINGS = False STORAGE_BASE_DIR = "output" STORAGE_TYPE = StorageType.file SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500 @@ -91,7 +92,7 @@ VECTOR_STORE = f""" type: {VectorStoreType.LanceDB.value} db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}' - collection_name: entity_description_embeddings + collection_name: default overwrite: true\ """ diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 8741cf74ae..99d385dff0 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -86,6 +86,7 @@ class TextEmbeddingTarget(str, Enum): all = "all" required = "required" + none = "none" def __repr__(self): """Get a string representation.""" diff --git a/graphrag/config/models/snapshots_config.py b/graphrag/config/models/snapshots_config.py index 08293fb7a7..d46d0ed4ae 100644 --- a/graphrag/config/models/snapshots_config.py +++ b/graphrag/config/models/snapshots_config.py @@ -23,3 +23,7 @@ class SnapshotsConfig(BaseModel): description="A flag indicating whether to take snapshots of top-level nodes.", default=defs.SNAPSHOTS_TOP_LEVEL_NODES, ) + embeddings: bool = Field( + description="A flag indicating whether to take snapshots of embeddings.", + default=defs.SNAPSHOTS_EMBEDDINGS, + ) diff --git a/graphrag/index/config/__init__.py b/graphrag/index/config/__init__.py index 3c40762a84..885e6ad386 100644 --- a/graphrag/index/config/__init__.py +++ b/graphrag/index/config/__init__.py @@ -11,6 +11,18 @@ PipelineMemoryCacheConfig, PipelineNoneCacheConfig, ) +from .embeddings import ( + all_embeddings, + community_full_content_embedding, + community_summary_embedding, + community_title_embedding, + document_raw_content_embedding, + entity_description_embedding, + entity_name_embedding, + relationship_description_embedding, + required_embeddings, + text_unit_text_embedding, +) from .input import ( PipelineCSVInputConfig, PipelineInputConfig, @@ -66,4 +78,14 @@ "PipelineWorkflowConfig", "PipelineWorkflowReference", "PipelineWorkflowStep", + "all_embeddings", + "community_full_content_embedding", + "community_summary_embedding", + "community_title_embedding", + "document_raw_content_embedding", + "entity_description_embedding", + "entity_name_embedding", + "relationship_description_embedding", + "required_embeddings", + "text_unit_text_embedding", ] diff --git a/graphrag/index/config/embeddings.py b/graphrag/index/config/embeddings.py new file mode 100644 index 0000000000..ab00e4566e --- /dev/null +++ b/graphrag/index/config/embeddings.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing embeddings values.""" + +entity_name_embedding = "entity.name" +entity_description_embedding = "entity.description" +relationship_description_embedding = "relationship.description" +document_raw_content_embedding = "document.raw_content" +community_title_embedding = "community.title" +community_summary_embedding = "community.summary" +community_full_content_embedding = "community.full_content" +text_unit_text_embedding = "text_unit.text" + +all_embeddings: set[str] = { + entity_name_embedding, + entity_description_embedding, + relationship_description_embedding, + document_raw_content_embedding, + community_title_embedding, + community_summary_embedding, + community_full_content_embedding, + text_unit_text_embedding, +} +required_embeddings: set[str] = {entity_description_embedding} diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index d6e7bff878..be5f97d7e8 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -22,6 +22,10 @@ PipelineMemoryCacheConfig, PipelineNoneCacheConfig, ) +from graphrag.index.config.embeddings import ( + all_embeddings, + required_embeddings, +) from graphrag.index.config.input import ( PipelineCSVInputConfig, PipelineInputConfigTypes, @@ -56,33 +60,11 @@ create_final_nodes, create_final_relationships, create_final_text_units, + generate_text_embeddings, ) log = logging.getLogger(__name__) - -entity_name_embedding = "entity.name" -entity_description_embedding = "entity.description" -relationship_description_embedding = "relationship.description" -document_raw_content_embedding = "document.raw_content" -community_title_embedding = "community.title" -community_summary_embedding = "community.summary" -community_full_content_embedding = "community.full_content" -text_unit_text_embedding = "text_unit.text" - -all_embeddings: set[str] = { - entity_name_embedding, - entity_description_embedding, - relationship_description_embedding, - document_raw_content_embedding, - community_title_embedding, - community_summary_embedding, - community_full_content_embedding, - text_unit_text_embedding, -} -required_embeddings: set[str] = {entity_description_embedding} - - builtin_document_attributes: set[str] = { "id", "source", @@ -121,11 +103,12 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC ), cache=_get_cache_config(settings), workflows=[ - *_document_workflows(settings, embedded_fields), - *_text_unit_workflows(settings, covariates_enabled, embedded_fields), - *_graph_workflows(settings, embedded_fields), - *_community_workflows(settings, covariates_enabled, embedded_fields), + *_document_workflows(settings), + *_text_unit_workflows(settings, covariates_enabled), + *_graph_workflows(settings), + *_community_workflows(settings, covariates_enabled), *(_covariate_workflows(settings) if covariates_enabled else []), + *(_embeddings_workflows(settings, embedded_fields)), ], ) @@ -138,9 +121,11 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC def _get_embedded_fields(settings: GraphRagConfig) -> set[str]: match settings.embeddings.target: case TextEmbeddingTarget.all: - return all_embeddings - {*settings.embeddings.skip} + return all_embeddings.difference(settings.embeddings.skip) case TextEmbeddingTarget.required: return required_embeddings + case TextEmbeddingTarget.none: + return set() case _: msg = f"Unknown embeddings target: {settings.embeddings.target}" raise ValueError(msg) @@ -163,11 +148,8 @@ def _log_llm_settings(settings: GraphRagConfig) -> None: def _document_workflows( - settings: GraphRagConfig, embedded_fields: set[str] + settings: GraphRagConfig, ) -> list[PipelineWorkflowReference]: - skip_document_raw_content_embedding = ( - document_raw_content_embedding not in embedded_fields - ) return [ PipelineWorkflowReference( name=create_final_documents, @@ -176,15 +158,6 @@ def _document_workflows( {*(settings.input.document_attribute_columns)} - builtin_document_attributes ), - "document_raw_content_embed": _get_embedding_settings( - settings.embeddings, - "document_raw_content", - { - "title_column": "raw_content", - "collection_name": "final_documents_raw_content_embedding", - }, - ), - "skip_raw_content_embedding": skip_document_raw_content_embedding, }, ), ] @@ -193,9 +166,7 @@ def _document_workflows( def _text_unit_workflows( settings: GraphRagConfig, covariates_enabled: bool, - embedded_fields: set[str], ) -> list[PipelineWorkflowReference]: - skip_text_unit_embedding = text_unit_text_embedding not in embedded_fields return [ PipelineWorkflowReference( name=create_base_text_units, @@ -211,13 +182,7 @@ def _text_unit_workflows( PipelineWorkflowReference( name=create_final_text_units, config={ - "text_unit_text_embed": _get_embedding_settings( - settings.embeddings, - "text_unit_text", - {"title_column": "text", "collection_name": "text_units_embedding"}, - ), "covariates_enabled": covariates_enabled, - "skip_text_unit_embedding": skip_text_unit_embedding, }, ), ] @@ -225,7 +190,6 @@ def _text_unit_workflows( def _get_embedding_settings( settings: TextEmbeddingConfig, - embedding_name: str, vector_store_params: dict | None = None, ) -> dict: vector_store_settings = settings.vector_store @@ -243,20 +207,10 @@ def _get_embedding_settings( # This ensures the vector store config is part of the strategy and not the global config return { "strategy": strategy, - "embedding_name": embedding_name, } -def _graph_workflows( - settings: GraphRagConfig, embedded_fields: set[str] -) -> list[PipelineWorkflowReference]: - skip_entity_name_embedding = entity_name_embedding not in embedded_fields - skip_entity_description_embedding = ( - entity_description_embedding not in embedded_fields - ) - skip_relationship_description_embedding = ( - relationship_description_embedding not in embedded_fields - ) +def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]: return [ PipelineWorkflowReference( name=create_base_entity_graph, @@ -286,40 +240,11 @@ def _graph_workflows( ), PipelineWorkflowReference( name=create_final_entities, - config={ - "entity_name_embed": _get_embedding_settings( - settings.embeddings, - "entity_name", - { - "title_column": "name", - "collection_name": "entity_name_embeddings", - }, - ), - "entity_name_description_embed": _get_embedding_settings( - settings.embeddings, - "entity_name_description", - { - "title_column": "description", - "collection_name": "entity_description_embeddings", - }, - ), - "skip_name_embedding": skip_entity_name_embedding, - "skip_description_embedding": skip_entity_description_embedding, - }, + config={}, ), PipelineWorkflowReference( name=create_final_relationships, - config={ - "relationship_description_embed": _get_embedding_settings( - settings.embeddings, - "relationship_description", - { - "title_column": "description", - "collection_name": "relationships_description_embeddings", - }, - ), - "skip_description_embedding": skip_relationship_description_embedding, - }, + config={}, ), PipelineWorkflowReference( name=create_final_nodes, @@ -332,24 +257,14 @@ def _graph_workflows( def _community_workflows( - settings: GraphRagConfig, covariates_enabled: bool, embedded_fields: set[str] + settings: GraphRagConfig, covariates_enabled: bool ) -> list[PipelineWorkflowReference]: - skip_community_title_embedding = community_title_embedding not in embedded_fields - skip_community_summary_embedding = ( - community_summary_embedding not in embedded_fields - ) - skip_community_full_content_embedding = ( - community_full_content_embedding not in embedded_fields - ) return [ PipelineWorkflowReference(name=create_final_communities), PipelineWorkflowReference( name=create_final_community_reports, config={ "covariates_enabled": covariates_enabled, - "skip_title_embedding": skip_community_title_embedding, - "skip_summary_embedding": skip_community_summary_embedding, - "skip_full_content_embedding": skip_community_full_content_embedding, "create_community_reports": { **settings.community_reports.parallelization.model_dump(), "async_mode": settings.community_reports.async_mode, @@ -357,27 +272,6 @@ def _community_workflows( settings.root_dir ), }, - "community_report_full_content_embed": _get_embedding_settings( - settings.embeddings, - "community_report_full_content", - { - "title_column": "full_content", - "collection_name": "final_community_reports_full_content_embedding", - }, - ), - "community_report_summary_embed": _get_embedding_settings( - settings.embeddings, - "community_report_summary", - { - "title_column": "summary", - "collection_name": "final_community_reports_summary_embedding", - }, - ), - "community_report_title_embed": _get_embedding_settings( - settings.embeddings, - "community_report_title", - {"title_column": "title"}, - ), }, ), ] @@ -401,6 +295,21 @@ def _covariate_workflows( ] +def _embeddings_workflows( + settings: GraphRagConfig, embedded_fields: set[str] +) -> list[PipelineWorkflowReference]: + return [ + PipelineWorkflowReference( + name=generate_text_embeddings, + config={ + "snapshot_embeddings": settings.snapshots.embeddings, + "text_embed": _get_embedding_settings(settings.embeddings), + "embedded_fields": embedded_fields, + }, + ), + ] + + def _get_pipeline_input_config( settings: GraphRagConfig, ) -> PipelineInputConfigTypes: diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py index f6d49a10fa..c0101c8d96 100644 --- a/graphrag/index/flows/create_final_community_reports.py +++ b/graphrag/index/flows/create_final_community_reports.py @@ -31,7 +31,6 @@ NODE_ID, NODE_NAME, ) -from graphrag.index.operations.embed_text import embed_text from graphrag.index.operations.summarize_communities import ( prepare_community_reports, restore_community_hierarchy, @@ -49,9 +48,6 @@ async def create_final_community_reports( summarization_strategy: dict, async_mode: AsyncType = AsyncType.AsyncIO, num_threads: int = 4, - full_content_text_embed: dict | None = None, - summary_text_embed: dict | None = None, - title_text_embed: dict | None = None, ) -> pd.DataFrame: """All the steps to transform community reports.""" nodes = _prep_nodes(nodes_input) @@ -86,39 +82,6 @@ async def create_final_community_reports( lambda _x: str(uuid4()) ) - # Embed full content if not skipped - if full_content_text_embed: - community_reports["full_content_embedding"] = await embed_text( - community_reports, - callbacks, - cache, - column="full_content", - strategy=full_content_text_embed["strategy"], - embedding_name="community_report_full_content", - ) - - # Embed summary if not skipped - if summary_text_embed: - community_reports["summary_embedding"] = await embed_text( - community_reports, - callbacks, - cache, - column="summary", - strategy=summary_text_embed["strategy"], - embedding_name="community_report_summary", - ) - - # Embed title if not skipped - if title_text_embed: - community_reports["title_embedding"] = await embed_text( - community_reports, - callbacks, - cache, - column="title", - strategy=title_text_embed["strategy"], - embedding_name="community_report_title", - ) - # Merge by community and it with communities to add size and period return community_reports.merge( communities_input.loc[:, ["id", "size", "period"]], diff --git a/graphrag/index/flows/create_final_documents.py b/graphrag/index/flows/create_final_documents.py index 29f28e52b1..fcd213e314 100644 --- a/graphrag/index/flows/create_final_documents.py +++ b/graphrag/index/flows/create_final_documents.py @@ -4,21 +4,12 @@ """All the steps to transform final documents.""" import pandas as pd -from datashaper import ( - VerbCallbacks, -) -from graphrag.index.cache import PipelineCache -from graphrag.index.operations.embed_text import embed_text - -async def create_final_documents( +def create_final_documents( documents: pd.DataFrame, text_units: pd.DataFrame, - callbacks: VerbCallbacks, - cache: PipelineCache, document_attribute_columns: list[str] | None = None, - raw_content_text_embed: dict | None = None, ) -> pd.DataFrame: """All the steps to transform final documents.""" exploded = ( @@ -72,13 +63,4 @@ async def create_final_documents( # Drop the original attribute columns after collapsing them rejoined.drop(columns=document_attribute_columns, inplace=True) - if raw_content_text_embed: - rejoined["raw_content_embedding"] = await embed_text( - rejoined, - callbacks, - cache, - column="raw_content", - strategy=raw_content_text_embed["strategy"], - ) - return rejoined diff --git a/graphrag/index/flows/create_final_entities.py b/graphrag/index/flows/create_final_entities.py index 9601cb3169..5c03977d40 100644 --- a/graphrag/index/flows/create_final_entities.py +++ b/graphrag/index/flows/create_final_entities.py @@ -8,18 +8,13 @@ VerbCallbacks, ) -from graphrag.index.cache import PipelineCache -from graphrag.index.operations.embed_text import embed_text from graphrag.index.operations.split_text import split_text from graphrag.index.operations.unpack_graph import unpack_graph -async def create_final_entities( +def create_final_entities( entity_graph: pd.DataFrame, callbacks: VerbCallbacks, - cache: PipelineCache, - name_text_embed: dict | None = None, - description_text_embed: dict | None = None, ) -> pd.DataFrame: """All the steps to transform final entities.""" # Process nodes @@ -44,37 +39,6 @@ async def create_final_entities( nodes = nodes.loc[nodes["name"].notna()] # Split 'source_id' column into 'text_unit_ids' - nodes = split_text( + return split_text( nodes, column="source_id", separator=",", to="text_unit_ids" ).drop(columns=["source_id"]) - - # Embed name if not skipped - if name_text_embed: - nodes["name_embedding"] = await embed_text( - nodes, - callbacks, - cache, - column="name", - strategy=name_text_embed["strategy"], - embedding_name="entity_name", - ) - - # Embed description if not skipped - if description_text_embed: - # Concatenate 'name' and 'description' and embed - nodes["name_description"] = nodes["name"] + ":" + nodes["description"] - nodes["description_embedding"] = await embed_text( - nodes, - callbacks, - cache, - column="name_description", - strategy=description_text_embed["strategy"], - embedding_name="entity_name_description", - ) - - # Drop rows with NaN 'description_embedding' if not using vector store - if not description_text_embed.get("strategy", {}).get("vector_store"): - nodes = nodes.loc[nodes["description_embedding"].notna()] - nodes.drop(columns="name_description", inplace=True) - - return nodes diff --git a/graphrag/index/flows/create_final_relationships.py b/graphrag/index/flows/create_final_relationships.py index 0f283cf621..0e61008ae3 100644 --- a/graphrag/index/flows/create_final_relationships.py +++ b/graphrag/index/flows/create_final_relationships.py @@ -10,20 +10,16 @@ VerbCallbacks, ) -from graphrag.index.cache import PipelineCache from graphrag.index.operations.compute_edge_combined_degree import ( compute_edge_combined_degree, ) -from graphrag.index.operations.embed_text import embed_text from graphrag.index.operations.unpack_graph import unpack_graph -async def create_final_relationships( +def create_final_relationships( entity_graph: pd.DataFrame, nodes: pd.DataFrame, callbacks: VerbCallbacks, - cache: PipelineCache, - description_text_embed: dict | None = None, ) -> pd.DataFrame: """All the steps to transform final relationships.""" graph_edges = unpack_graph(entity_graph, callbacks, "clustered_graph", "edges") @@ -34,16 +30,6 @@ async def create_final_relationships( pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True) ) - if description_text_embed: - filtered["description_embedding"] = await embed_text( - filtered, - callbacks, - cache, - column="description", - strategy=description_text_embed["strategy"], - embedding_name="relationship_description", - ) - pruned_edges = filtered.drop(columns=["level"]) filtered_nodes = nodes[nodes["level"] == 0].reset_index(drop=True) diff --git a/graphrag/index/flows/create_final_text_units.py b/graphrag/index/flows/create_final_text_units.py index a63d797fa0..6f6ed70572 100644 --- a/graphrag/index/flows/create_final_text_units.py +++ b/graphrag/index/flows/create_final_text_units.py @@ -6,22 +6,13 @@ from typing import cast import pandas as pd -from datashaper import ( - VerbCallbacks, -) -from graphrag.index.cache import PipelineCache -from graphrag.index.operations.embed_text import embed_text - -async def create_final_text_units( +def create_final_text_units( text_units: pd.DataFrame, final_entities: pd.DataFrame, final_relationships: pd.DataFrame, final_covariates: pd.DataFrame | None, - callbacks: VerbCallbacks, - cache: PipelineCache, - text_text_embed: dict | None = None, ) -> pd.DataFrame: """All the steps to transform the text units.""" selected = text_units.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename( @@ -41,30 +32,12 @@ async def create_final_text_units( aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index() - is_using_vector_store = False - if text_text_embed: - aggregated["text_embedding"] = await embed_text( - aggregated, - callbacks, - cache, - column="text", - strategy=text_text_embed["strategy"], - ) - is_using_vector_store = ( - text_text_embed.get("strategy", {}).get("vector_store", None) is not None - ) - return cast( pd.DataFrame, aggregated[ [ "id", "text", - *( - [] - if (not text_text_embed or is_using_vector_store) - else ["text_embedding"] - ), "n_tokens", "document_ids", "entity_ids", diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py new file mode 100644 index 0000000000..c4518e2a1c --- /dev/null +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -0,0 +1,146 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform the text units.""" + +import logging + +import pandas as pd +from datashaper import ( + VerbCallbacks, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.config.embeddings import ( + community_full_content_embedding, + community_summary_embedding, + community_title_embedding, + document_raw_content_embedding, + entity_description_embedding, + entity_name_embedding, + relationship_description_embedding, + text_unit_text_embedding, +) +from graphrag.index.operations.embed_text import embed_text +from graphrag.index.operations.snapshot import snapshot +from graphrag.index.storage import PipelineStorage + +log = logging.getLogger(__name__) + + +async def generate_text_embeddings( + final_documents: pd.DataFrame | None, + final_relationships: pd.DataFrame | None, + final_text_units: pd.DataFrame | None, + final_entities: pd.DataFrame | None, + final_community_reports: pd.DataFrame | None, + callbacks: VerbCallbacks, + cache: PipelineCache, + storage: PipelineStorage, + text_embed_config: dict, + embedded_fields: set[str], + embeddings_snapshot_enabled: bool = False, +) -> None: + """All the steps to generate all embeddings.""" + embedding_param_map = { + document_raw_content_embedding: { + "data": final_documents.loc[:, ["id", "raw_content"]] + if final_documents is not None + else None, + "column_to_embed": "raw_content", + }, + relationship_description_embedding: { + "data": final_relationships.loc[:, ["id", "description"]] + if final_relationships is not None + else None, + "column_to_embed": "description", + }, + text_unit_text_embedding: { + "data": final_text_units.loc[:, ["id", "text"]] + if final_text_units is not None + else None, + "column_to_embed": "text", + }, + entity_name_embedding: { + "data": final_entities.loc[:, ["id", "name", "description"]] + if final_entities is not None + else None, + "column_to_embed": "name", + }, + entity_description_embedding: { + "data": final_entities.loc[:, ["id", "name", "description"]].assign( + name_description=lambda df: df["name"] + ":" + df["description"] + ) + if final_entities is not None + else None, + "column_to_embed": "name_description", + }, + community_title_embedding: { + "data": final_community_reports.loc[ + :, ["id", "full_content", "summary", "title"] + ] + if final_community_reports is not None + else None, + "column_to_embed": "title", + }, + community_summary_embedding: { + "data": final_community_reports.loc[ + :, ["id", "full_content", "summary", "title"] + ] + if final_community_reports is not None + else None, + "column_to_embed": "summary", + }, + community_full_content_embedding: { + "data": final_community_reports.loc[ + :, ["id", "full_content", "summary", "title"] + ] + if final_community_reports is not None + else None, + "column_to_embed": "full_content", + }, + } + + log.info("Creating embeddings") + for field in embedded_fields: + await _run_and_snapshot_embeddings( + name=field, + callbacks=callbacks, + cache=cache, + storage=storage, + text_embed_config=text_embed_config, + embeddings_snapshot_enabled=embeddings_snapshot_enabled, + **embedding_param_map[field], + ) + + +async def _run_and_snapshot_embeddings( + name: str, + data: pd.DataFrame, + column_to_embed: str, + callbacks: VerbCallbacks, + cache: PipelineCache, + storage: PipelineStorage, + text_embed_config: dict, + embeddings_snapshot_enabled: bool, +) -> None: + """All the steps to generate single embedding.""" + if text_embed_config: + data["embedding"] = await embed_text( + data, + callbacks, + cache, + embed_column=column_to_embed, + embedding_name=name, + strategy=text_embed_config["strategy"], + ) + + data = data.loc[:, ["id", "embedding"]] + + if embeddings_snapshot_enabled is True: + await snapshot( + data, + name=f"embeddings.{name}", + storage=storage, + formats=["parquet"], + ) diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index 00a852df18..869202790a 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -49,7 +49,7 @@ # api_key: # if not set, will attempt to use managed identity. Expects the `Search Index Data Contributor` RBAC role in this case. # audience: # if using managed identity, the audience to use for the token # overwrite: true # or false. Only applicable at index creation time - # collection_name: # the name of the collection to use + # collection_name: # the name of the collection to use. Default: 'default' llm: api_key: ${{GRAPHRAG_API_KEY}} type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index df4f400e62..49114592ef 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -42,9 +42,11 @@ async def embed_text( input: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, - column: str, + embed_column: str, strategy: dict, - embedding_name: str = "default", + embedding_name: str, + id_column: str = "id", + title_column: str | None = None, ): """ Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector. @@ -91,18 +93,19 @@ async def embed_text( input, callbacks, cache, - column, + embed_column, strategy, vector_store, vector_store_workflow_config, - vector_store_config.get("store_in_table", False), + id_column=id_column, + title_column=title_column, ) return await _text_embed_in_memory( input, callbacks, cache, - column, + embed_column, strategy, ) @@ -111,14 +114,14 @@ async def _text_embed_in_memory( input: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, - column: str, + embed_column: str, strategy: dict, ): strategy_type = strategy["type"] strategy_exec = load_strategy(strategy_type) strategy_args = {**strategy} - texts: list[str] = input[column].to_numpy().tolist() + texts: list[str] = input[embed_column].to_numpy().tolist() result = await strategy_exec(texts, callbacks, cache, strategy_args) return result.embeddings @@ -128,11 +131,12 @@ async def _text_embed_with_vector_store( input: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, - column: str, + embed_column: str, strategy: dict[str, Any], vector_store: BaseVectorStore, vector_store_config: dict, - store_in_table: bool = False, + id_column: str = "id", + title_column: str | None = None, ): strategy_type = strategy["type"] strategy_exec = load_strategy(strategy_type) @@ -142,24 +146,24 @@ async def _text_embed_with_vector_store( insert_batch_size: int = ( vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE ) - title_column: str = vector_store_config.get("title_column", "title") - id_column: str = vector_store_config.get("id_column", "id") + overwrite: bool = vector_store_config.get("overwrite", True) - if column not in input.columns: + if embed_column not in input.columns: + msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}" + raise ValueError(msg) + title = title_column or embed_column + if title not in input.columns: msg = ( - f"Column {column} not found in input dataframe with columns {input.columns}" + f"Column {title} not found in input dataframe with columns {input.columns}" ) raise ValueError(msg) - if title_column not in input.columns: - msg = f"Column {title_column} not found in input dataframe with columns {input.columns}" - raise ValueError(msg) if id_column not in input.columns: msg = f"Column {id_column} not found in input dataframe with columns {input.columns}" raise ValueError(msg) total_rows = 0 - for row in input[column]: + for row in input[embed_column]: if isinstance(row, list): total_rows += len(row) else: @@ -172,8 +176,8 @@ async def _text_embed_with_vector_store( while insert_batch_size * i < input.shape[0]: batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)] - texts: list[str] = batch[column].to_numpy().tolist() - titles: list[str] = batch[title_column].to_numpy().tolist() + texts: list[str] = batch[embed_column].to_numpy().tolist() + titles: list[str] = batch[title].to_numpy().tolist() ids: list[str] = batch[id_column].to_numpy().tolist() result = await strategy_exec( texts, @@ -181,7 +185,7 @@ async def _text_embed_with_vector_store( cache, strategy_args, ) - if store_in_table and result.embeddings: + if result.embeddings: embeddings = [ embedding for embedding in result.embeddings if embedding is not None ] @@ -204,10 +208,7 @@ async def _text_embed_with_vector_store( starting_index += len(documents) i += 1 - if store_in_table: - return all_results - - return None + return all_results def _create_vector_store( @@ -226,12 +227,10 @@ def _create_vector_store( def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: - collection_name = vector_store_config.get("collection_name") - if not collection_name: - collection_names = vector_store_config.get("collection_names", {}) - collection_name = collection_names.get(embedding_name, embedding_name) + container_name = vector_store_config.get("container_name") + collection_name = f"{container_name}.{embedding_name}".replace(".", "-") - msg = f"using vector store {vector_store_config.get('type')} with collection_name {collection_name} for embedding {embedding_name}" + msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}" log.info(msg) return collection_name diff --git a/graphrag/index/operations/snapshot.py b/graphrag/index/operations/snapshot.py index 2b85c3ea80..c889595649 100644 --- a/graphrag/index/operations/snapshot.py +++ b/graphrag/index/operations/snapshot.py @@ -17,8 +17,8 @@ async def snapshot( """Take a entire snapshot of the tabular data.""" for fmt in formats: if fmt == "parquet": - await storage.set(name + ".parquet", input.to_parquet()) + await storage.set(f"{name}.parquet", input.to_parquet()) elif fmt == "json": await storage.set( - name + ".json", input.to_json(orient="records", lines=True) + f"{name}.json", input.to_json(orient="records", lines=True) ) diff --git a/graphrag/index/update/entities.py b/graphrag/index/update/entities.py index 68ce15177c..2e16f4f05f 100644 --- a/graphrag/index/update/entities.py +++ b/graphrag/index/update/entities.py @@ -181,7 +181,8 @@ async def _run_entity_description_embedding( entities_df, callbacks, cache, - column="name_description", + embed_column="name_description", + embedding_name="entity.description", strategy=embed_config.get("strategy", {}), ) return entities_df.drop(columns=["name_description"]) diff --git a/graphrag/index/workflows/default_workflows.py b/graphrag/index/workflows/default_workflows.py index 896a861a0b..5a3d176b56 100644 --- a/graphrag/index/workflows/default_workflows.py +++ b/graphrag/index/workflows/default_workflows.py @@ -67,7 +67,13 @@ from .v1.create_final_text_units import ( workflow_name as create_final_text_units, ) +from .v1.generate_text_embeddings import ( + build_steps as build_generate_text_embeddings_steps, +) +from .v1.generate_text_embeddings import ( + workflow_name as generate_text_embeddings, +) default_workflows: WorkflowDefinitions = { create_base_entity_graph: build_create_base_entity_graph_steps, @@ -80,4 +86,5 @@ create_final_covariates: build_create_final_covariates_steps, create_final_entities: build_create_final_entities_steps, create_final_communities: build_create_final_communities_steps, + generate_text_embeddings: build_generate_text_embeddings_steps, } diff --git a/graphrag/index/workflows/v1/create_final_community_reports.py b/graphrag/index/workflows/v1/create_final_community_reports.py index 3077ac6960..9303a32d2c 100644 --- a/graphrag/index/workflows/v1/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/create_final_community_reports.py @@ -23,19 +23,6 @@ def build_steps( async_mode = create_community_reports_config.get("async_mode") num_threads = create_community_reports_config.get("num_threads") - base_text_embed = config.get("text_embed", {}) - community_report_full_content_embed_config = config.get( - "community_report_full_content_embed", base_text_embed - ) - community_report_summary_embed_config = config.get( - "community_report_summary_embed", base_text_embed - ) - community_report_title_embed_config = config.get( - "community_report_title_embed", base_text_embed - ) - skip_title_embedding = config.get("skip_title_embedding", False) - skip_summary_embedding = config.get("skip_summary_embedding", False) - skip_full_content_embedding = config.get("skip_full_content_embedding", False) input = { "source": "workflow:create_final_nodes", "relationships": "workflow:create_final_relationships", @@ -48,21 +35,6 @@ def build_steps( { "verb": "create_final_community_reports", "args": { - "full_content_text_embed": ( - community_report_full_content_embed_config - if not skip_full_content_embedding - else None - ), - "summary_text_embed": ( - community_report_summary_embed_config - if not skip_summary_embedding - else None - ), - "title_text_embed": ( - community_report_title_embed_config - if not skip_title_embedding - else None - ), "summarization_strategy": summarization_strategy, "async_mode": async_mode, "num_threads": num_threads, diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py index d4dc96b924..65e22b35ee 100644 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ b/graphrag/index/workflows/v1/create_final_documents.py @@ -19,23 +19,11 @@ def build_steps( ## Dependencies * `workflow:create_final_text_units` """ - base_text_embed = config.get("text_embed", {}) - document_raw_content_embed_config = config.get( - "document_raw_content_embed", base_text_embed - ) - skip_raw_content_embedding = config.get("skip_raw_content_embedding", False) document_attribute_columns = config.get("document_attribute_columns", []) return [ { "verb": "create_final_documents", - "args": { - "document_attribute_columns": document_attribute_columns, - "raw_content_text_embed": ( - document_raw_content_embed_config - if not skip_raw_content_embedding - else None - ), - }, + "args": {"document_attribute_columns": document_attribute_columns}, "input": { "source": DEFAULT_INPUT_NAME, "text_units": "workflow:create_final_text_units", diff --git a/graphrag/index/workflows/v1/create_final_entities.py b/graphrag/index/workflows/v1/create_final_entities.py index 7242800f39..50ee56d8e5 100644 --- a/graphrag/index/workflows/v1/create_final_entities.py +++ b/graphrag/index/workflows/v1/create_final_entities.py @@ -3,13 +3,16 @@ """A module containing build_steps method definition.""" +import logging + from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_entities" +log = logging.getLogger(__name__) def build_steps( - config: PipelineWorkflowConfig, + config: PipelineWorkflowConfig, # noqa: ARG001 ) -> list[PipelineWorkflowStep]: """ Create the final entities table. @@ -17,26 +20,10 @@ def build_steps( ## Dependencies * `workflow:create_base_entity_graph` """ - base_text_embed = config.get("text_embed", {}) - entity_name_embed_config = config.get("entity_name_embed", base_text_embed) - entity_name_description_embed_config = config.get( - "entity_name_description_embed", base_text_embed - ) - - skip_name_embedding = config.get("skip_name_embedding", False) - skip_description_embedding = config.get("skip_description_embedding", False) - return [ { "verb": "create_final_entities", - "args": { - "name_text_embed": entity_name_embed_config - if not skip_name_embedding - else None, - "description_text_embed": entity_name_description_embed_config - if not skip_description_embedding - else None, - }, + "args": {}, "input": {"source": "workflow:create_base_entity_graph"}, }, ] diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index 90e2145107..3eaff05b0e 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -3,13 +3,17 @@ """A module containing build_steps method definition.""" +import logging + from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_relationships" +log = logging.getLogger(__name__) + def build_steps( - config: PipelineWorkflowConfig, + config: PipelineWorkflowConfig, # noqa: ARG001 ) -> list[PipelineWorkflowStep]: """ Create the final relationships table. @@ -18,19 +22,10 @@ def build_steps( * `workflow:create_base_entity_graph` * `workflow:create_final_nodes` """ - base_text_embed = config.get("text_embed", {}) - relationship_description_embed_config = config.get( - "relationship_description_embed", base_text_embed - ) - skip_description_embedding = config.get("skip_description_embedding", False) return [ { "verb": "create_final_relationships", - "args": { - "description_text_embed": relationship_description_embed_config - if not skip_description_embedding - else None, - }, + "args": {}, "input": { "source": "workflow:create_base_entity_graph", "nodes": "workflow:create_final_nodes", diff --git a/graphrag/index/workflows/v1/create_final_text_units.py b/graphrag/index/workflows/v1/create_final_text_units.py index 0638253a62..31015b0d01 100644 --- a/graphrag/index/workflows/v1/create_final_text_units.py +++ b/graphrag/index/workflows/v1/create_final_text_units.py @@ -19,10 +19,6 @@ def build_steps( * `workflow:create_final_entities` * `workflow:create_final_communities` """ - base_text_embed = config.get("text_embed", {}) - text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed) - - skip_text_unit_embedding = config.get("skip_text_unit_embedding", False) covariates_enabled = config.get("covariates_enabled", False) input = { @@ -37,11 +33,7 @@ def build_steps( return [ { "verb": "create_final_text_units", - "args": { - "text_text_embed": text_unit_text_embed_config - if not skip_text_unit_embedding - else None, - }, + "args": {}, "input": input, }, ] diff --git a/graphrag/index/workflows/v1/generate_text_embeddings.py b/graphrag/index/workflows/v1/generate_text_embeddings.py new file mode 100644 index 0000000000..58a8a44afd --- /dev/null +++ b/graphrag/index/workflows/v1/generate_text_embeddings.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +import logging + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +log = logging.getLogger(__name__) + +workflow_name = "generate_text_embeddings" + +input = { + "source": "workflow:create_final_documents", + "relationships": "workflow:create_final_relationships", + "text_units": "workflow:create_final_text_units", + "entities": "workflow:create_final_entities", + "community_reports": "workflow:create_final_community_reports", +} + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final embeddings files. + + ## Dependencies + * `workflow:create_final_documents` + * `workflow:create_final_relationships` + * `workflow:create_final_text_units` + * `workflow:create_final_entities` + * `workflow:create_final_community_reports` + """ + text_embed = config.get("text_embed", {}) + embedded_fields = config.get("embedded_fields", {}) + embeddings_snapshot_enabled = config.get("snapshot_embeddings", False) + return [ + { + "verb": "generate_text_embeddings", + "args": { + "text_embed": text_embed, + "embedded_fields": embedded_fields, + "embeddings_snapshot_enabled": embeddings_snapshot_enabled, + }, + "input": input, + }, + ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index 37e2036743..8e080c5c71 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -15,6 +15,7 @@ create_final_relationships, ) from .create_final_text_units import create_final_text_units +from .generate_text_embeddings import generate_text_embeddings __all__ = [ "create_base_entity_graph", @@ -27,4 +28,5 @@ "create_final_nodes", "create_final_relationships", "create_final_text_units", + "generate_text_embeddings", ] diff --git a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py index bd67fd23aa..ec52f7d564 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py @@ -30,9 +30,6 @@ async def create_final_community_reports( summarization_strategy: dict, async_mode: AsyncType = AsyncType.AsyncIO, num_threads: int = 4, - full_content_text_embed: dict | None = None, - summary_text_embed: dict | None = None, - title_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform community reports.""" @@ -57,9 +54,6 @@ async def create_final_community_reports( summarization_strategy, async_mode=async_mode, num_threads=num_threads, - full_content_text_embed=full_content_text_embed, - summary_text_embed=summary_text_embed, - title_text_embed=title_text_embed, ) return create_verb_result( diff --git a/graphrag/index/workflows/v1/subflows/create_final_documents.py b/graphrag/index/workflows/v1/subflows/create_final_documents.py index 191bbd9ed0..34e2e5e018 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_documents.py +++ b/graphrag/index/workflows/v1/subflows/create_final_documents.py @@ -8,13 +8,11 @@ import pandas as pd from datashaper import ( Table, - VerbCallbacks, VerbInput, verb, ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_final_documents import ( create_final_documents as create_final_documents_flow, ) @@ -25,25 +23,15 @@ name="create_final_documents", treats_input_tables_as_immutable=True, ) -async def create_final_documents( +def create_final_documents( input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, document_attribute_columns: list[str] | None = None, - raw_content_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final documents.""" source = cast(pd.DataFrame, input.get_input()) text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table) - output = await create_final_documents_flow( - source, - text_units, - callbacks, - cache, - document_attribute_columns=document_attribute_columns, - raw_content_text_embed=raw_content_text_embed, - ) + output = create_final_documents_flow(source, text_units, document_attribute_columns) return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_entities.py b/graphrag/index/workflows/v1/subflows/create_final_entities.py index fa3ae8981b..74fe1abcce 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_entities.py +++ b/graphrag/index/workflows/v1/subflows/create_final_entities.py @@ -14,7 +14,6 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_final_entities import ( create_final_entities as create_final_entities_flow, ) @@ -24,23 +23,17 @@ name="create_final_entities", treats_input_tables_as_immutable=True, ) -async def create_final_entities( +def create_final_entities( input: VerbInput, callbacks: VerbCallbacks, - cache: PipelineCache, - name_text_embed: dict | None = None, - description_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final entities.""" source = cast(pd.DataFrame, input.get_input()) - output = await create_final_entities_flow( + output = create_final_entities_flow( source, callbacks, - cache, - name_text_embed=name_text_embed, - description_text_embed=description_text_embed, ) return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_relationships.py b/graphrag/index/workflows/v1/subflows/create_final_relationships.py index 24222846ba..b502e34f3d 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_relationships.py +++ b/graphrag/index/workflows/v1/subflows/create_final_relationships.py @@ -14,7 +14,6 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result -from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_final_relationships import ( create_final_relationships as create_final_relationships_flow, ) @@ -25,23 +24,19 @@ name="create_final_relationships", treats_input_tables_as_immutable=True, ) -async def create_final_relationships( +def create_final_relationships( input: VerbInput, callbacks: VerbCallbacks, - cache: PipelineCache, - description_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final relationships.""" source = cast(pd.DataFrame, input.get_input()) nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) - output = await create_final_relationships_flow( - source, - nodes, - callbacks, - cache, - description_text_embed=description_text_embed, + output = create_final_relationships_flow( + entity_graph=source, + nodes=nodes, + callbacks=callbacks, ) return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units.py b/graphrag/index/workflows/v1/subflows/create_final_text_units.py index 37ddb10697..a950dcb974 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_final_text_units.py @@ -8,14 +8,12 @@ import pandas as pd from datashaper import ( Table, - VerbCallbacks, VerbInput, VerbResult, create_verb_result, verb, ) -from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_final_text_units import ( create_final_text_units as create_final_text_units_flow, ) @@ -26,10 +24,7 @@ @verb(name="create_final_text_units", treats_input_tables_as_immutable=True) async def create_final_text_units( input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, runtime_storage: PipelineStorage, - text_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform the text units.""" @@ -45,14 +40,11 @@ async def create_final_text_units( if final_covariates: final_covariates = cast(pd.DataFrame, final_covariates.table) - output = await create_final_text_units_flow( + output = create_final_text_units_flow( text_units, final_entities, final_relationships, final_covariates, - callbacks, - cache, - text_text_embed=text_text_embed, ) return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py b/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py new file mode 100644 index 0000000000..403be28404 --- /dev/null +++ b/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to transform the text units.""" + +import logging +from typing import cast + +import pandas as pd +from datashaper import ( + Table, + VerbCallbacks, + VerbInput, + VerbResult, + create_verb_result, + verb, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.flows.generate_text_embeddings import ( + generate_text_embeddings as generate_text_embeddings_flow, +) +from graphrag.index.storage import PipelineStorage +from graphrag.index.utils.ds_util import get_required_input_table + +log = logging.getLogger(__name__) + + +@verb(name="generate_text_embeddings", treats_input_tables_as_immutable=True) +async def generate_text_embeddings( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + storage: PipelineStorage, + text_embed: dict, + embedded_fields: set[str], + embeddings_snapshot_enabled: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to generate embeddings.""" + source = cast(pd.DataFrame, input.get_input()) + final_relationships = cast( + pd.DataFrame, get_required_input_table(input, "relationships").table + ) + final_text_units = cast( + pd.DataFrame, get_required_input_table(input, "text_units").table + ) + final_entities = cast( + pd.DataFrame, get_required_input_table(input, "entities").table + ) + + final_community_reports = cast( + pd.DataFrame, get_required_input_table(input, "community_reports").table + ) + + await generate_text_embeddings_flow( + final_documents=source, + final_relationships=final_relationships, + final_text_units=final_text_units, + final_entities=final_entities, + final_community_reports=final_community_reports, + callbacks=callbacks, + cache=cache, + storage=storage, + text_embed_config=text_embed, + embedded_fields=embedded_fields, + embeddings_snapshot_enabled=embeddings_snapshot_enabled, + ) + + return create_verb_result(cast(Table, pd.DataFrame())) diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 9d8b24af6b..0d64a9edac 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -3,10 +3,9 @@ """The LanceDB vector storage implementation package.""" -import json +import json # noqa: I001 from typing import Any -import lancedb as lancedb import pyarrow as pa from graphrag.model.types import TextEmbedder @@ -16,6 +15,7 @@ VectorStoreDocument, VectorStoreSearchResult, ) +import lancedb class LanceDBVectorStore(BaseVectorStore): diff --git a/tests/fixtures/azure/settings.yml b/tests/fixtures/azure/settings.yml index cb61192b9d..5ecec80990 100644 --- a/tests/fixtures/azure/settings.yml +++ b/tests/fixtures/azure/settings.yml @@ -6,9 +6,7 @@ embeddings: type: "azure_ai_search" url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} - collection_name: "azure_ci" - entity_name_description: - title_column: "name" + container_name: "azure_ci" input: type: blob diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 5b135145d4..0b79d8dbca 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -8,7 +8,8 @@ 2500 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 0 }, "create_base_entity_graph": { "row_range": [ @@ -16,7 +17,8 @@ 2500 ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 300, + "expected_artifacts": 1 }, "create_final_entities": { "row_range": [ @@ -29,7 +31,8 @@ "graph_embedding" ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 300, + "expected_artifacts": 1 }, "create_final_relationships": { "row_range": [ @@ -37,7 +40,8 @@ 6000 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_nodes": { "row_range": [ @@ -52,7 +56,8 @@ "level" ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_communities": { "row_range": [ @@ -60,7 +65,8 @@ 2500 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_community_reports": { "row_range": [ @@ -78,7 +84,8 @@ "findings" ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 300, + "expected_artifacts": 1 }, "create_final_text_units": { "row_range": [ @@ -90,7 +97,8 @@ "entity_ids" ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_documents": { "row_range": [ @@ -98,7 +106,17 @@ 2500 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 + }, + "generate_text_embeddings": { + "row_range": [ + 1, + 2500 + ], + "subworkflows": 1, + "max_runtime": 150, + "expected_artifacts": 1 } }, "query_config": [ diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index 3b6954e25a..b3e5b4e92a 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -5,19 +5,8 @@ embeddings: vector_store: type: "lancedb" db_uri: "./tests/fixtures/min-csv/lancedb" - collection_name: "lancedb_ci" + container_name: "lancedb_ci" overwrite: True - store_in_table: True - entity_name_description: - title_column: "name" - # id_column: "id" - # entity_name: ... - # relationship_description: ... - # community_report_full_content: ... - # community_report_summary: ... - # community_report_title: ... - # document_raw_content: ... - # text_unit_text: ... storage: type: file # or blob @@ -29,4 +18,7 @@ reporting: type: file # or console, blob base_dir: "output/${timestamp}/reports" # connection_string: - # container_name: \ No newline at end of file + # container_name: + +snapshots: + embeddings: True \ No newline at end of file diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index f6d945b9e1..d6a0ecfa54 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -8,7 +8,8 @@ 2500 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 0 }, "create_final_covariates": { "row_range": [ @@ -25,7 +26,8 @@ "source_text" ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 300, + "expected_artifacts": 1 }, "create_base_entity_graph": { "row_range": [ @@ -33,7 +35,8 @@ 2500 ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 300, + "expected_artifacts": 1 }, "create_final_entities": { "row_range": [ @@ -46,7 +49,8 @@ "graph_embedding" ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 300, + "expected_artifacts": 1 }, "create_final_relationships": { "row_range": [ @@ -54,7 +58,8 @@ 6000 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_nodes": { "row_range": [ @@ -69,7 +74,8 @@ "level" ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_communities": { "row_range": [ @@ -77,7 +83,8 @@ 2500 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_community_reports": { "row_range": [ @@ -95,7 +102,8 @@ "findings" ], "subworkflows": 1, - "max_runtime": 300 + "max_runtime": 300, + "expected_artifacts": 1 }, "create_final_text_units": { "row_range": [ @@ -107,7 +115,8 @@ "entity_ids" ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 }, "create_final_documents": { "row_range": [ @@ -115,7 +124,17 @@ 2500 ], "subworkflows": 1, - "max_runtime": 150 + "max_runtime": 150, + "expected_artifacts": 1 + }, + "generate_text_embeddings": { + "row_range": [ + 1, + 2500 + ], + "subworkflows": 1, + "max_runtime": 150, + "expected_artifacts": 1 } }, "query_config": [ diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index 37d7f09c38..51cccf3d15 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -6,10 +6,7 @@ embeddings: type: "azure_ai_search" url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} - collection_name: "simple_text_ci" - store_in_table: True - entity_name_description: - title_column: "name" + container_name: "simple_text_ci" community_reports: prompt: "prompts/community_report.txt" @@ -27,4 +24,7 @@ reporting: type: file # or console, blob base_dir: "output/${timestamp}/reports" # connection_string: - # container_name: \ No newline at end of file + # container_name: + +snapshots: + embeddings: True \ No newline at end of file diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 385f32bd22..a668ef9b22 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -172,6 +172,7 @@ def __assert_indexer_outputs( stats = json.loads((artifacts / "stats.json").read_bytes().decode("utf-8")) # Check all workflows run + expected_artifacts = 0 expected_workflows = set(workflow_config.keys()) workflows = set(stats["workflows"].keys()) assert ( @@ -180,6 +181,10 @@ def __assert_indexer_outputs( # [OPTIONAL] Check runtime for workflow in expected_workflows: + # Check expected artifacts + expected_artifacts = expected_artifacts + workflow_config[workflow].get( + "expected_artifacts", 1 + ) # Check max runtime max_runtime = workflow_config[workflow].get("max_runtime", None) if max_runtime: @@ -189,38 +194,40 @@ def __assert_indexer_outputs( # Check artifacts artifact_files = os.listdir(artifacts) - # check that the number of workflows matches the number of artifacts, but: - # (1) do not count workflows with only transient output - # (2) account for the stats.json file - transient_workflows = [ - "workflow:create_base_text_units", - ] + + # check that the number of workflows matches the number of artifacts assert ( - len(artifact_files) - == (len(expected_workflows) - len(transient_workflows) + 1) + len(artifact_files) == (expected_artifacts + 1) ), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}" for artifact in artifact_files: if artifact.endswith(".parquet"): output_df = pd.read_parquet(artifacts / artifact) artifact_name = artifact.split(".")[0] - workflow = workflow_config[artifact_name] - # Check number of rows between range - assert ( - workflow["row_range"][0] - <= len(output_df) - <= workflow["row_range"][1] - ), f"Expected between {workflow['row_range'][0]} and {workflow['row_range'][1]}, found: {len(output_df)} for file: {artifact}" - - # Get non-nan rows - nan_df = output_df.loc[ - :, ~output_df.columns.isin(workflow.get("nan_allowed_columns", [])) - ] - nan_df = nan_df[nan_df.isna().any(axis=1)] - assert ( - len(nan_df) == 0 - ), f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}" + try: + workflow = workflow_config[artifact_name] + + # Check number of rows between range + assert ( + workflow["row_range"][0] + <= len(output_df) + <= workflow["row_range"][1] + ), f"Expected between {workflow['row_range'][0]} and {workflow['row_range'][1]}, found: {len(output_df)} for file: {artifact}" + + # Get non-nan rows + nan_df = output_df.loc[ + :, + ~output_df.columns.isin( + workflow.get("nan_allowed_columns", []) + ), + ] + nan_df = nan_df[nan_df.isna().any(axis=1)] + assert ( + len(nan_df) == 0 + ), f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}" + except KeyError: + log.warning("No workflow config found %s", artifact_name) def __run_query(self, root: Path, query_config: dict[str, str]): command = [ diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py index cda2636644..6237762395 100644 --- a/tests/verbs/test_create_final_community_reports.py +++ b/tests/verbs/test_create_final_community_reports.py @@ -74,44 +74,6 @@ async def test_create_final_community_reports(): assert actual["rank_explanation"][:1][0] == "" -async def test_create_final_community_reports_with_embeddings(): - input_tables = load_input_tables([ - "workflow:create_final_nodes", - "workflow:create_final_covariates", - "workflow:create_final_relationships", - "workflow:create_final_communities", - ]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - config["create_community_reports"]["strategy"]["llm"] = MOCK_LLM_CONFIG - - config["skip_full_content_embedding"] = False - config["community_report_full_content_embed"]["strategy"]["type"] = "mock" - config["skip_summary_embedding"] = False - config["community_report_summary_embed"]["strategy"]["type"] = "mock" - config["skip_title_embedding"] = False - config["community_report_title_embed"]["strategy"]["type"] = "mock" - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - assert len(actual.columns) == len(expected.columns) + 3 - assert "full_content_embedding" in actual.columns - assert len(actual["full_content_embedding"][:1][0]) == 3 - assert "summary_embedding" in actual.columns - assert len(actual["summary_embedding"][:1][0]) == 3 - assert "title_embedding" in actual.columns - assert len(actual["title_embedding"][:1][0]) == 3 - - async def test_create_final_community_reports_missing_llm_throws(): input_tables = load_input_tables([ "workflow:create_final_nodes", diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index ac42fb7ed2..6d8138088e 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -23,8 +23,6 @@ async def test_create_final_documents(): config = get_config_for_workflow(workflow_name) - config["skip_raw_content_embedding"] = True - steps = build_steps(config) actual = await get_workflow_output( @@ -37,34 +35,6 @@ async def test_create_final_documents(): compare_outputs(actual, expected) -async def test_create_final_documents_with_embeddings(): - input_tables = load_input_tables([ - "workflow:create_final_text_units", - ]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - config["skip_raw_content_embedding"] = False - # default config has a detailed standard embed config - # just override the strategy to mock so the rest of the required parameters are in place - config["document_raw_content_embed"]["strategy"]["type"] = "mock" - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - assert "raw_content_embedding" in actual.columns - assert len(actual.columns) == len(expected.columns) + 1 - # the mock impl returns an array of 3 floats for each embedding - assert len(actual["raw_content_embedding"][:1][0]) == 3 - - async def test_create_final_documents_with_attribute_columns(): input_tables = load_input_tables(["workflow:create_final_text_units"]) expected = load_expected(workflow_name) diff --git a/tests/verbs/test_create_final_entities.py b/tests/verbs/test_create_final_entities.py index d63776aef4..e65c09639e 100644 --- a/tests/verbs/test_create_final_entities.py +++ b/tests/verbs/test_create_final_entities.py @@ -23,9 +23,6 @@ async def test_create_final_entities(): config = get_config_for_workflow(workflow_name) - config["skip_name_embedding"] = True - config["skip_description_embedding"] = True - steps = build_steps(config) actual = await get_workflow_output( @@ -50,83 +47,3 @@ async def test_create_final_entities(): ], ) assert len(actual.columns) == len(expected.columns) - 1 - - -async def test_create_final_entities_with_name_embeddings(): - input_tables = load_input_tables([ - "workflow:create_base_entity_graph", - ]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - config["skip_name_embedding"] = False - config["skip_description_embedding"] = True - config["entity_name_embed"]["strategy"]["type"] = "mock" - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - assert "name_embedding" in actual.columns - assert len(actual.columns) == len(expected.columns) - # the mock impl returns an array of 3 floats for each embedding - assert len(actual["name_embedding"][:1][0]) == 3 - - -async def test_create_final_entities_with_description_embeddings(): - input_tables = load_input_tables([ - "workflow:create_base_entity_graph", - ]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - config["skip_name_embedding"] = True - config["skip_description_embedding"] = False - config["entity_name_description_embed"]["strategy"]["type"] = "mock" - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - assert "description_embedding" in actual.columns - assert len(actual.columns) == len(expected.columns) - assert len(actual["description_embedding"][:1][0]) == 3 - - -async def test_create_final_entities_with_name_and_description_embeddings(): - input_tables = load_input_tables([ - "workflow:create_base_entity_graph", - ]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - config["skip_name_embedding"] = False - config["skip_description_embedding"] = False - config["entity_name_description_embed"]["strategy"]["type"] = "mock" - config["entity_name_embed"]["strategy"]["type"] = "mock" - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - assert "description_embedding" in actual.columns - assert len(actual.columns) == len(expected.columns) + 1 - assert len(actual["description_embedding"][:1][0]) == 3 diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index d3ac5c807a..c5877014ff 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -24,8 +24,6 @@ async def test_create_final_relationships(): config = get_config_for_workflow(workflow_name) - config["skip_description_embedding"] = True - steps = build_steps(config) actual = await get_workflow_output( @@ -36,32 +34,3 @@ async def test_create_final_relationships(): ) compare_outputs(actual, expected) - - -async def test_create_final_relationships_with_embeddings(): - input_tables = load_input_tables([ - "workflow:create_base_entity_graph", - "workflow:create_final_nodes", - ]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - config["skip_description_embedding"] = False - # default config has a detailed standard embed config - # just override the strategy to mock so the rest of the required parameters are in place - config["relationship_description_embed"]["strategy"]["type"] = "mock" - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - assert "description_embedding" in actual.columns - assert len(actual.columns) == len(expected.columns) + 1 - # the mock impl returns an array of 3 floats for each embedding - assert len(actual["description_embedding"][:1][0]) == 3 diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 9560d06ae2..1aa919d153 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -33,7 +33,6 @@ async def test_create_final_text_units(): config = get_config_for_workflow(workflow_name) config["covariates_enabled"] = True - config["skip_text_unit_embedding"] = True steps = build_steps(config) @@ -65,7 +64,6 @@ async def test_create_final_text_units_no_covariates(): config = get_config_for_workflow(workflow_name) config["covariates_enabled"] = False - config["skip_text_unit_embedding"] = True steps = build_steps(config) @@ -83,41 +81,3 @@ async def test_create_final_text_units_no_covariates(): expected, ["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"], ) - - -async def test_create_final_text_units_with_embeddings(): - input_tables = load_input_tables([ - "workflow:create_base_text_units", - "workflow:create_final_entities", - "workflow:create_final_relationships", - "workflow:create_final_covariates", - ]) - expected = load_expected(workflow_name) - - context = create_run_context(None, None, None) - await context.runtime_storage.set( - "base_text_units", input_tables["workflow:create_base_text_units"] - ) - - config = get_config_for_workflow(workflow_name) - - config["covariates_enabled"] = True - config["skip_text_unit_embedding"] = False - # default config has a detailed standard embed config - # just override the strategy to mock so the rest of the required parameters are in place - config["text_unit_text_embed"]["strategy"]["type"] = "mock" - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - context=context, - ) - - assert "text_embedding" in actual.columns - assert len(actual.columns) == len(expected.columns) + 1 - # the mock impl returns an array of 3 floats for each embedding - assert len(actual["text_embedding"][:1][0]) == 3 diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py new file mode 100644 index 0000000000..4f192ca1fd --- /dev/null +++ b/tests/verbs/test_generate_text_embeddings.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from io import BytesIO + +import pandas as pd + +from graphrag.index.config.embeddings import ( + all_embeddings, +) +from graphrag.index.run.utils import create_run_context +from graphrag.index.workflows.v1.generate_text_embeddings import ( + build_steps, + workflow_name, +) + +from .util import ( + get_config_for_workflow, + get_workflow_output, + load_input_tables, +) + + +async def test_generate_text_embeddings(): + input_tables = load_input_tables( + inputs=[ + "workflow:create_final_documents", + "workflow:create_final_relationships", + "workflow:create_final_text_units", + "workflow:create_final_entities", + "workflow:create_final_community_reports", + ] + ) + context = create_run_context(None, None, None) + + config = get_config_for_workflow(workflow_name) + + config["text_embed"]["strategy"]["type"] = "mock" + config["snapshot_embeddings"] = True + + config["embedded_fields"] = all_embeddings + + steps = build_steps(config) + + await get_workflow_output( + input_tables, + { + "steps": steps, + }, + context, + ) + + parquet_files = context.storage.keys() + + for field in all_embeddings: + assert f"embeddings.{field}.parquet" in parquet_files + + # entity description should always be here, let's assert its format + entity_description_embeddings_buffer = BytesIO( + await context.storage.get( + "embeddings.entity.description.parquet", as_bytes=True + ) + ) + entity_description_embeddings = pd.read_parquet( + entity_description_embeddings_buffer + ) + assert len(entity_description_embeddings.columns) == 2 + assert "id" in entity_description_embeddings.columns + assert "embedding" in entity_description_embeddings.columns + + # every other embedding is optional but we've turned them all on, so check a random one + document_raw_content_embeddings_buffer = BytesIO( + await context.storage.get( + "embeddings.document.raw_content.parquet", as_bytes=True + ) + ) + document_raw_content_embeddings = pd.read_parquet( + document_raw_content_embeddings_buffer + ) + assert len(document_raw_content_embeddings.columns) == 2 + assert "id" in document_raw_content_embeddings.columns + assert "embedding" in document_raw_content_embeddings.columns