Skip to content

Commit

Permalink
Remove aggregate_df from final coomunities and final text units (#1179)
Browse files Browse the repository at this point in the history
* Remove aggregate_df from final coomunities and final text units

* Semver

* Ruff and format

* Format

* Format

* Fix tests, ruff and checks

* Remove some leftover prints

* Removed _final_join method
  • Loading branch information
AlonsoGuevara authored Sep 23, 2024
1 parent fbc483e commit be7d3eb
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 148 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240920221112632172.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Remove aggregate_df from final coomunities and final text units"
}
63 changes: 22 additions & 41 deletions graphrag/index/workflows/v1/subflows/create_final_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(name="create_final_communities", treats_input_tables_as_immutable=True)
Expand All @@ -30,54 +29,35 @@ def create_final_communities(
graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")

# Merge graph_nodes with graph_edges for both source and target matches
source_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="source",
how="inner",
graph_edges, left_on="label", right_on="source", how="inner"
)

target_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="target",
how="inner",
graph_edges, left_on="label", right_on="target", how="inner"
)

concatenated_clusters = pd.concat(
[source_clusters, target_clusters], ignore_index=True
)
# Concatenate the source and target clusters
clusters = pd.concat([source_clusters, target_clusters], ignore_index=True)

# level_x is the left side of the join
# level_y is the right side of the join
# we only want to keep the clusters that are the same on both sides
combined_clusters = concatenated_clusters[
concatenated_clusters["level_x"] == concatenated_clusters["level_y"]
# Keep only rows where level_x == level_y
combined_clusters = clusters[
clusters["level_x"] == clusters["level_y"]
].reset_index(drop=True)

cluster_relationships = aggregate_df(
cast(Table, combined_clusters),
aggregations=[
{
"column": "id_y", # this is the id of the edge from the join steps above
"to": "relationship_ids",
"operation": "array_agg_distinct",
},
{
"column": "source_id_x",
"to": "text_unit_ids",
"operation": "array_agg_distinct",
},
],
groupby=[
"cluster",
"level_x", # level_x is the left side of the join
],
cluster_relationships = (
combined_clusters.groupby(["cluster", "level_x"], sort=False)
.agg(
relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique")
)
.reset_index()
)

all_clusters = aggregate_df(
graph_nodes,
aggregations=[{"column": "cluster", "to": "id", "operation": "any"}],
groupby=["cluster", "level"],
all_clusters = (
graph_nodes.groupby(["cluster", "level"], sort=False)
.agg(id=("cluster", "first"))
.reset_index()
)

joined = all_clusters.merge(
Expand All @@ -94,14 +74,15 @@ def create_final_communities(
return create_verb_result(
cast(
Table,
filtered[
filtered.loc[
:,
[
"id",
"title",
"level",
"relationship_ids",
"text_unit_ids",
]
],
],
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

from typing import cast

import pandas as pd
from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result

from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(
name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True
Expand All @@ -21,15 +20,15 @@ def create_final_text_units_pre_embedding(
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform before we embed the text units."""
table = input.get_input()
table = cast(pd.DataFrame, input.get_input())
others = input.get_others()

selected = cast(Table, table[["id", "chunk", "document_ids", "n_tokens"]]).rename(
selected = table.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename(
columns={"chunk": "text"}
)

final_entities = others[0]
final_relationships = others[1]
final_entities = cast(pd.DataFrame, others[0])
final_relationships = cast(pd.DataFrame, others[1])
entity_join = _entities(final_entities)
relationship_join = _relationships(final_relationships)

Expand All @@ -38,116 +37,47 @@ def create_final_text_units_pre_embedding(
final_joined = relationship_joined

if covariates_enabled:
final_covariates = others[2]
final_covariates = cast(pd.DataFrame, others[2])
covariate_join = _covariates(final_covariates)
final_joined = _join(relationship_joined, covariate_join)

aggregated = _final_aggregation(final_joined, covariates_enabled)

return create_verb_result(aggregated)


def _final_aggregation(table, covariates_enabled):
aggregations = [
{
"column": "text",
"operation": "any",
"to": "text",
},
{
"column": "n_tokens",
"operation": "any",
"to": "n_tokens",
},
{
"column": "document_ids",
"operation": "any",
"to": "document_ids",
},
{
"column": "entity_ids",
"operation": "any",
"to": "entity_ids",
},
{
"column": "relationship_ids",
"operation": "any",
"to": "relationship_ids",
},
]
if covariates_enabled:
aggregations.append({
"column": "covariate_ids",
"operation": "any",
"to": "covariate_ids",
})
return aggregate_df(
table,
aggregations,
["id"],
)
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()

return create_verb_result(cast(Table, aggregated))


def _entities(df: pd.DataFrame) -> pd.DataFrame:
selected = df.loc[:, ["id", "text_unit_ids"]]
unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)

def _entities(table):
selected = cast(Table, table[["id", "text_unit_ids"]])
unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
return aggregate_df(
unrolled,
[
{
"column": "id",
"operation": "array_agg_distinct",
"to": "entity_ids",
},
{
"column": "text_unit_ids",
"operation": "any",
"to": "id",
},
],
["text_unit_ids"],
return (
unrolled.groupby("text_unit_ids", sort=False)
.agg(entity_ids=("id", "unique"))
.reset_index()
.rename(columns={"text_unit_ids": "id"})
)


def _relationships(table):
selected = cast(Table, table[["id", "text_unit_ids"]])
unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
aggregated = aggregate_df(
unrolled,
[
{
"column": "id",
"operation": "array_agg_distinct",
"to": "relationship_ids",
},
{
"column": "text_unit_ids",
"operation": "any",
"to": "id",
},
],
["text_unit_ids"],
def _relationships(df: pd.DataFrame) -> pd.DataFrame:
selected = df.loc[:, ["id", "text_unit_ids"]]
unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)

return (
unrolled.groupby("text_unit_ids", sort=False)
.agg(relationship_ids=("id", "unique"))
.reset_index()
.rename(columns={"text_unit_ids": "id"})
)
return aggregated[["id", "relationship_ids"]]


def _covariates(table):
selected = cast(Table, table[["id", "text_unit_id"]])
return aggregate_df(
selected,
[
{
"column": "id",
"operation": "array_agg_distinct",
"to": "covariate_ids",
},
{
"column": "text_unit_id",
"operation": "any",
"to": "id",
},
],
["text_unit_id"],


def _covariates(df: pd.DataFrame) -> pd.DataFrame:
selected = df.loc[:, ["id", "text_unit_id"]]

return (
selected.groupby("text_unit_id", sort=False)
.agg(covariate_ids=("id", "unique"))
.reset_index()
.rename(columns={"text_unit_id": "id"})
)


Expand Down

0 comments on commit be7d3eb

Please sign in to comment.