Skip to content

Commit

Permalink
Collapse create final entities (#1220)
Browse files Browse the repository at this point in the history
* Collapse create_final_entities

* Update smoke tests

* Semver

* Remove prints

* Update embedding assertions
  • Loading branch information
natoverse authored Sep 26, 2024
1 parent 3217013 commit ce71bcf
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 106 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240925235837094282.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create-final-entities."
}
104 changes: 5 additions & 99 deletions graphrag/index/workflows/v1/create_final_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,110 +24,16 @@ def build_steps(
)
skip_name_embedding = config.get("skip_name_embedding", False)
skip_description_embedding = config.get("skip_description_embedding", False)
is_using_vector_store = (
entity_name_embed_config.get("strategy", {}).get("vector_store", None)
is not None
)

return [
{
"verb": "unpack_graph",
"verb": "create_final_entities",
"args": {
"column": "clustered_graph",
"type": "nodes",
"skip_name_embedding": skip_name_embedding,
"skip_description_embedding": skip_description_embedding,
"name_text_embed": entity_name_embed_config,
"description_text_embed": entity_name_description_embed_config,
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{"verb": "rename", "args": {"columns": {"label": "title"}}},
{
"verb": "select",
"args": {
"columns": [
"id",
"title",
"type",
"description",
"human_readable_id",
"graph_embedding",
"source_id",
],
},
},
{
# create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities
# this dedupes the entities so that there is only one of each entity
"verb": "dedupe",
"args": {"columns": ["id"]},
},
{"verb": "rename", "args": {"columns": {"title": "name"}}},
{
# ELIMINATE EMPTY NAMES
"verb": "filter",
"args": {
"column": "name",
"criteria": [
{
"type": "value",
"operator": "is not empty",
}
],
},
},
{
"verb": "text_split",
"args": {"separator": ",", "column": "source_id", "to": "text_unit_ids"},
},
{"verb": "drop", "args": {"columns": ["source_id"]}},
{
"verb": "text_embed",
"enabled": not skip_name_embedding,
"args": {
"embedding_name": "entity_name",
"column": "name",
"to": "name_embedding",
**entity_name_embed_config,
},
},
{
"verb": "merge",
"enabled": not skip_description_embedding,
"args": {
"strategy": "concat",
"columns": ["name", "description"],
"to": "name_description",
"delimiter": ":",
"preserveSource": True,
},
},
{
"verb": "text_embed",
"enabled": not skip_description_embedding,
"args": {
"embedding_name": "entity_name_description",
"column": "name_description",
"to": "description_embedding",
**entity_name_description_embed_config,
},
},
{
"verb": "drop",
"enabled": not skip_description_embedding,
"args": {
"columns": ["name_description"],
},
},
{
# ELIMINATE EMPTY DESCRIPTION EMBEDDINGS
"verb": "filter",
"enabled": not skip_description_embedding and not is_using_vector_store,
"args": {
"column": "description_embedding",
"criteria": [
{
"type": "value",
"operator": "is not empty",
}
],
},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .create_final_communities import create_final_communities
from .create_final_covariates import create_final_covariates
from .create_final_documents import create_final_documents
from .create_final_entities import create_final_entities
from .create_final_nodes import create_final_nodes
from .create_final_relationships import (
create_final_relationships,
Expand All @@ -20,6 +21,7 @@
"create_final_communities",
"create_final_covariates",
"create_final_documents",
"create_final_entities",
"create_final_nodes",
"create_final_relationships",
"create_final_text_units",
Expand Down
105 changes: 105 additions & 0 deletions graphrag/index/workflows/v1/subflows/create_final_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform final entities."""

from typing import cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
from graphrag.index.verbs.text.split import text_split_df


@verb(
name="create_final_entities",
treats_input_tables_as_immutable=True,
)
async def create_final_entities(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
name_text_embed: dict,
description_text_embed: dict,
skip_name_embedding: bool = False,
skip_description_embedding: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final entities."""
table = cast(pd.DataFrame, input.get_input())

nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
nodes.rename(columns={"label": "name"}, inplace=True)

nodes = cast(
pd.DataFrame,
nodes[
[
"id",
"name",
"type",
"description",
"human_readable_id",
"graph_embedding",
"source_id",
]
],
)

# create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities
# this dedupes the entities so that there is only one of each entity
nodes.drop_duplicates(subset="id", inplace=True)

# eliminate empty names
filtered = cast(pd.DataFrame, nodes[nodes["name"].notna()].reset_index(drop=True))

with_ids = text_split_df(
filtered, column="source_id", separator=",", to="text_unit_ids"
)
with_ids.drop(columns=["source_id"], inplace=True)

embedded = with_ids

if not skip_name_embedding:
embedded = await text_embed_df(
embedded,
callbacks,
cache,
column="name",
strategy=name_text_embed["strategy"],
to="name_embedding",
embedding_name="entity_name",
)

if not skip_description_embedding:
# description embedding is a concat of the name + description, so we'll create a temporary column
embedded["name_description"] = embedded["name"] + ":" + embedded["description"]
embedded = await text_embed_df(
embedded,
callbacks,
cache,
column="name_description",
strategy=description_text_embed["strategy"],
to="description_embedding",
embedding_name="entity_name_description",
)
embedded.drop(columns=["name_description"], inplace=True)
is_using_vector_store = (
description_text_embed.get("strategy", {}).get("vector_store", None)
is not None
)
if not is_using_vector_store:
embedded = embedded[embedded["description_embedding"].notna()].reset_index(
drop=True
)

return create_verb_result(cast(Table, embedded))
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"description",
"graph_embedding"
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 300
},
"create_final_relationships": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"description",
"graph_embedding"
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 300
},
"create_final_relationships": {
Expand Down
2 changes: 1 addition & 1 deletion tests/verbs/test_create_final_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ async def test_create_final_documents_with_embeddings():
assert "raw_content_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["raw_content_embedding"][0]) == 3
assert len(actual["raw_content_embedding"][:1][0]) == 3
Loading

0 comments on commit ce71bcf

Please sign in to comment.