diff --git a/.semversioner/next-release/minor-20241113010525824646.json b/.semversioner/next-release/minor-20241113010525824646.json new file mode 100644 index 0000000000..f7cd8962f9 --- /dev/null +++ b/.semversioner/next-release/minor-20241113010525824646.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Data model changes." +} diff --git a/.semversioner/next-release/patch-20241031230557819462.json b/.semversioner/next-release/patch-20241031230557819462.json new file mode 100644 index 0000000000..3f5ea8a3d1 --- /dev/null +++ b/.semversioner/next-release/patch-20241031230557819462.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Cleanup of artifact outputs/schemas." +} diff --git a/docs/config/env_vars.md b/docs/config/env_vars.md index 21a11422ec..6632ace9e4 100644 --- a/docs/config/env_vars.md +++ b/docs/config/env_vars.md @@ -9,8 +9,8 @@ If the embedding target is `all`, and you want to only embed a subset of these f ### Embedded Fields - `text_unit.text` -- `document.raw_content` -- `entity.name` +- `document.text` +- `entity.title` - `entity.description` - `relationship.description` - `community.title` diff --git a/docs/examples_notebooks/drift_search.ipynb b/docs/examples_notebooks/drift_search.ipynb index 327c9513e4..c1b06a564d 100644 --- a/docs/examples_notebooks/drift_search.ipynb +++ b/docs/examples_notebooks/drift_search.ipynb @@ -204,7 +204,7 @@ "# load description embeddings to an in-memory lancedb vectorstore\n", "# to connect to a remote db, specify url and port values.\n", "description_embedding_store = LanceDBVectorStore(\n", - " collection_name=\"entity_description_embeddings\",\n", + " collection_name=\"default-entity-description\",\n", ")\n", "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", "entity_description_embeddings = store_entity_semantic_embeddings(\n", diff --git a/docs/examples_notebooks/local_search.ipynb b/docs/examples_notebooks/local_search.ipynb index af4dccbbd1..8f9afb3504 100644 --- a/docs/examples_notebooks/local_search.ipynb +++ b/docs/examples_notebooks/local_search.ipynb @@ -108,7 +108,7 @@ "# load description embeddings to an in-memory lancedb vectorstore\n", "# to connect to a remote db, specify url and port values.\n", "description_embedding_store = LanceDBVectorStore(\n", - " collection_name=\"entity.description\",\n", + " collection_name=\"default-entity-description\",\n", ")\n", "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", "entity_description_embeddings = store_entity_semantic_embeddings(\n", diff --git a/docs/index/default_dataflow.md b/docs/index/default_dataflow.md index 7ac990ffe2..1397d8d85f 100644 --- a/docs/index/default_dataflow.md +++ b/docs/index/default_dataflow.md @@ -9,7 +9,8 @@ The knowledge model is a specification for data outputs that conform to our data - `Entity` - An entity extracted from a TextUnit. These represent people, places, events, or some other entity-model that you provide. - `Relationship` - A relationship between two entities. These are generated from the covariates. - `Covariate` - Extracted claim information, which contains statements about entities which may be time-bound. -- `Community Report` - Once entities are generated, we perform hierarchical community detection on them and generate reports for each community in this hierarchy. +- `Community` - Once the graph of entities and relationships is built, we perform hierarchical community detection on them to create a clustering structure. +- `Community Report` - The contents of each community are summarized into a generated report, useful for human reading and downstream search. - `Node` - This table contains layout information for rendered graph-views of the Entities and Documents which have been embedded and clustered. ## The Default Configuration Workflow diff --git a/docs/index/outputs.md b/docs/index/outputs.md new file mode 100644 index 0000000000..998995467c --- /dev/null +++ b/docs/index/outputs.md @@ -0,0 +1,89 @@ +# Outputs + +The default pipeline produces a series of output tables that align with the [conceptual knowledge model](../index/default_dataflow.md). This page describes the detailed output table schemas. By default we write these tables out as parquet files on disk. + +## Shared fields +All tables have two identifier fields: +- id: str - Generated UUID, assuring global uniqueness +- human_readable_id: int - This is an incremented short ID created per-run. For example, we use this short ID with generated summaries that print citations so they are easy to cross-reference visually. + +## create_final_communities +This is a list of the final communities generated by Leiden. Communities are strictly hierarchical, subdividing into children as the cluster affinity is narrowed. +- community: int - Leiden-generated cluster ID for the community. Note that these increment with depth, so they are unique through all levels of the community hierarchy. For this table, human_readable_id is a copy of the community ID rather than a plain increment. +- level: int - Depth of the community in the hierarchy. +- title: str - Friendly name of the community. +- entity_ids - List of entities that are members of the community. +- relationship_ids - List of relationships that are wholly within the community (source and target are both in the community). +- text_unit_ids - List of text units represented within the community. +- period - Date of ingest, used for incremental update merges. +- size - Size of the community (entity count), used for incremental update merges. + +## create_final_community_reports +This is the list of summarized reports for each community. +- community: int - Short ID of the community this report applies to. +- level: int - Level of the community this report applies to. +- title: str - LM-generated title for the report. +- summary: str - LM-generated summary of the report. +- full_content: str - LM-generated full report. +- rank: float - LM-derived relevance ranking of the report based on member entity salience +- rank_explanation - LM-derived explanation of the rank. +- findings: dict - LM-derived list of the top 5-10 insights from the community. Contains `summary` and `explanation` values. +- full_content_json - Full JSON output as returned by the LM. Most fields are extracted into columns, but this JSON is sent for query summarization so we leave it to allow for prompt tuning to add fields/content by end users. +- period - Date of ingest, used for incremental update merges. +- size - Size of the community (entity count), used for incremental update merges. + +## create_final_covariates +(Optional) If claim extraction is turned on, this is a list of the extracted covariates. Note that claims are typically oriented around identifying malicious behavior such as fraud, so they are not useful for all datasets. +- covariate_type: str - This is always "claim" with our default covariates. +- type: str - Nature of the claim type. +- description: str - LM-generated description of the behavior. +- subject_id: str - Name of the source entity (that is performing the claimed behavior). +- object_id: str - Name of the target entity (that the claimed behavior is performed on). +- status: str [TRUE, FALSE, SUSPECTED] - LM-derived assessment of the correctness of the claim. +- start_date: str (ISO8601) - LM-derived start of the claimed activity. +- end_date: str (ISO8601) - LM-derived end of the claimed activity. +- source_text: str - Short string of text containing the claimed behavior. +- text_unit_id: str - ID of the text unit the claim text was extracted from. + +## create_final_documents +List of document content after import. +- title: str - Filename, unless otherwise configured during CSV import. +- text: str - Full text of the document. +- text_unit_ids: str[] - List of text units (chunks) that were parsed from the document. +- attributes: dict (optional) - If specified during CSV import, this is a dict of attributes for the document. + +# create_final_entities +List of all entities found in the data by the LM. +- title: str - Name of the entity. +- type: str - Type of the entity. By default this will be "organization", "person", "geo", or "event" unless configured differently or auto-tuning is used. +- description: str - Textual description of the entity. Entities may be found in many text units, so this is an LM-derived summary of all descriptions. +- text_unit_ids: str[] - List of the text units containing the entity. + +# create_final_nodes +This is graph-related information for the entities. It contains only information relevant to the graph such as community. There is an entry for each entity at every community level it is found within, so you may see "duplicate" entities. + +Note that the ID fields match those in create_final_entities and can be used for joining if additional information about a node is required. +- title: str - Name of the referenced entity. Duplicated from create_final_entities for convenient cross-referencing. +- community: int - Leiden community the node is found within. Entities are not always assigned a community (they may not be close enough to any), so they may have a ID of -1. +- level: int - Level of the community the entity is in. +- degree: int - Node degree (connectedness) in the graph. +- x: float - X position of the node for visual layouts. If graph embeddings and UMAP are not turned on, this will be 0. +- y: float - Y position of the node for visual layouts. If graph embeddings and UMAP are not turned on, this will be 0. + +## create_final_relationships +List of all entity-to-entity relationships found in the data by the LM. This is also the _edge list_ for the graph. +- source: str - Name of the source entity. +- target: str - Name of the target entity. +- description: str - LM-derived description of the relationship. Also see note for entity descriptions. +- weight: float - Weight of the edge in the graph. This is summed from an LM-derived "strength" measure for each relationship instance. +- combined_degree: int - Sum of source and target node degrees. +- text_unit_ids: str[] - List of text units the relationship was found within. + +## create_final_text_units +List of all text chunks parsed from the input documents. +- text: str - Raw full text of the chunk. +- n_tokens: int - Number of tokens in the chunk. This should normally match the `chunk_size` config parameter, except for the last chunk which is often shorter. +- document_ids: str[] - List of document IDs the chunk came from. This is normally only 1 due to our default groupby, but for very short text documents (e.g., microblogs) it can be configured so text units span multiple documents. +- entity_ids: str[] - List of entities found in the text unit. +- relationships_ids: str[] - List of relationships found in the text unit. +- covariate_ids: str[] - Optional list of covariates found in the text unit. \ No newline at end of file diff --git a/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb b/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb index 3874b172b1..bbb696c899 100644 --- a/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb +++ b/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb @@ -299,7 +299,7 @@ "entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n", "\n", "description_embedding_store = LanceDBVectorStore(\n", - " collection_name=\"entity.description\",\n", + " collection_name=\"default-entity-description\",\n", ")\n", "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", "entity_description_embeddings = store_entity_semantic_embeddings(\n", diff --git a/graphrag/api/query.py b/graphrag/api/query.py index 047d4e3830..4c7941b5f1 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -25,6 +25,10 @@ from pydantic import validate_call from graphrag.config import GraphRagConfig +from graphrag.index.config.embeddings import ( + community_full_content_embedding, + entity_description_embedding, +) from graphrag.logging import PrintProgressReporter from graphrag.query.factories import ( get_drift_search_engine, @@ -42,6 +46,7 @@ ) from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 from graphrag.utils.cli import redact +from graphrag.utils.embeddings import create_collection_name from graphrag.vector_stores import VectorStoreFactory, VectorStoreType from graphrag.vector_stores.base import BaseVectorStore @@ -228,7 +233,7 @@ async def local_search( description_embedding_store = _get_embedding_store( config_args=vector_store_args, # type: ignore - container_suffix="entity-description", + embedding_name=entity_description_embedding, ) _entities = read_indexer_entities(nodes, entities, community_level) @@ -302,7 +307,7 @@ async def local_search_streaming( description_embedding_store = _get_embedding_store( config_args=vector_store_args, # type: ignore - container_suffix="entity-description", + embedding_name=entity_description_embedding, ) _entities = read_indexer_entities(nodes, entities, community_level) @@ -385,12 +390,12 @@ async def drift_search( description_embedding_store = _get_embedding_store( config_args=vector_store_args, # type: ignore - container_suffix="entity-description", + embedding_name=entity_description_embedding, ) full_content_embedding_store = _get_embedding_store( config_args=vector_store_args, # type: ignore - container_suffix="community-full_content", + embedding_name=community_full_content_embedding, ) _entities = read_indexer_entities(nodes, entities, community_level) @@ -450,7 +455,10 @@ def _patch_vector_store( } description_embedding_store = LanceDBVectorStore( db_uri=config.embeddings.vector_store["db_uri"], - collection_name="default-entity-description", + collection_name=create_collection_name( + config.embeddings.vector_store["container_name"], + entity_description_embedding, + ), overwrite=config.embeddings.vector_store["overwrite"], ) description_embedding_store.connect( @@ -469,11 +477,7 @@ def _patch_vector_store( from graphrag.vector_stores.lancedb import LanceDBVectorStore community_reports = with_reports - collection_name = ( - config.embeddings.vector_store.get("container_name", "default") - if config.embeddings.vector_store - else "default" - ) + container_name = config.embeddings.vector_store["container_name"] # Store report embeddings _reports = read_indexer_reports( community_reports, @@ -485,7 +489,9 @@ def _patch_vector_store( full_content_embedding_store = LanceDBVectorStore( db_uri=config.embeddings.vector_store["db_uri"], - collection_name=f"{collection_name}-community-full_content", + collection_name=create_collection_name( + container_name, community_full_content_embedding + ), overwrite=config.embeddings.vector_store["overwrite"], ) full_content_embedding_store.connect( @@ -501,12 +507,12 @@ def _patch_vector_store( def _get_embedding_store( config_args: dict, - container_suffix: str, + embedding_name: str, ) -> BaseVectorStore: """Get the embedding description store.""" vector_store_type = config_args["type"] - collection_name = ( - f"{config_args.get('container_name', 'default')}-{container_suffix}" + collection_name = create_collection_name( + config_args.get("container_name", "default"), embedding_name ) embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, diff --git a/graphrag/index/config/__init__.py b/graphrag/index/config/__init__.py index 885e6ad386..847659cd24 100644 --- a/graphrag/index/config/__init__.py +++ b/graphrag/index/config/__init__.py @@ -16,9 +16,9 @@ community_full_content_embedding, community_summary_embedding, community_title_embedding, - document_raw_content_embedding, + document_text_embedding, entity_description_embedding, - entity_name_embedding, + entity_title_embedding, relationship_description_embedding, required_embeddings, text_unit_text_embedding, @@ -82,9 +82,9 @@ "community_full_content_embedding", "community_summary_embedding", "community_title_embedding", - "document_raw_content_embedding", + "document_text_embedding", "entity_description_embedding", - "entity_name_embedding", + "entity_title_embedding", "relationship_description_embedding", "required_embeddings", "text_unit_text_embedding", diff --git a/graphrag/index/config/embeddings.py b/graphrag/index/config/embeddings.py index 656dd8c0c8..02e9c912c4 100644 --- a/graphrag/index/config/embeddings.py +++ b/graphrag/index/config/embeddings.py @@ -3,20 +3,20 @@ """A module containing embeddings values.""" -entity_name_embedding = "entity.name" +entity_title_embedding = "entity.title" entity_description_embedding = "entity.description" relationship_description_embedding = "relationship.description" -document_raw_content_embedding = "document.raw_content" +document_text_embedding = "document.text" community_title_embedding = "community.title" community_summary_embedding = "community.summary" community_full_content_embedding = "community.full_content" text_unit_text_embedding = "text_unit.text" all_embeddings: set[str] = { - entity_name_embedding, + entity_title_embedding, entity_description_embedding, relationship_description_embedding, - document_raw_content_embedding, + document_text_embedding, community_title_embedding, community_summary_embedding, community_full_content_embedding, diff --git a/graphrag/index/flows/create_base_entity_graph.py b/graphrag/index/flows/create_base_entity_graph.py index fc85f08ff5..bded3ca6fb 100644 --- a/graphrag/index/flows/create_base_entity_graph.py +++ b/graphrag/index/flows/create_base_entity_graph.py @@ -30,8 +30,6 @@ async def create_base_entity_graph( callbacks: VerbCallbacks, cache: PipelineCache, storage: PipelineStorage, - text_column: str, - id_column: str, clustering_strategy: dict[str, Any], extraction_strategy: dict[str, Any] | None = None, extraction_num_threads: int = 4, @@ -52,8 +50,8 @@ async def create_base_entity_graph( text_units, callbacks, cache, - text_column=text_column, - id_column=id_column, + text_column="text", + id_column="id", strategy=extraction_strategy, async_mode=extraction_async_mode, entity_types=entity_types, diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py index 7860a69b11..51a9dba643 100644 --- a/graphrag/index/flows/create_base_text_units.py +++ b/graphrag/index/flows/create_base_text_units.py @@ -24,8 +24,6 @@ async def create_base_text_units( documents: pd.DataFrame, callbacks: VerbCallbacks, storage: PipelineStorage, - chunk_column_name: str, - n_tokens_column_name: str, chunk_by_columns: list[str], chunk_strategy: dict[str, Any] | None = None, snapshot_transient_enabled: bool = False, @@ -65,21 +63,18 @@ async def create_base_text_units( chunked = chunked.explode("chunks") chunked.rename( columns={ - "chunks": chunk_column_name, + "chunks": "chunk", }, inplace=True, ) - chunked["chunk_id"] = chunked.apply( - lambda row: gen_md5_hash(row, [chunk_column_name]), axis=1 + chunked["id"] = chunked.apply(lambda row: gen_md5_hash(row, ["chunk"]), axis=1) + chunked[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame( + chunked["chunk"].tolist(), index=chunked.index ) - chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame( - chunked[chunk_column_name].tolist(), index=chunked.index - ) - chunked["id"] = chunked["chunk_id"] + # rename for downstream consumption + chunked.rename(columns={"chunk": "text"}, inplace=True) - output = cast( - pd.DataFrame, chunked[chunked[chunk_column_name].notna()].reset_index(drop=True) - ) + output = cast(pd.DataFrame, chunked[chunked["text"].notna()].reset_index(drop=True)) if snapshot_transient_enabled: await snapshot( diff --git a/graphrag/index/flows/create_final_communities.py b/graphrag/index/flows/create_final_communities.py index 08ebb9617a..a433545ed3 100644 --- a/graphrag/index/flows/create_final_communities.py +++ b/graphrag/index/flows/create_final_communities.py @@ -4,6 +4,8 @@ """All the steps to transform final communities.""" from datetime import datetime, timezone +from typing import cast +from uuid import uuid4 import pandas as pd from datashaper import ( @@ -41,40 +43,51 @@ def create_final_communities( cluster_relationships = ( combined_clusters.groupby(["cluster", "level_x"], sort=False) .agg( - relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique") + relationship_ids=("id_y", "unique"), + text_unit_ids=("source_id_x", "unique"), + entity_ids=("id_x", "unique"), ) .reset_index() ) all_clusters = ( graph_nodes.groupby(["cluster", "level"], sort=False) - .agg(id=("cluster", "first")) + .agg(community=("cluster", "first")) .reset_index() ) joined = all_clusters.merge( cluster_relationships, - left_on="id", + left_on="community", right_on="cluster", how="inner", ) - filtered = joined[joined["level"] == joined["level_x"]].reset_index(drop=True) + filtered = cast( + pd.DataFrame, + joined[joined["level"] == joined["level_x"]].reset_index(drop=True), + ) - filtered["title"] = "Community " + filtered["id"].astype(str) + filtered["id"] = filtered["community"].apply(lambda _x: str(uuid4())) + filtered["community"] = filtered["community"].astype(int) + filtered["human_readable_id"] = filtered["community"] + filtered["title"] = "Community " + filtered["community"].astype(str) # Add period timestamp to the community reports filtered["period"] = datetime.now(timezone.utc).date().isoformat() # Add size of the community - filtered["size"] = filtered.loc[:, "text_unit_ids"].apply(lambda x: len(x)) + filtered["size"] = filtered.loc[:, "entity_ids"].apply(lambda x: len(x)) return filtered.loc[ :, [ "id", - "title", + "human_readable_id", + "community", "level", + "title", + "entity_ids", "relationship_ids", "text_unit_ids", "period", diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py index c0101c8d96..001844b5b0 100644 --- a/graphrag/index/flows/create_final_community_reports.py +++ b/graphrag/index/flows/create_final_community_reports.py @@ -41,7 +41,8 @@ async def create_final_community_reports( nodes_input: pd.DataFrame, edges_input: pd.DataFrame, - communities_input: pd.DataFrame, + entities: pd.DataFrame, + communities: pd.DataFrame, claims_input: pd.DataFrame | None, callbacks: VerbCallbacks, cache: PipelineCache, @@ -50,7 +51,9 @@ async def create_final_community_reports( num_threads: int = 4, ) -> pd.DataFrame: """All the steps to transform community reports.""" - nodes = _prep_nodes(nodes_input) + entities_df = entities.loc[:, ["id", "description"]] + nodes_df = nodes_input.merge(entities_df, on="id") + nodes = _prep_nodes(nodes_df) edges = _prep_edges(edges_input) claims = None @@ -78,19 +81,37 @@ async def create_final_community_reports( num_threads=num_threads, ) + community_reports["community"] = community_reports["community"].astype(int) + community_reports["human_readable_id"] = community_reports["community"] community_reports["id"] = community_reports["community"].apply( lambda _x: str(uuid4()) ) - # Merge by community and it with communities to add size and period - return community_reports.merge( - communities_input.loc[:, ["id", "size", "period"]], - left_on="community", - right_on="id", + # Merge with communities to add size and period + merged = community_reports.merge( + communities.loc[:, ["community", "size", "period"]], + on="community", how="left", copy=False, - suffixes=("", "_y"), - ).drop(columns=["id_y"]) + ) + return merged.loc[ + :, + [ + "id", + "human_readable_id", + "community", + "level", + "title", + "summary", + "full_content", + "rank", + "rank_explanation", + "findings", + "full_content_json", + "period", + "size", + ], + ] def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame: diff --git a/graphrag/index/flows/create_final_covariates.py b/graphrag/index/flows/create_final_covariates.py index 09ec9f8fce..e04e7fe926 100644 --- a/graphrag/index/flows/create_final_covariates.py +++ b/graphrag/index/flows/create_final_covariates.py @@ -3,7 +3,7 @@ """All the steps to extract and format covariates.""" -from typing import Any, cast +from typing import Any from uuid import uuid4 import pandas as pd @@ -22,7 +22,6 @@ async def create_final_covariates( text_units: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, - column: str, covariate_type: str, extraction_strategy: dict[str, Any] | None, async_mode: AsyncType = AsyncType.AsyncIO, @@ -30,40 +29,38 @@ async def create_final_covariates( num_threads: int = 4, ) -> pd.DataFrame: """All the steps to extract and format covariates.""" + # reassign the id because it will be overwritten in the output by a covariate one + # this also results in text_unit_id being copied to the output covariate table + text_units["text_unit_id"] = text_units["id"] covariates = await extract_covariates( text_units, callbacks, cache, - column, + "text", covariate_type, extraction_strategy, async_mode, entity_types, num_threads, ) - + text_units.drop(columns=["text_unit_id"], inplace=True) # don't pollute the global covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4())) - covariates["human_readable_id"] = (covariates.index + 1).astype(str) - covariates.rename(columns={"chunk_id": "text_unit_id"}, inplace=True) + covariates["human_readable_id"] = covariates.index + 1 - return cast( - pd.DataFrame, - covariates[ - [ - "id", - "human_readable_id", - "covariate_type", - "type", - "description", - "subject_id", - "object_id", - "status", - "start_date", - "end_date", - "source_text", - "text_unit_id", - "document_ids", - "n_tokens", - ] + return covariates.loc[ + :, + [ + "id", + "human_readable_id", + "covariate_type", + "type", + "description", + "subject_id", + "object_id", + "status", + "start_date", + "end_date", + "source_text", + "text_unit_id", ], - ) + ] diff --git a/graphrag/index/flows/create_final_documents.py b/graphrag/index/flows/create_final_documents.py index fcd213e314..7df1d44f90 100644 --- a/graphrag/index/flows/create_final_documents.py +++ b/graphrag/index/flows/create_final_documents.py @@ -33,7 +33,7 @@ def create_final_documents( ) docs_with_text_units = joined.groupby("id", sort=False).agg( - text_units=("chunk_id", list) + text_unit_ids=("chunk_id", list) ) rejoined = docs_with_text_units.merge( @@ -43,10 +43,8 @@ def create_final_documents( copy=False, ).reset_index(drop=True) - rejoined.rename( - columns={"text": "raw_content", "text_units": "text_unit_ids"}, inplace=True - ) rejoined["id"] = rejoined["id"].astype(str) + rejoined["human_readable_id"] = rejoined.index + 1 # Convert attribute columns to strings and collapse them into a JSON object if document_attribute_columns: @@ -63,4 +61,16 @@ def create_final_documents( # Drop the original attribute columns after collapsing them rejoined.drop(columns=document_attribute_columns, inplace=True) - return rejoined + # set the final column order, but adjust for attributes + core_columns = [ + "id", + "human_readable_id", + "title", + "text", + "text_unit_ids", + ] + final_columns = [column for column in core_columns if column in rejoined.columns] + if document_attribute_columns: + final_columns.append("attributes") + + return rejoined.loc[:, final_columns] diff --git a/graphrag/index/flows/create_final_entities.py b/graphrag/index/flows/create_final_entities.py index 5c03977d40..48ac16e70a 100644 --- a/graphrag/index/flows/create_final_entities.py +++ b/graphrag/index/flows/create_final_entities.py @@ -8,7 +8,6 @@ VerbCallbacks, ) -from graphrag.index.operations.split_text import split_text from graphrag.index.operations.unpack_graph import unpack_graph @@ -20,25 +19,33 @@ def create_final_entities( # Process nodes nodes = ( unpack_graph(entity_graph, callbacks, "clustered_graph", "nodes") - .rename(columns={"label": "name"}) + .rename(columns={"label": "title"}) .loc[ :, [ "id", - "name", + "title", "type", "description", "human_readable_id", - "graph_embedding", "source_id", ], ] .drop_duplicates(subset="id") ) - nodes = nodes.loc[nodes["name"].notna()] + nodes = nodes.loc[nodes["title"].notna()] - # Split 'source_id' column into 'text_unit_ids' - return split_text( - nodes, column="source_id", separator=",", to="text_unit_ids" - ).drop(columns=["source_id"]) + nodes["text_unit_ids"] = nodes["source_id"].str.split(",") + + return nodes.loc[ + :, + [ + "id", + "human_readable_id", + "title", + "type", + "description", + "text_unit_ids", + ], + ] diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py index 8d87927eea..d2adcb34c8 100644 --- a/graphrag/index/flows/create_final_nodes.py +++ b/graphrag/index/flows/create_final_nodes.py @@ -58,16 +58,32 @@ async def create_final_nodes( formats=["json"], ) - nodes.rename(columns={"id": "top_level_node_id"}, inplace=True) - nodes["top_level_node_id"] = nodes["top_level_node_id"].astype(str) - joined = nodes_without_positions.merge( nodes, - left_on="id", - right_on="top_level_node_id", + on="id", how="inner", ) joined.rename(columns={"label": "title", "cluster": "community"}, inplace=True) + joined["community"] = joined["community"].fillna(-1).astype(int) + + # drop anything that isn't graph-related or needing to be preserved + # the rest can be looked up on the canonical entities table + joined.drop( + columns=["source_id", "type", "description", "size", "graph_embedding"], + inplace=True, + ) - # TODO: Find duplication source - return joined.drop_duplicates(subset=["title", "community"]) + deduped = joined.drop_duplicates(subset=["title", "community"]) + return deduped.loc[ + :, + [ + "id", + "human_readable_id", + "title", + "community", + "level", + "degree", + "x", + "y", + ], + ] diff --git a/graphrag/index/flows/create_final_relationships.py b/graphrag/index/flows/create_final_relationships.py index 0e61008ae3..f23215969a 100644 --- a/graphrag/index/flows/create_final_relationships.py +++ b/graphrag/index/flows/create_final_relationships.py @@ -35,22 +35,29 @@ def create_final_relationships( filtered_nodes = nodes[nodes["level"] == 0].reset_index(drop=True) filtered_nodes = cast(pd.DataFrame, filtered_nodes[["title", "degree"]]) - edge_combined_degree = compute_edge_combined_degree( + pruned_edges["combined_degree"] = compute_edge_combined_degree( pruned_edges, filtered_nodes, - to="rank", node_name_column="title", node_degree_column="degree", edge_source_column="source", edge_target_column="target", ) - edge_combined_degree["human_readable_id"] = edge_combined_degree[ - "human_readable_id" - ].astype(str) - edge_combined_degree["text_unit_ids"] = edge_combined_degree[ - "text_unit_ids" - ].str.split(",") + pruned_edges["text_unit_ids"] = pruned_edges["text_unit_ids"].str.split(",") # TODO: Find duplication source - return edge_combined_degree.drop_duplicates(subset=["source", "target"]) + deduped = pruned_edges.drop_duplicates(subset=["source", "target"]) + return deduped.loc[ + :, + [ + "id", + "human_readable_id", + "source", + "target", + "description", + "weight", + "combined_degree", + "text_unit_ids", + ], + ] diff --git a/graphrag/index/flows/create_final_text_units.py b/graphrag/index/flows/create_final_text_units.py index 6f6ed70572..adb2fd4beb 100644 --- a/graphrag/index/flows/create_final_text_units.py +++ b/graphrag/index/flows/create_final_text_units.py @@ -3,8 +3,6 @@ """All the steps to transform the text units.""" -from typing import cast - import pandas as pd @@ -15,9 +13,8 @@ def create_final_text_units( final_covariates: pd.DataFrame | None, ) -> pd.DataFrame: """All the steps to transform the text units.""" - selected = text_units.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename( - columns={"chunk": "text"} - ) + selected = text_units.loc[:, ["id", "text", "document_ids", "n_tokens"]] + selected["human_readable_id"] = selected.index + 1 entity_join = _entities(final_entities) relationship_join = _relationships(final_relationships) @@ -32,20 +29,19 @@ def create_final_text_units( aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index() - return cast( - pd.DataFrame, - aggregated[ - [ - "id", - "text", - "n_tokens", - "document_ids", - "entity_ids", - "relationship_ids", - *([] if final_covariates is None else ["covariate_ids"]), - ] + return aggregated.loc[ + :, + [ + "id", + "human_readable_id", + "text", + "n_tokens", + "document_ids", + "entity_ids", + "relationship_ids", + *([] if final_covariates is None else ["covariate_ids"]), ], - ) + ] def _entities(df: pd.DataFrame) -> pd.DataFrame: diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py index a74e4ff6c2..258d01c768 100644 --- a/graphrag/index/flows/generate_text_embeddings.py +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -15,9 +15,9 @@ community_full_content_embedding, community_summary_embedding, community_title_embedding, - document_raw_content_embedding, + document_text_embedding, entity_description_embedding, - entity_name_embedding, + entity_title_embedding, relationship_description_embedding, text_unit_text_embedding, ) @@ -43,61 +43,55 @@ async def generate_text_embeddings( ) -> None: """All the steps to generate all embeddings.""" embedding_param_map = { - document_raw_content_embedding: { - "data": final_documents.loc[:, ["id", "raw_content"]] + document_text_embedding: { + "data": final_documents.loc[:, ["id", "text"]] if final_documents is not None else None, - "column_to_embed": "raw_content", + "embed_column": "text", }, relationship_description_embedding: { "data": final_relationships.loc[:, ["id", "description"]] if final_relationships is not None else None, - "column_to_embed": "description", + "embed_column": "description", }, text_unit_text_embedding: { "data": final_text_units.loc[:, ["id", "text"]] if final_text_units is not None else None, - "column_to_embed": "text", + "embed_column": "text", }, - entity_name_embedding: { - "data": final_entities.loc[:, ["id", "name", "description"]] + entity_title_embedding: { + "data": final_entities.loc[:, ["id", "title"]] if final_entities is not None else None, - "column_to_embed": "name", + "embed_column": "title", }, entity_description_embedding: { - "data": final_entities.loc[:, ["id", "name", "description"]].assign( - name_description=lambda df: df["name"] + ":" + df["description"] + "data": final_entities.loc[:, ["id", "title", "description"]].assign( + title_description=lambda df: df["title"] + ":" + df["description"] ) if final_entities is not None else None, - "column_to_embed": "name_description", + "embed_column": "title_description", }, community_title_embedding: { - "data": final_community_reports.loc[ - :, ["id", "full_content", "summary", "title"] - ] + "data": final_community_reports.loc[:, ["id", "title"]] if final_community_reports is not None else None, - "column_to_embed": "title", + "embed_column": "title", }, community_summary_embedding: { - "data": final_community_reports.loc[ - :, ["id", "full_content", "summary", "title"] - ] + "data": final_community_reports.loc[:, ["id", "summary"]] if final_community_reports is not None else None, - "column_to_embed": "summary", + "embed_column": "summary", }, community_full_content_embedding: { - "data": final_community_reports.loc[ - :, ["id", "full_content", "summary", "title"] - ] + "data": final_community_reports.loc[:, ["id", "full_content"]] if final_community_reports is not None else None, - "column_to_embed": "full_content", + "embed_column": "full_content", }, } @@ -117,7 +111,7 @@ async def generate_text_embeddings( async def _run_and_snapshot_embeddings( name: str, data: pd.DataFrame, - column_to_embed: str, + embed_column: str, callbacks: VerbCallbacks, cache: PipelineCache, storage: PipelineStorage, @@ -130,7 +124,7 @@ async def _run_and_snapshot_embeddings( data, callbacks, cache, - embed_column=column_to_embed, + embed_column=embed_column, embedding_name=name, strategy=text_embed_config["strategy"], ) diff --git a/graphrag/index/graph/extractors/community_reports/schemas.py b/graphrag/index/graph/extractors/community_reports/schemas.py index a7df2ff327..9c8e29b66c 100644 --- a/graphrag/index/graph/extractors/community_reports/schemas.py +++ b/graphrag/index/graph/extractors/community_reports/schemas.py @@ -16,7 +16,7 @@ EDGE_SOURCE = "source" EDGE_TARGET = "target" EDGE_DESCRIPTION = "description" -EDGE_DEGREE = "rank" +EDGE_DEGREE = "combined_degree" EDGE_DETAILS = "edge_details" EDGE_WEIGHT = "weight" @@ -41,7 +41,7 @@ # COMMUNITY REPORT TABLE SCHEMA REPORT_ID = "id" -COMMUNITY_ID = "id" +COMMUNITY_ID = "community" COMMUNITY_LEVEL = "level" TITLE = "title" SUMMARY = "summary" diff --git a/graphrag/index/graph/extractors/graph/graph_extractor.py b/graphrag/index/graph/extractors/graph/graph_extractor.py index 49ca671a72..87bd8a4aa6 100644 --- a/graphrag/index/graph/extractors/graph/graph_extractor.py +++ b/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -227,8 +227,8 @@ async def _process_results( str(source_doc_id), }) ) - node["entity_type"] = ( - entity_type if entity_type != "" else node["entity_type"] + node["type"] = ( + entity_type if entity_type != "" else node["type"] ) else: graph.add_node( diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index da3d19807d..6db1843dc4 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -154,6 +154,8 @@ graphml: false raw_entities: false top_level_nodes: false + embeddings: false + transient: false local_search: # text_unit_prop: {defs.LOCAL_SEARCH_TEXT_UNIT_PROP} diff --git a/graphrag/index/operations/compute_edge_combined_degree.py b/graphrag/index/operations/compute_edge_combined_degree.py index e0a81be0a3..1fec455c33 100644 --- a/graphrag/index/operations/compute_edge_combined_degree.py +++ b/graphrag/index/operations/compute_edge_combined_degree.py @@ -3,21 +3,20 @@ """A module containing compute_edge_combined_degree methods definition.""" +from typing import cast + import pandas as pd def compute_edge_combined_degree( edge_df: pd.DataFrame, node_degree_df: pd.DataFrame, - to: str, node_name_column: str, node_degree_column: str, edge_source_column: str, edge_target_column: str, -) -> pd.DataFrame: +) -> pd.Series: """Compute the combined degree for each edge in a graph.""" - if to in edge_df.columns: - return edge_df def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame: degree_column = _degree_colname(column) @@ -33,11 +32,11 @@ def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame: output_df = join_to_degree(edge_df, edge_source_column) output_df = join_to_degree(output_df, edge_target_column) - output_df[to] = ( + output_df["combined_degree"] = ( output_df[_degree_colname(edge_source_column)] + output_df[_degree_colname(edge_target_column)] ) - return output_df + return cast(pd.Series, output_df["combined_degree"]) def _degree_colname(column: str) -> str: diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index fe41d01cee..d6f0b3b9ae 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -12,6 +12,7 @@ from datashaper import VerbCallbacks from graphrag.index.cache import PipelineCache +from graphrag.utils.embeddings import create_collection_name from graphrag.vector_stores import ( BaseVectorStore, VectorStoreDocument, @@ -229,8 +230,8 @@ def _create_vector_store( def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: - container_name = vector_store_config.get("container_name") - collection_name = f"{container_name}.{embedding_name}".replace(".", "-") + container_name = vector_store_config.get("container_name", "default") + collection_name = create_collection_name(container_name, embedding_name) msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}" log.info(msg) diff --git a/graphrag/index/operations/split_text.py b/graphrag/index/operations/split_text.py deleted file mode 100644 index 7a9a90767a..0000000000 --- a/graphrag/index/operations/split_text.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing the split_text method definition.""" - -import pandas as pd - - -def split_text( - input: pd.DataFrame, column: str, to: str, separator: str = "," -) -> pd.DataFrame: - """Split a column into a list of strings.""" - output = input - - def _apply_split(row): - if row[column] is None or isinstance(row[column], list): - return row[column] - if row[column] == "": - return [] - if not isinstance(row[column], str): - message = f"Expected {column} to be a string, but got {type(row[column])}" - raise TypeError(message) - return row[column].split(separator) - - output[to] = output.apply(_apply_split, axis=1) - return output diff --git a/graphrag/index/operations/summarize_communities/prepare_community_reports.py b/graphrag/index/operations/summarize_communities/prepare_community_reports.py index 1a54f7d987..b402e6afc8 100644 --- a/graphrag/index/operations/summarize_communities/prepare_community_reports.py +++ b/graphrag/index/operations/summarize_communities/prepare_community_reports.py @@ -173,4 +173,7 @@ def get_edge_details(node_df: pd.DataFrame, edge_df: pd.DataFrame, name_col: str set_context_exceeds_flag(community_df, max_tokens) community_df[schemas.COMMUNITY_LEVEL] = level + community_df[node_community_column] = community_df[node_community_column].astype( + int + ) return community_df diff --git a/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py b/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py index 368e4b0586..2430472abd 100644 --- a/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py +++ b/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py @@ -36,20 +36,13 @@ def restore_community_hierarchy( # get unique levels, sorted in ascending order levels = sorted(community_levels.keys()) - community_hierarchy = [] for idx in range(len(levels) - 1): level = levels[idx] - log.debug("Level: %s", level) next_level = levels[idx + 1] current_level_communities = community_levels[level] next_level_communities = community_levels[next_level] - log.debug( - "Number of communities at level %s: %s", - level, - len(current_level_communities), - ) for current_community in current_level_communities: current_entities = current_level_communities[current_community] @@ -70,4 +63,6 @@ def restore_community_hierarchy( if entities_found == len(current_entities): break - return pd.DataFrame(community_hierarchy) + return pd.DataFrame( + community_hierarchy, + ) diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index a704fcc1cb..ad8c44077e 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -88,7 +88,7 @@ async def _generate_report( callbacks: VerbCallbacks, cache: PipelineCache, strategy: dict, - community_id: int | str, + community_id: int, community_level: int, community_context: str, ) -> CommunityReport | None: diff --git a/graphrag/index/update/communities.py b/graphrag/index/update/communities.py index 3dc3716701..6e64cdd826 100644 --- a/graphrag/index/update/communities.py +++ b/graphrag/index/update/communities.py @@ -35,9 +35,8 @@ def _merge_and_resolve_nodes( # Merge delta_nodes with merged_entities_df to get the new human_readable_id delta_nodes = delta_nodes.merge( - merged_entities_df[["name", "human_readable_id"]], - left_on="title", - right_on="name", + merged_entities_df[["id", "human_readable_id"]], + on="id", how="left", suffixes=("", "_new"), ) @@ -48,7 +47,7 @@ def _merge_and_resolve_nodes( ].combine_first(delta_nodes.loc[:, "human_readable_id"]) # Drop the auxiliary column from the merge - delta_nodes.drop(columns=["name", "human_readable_id_new"], inplace=True) + delta_nodes.drop(columns=["human_readable_id_new"], inplace=True) # Increment only the non-NaN values in delta_nodes["community"] community_id_mapping = { @@ -64,36 +63,29 @@ def _merge_and_resolve_nodes( # Concat the DataFrames concat_nodes = pd.concat([old_nodes, delta_nodes], ignore_index=True) columns_to_agg: dict[str, str | Callable] = { - col: "first" - for col in concat_nodes.columns - if col not in ["source_id", "level", "title"] + col: "first" for col in concat_nodes.columns if col not in ["level", "title"] } - # Specify custom aggregation for description and source_id - columns_to_agg.update({ - "source_id": lambda x: ",".join(str(i) for i in x.tolist()), - }) - merged_nodes = ( concat_nodes.groupby(["level", "title"]).agg(columns_to_agg).reset_index() ) - # Use description from merged_entities_df - merged_nodes = ( - merged_nodes.drop(columns=["description"]) - .merge( - merged_entities_df[["name", "description"]], - left_on="title", - right_on="name", - how="left", - ) - .drop(columns=["name"]) - ) - - # Mantain type compat with query - merged_nodes["community"] = ( - merged_nodes["community"].astype(pd.StringDtype()).astype("object") - ) + merged_nodes["community"] = merged_nodes["community"].astype(int) + merged_nodes["human_readable_id"] = merged_nodes["human_readable_id"].astype(int) + + merged_nodes = merged_nodes.loc[ + :, + [ + "id", + "human_readable_id", + "title", + "community", + "level", + "degree", + "x", + "y", + ], + ] return merged_nodes, community_id_mapping @@ -132,13 +124,13 @@ def _update_and_merge_communities( delta_communities["period"] = None # Look for community ids in community and replace them with the corresponding id in the mapping - delta_communities["id"] = ( - delta_communities["id"] - .astype("Int64") + delta_communities["community"] = ( + delta_communities["community"] + .astype(int) .apply(lambda x: community_id_mapping.get(x, x)) ) - old_communities["id"] = old_communities["id"].astype("Int64") + old_communities["community"] = old_communities["community"].astype(int) # Merge the final communities merged_communities = pd.concat( @@ -146,10 +138,27 @@ def _update_and_merge_communities( ) # Rename title - merged_communities["title"] = "Community " + merged_communities["id"].astype(str) - # Mantain type compat with query - merged_communities["id"] = merged_communities["id"].astype(str) - return merged_communities + merged_communities["title"] = "Community " + merged_communities["community"].astype( + str + ) + # Re-assign the human_readable_id + merged_communities["human_readable_id"] = merged_communities["community"] + + return merged_communities.loc[ + :, + [ + "id", + "human_readable_id", + "community", + "level", + "title", + "entity_ids", + "relationship_ids", + "text_unit_ids", + "period", + "size", + ], + ] def _update_and_merge_community_reports( @@ -188,22 +197,41 @@ def _update_and_merge_community_reports( # Look for community ids in community and replace them with the corresponding id in the mapping delta_community_reports["community"] = ( delta_community_reports["community"] - .astype("Int64") + .astype(int) .apply(lambda x: community_id_mapping.get(x, x)) ) - old_community_reports["community"] = old_community_reports["community"].astype( - "Int64" - ) + old_community_reports["community"] = old_community_reports["community"].astype(int) # Merge the final community reports merged_community_reports = pd.concat( [old_community_reports, delta_community_reports], ignore_index=True, copy=False ) - # Mantain type compat with query - merged_community_reports["community"] = ( - merged_community_reports["community"].astype(pd.StringDtype()).astype("object") - ) - - return merged_community_reports + # Maintain type compat with query + merged_community_reports["community"] = merged_community_reports[ + "community" + ].astype(int) + # Re-assign the human_readable_id + merged_community_reports["human_readable_id"] = merged_community_reports[ + "community" + ] + + return merged_community_reports.loc[ + :, + [ + "id", + "human_readable_id", + "community", + "level", + "title", + "summary", + "full_content", + "rank", + "rank_explanation", + "findings", + "full_content_json", + "period", + "size", + ], + ] diff --git a/graphrag/index/update/entities.py b/graphrag/index/update/entities.py index 02f83c6446..6f59bb81ef 100644 --- a/graphrag/index/update/entities.py +++ b/graphrag/index/update/entities.py @@ -4,6 +4,7 @@ """Entity related operations and utils for Incremental Indexing.""" import asyncio +import itertools import numpy as np import pandas as pd @@ -36,10 +37,10 @@ def _group_and_resolve_entities( dict The id mapping for existing entities. In the form of {df_b.id: df_a.id}. """ - # If a name exists in A and B, make a dictionary for {B.id : A.id} - merged = delta_entities_df[["id", "name"]].merge( - old_entities_df[["id", "name"]], - on="name", + # If a title exists in A and B, make a dictionary for {B.id : A.id} + merged = delta_entities_df[["id", "title"]].merge( + old_entities_df[["id", "title"]], + on="title", suffixes=("_B", "_A"), copy=False, ) @@ -55,17 +56,16 @@ def _group_and_resolve_entities( [old_entities_df, delta_entities_df], ignore_index=True, copy=False ) - # Group by name and resolve conflicts + # Group by title and resolve conflicts aggregated = ( - combined.groupby("name") + combined.groupby("title") .agg({ "id": "first", "type": "first", "human_readable_id": "first", - "graph_embedding": "first", "description": lambda x: list(x.astype(str)), # Ensure str # Concatenate nd.array into a single list - "text_unit_ids": lambda x: ",".join(str(i) for j in x.tolist() for i in j), + "text_unit_ids": lambda x: list(itertools.chain(*x.tolist())), }) .reset_index() ) @@ -78,11 +78,10 @@ def _group_and_resolve_entities( :, [ "id", - "name", - "description", - "type", "human_readable_id", - "graph_embedding", + "title", + "type", + "description", "text_unit_ids", ], ] @@ -123,7 +122,7 @@ async def process_row(row): if isinstance(description, list) and len(description) > 1: # Run entity summarization asynchronously result = await run_entity_summarization( - row["name"], description, callbacks, cache, strategy + row["title"], description, callbacks, cache, strategy ) return result.description # Handle case where description is a single-item list or not a list diff --git a/graphrag/index/update/incremental_index.py b/graphrag/index/update/incremental_index.py index 57d3cd2b43..89acb285e0 100644 --- a/graphrag/index/update/incremental_index.py +++ b/graphrag/index/update/incremental_index.py @@ -282,7 +282,7 @@ async def _update_entities( old_entities, delta_entities ) - # Re-run description summarization and embeddings + # Re-run description summarization merged_entities_df = await _run_entity_summarization( merged_entities_df, config, diff --git a/graphrag/index/update/relationships.py b/graphrag/index/update/relationships.py index 9df427972a..4d5899ad2f 100644 --- a/graphrag/index/update/relationships.py +++ b/graphrag/index/update/relationships.py @@ -52,9 +52,21 @@ def _update_and_merge_relationships( "source" ].transform("count") - # Recalculate the rank of the relationships (source degree + target degree) - final_relationships["rank"] = ( + # Recalculate the combined_degree of the relationships (source degree + target degree) + final_relationships["combined_degree"] = ( final_relationships["source_degree"] + final_relationships["target_degree"] ) - return final_relationships + return final_relationships.loc[ + :, + [ + "id", + "human_readable_id", + "source", + "target", + "description", + "weight", + "combined_degree", + "text_unit_ids", + ], + ] diff --git a/graphrag/index/workflows/v1/create_base_entity_graph.py b/graphrag/index/workflows/v1/create_base_entity_graph.py index 5e2f52d798..bb0c41ac57 100644 --- a/graphrag/index/workflows/v1/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/create_base_entity_graph.py @@ -19,11 +19,9 @@ def build_steps( Create the base table for the entity graph. ## Dependencies - * `workflow:create_base_summarized_entities` + * `workflow:create_base_text_units` """ entity_extraction_config = config.get("entity_extract", {}) - text_column = entity_extraction_config.get("text_column", "chunk") - id_column = entity_extraction_config.get("id_column", "chunk_id") async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO) extraction_strategy = entity_extraction_config.get("strategy") extraction_num_threads = entity_extraction_config.get("num_threads", 4) @@ -96,8 +94,6 @@ def build_steps( { "verb": "create_base_entity_graph", "args": { - "text_column": text_column, - "id_column": id_column, "extraction_strategy": extraction_strategy, "extraction_num_threads": extraction_num_threads, "extraction_async_mode": async_mode, diff --git a/graphrag/index/workflows/v1/create_base_text_units.py b/graphrag/index/workflows/v1/create_base_text_units.py index 2889313d8f..efd7f7eab1 100644 --- a/graphrag/index/workflows/v1/create_base_text_units.py +++ b/graphrag/index/workflows/v1/create_base_text_units.py @@ -17,11 +17,9 @@ def build_steps( Create the base table for text units. ## Dependencies - None + (input dataframe) """ - chunk_column_name = config.get("chunk_column", "chunk") chunk_by_columns = config.get("chunk_by", []) or [] - n_tokens_column_name = config.get("n_tokens_column", "n_tokens") text_chunk_config = config.get("text_chunk", {}) chunk_strategy = text_chunk_config.get("strategy") @@ -30,8 +28,6 @@ def build_steps( { "verb": "create_base_text_units", "args": { - "chunk_column_name": chunk_column_name, - "n_tokens_column_name": n_tokens_column_name, "chunk_by_columns": chunk_by_columns, "chunk_strategy": chunk_strategy, "snapshot_transient_enabled": snapshot_transient, diff --git a/graphrag/index/workflows/v1/create_final_community_reports.py b/graphrag/index/workflows/v1/create_final_community_reports.py index 9303a32d2c..8a56583d84 100644 --- a/graphrag/index/workflows/v1/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/create_final_community_reports.py @@ -26,6 +26,7 @@ def build_steps( input = { "source": "workflow:create_final_nodes", "relationships": "workflow:create_final_relationships", + "entities": "workflow:create_final_entities", "communities": "workflow:create_final_communities", } if covariates_enabled: diff --git a/graphrag/index/workflows/v1/create_final_covariates.py b/graphrag/index/workflows/v1/create_final_covariates.py index 0f8191abbd..6bdf32e4bb 100644 --- a/graphrag/index/workflows/v1/create_final_covariates.py +++ b/graphrag/index/workflows/v1/create_final_covariates.py @@ -26,15 +26,10 @@ def build_steps( async_mode = claim_extract_config.get("async_mode", AsyncType.AsyncIO) num_threads = claim_extract_config.get("num_threads") - chunk_column = config.get("chunk_column", "chunk") - chunk_id_column = config.get("chunk_id_column", "chunk_id") - return [ { "verb": "create_final_covariates", "args": { - "column": chunk_column, - "id_column": chunk_id_column, "covariate_type": "claim", "extraction_strategy": extraction_strategy, "async_mode": async_mode, diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py index 65e22b35ee..5160cc5165 100644 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ b/graphrag/index/workflows/v1/create_final_documents.py @@ -17,16 +17,16 @@ def build_steps( Create the final documents table. ## Dependencies - * `workflow:create_final_text_units` + * `workflow:create_base_text_units` """ - document_attribute_columns = config.get("document_attribute_columns", []) + document_attribute_columns = config.get("document_attribute_columns", None) return [ { "verb": "create_final_documents", "args": {"document_attribute_columns": document_attribute_columns}, "input": { "source": DEFAULT_INPUT_NAME, - "text_units": "workflow:create_final_text_units", + "text_units": "workflow:create_base_text_units", }, }, ] diff --git a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py index 61ab519933..09c8e8067c 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py @@ -30,8 +30,6 @@ async def create_base_entity_graph( cache: PipelineCache, storage: PipelineStorage, runtime_storage: PipelineStorage, - text_column: str, - id_column: str, clustering_strategy: dict[str, Any], extraction_strategy: dict[str, Any] | None, extraction_num_threads: int = 4, @@ -55,8 +53,6 @@ async def create_base_entity_graph( callbacks, cache, storage, - text_column, - id_column, clustering_strategy=clustering_strategy, extraction_strategy=extraction_strategy, extraction_num_threads=extraction_num_threads, diff --git a/graphrag/index/workflows/v1/subflows/create_base_text_units.py b/graphrag/index/workflows/v1/subflows/create_base_text_units.py index bdfea070cb..598c6cf10c 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_base_text_units.py @@ -26,8 +26,6 @@ async def create_base_text_units( callbacks: VerbCallbacks, storage: PipelineStorage, runtime_storage: PipelineStorage, - chunk_column_name: str, - n_tokens_column_name: str, chunk_by_columns: list[str], chunk_strategy: dict[str, Any] | None = None, snapshot_transient_enabled: bool = False, @@ -40,8 +38,6 @@ async def create_base_text_units( source, callbacks, storage, - chunk_column_name, - n_tokens_column_name, chunk_by_columns, chunk_strategy=chunk_strategy, snapshot_transient_enabled=snapshot_transient_enabled, diff --git a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py index ec52f7d564..b8f5984e8c 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py @@ -35,7 +35,7 @@ async def create_final_community_reports( """All the steps to transform community reports.""" nodes = cast(pd.DataFrame, input.get_input()) edges = cast(pd.DataFrame, get_required_input_table(input, "relationships").table) - + entities = cast(pd.DataFrame, get_required_input_table(input, "entities").table) communities = cast( pd.DataFrame, get_required_input_table(input, "communities").table ) @@ -47,6 +47,7 @@ async def create_final_community_reports( output = await create_final_community_reports_flow( nodes, edges, + entities, communities, claims, callbacks, diff --git a/graphrag/index/workflows/v1/subflows/create_final_covariates.py b/graphrag/index/workflows/v1/subflows/create_final_covariates.py index f699877006..d6b83ed70f 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_covariates.py +++ b/graphrag/index/workflows/v1/subflows/create_final_covariates.py @@ -25,7 +25,6 @@ async def create_final_covariates( callbacks: VerbCallbacks, cache: PipelineCache, runtime_storage: PipelineStorage, - column: str, covariate_type: str, extraction_strategy: dict[str, Any] | None, async_mode: AsyncType = AsyncType.AsyncIO, @@ -40,7 +39,6 @@ async def create_final_covariates( text_units, callbacks, cache, - column, covariate_type, extraction_strategy, async_mode=async_mode, diff --git a/graphrag/index/workflows/v1/subflows/create_final_documents.py b/graphrag/index/workflows/v1/subflows/create_final_documents.py index 34e2e5e018..9b7a4e7559 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_documents.py +++ b/graphrag/index/workflows/v1/subflows/create_final_documents.py @@ -16,21 +16,22 @@ from graphrag.index.flows.create_final_documents import ( create_final_documents as create_final_documents_flow, ) -from graphrag.index.utils.ds_util import get_required_input_table +from graphrag.index.storage import PipelineStorage @verb( name="create_final_documents", treats_input_tables_as_immutable=True, ) -def create_final_documents( +async def create_final_documents( input: VerbInput, + runtime_storage: PipelineStorage, document_attribute_columns: list[str] | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final documents.""" source = cast(pd.DataFrame, input.get_input()) - text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table) + text_units = await runtime_storage.get("base_text_units") output = create_final_documents_flow(source, text_units, document_attribute_columns) diff --git a/graphrag/model/community.py b/graphrag/model/community.py index d108d4e492..041aaa5e47 100644 --- a/graphrag/model/community.py +++ b/graphrag/model/community.py @@ -43,7 +43,7 @@ def from_dict( d: dict[str, Any], id_key: str = "id", title_key: str = "title", - short_id_key: str = "short_id", + short_id_key: str = "human_readable_id", level_key: str = "level", entities_key: str = "entity_ids", relationships_key: str = "relationship_ids", diff --git a/graphrag/model/community_report.py b/graphrag/model/community_report.py index 95547fa6c3..53c35a5117 100644 --- a/graphrag/model/community_report.py +++ b/graphrag/model/community_report.py @@ -46,8 +46,8 @@ def from_dict( d: dict[str, Any], id_key: str = "id", title_key: str = "title", - community_id_key: str = "community_id", - short_id_key: str = "short_id", + community_id_key: str = "community", + short_id_key: str = "human_readable_id", summary_key: str = "summary", full_content_key: str = "full_content", rank_key: str = "rank", diff --git a/graphrag/model/covariate.py b/graphrag/model/covariate.py index e00433128a..484ea16fae 100644 --- a/graphrag/model/covariate.py +++ b/graphrag/model/covariate.py @@ -30,9 +30,6 @@ class Covariate(Identified): text_unit_ids: list[str] | None = None """List of text unit IDs in which the covariate info appears (optional).""" - document_ids: list[str] | None = None - """List of document IDs in which the covariate info appears (optional).""" - attributes: dict[str, Any] | None = None @classmethod @@ -42,9 +39,8 @@ def from_dict( id_key: str = "id", subject_id_key: str = "subject_id", covariate_type_key: str = "covariate_type", - short_id_key: str = "short_id", + short_id_key: str = "human_readable_id", text_unit_ids_key: str = "text_unit_ids", - document_ids_key: str = "document_ids", attributes_key: str = "attributes", ) -> "Covariate": """Create a new covariate from the dict data.""" @@ -54,6 +50,5 @@ def from_dict( subject_id=d[subject_id_key], covariate_type=d.get(covariate_type_key, "claim"), text_unit_ids=d.get(text_unit_ids_key), - document_ids=d.get(document_ids_key), attributes=d.get(attributes_key), ) diff --git a/graphrag/model/document.py b/graphrag/model/document.py index b54a39ac91..2980318376 100644 --- a/graphrag/model/document.py +++ b/graphrag/model/document.py @@ -19,16 +19,10 @@ class Document(Named): text_unit_ids: list[str] = field(default_factory=list) """list of text units in the document.""" - raw_content: str = "" + text: str = "" """The raw text content of the document.""" - summary: str | None = None - """Summary of the document (optional).""" - - summary_embedding: list[float] | None = None - """The semantic embedding for the document summary (optional).""" - - raw_content_embedding: list[float] | None = None + text_embedding: list[float] | None = None """The semantic embedding for the document raw content (optional).""" attributes: dict[str, Any] | None = None @@ -39,13 +33,11 @@ def from_dict( cls, d: dict[str, Any], id_key: str = "id", - short_id_key: str = "short_id", + short_id_key: str = "human_readable_id", title_key: str = "title", type_key: str = "type", - raw_content_key: str = "raw_content", - summary_key: str = "summary", - summary_embedding_key: str = "summary_embedding", - raw_content_embedding_key: str = "raw_content_embedding", + text_key: str = "text", + text_embedding_key: str = "text_embedding", text_units_key: str = "text_units", attributes_key: str = "attributes", ) -> "Document": @@ -55,10 +47,8 @@ def from_dict( short_id=d.get(short_id_key), title=d[title_key], type=d.get(type_key, "text"), - raw_content=d[raw_content_key], - summary=d.get(summary_key), - summary_embedding=d.get(summary_embedding_key), - raw_content_embedding=d.get(raw_content_embedding_key), + text=d[text_key], + text_embedding=d.get(text_embedding_key), text_unit_ids=d.get(text_units_key, []), attributes=d.get(attributes_key), ) diff --git a/graphrag/model/entity.py b/graphrag/model/entity.py index 37c26342aa..a152abf2a5 100644 --- a/graphrag/model/entity.py +++ b/graphrag/model/entity.py @@ -25,18 +25,12 @@ class Entity(Named): name_embedding: list[float] | None = None """The semantic (i.e. text) embedding of the entity (optional).""" - graph_embedding: list[float] | None = None - """The graph embedding of the entity, likely from node2vec (optional).""" - community_ids: list[str] | None = None """The community IDs of the entity (optional).""" text_unit_ids: list[str] | None = None """List of text unit IDs in which the entity appears (optional).""" - document_ids: list[str] | None = None - """List of document IDs in which the entity appears (optional).""" - rank: int | None = 1 """Rank of the entity, used for sorting (optional). Higher rank indicates more important entity. This can be based on centrality or other metrics.""" @@ -48,16 +42,14 @@ def from_dict( cls, d: dict[str, Any], id_key: str = "id", - short_id_key: str = "short_id", + short_id_key: str = "human_readable_id", title_key: str = "title", type_key: str = "type", description_key: str = "description", description_embedding_key: str = "description_embedding", name_embedding_key: str = "name_embedding", - graph_embedding_key: str = "graph_embedding", community_key: str = "community", text_unit_ids_key: str = "text_unit_ids", - document_ids_key: str = "document_ids", rank_key: str = "degree", attributes_key: str = "attributes", ) -> "Entity": @@ -70,10 +62,8 @@ def from_dict( description=d.get(description_key), name_embedding=d.get(name_embedding_key), description_embedding=d.get(description_embedding_key), - graph_embedding=d.get(graph_embedding_key), community_ids=d.get(community_key), rank=d.get(rank_key, 1), text_unit_ids=d.get(text_unit_ids_key), - document_ids=d.get(document_ids_key), attributes=d.get(attributes_key), ) diff --git a/graphrag/model/relationship.py b/graphrag/model/relationship.py index fadd0aaa6f..54fb20c31c 100644 --- a/graphrag/model/relationship.py +++ b/graphrag/model/relationship.py @@ -31,8 +31,8 @@ class Relationship(Identified): text_unit_ids: list[str] | None = None """List of text unit IDs in which the relationship appears (optional).""" - document_ids: list[str] | None = None - """List of document IDs in which the relationship appears (optional).""" + rank: int | None = 1 + """Rank of the relationship, used for sorting (optional). Higher rank indicates more important relationship. This can be based on centrality or other metrics.""" attributes: dict[str, Any] | None = None """Additional attributes associated with the relationship (optional). To be included in the search prompt""" @@ -42,13 +42,13 @@ def from_dict( cls, d: dict[str, Any], id_key: str = "id", - short_id_key: str = "short_id", + short_id_key: str = "human_readable_id", source_key: str = "source", target_key: str = "target", description_key: str = "description", + rank_key: str = "rank", weight_key: str = "weight", text_unit_ids_key: str = "text_unit_ids", - document_ids_key: str = "document_ids", attributes_key: str = "attributes", ) -> "Relationship": """Create a new relationship from the dict data.""" @@ -57,9 +57,9 @@ def from_dict( short_id=d.get(short_id_key), source=d[source_key], target=d[target_key], + rank=d.get(rank_key, 1), description=d.get(description_key), weight=d.get(weight_key, 1.0), text_unit_ids=d.get(text_unit_ids_key), - document_ids=d.get(document_ids_key), attributes=d.get(attributes_key), ) diff --git a/graphrag/model/text_unit.py b/graphrag/model/text_unit.py index cff4ac01c1..b54ee9e5f8 100644 --- a/graphrag/model/text_unit.py +++ b/graphrag/model/text_unit.py @@ -42,7 +42,7 @@ def from_dict( cls, d: dict[str, Any], id_key: str = "id", - short_id_key: str = "short_id", + short_id_key: str = "human_readable_id", text_key: str = "text", text_embedding_key: str = "text_embedding", entities_key: str = "entity_ids", diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index 4b8767b87d..f7e1fbfe18 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -91,52 +91,6 @@ def map_query_to_entities( return included_entities + matched_entities -def find_nearest_neighbors_by_graph_embeddings( - entity_id: str, - graph_embedding_vectorstore: BaseVectorStore, - all_entities: list[Entity], - exclude_entity_names: list[str] | None = None, - embedding_vectorstore_key: str = EntityVectorStoreKey.ID, - k: int = 10, - oversample_scaler: int = 2, -) -> list[Entity]: - """Retrieve related entities by graph embeddings.""" - if exclude_entity_names is None: - exclude_entity_names = [] - # find nearest neighbors of this entity using graph embedding - query_entity = get_entity_by_key( - entities=all_entities, key=embedding_vectorstore_key, value=entity_id - ) - query_embedding = query_entity.graph_embedding if query_entity else None - - # oversample to account for excluded entities - if query_embedding: - matched_entities = [] - search_results = graph_embedding_vectorstore.similarity_search_by_vector( - query_embedding=query_embedding, k=k * oversample_scaler - ) - for result in search_results: - matched = get_entity_by_key( - entities=all_entities, - key=embedding_vectorstore_key, - value=result.document.id, - ) - if matched: - matched_entities.append(matched) - - # filter out excluded entities - if exclude_entity_names: - matched_entities = [ - entity - for entity in matched_entities - if entity.title not in exclude_entity_names - ] - matched_entities.sort(key=lambda x: x.rank, reverse=True) - return matched_entities[:k] - - return [] - - def find_nearest_neighbors_by_entity_rank( entity_name: str, all_entities: list[Entity], diff --git a/graphrag/query/context_builder/local_context.py b/graphrag/query/context_builder/local_context.py index c2f7527e33..78522af5a9 100644 --- a/graphrag/query/context_builder/local_context.py +++ b/graphrag/query/context_builder/local_context.py @@ -288,7 +288,12 @@ def _filter_relationships( ) # sort by attributes[links] first, then by ranking_attribute - if relationship_ranking_attribute == "weight": + if relationship_ranking_attribute == "rank": + out_network_relationships.sort( + key=lambda x: (x.attributes["links"], x.rank), # type: ignore + reverse=True, # type: ignore + ) + elif relationship_ranking_attribute == "weight": out_network_relationships.sort( key=lambda x: (x.attributes["links"], x.weight), # type: ignore reverse=True, # type: ignore diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index 86561a6544..9bd73d6531 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -40,7 +40,6 @@ def read_indexer_text_units(final_text_units: pd.DataFrame) -> list[TextUnit]: """Read in the Text Units from the raw indexing outputs.""" return read_text_units( df=final_text_units, - short_id_col=None, # expects a covariate map of type -> ids covariates_col=None, ) @@ -66,12 +65,19 @@ def read_indexer_covariates(final_covariates: pd.DataFrame) -> list[Covariate]: def read_indexer_relationships(final_relationships: pd.DataFrame) -> list[Relationship]: """Read in the Relationships from the raw indexing outputs.""" + # rank is for back-compat with older indexes + # TODO: remove for 1.0 + rank_col = ( + "combined_degree" + if "combined_degree" in final_relationships.columns + else "rank" + ) return read_relationships( df=final_relationships, short_id_col="human_readable_id", + rank_col=rank_col, description_embedding_col=None, - document_ids_col=None, - attributes_cols=["rank"], + attributes_cols=None, ) @@ -87,34 +93,40 @@ def read_indexer_reports( If not dynamic_community_selection, then select reports with the max community level that an entity belongs to. """ - report_df = final_community_reports - entity_df = final_nodes + reports_df = final_community_reports + nodes_df = final_nodes + if community_level is not None: - entity_df = _filter_under_community_level(entity_df, community_level) - report_df = _filter_under_community_level(report_df, community_level) + nodes_df = _filter_under_community_level(nodes_df, community_level) + reports_df = _filter_under_community_level(reports_df, community_level) if not dynamic_community_selection: # perform community level roll up - entity_df.loc[:, "community"] = entity_df["community"].fillna(-1) - entity_df.loc[:, "community"] = entity_df["community"].astype(int) + nodes_df.loc[:, "community"] = nodes_df["community"].fillna(-1) + nodes_df.loc[:, "community"] = nodes_df["community"].astype(int) + + nodes_df = nodes_df.groupby(["title"]).agg({"community": "max"}).reset_index() + filtered_community_df = nodes_df["community"].drop_duplicates() - entity_df = entity_df.groupby(["title"]).agg({"community": "max"}).reset_index() - entity_df["community"] = entity_df["community"].astype(str) - filtered_community_df = entity_df["community"].drop_duplicates() + # todo: pre 1.0 back-compat where community was a string + reports_df.loc[:, "community"] = reports_df["community"].fillna(-1) + reports_df.loc[:, "community"] = reports_df["community"].astype(int) - report_df = report_df.merge(filtered_community_df, on="community", how="inner") + reports_df = reports_df.merge( + filtered_community_df, on="community", how="inner" + ) if config and ( - content_embedding_col not in report_df.columns - or report_df.loc[:, content_embedding_col].isna().any() + content_embedding_col not in reports_df.columns + or reports_df.loc[:, content_embedding_col].isna().any() ): embedder = get_text_embedder(config) - report_df = embed_community_reports( - report_df, embedder, embedding_col=content_embedding_col + reports_df = embed_community_reports( + reports_df, embedder, embedding_col=content_embedding_col ) return read_community_reports( - df=report_df, + df=reports_df, id_col="id", short_id_col="community", summary_embedding_col=None, @@ -137,44 +149,42 @@ def read_indexer_entities( community_level: int | None, ) -> list[Entity]: """Read in the Entities from the raw indexing outputs.""" - entity_df = final_nodes - entity_embedding_df = final_entities + nodes_df = final_nodes + entities_df = final_entities if community_level is not None: - entity_df = _filter_under_community_level(entity_df, community_level) + nodes_df = _filter_under_community_level(nodes_df, community_level) - entity_df = cast(pd.DataFrame, entity_df[["title", "degree", "community"]]).rename( - columns={"title": "name", "degree": "rank"} - ) + nodes_df = cast(pd.DataFrame, nodes_df[["id", "degree", "community"]]) - entity_df["community"] = entity_df["community"].fillna(-1) - entity_df["community"] = entity_df["community"].astype(int) - entity_df["rank"] = entity_df["rank"].astype(int) + nodes_df["community"] = nodes_df["community"].fillna(-1) + nodes_df["community"] = nodes_df["community"].astype(int) + nodes_df["degree"] = nodes_df["degree"].astype(int) - # group entities by name and rank and remove duplicated community IDs - entity_df = ( - entity_df.groupby(["name", "rank"]).agg({"community": set}).reset_index() + # group entities by id and degree and remove duplicated community IDs + nodes_df = nodes_df.groupby(["id", "degree"]).agg({"community": set}).reset_index() + nodes_df["community"] = nodes_df["community"].apply(lambda x: [str(i) for i in x]) + final_df = nodes_df.merge(entities_df, on="id", how="inner").drop_duplicates( + subset=["id"] ) - entity_df["community"] = entity_df["community"].apply(lambda x: [str(i) for i in x]) - entity_df = entity_df.merge( - entity_embedding_df, on="name", how="inner" - ).drop_duplicates(subset=["name"]) + + # todo: pre 1.0 back-compat where title was name + if "title" not in final_df.columns: + final_df["title"] = final_df["name"] # read entity dataframe to knowledge model objects return read_entities( - df=entity_df, + df=final_df, id_col="id", - title_col="name", + title_col="title", type_col="type", short_id_col="human_readable_id", description_col="description", community_col="community", - rank_col="rank", + rank_col="degree", name_embedding_col=None, description_embedding_col="description_embedding", - graph_embedding_col=None, text_unit_ids_col="text_unit_ids", - document_ids_col=None, ) @@ -187,41 +197,50 @@ def read_indexer_communities( Reconstruct the community hierarchy information and add to the sub-community field. """ - community_df = final_communities - node_df = final_nodes - report_df = final_community_reports + communities_df = final_communities + nodes_df = final_nodes + reports_df = final_community_reports + + # todo: pre 1.0 back-compat! + if "community" not in communities_df.columns: + communities_df["community"] = communities_df["id"] # ensure communities matches community reports - missing_reports = community_df[ - ~community_df.id.isin(report_df.community.unique()) - ].id.to_list() + missing_reports = communities_df[ + ~communities_df.community.isin(reports_df.community.unique()) + ].community.to_list() if len(missing_reports): log.warning("Missing reports for communities: %s", missing_reports) - community_df = community_df.loc[ - community_df.id.isin(report_df.community.unique()) + communities_df = communities_df.loc[ + communities_df.community.isin(reports_df.community.unique()) ] - node_df = node_df.loc[node_df.community.isin(report_df.community.unique())] + nodes_df = nodes_df.loc[nodes_df.community.isin(reports_df.community.unique())] # reconstruct the community hierarchy # note that restore_community_hierarchy only return communities with sub communities - community_hierarchy = restore_community_hierarchy(input=node_df) - community_hierarchy = ( - community_hierarchy.groupby(["community"]) - .agg({"sub_community": list}) - .reset_index() - .rename(columns={"community": "id", "sub_community": "sub_community_ids"}) - ) - # add sub community IDs to community DataFrame - community_df = community_df.merge(community_hierarchy, on="id", how="left") - # replace NaN sub community IDs with empty list - community_df.sub_community_ids = community_df.sub_community_ids.apply( - lambda x: x if isinstance(x, list) else [] - ) + community_hierarchy = restore_community_hierarchy(input=nodes_df) + + # small datasets can result in hierarchies that are only one deep, so the hierarchy will have no rows + if not community_hierarchy.empty: + community_hierarchy = ( + community_hierarchy.groupby(["community"]) + .agg({"sub_community": list}) + .reset_index() + .rename(columns={"sub_community": "sub_community_ids"}) + ) + # add sub community IDs to community DataFrame + communities_df = communities_df.merge( + community_hierarchy, on="community", how="left" + ) + # replace NaN sub community IDs with empty list + communities_df.sub_community_ids = communities_df.sub_community_ids.apply( + lambda x: x if isinstance(x, list) else [] + ) return read_communities( - community_df, + communities_df, id_col="id", - short_id_col="id", + short_id_col="community", title_col="title", level_col="level", entities_col=None, diff --git a/graphrag/query/input/loaders/dfs.py b/graphrag/query/input/loaders/dfs.py index 6f1cdb0e46..f144ad8a47 100644 --- a/graphrag/query/input/loaders/dfs.py +++ b/graphrag/query/input/loaders/dfs.py @@ -9,13 +9,11 @@ Community, CommunityReport, Covariate, - Document, Entity, Relationship, TextUnit, ) from graphrag.query.input.loaders.utils import ( - to_list, to_optional_dict, to_optional_float, to_optional_int, @@ -29,16 +27,14 @@ def read_entities( df: pd.DataFrame, id_col: str = "id", - short_id_col: str | None = "short_id", + short_id_col: str | None = "human_readable_id", title_col: str = "title", type_col: str | None = "type", description_col: str | None = "description", name_embedding_col: str | None = "name_embedding", description_embedding_col: str | None = "description_embedding", - graph_embedding_col: str | None = "graph_embedding", community_col: str | None = "community_ids", text_unit_ids_col: str | None = "text_unit_ids", - document_ids_col: str | None = "document_ids", rank_col: str | None = "degree", attributes_cols: list[str] | None = None, ) -> list[Entity]: @@ -55,10 +51,8 @@ def read_entities( description_embedding=to_optional_list( row, description_embedding_col, item_type=float ), - graph_embedding=to_optional_list(row, graph_embedding_col, item_type=float), community_ids=to_optional_list(row, community_col, item_type=str), text_unit_ids=to_optional_list(row, text_unit_ids_col), - document_ids=to_optional_list(row, document_ids_col), rank=to_optional_int(row, rank_col), attributes=( {col: row.get(col) for col in attributes_cols} @@ -92,28 +86,6 @@ def store_entity_semantic_embeddings( return vectorstore -def store_entity_behavior_embeddings( - entities: list[Entity], - vectorstore: BaseVectorStore, -) -> BaseVectorStore: - """Store entity behavior embeddings in a vectorstore.""" - documents = [ - VectorStoreDocument( - id=entity.id, - text=entity.description, - vector=entity.graph_embedding, - attributes=( - {"title": entity.title, **entity.attributes} - if entity.attributes - else {"title": entity.title} - ), - ) - for entity in entities - ] - vectorstore.load_documents(documents=documents) - return vectorstore - - def store_reports_semantic_embeddings( reports: list[CommunityReport], vectorstore: BaseVectorStore, @@ -139,14 +111,14 @@ def store_reports_semantic_embeddings( def read_relationships( df: pd.DataFrame, id_col: str = "id", - short_id_col: str | None = "short_id", + short_id_col: str | None = "human_readable_id", source_col: str = "source", target_col: str = "target", description_col: str | None = "description", + rank_col: str | None = "combined_degree", description_embedding_col: str | None = "description_embedding", weight_col: str | None = "weight", text_unit_ids_col: str | None = "text_unit_ids", - document_ids_col: str | None = "document_ids", attributes_cols: list[str] | None = None, ) -> list[Relationship]: """Read relationships from a dataframe.""" @@ -163,7 +135,7 @@ def read_relationships( ), weight=to_optional_float(row, weight_col), text_unit_ids=to_optional_list(row, text_unit_ids_col, item_type=str), - document_ids=to_optional_list(row, document_ids_col, item_type=str), + rank=to_optional_int(row, rank_col), attributes=( {col: row.get(col) for col in attributes_cols} if attributes_cols @@ -177,11 +149,10 @@ def read_relationships( def read_covariates( df: pd.DataFrame, id_col: str = "id", - short_id_col: str | None = "short_id", + short_id_col: str | None = "human_readable_id", subject_col: str = "subject_id", covariate_type_col: str | None = "type", text_unit_ids_col: str | None = "text_unit_ids", - document_ids_col: str | None = "document_ids", attributes_cols: list[str] | None = None, ) -> list[Covariate]: """Read covariates from a dataframe.""" @@ -195,7 +166,6 @@ def read_covariates( to_str(row, covariate_type_col) if covariate_type_col else "claim" ), text_unit_ids=to_optional_list(row, text_unit_ids_col, item_type=str), - document_ids=to_optional_list(row, document_ids_col, item_type=str), attributes=( {col: row.get(col) for col in attributes_cols} if attributes_cols @@ -209,7 +179,7 @@ def read_covariates( def read_communities( df: pd.DataFrame, id_col: str = "id", - short_id_col: str | None = "short_id", + short_id_col: str | None = "community", title_col: str = "title", level_col: str = "level", entities_col: str | None = "entity_ids", @@ -231,7 +201,7 @@ def read_communities( covariate_ids=to_optional_dict( row, covariates_col, key_type=str, value_type=str ), - sub_community_ids=to_optional_list(row, sub_communities_col, item_type=str), + sub_community_ids=to_optional_list(row, sub_communities_col), attributes=( {col: row.get(col) for col in attributes_cols} if attributes_cols @@ -245,7 +215,7 @@ def read_communities( def read_community_reports( df: pd.DataFrame, id_col: str = "id", - short_id_col: str | None = "short_id", + short_id_col: str | None = "community", title_col: str = "title", community_col: str = "community", summary_col: str = "summary", @@ -285,7 +255,6 @@ def read_community_reports( def read_text_units( df: pd.DataFrame, id_col: str = "id", - short_id_col: str | None = "short_id", text_col: str = "text", entities_col: str | None = "entity_ids", relationships_col: str | None = "relationship_ids", @@ -300,7 +269,7 @@ def read_text_units( for idx, row in df.iterrows(): chunk = TextUnit( id=to_str(row, id_col), - short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + short_id=str(idx), text=to_str(row, text_col), entity_ids=to_optional_list(row, entities_col, item_type=str), relationship_ids=to_optional_list(row, relationships_col, item_type=str), @@ -318,43 +287,3 @@ def read_text_units( ) text_units.append(chunk) return text_units - - -def read_documents( - df: pd.DataFrame, - id_col: str = "id", - short_id_col: str = "short_id", - title_col: str = "title", - type_col: str = "type", - summary_col: str | None = "entities", - raw_content_col: str | None = "relationships", - summary_embedding_col: str | None = "summary_embedding", - content_embedding_col: str | None = "raw_content_embedding", - text_units_col: str | None = "text_units", - attributes_cols: list[str] | None = None, -) -> list[Document]: - """Read documents from a dataframe.""" - docs = [] - for idx, row in df.iterrows(): - doc = Document( - id=to_str(row, id_col), - short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), - title=to_str(row, title_col), - type=to_str(row, type_col), - summary=to_optional_str(row, summary_col), - raw_content=to_str(row, raw_content_col), - summary_embedding=to_optional_list( - row, summary_embedding_col, item_type=float - ), - raw_content_embedding=to_optional_list( - row, content_embedding_col, item_type=float - ), - text_units=to_list(row, text_units_col, item_type=str), # type: ignore - attributes=( - {col: row.get(col) for col in attributes_cols} - if attributes_cols - else None - ), - ) - docs.append(doc) - return docs diff --git a/graphrag/query/input/retrieval/relationships.py b/graphrag/query/input/retrieval/relationships.py index 2ef5ba38cb..2dec596ff3 100644 --- a/graphrag/query/input/retrieval/relationships.py +++ b/graphrag/query/input/retrieval/relationships.py @@ -27,9 +27,7 @@ def get_in_network_relationships( return selected_relationships # sort by ranking attribute - return sort_relationships_by_ranking_attribute( - selected_relationships, selected_entities, ranking_attribute - ) + return sort_relationships_by_rank(selected_relationships, ranking_attribute) def get_out_network_relationships( @@ -52,9 +50,7 @@ def get_out_network_relationships( and relationship.source not in selected_entity_names ] selected_relationships = source_relationships + target_relationships - return sort_relationships_by_ranking_attribute( - selected_relationships, selected_entities, ranking_attribute - ) + return sort_relationships_by_rank(selected_relationships, ranking_attribute) def get_candidate_relationships( @@ -81,35 +77,11 @@ def get_entities_from_relationships( return [entity for entity in entities if entity.title in selected_entity_names] -def calculate_relationship_combined_rank( - relationships: list[Relationship], - entities: list[Entity], - ranking_attribute: str = "rank", -) -> list[Relationship]: - """Calculate default rank for a relationship based on the combined rank of source and target entities.""" - entity_mappings = {entity.title: entity for entity in entities} - - for relationship in relationships: - if relationship.attributes is None: - relationship.attributes = {} - source = entity_mappings.get(relationship.source) - target = entity_mappings.get(relationship.target) - source_rank = source.rank if source and source.rank else 0 - target_rank = target.rank if target and target.rank else 0 - relationship.attributes[ranking_attribute] = source_rank + target_rank # type: ignore - return relationships - - -def sort_relationships_by_ranking_attribute( +def sort_relationships_by_rank( relationships: list[Relationship], - entities: list[Entity], ranking_attribute: str = "rank", ) -> list[Relationship]: - """ - Sort relationships by a ranking_attribute. - - If no ranking attribute exists, sort by combined rank of source and target entities. - """ + """Sort relationships by a ranking_attribute.""" if len(relationships) == 0: return relationships @@ -122,17 +94,10 @@ def sort_relationships_by_ranking_attribute( key=lambda x: int(x.attributes[ranking_attribute]) if x.attributes else 0, reverse=True, ) + elif ranking_attribute == "rank": + relationships.sort(key=lambda x: x.rank if x.rank else 0.0, reverse=True) elif ranking_attribute == "weight": relationships.sort(key=lambda x: x.weight if x.weight else 0.0, reverse=True) - else: - # ranking attribute do not exist, calculate rank = combined ranks of source and target - relationships = calculate_relationship_combined_rank( - relationships, entities, ranking_attribute - ) - relationships.sort( - key=lambda x: int(x.attributes[ranking_attribute]) if x.attributes else 0, - reverse=True, - ) return relationships diff --git a/graphrag/utils/embeddings.py b/graphrag/utils/embeddings.py new file mode 100644 index 0000000000..1431c8b510 --- /dev/null +++ b/graphrag/utils/embeddings.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utilities for working with embeddings stores.""" + +from graphrag.index.config.embeddings import all_embeddings + + +def create_collection_name( + container_name: str, embedding_name: str, validate: bool = True +) -> str: + """ + Create a collection name for the embedding store. + + Within any given vector store, we can have multiple sets of embeddings organized into projects. + The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation. + + The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings + + Note that we use dot notation in our names, but many vector stores do not support this - so we convert to dashes. + """ + if validate and embedding_name not in all_embeddings: + msg = f"Invalid embedding name: {embedding_name}" + raise KeyError(msg) + return f"{container_name}-{embedding_name}".replace(".", "-") diff --git a/mkdocs.yaml b/mkdocs.yaml index c2a04c081a..0f8a0a794c 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -29,6 +29,7 @@ nav: - Overview: "index/overview.md" - Architecture: "index/architecture.md" - Dataflow: "index/default_dataflow.md" + - Outputs: "index/outputs.md" - Configuration: - Overview: "config/overview.md" - Init Command: "config/init.md" diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 28a2c87d71..11b2dceb51 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -27,8 +27,7 @@ ], "nan_allowed_columns": [ "type", - "description", - "graph_embedding" + "description" ], "subworkflows": 1, "max_runtime": 300, @@ -49,9 +48,7 @@ 2500 ], "nan_allowed_columns": [ - "entity_type", "description", - "graph_embedding", "community", "level" ], @@ -74,14 +71,15 @@ 2500 ], "nan_allowed_columns": [ - "community_id", "title", "summary", "full_content", "full_content_json", "rank", "rank_explanation", - "findings" + "findings", + "period", + "size" ], "subworkflows": 1, "max_runtime": 300, diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 49cf03132b..0b5ca26581 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -45,8 +45,7 @@ ], "nan_allowed_columns": [ "type", - "description", - "graph_embedding" + "description" ], "subworkflows": 1, "max_runtime": 300, @@ -67,9 +66,7 @@ 2500 ], "nan_allowed_columns": [ - "entity_type", "description", - "graph_embedding", "community", "level" ], @@ -92,14 +89,15 @@ 2500 ], "nan_allowed_columns": [ - "community_id", "title", "summary", "full_content", "full_content_json", "rank", "rank_explanation", - "findings" + "findings", + "period", + "size" ], "subworkflows": 1, "max_runtime": 300, diff --git a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py index 4d6e36c306..873eb45616 100644 --- a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py +++ b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py @@ -25,7 +25,7 @@ "source": "SCROOGE", "target": "ALI BABA", "description": "Scrooge recalls Ali Baba as a fond memory from his childhood readings", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -46,7 +46,7 @@ "source": "SCROOGE", "target": "BELLE", "description": "Belle and Scrooge were once engaged, but their relationship ended due to Scrooge's growing obsession with wealth", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -67,7 +67,7 @@ "source": "SCROOGE", "target": "CHRISTMAS", "description": "Scrooge's disdain for Christmas is a central theme, highlighting his miserliness and lack of compassion", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -88,7 +88,7 @@ "source": "SCROOGE", "target": "CHRISTMAS DAY", "description": "Scrooge wakes up on Christmas Day with a changed heart, ready to celebrate and spread happiness", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -109,7 +109,7 @@ "source": "SCROOGE", "target": "DUTCH MERCHANT", "description": "Scrooge's fireplace, built by the Dutch Merchant, serves as a focal point in his room where he encounters Marley's Ghost", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -130,7 +130,7 @@ "source": "SCROOGE", "target": "FAN", "description": "Fan is Scrooge's sister, who shows love and care by bringing him home for Christmas", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -151,7 +151,7 @@ "source": "SCROOGE", "target": "FRED", "description": "Scrooge accepts Fred's invitation to Christmas dinner, marking a significant step in repairing their relationship", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -172,7 +172,7 @@ "source": "SCROOGE", "target": "GENTLEMAN", "description": "The gentleman approaches Scrooge to solicit donations for the poor, which Scrooge rebuffs", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -193,7 +193,7 @@ "source": "SCROOGE", "target": "GHOST", "description": "The Ghost is taking Scrooge on a transformative journey by showing him scenes from his past, aiming to make him reflect on his life choices and their consequences. This spectral guide is not only focusing on Scrooge's personal history but also emphasizing the importance of Christmas and the need for a change in perspective. Through these vivid reenactments, the Ghost highlights the error of Scrooge's ways and the significant impact his actions have on others, including Tiny Tim. This experience is designed to enlighten Scrooge, encouraging him to reconsider his approach to life and the people around him.", - "rank": 32, + "combined_degree": 32, }, ], "claim_details": [nan], @@ -203,11 +203,15 @@ def test_sort_context(): ctx = sort_context(context) - assert num_tokens(ctx) == 827 if platform.system() == "Windows" else 826 - assert ctx is not None + assert ctx is not None, "Context is none" + num = num_tokens(ctx) + assert ( + num == 828 if platform.system() == "Windows" else 826 + ), f"num_tokens is not matched for platform (win = 827, else 826): {num}" def test_sort_context_max_tokens(): ctx = sort_context(context, max_tokens=800) - assert ctx is not None - assert num_tokens(ctx) <= 800 + assert ctx is not None, "Context is none" + num = num_tokens(ctx) + assert num <= 800, f"num_tokens is not less than or equal to 800: {num}" diff --git a/tests/unit/indexing/verbs/text/test_split.py b/tests/unit/indexing/verbs/text/test_split.py deleted file mode 100644 index abbb0eeff3..0000000000 --- a/tests/unit/indexing/verbs/text/test_split.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import unittest - -import pandas as pd -import pytest - -from graphrag.index.operations.split_text import split_text - - -class TestTextSplit(unittest.TestCase): - def test_empty_string(self): - input = pd.DataFrame([{"in": ""}]) - result = split_text(input, "in", "out", ",").to_dict(orient="records") - - assert len(result) == 1 - assert result[0]["out"] == [] - - def test_string_without_seperator(self): - input = pd.DataFrame([{"in": "test_string_without_seperator"}]) - result = split_text(input, "in", "out", ",").to_dict(orient="records") - - assert len(result) == 1 - assert result[0]["out"] == ["test_string_without_seperator"] - - def test_string_with_seperator(self): - input = pd.DataFrame([{"in": "test_1,test_2"}]) - result = split_text(input, "in", "out", ",").to_dict(orient="records") - - assert len(result) == 1 - assert result[0]["out"] == ["test_1", "test_2"] - - def test_row_with_list_as_column(self): - input = pd.DataFrame([{"in": ["test_1", "test_2"]}]) - result = split_text(input, "in", "out", ",").to_dict(orient="records") - - assert len(result) == 1 - assert result[0]["out"] == ["test_1", "test_2"] - - def test_non_string_column_throws_error(self): - input = pd.DataFrame([{"in": 5}]) - with pytest.raises(TypeError): - split_text(input, "in", "out", ",").to_dict(orient="records") - - def test_more_than_one_row_returns_correctly(self): - input = pd.DataFrame([{"in": "row_1_1,row_1_2"}, {"in": "row_2_1,row_2_2"}]) - result = split_text(input, "in", "out", ",").to_dict(orient="records") - - assert len(result) == 2 - assert result[0]["out"] == ["row_1_1", "row_1_2"] - assert result[1]["out"] == ["row_2_1", "row_2_2"] diff --git a/tests/unit/indexing/verbs/text/__init__.py b/tests/unit/utils/__init__.py similarity index 100% rename from tests/unit/indexing/verbs/text/__init__.py rename to tests/unit/utils/__init__.py diff --git a/tests/unit/utils/test_embeddings.py b/tests/unit/utils/test_embeddings.py new file mode 100644 index 0000000000..54a6b79647 --- /dev/null +++ b/tests/unit/utils/test_embeddings.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +import pytest + +from graphrag.utils.embeddings import create_collection_name + + +def test_create_collection_name(): + collection = create_collection_name("default", "entity.title") + assert collection == "default-entity-title" + + +def test_create_collection_name_invalid_embedding_throws(): + with pytest.raises(KeyError): + create_collection_name("default", "invalid.name") + + +def test_create_collection_name_invalid_embedding_does_not_throw(): + collection = create_collection_name("default", "invalid.name", validate=False) + assert collection == "default-invalid-name" diff --git a/tests/verbs/data/create_base_entity_graph.parquet b/tests/verbs/data/create_base_entity_graph.parquet index 0e50332cdb..528df2943e 100644 Binary files a/tests/verbs/data/create_base_entity_graph.parquet and b/tests/verbs/data/create_base_entity_graph.parquet differ diff --git a/tests/verbs/data/create_base_text_units.parquet b/tests/verbs/data/create_base_text_units.parquet index 3156191a63..61b082f6c3 100644 Binary files a/tests/verbs/data/create_base_text_units.parquet and b/tests/verbs/data/create_base_text_units.parquet differ diff --git a/tests/verbs/data/create_final_communities.parquet b/tests/verbs/data/create_final_communities.parquet index 6c7fbbf578..6c3f91a990 100644 Binary files a/tests/verbs/data/create_final_communities.parquet and b/tests/verbs/data/create_final_communities.parquet differ diff --git a/tests/verbs/data/create_final_community_reports.parquet b/tests/verbs/data/create_final_community_reports.parquet index f50493f42e..e9e35a928e 100644 Binary files a/tests/verbs/data/create_final_community_reports.parquet and b/tests/verbs/data/create_final_community_reports.parquet differ diff --git a/tests/verbs/data/create_final_covariates.parquet b/tests/verbs/data/create_final_covariates.parquet index 1d89c6104b..e28a01a320 100644 Binary files a/tests/verbs/data/create_final_covariates.parquet and b/tests/verbs/data/create_final_covariates.parquet differ diff --git a/tests/verbs/data/create_final_documents.parquet b/tests/verbs/data/create_final_documents.parquet index 4ff33ebf4d..5076952d4c 100644 Binary files a/tests/verbs/data/create_final_documents.parquet and b/tests/verbs/data/create_final_documents.parquet differ diff --git a/tests/verbs/data/create_final_entities.parquet b/tests/verbs/data/create_final_entities.parquet index 295f1cb915..fc4a90c7ee 100644 Binary files a/tests/verbs/data/create_final_entities.parquet and b/tests/verbs/data/create_final_entities.parquet differ diff --git a/tests/verbs/data/create_final_nodes.parquet b/tests/verbs/data/create_final_nodes.parquet index ecf9a9a673..b2f0b3e91e 100644 Binary files a/tests/verbs/data/create_final_nodes.parquet and b/tests/verbs/data/create_final_nodes.parquet differ diff --git a/tests/verbs/data/create_final_relationships.parquet b/tests/verbs/data/create_final_relationships.parquet index d7bfe658c0..fb885b0273 100644 Binary files a/tests/verbs/data/create_final_relationships.parquet and b/tests/verbs/data/create_final_relationships.parquet differ diff --git a/tests/verbs/data/create_final_text_units.parquet b/tests/verbs/data/create_final_text_units.parquet index a238493694..5191a2cb9d 100644 Binary files a/tests/verbs/data/create_final_text_units.parquet and b/tests/verbs/data/create_final_text_units.parquet differ diff --git a/tests/verbs/data/source_documents.parquet b/tests/verbs/data/source_documents.parquet new file mode 100644 index 0000000000..1d78b25919 Binary files /dev/null and b/tests/verbs/data/source_documents.parquet differ diff --git a/tests/verbs/test_create_final_communities.py b/tests/verbs/test_create_final_communities.py index f6c5b15c80..9f002f7c1d 100644 --- a/tests/verbs/test_create_final_communities.py +++ b/tests/verbs/test_create_final_communities.py @@ -36,9 +36,12 @@ async def test_create_final_communities(): context=context, ) - # ignore the period column, because it is recalculated every time + # ignore the period and id columns, because they recalculated every time + assert "period" in expected.columns + assert "id" in expected.columns columns = list(expected.columns.values) columns.remove("period") + columns.remove("id") compare_outputs( actual, expected, diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py index 6237762395..b799bffbfb 100644 --- a/tests/verbs/test_create_final_community_reports.py +++ b/tests/verbs/test_create_final_community_reports.py @@ -47,6 +47,7 @@ async def test_create_final_community_reports(): "workflow:create_final_nodes", "workflow:create_final_covariates", "workflow:create_final_relationships", + "workflow:create_final_entities", "workflow:create_final_communities", ]) expected = load_expected(workflow_name) @@ -79,6 +80,7 @@ async def test_create_final_community_reports_missing_llm_throws(): "workflow:create_final_nodes", "workflow:create_final_covariates", "workflow:create_final_relationships", + "workflow:create_final_entities", "workflow:create_final_communities", ]) diff --git a/tests/verbs/test_create_final_covariates.py b/tests/verbs/test_create_final_covariates.py index 077cfe8fc8..23a02bca67 100644 --- a/tests/verbs/test_create_final_covariates.py +++ b/tests/verbs/test_create_final_covariates.py @@ -59,13 +59,10 @@ async def test_create_final_covariates(): # assert all of the columns that covariates copied from the input assert_series_equal(actual["text_unit_id"], input["id"], check_names=False) - assert_series_equal(actual["text_unit_id"], input["chunk_id"], check_names=False) - assert_series_equal(actual["document_ids"], input["document_ids"]) - assert_series_equal(actual["n_tokens"], input["n_tokens"]) - # make sure the human ids are incrementing and cast to strings - assert actual["human_readable_id"][0] == "1" - assert actual["human_readable_id"][1] == "2" + # make sure the human ids are incrementing + assert actual["human_readable_id"][0] == 1 + assert actual["human_readable_id"][1] == 2 # check that the mock data is parsed and inserted into the correct columns assert actual["covariate_type"][0] == "claim" diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index 6d8138088e..eb84c90b22 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.create_final_documents import ( build_steps, workflow_name, @@ -17,10 +18,15 @@ async def test_create_final_documents(): input_tables = load_input_tables([ - "workflow:create_final_text_units", + "workflow:create_base_text_units", ]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) steps = build_steps(config) @@ -30,15 +36,21 @@ async def test_create_final_documents(): { "steps": steps, }, + context=context, ) compare_outputs(actual, expected) async def test_create_final_documents_with_attribute_columns(): - input_tables = load_input_tables(["workflow:create_final_text_units"]) + input_tables = load_input_tables(["workflow:create_base_text_units"]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) config["document_attribute_columns"] = ["title"] @@ -50,11 +62,15 @@ async def test_create_final_documents_with_attribute_columns(): { "steps": steps, }, + context=context, ) # we should have dropped "title" and added "attributes" # our test dataframe does not have attributes, so we'll assert without it # and separately confirm it is in the output - compare_outputs(actual, expected, columns=["id", "text_unit_ids", "raw_content"]) - assert len(actual.columns) == 4 + compare_outputs( + actual, expected, columns=["id", "human_readable_id", "text", "text_unit_ids"] + ) + assert len(actual.columns) == 5 + assert "title" not in actual.columns assert "attributes" in actual.columns diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 1aa919d153..c167ec7b29 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -76,8 +76,10 @@ async def test_create_final_text_units_no_covariates(): ) # we're short a covariate_ids column + columns = list(expected.columns.values) + columns.remove("covariate_ids") compare_outputs( actual, expected, - ["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"], + columns=columns, ) diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index 4f192ca1fd..c0919501d8 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -69,14 +69,10 @@ async def test_generate_text_embeddings(): assert "embedding" in entity_description_embeddings.columns # every other embedding is optional but we've turned them all on, so check a random one - document_raw_content_embeddings_buffer = BytesIO( - await context.storage.get( - "embeddings.document.raw_content.parquet", as_bytes=True - ) - ) - document_raw_content_embeddings = pd.read_parquet( - document_raw_content_embeddings_buffer + document_text_embeddings_buffer = BytesIO( + await context.storage.get("embeddings.document.text.parquet", as_bytes=True) ) - assert len(document_raw_content_embeddings.columns) == 2 - assert "id" in document_raw_content_embeddings.columns - assert "embedding" in document_raw_content_embeddings.columns + document_text_embeddings = pd.read_parquet(document_text_embeddings_buffer) + assert len(document_text_embeddings.columns) == 2 + assert "id" in document_text_embeddings.columns + assert "embedding" in document_text_embeddings.columns diff --git a/tests/verbs/util.py b/tests/verbs/util.py index e0e8c6db9f..c53c4e5be4 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -23,12 +23,8 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]: # stick all the inputs in a map - Workflow looks them up by name input_tables: dict[str, pd.DataFrame] = {} - # all workflows implicitly receive the `input` source, which is formatted as a dataframe after loading from storage - # we'll simulate that by just loading one of our output parquets and converting back to equivalent dataframe - # so we aren't dealing with storage vagaries (which would become an integration test) - source = pd.read_parquet("tests/verbs/data/create_final_documents.parquet") - source.rename(columns={"raw_content": "text"}, inplace=True) - input_tables["source"] = cast(pd.DataFrame, source[["id", "text", "title"]]) + source = pd.read_parquet("tests/verbs/data/source_documents.parquet") + input_tables["source"] = source for input in inputs: # remove the workflow: prefix if it exists, because that is not part of the actual table filename