Skip to content

Commit

Permalink
smoke test for generate_text_embeddings workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
gaudyb committed Oct 26, 2024
1 parent 6cc1481 commit 218787e
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 21 deletions.
38 changes: 19 additions & 19 deletions graphrag/index/flows/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

log = logging.getLogger(__name__)


async def generate_text_embeddings(
final_documents: pd.DataFrame,
final_relationships: pd.DataFrame,
Expand Down Expand Up @@ -56,45 +57,44 @@ async def generate_text_embeddings(
embedding_param_map = {
document_raw_content_embedding: {
"data": documents_embeddings,
"column_to_embed":"raw_content",
"column_to_embed": "raw_content",
"filename": "create_final_documents_raw_content_embeddings",
},
relationship_description_embedding:{
relationship_description_embedding: {
"data": relationships_embeddings,
"column_to_embed":"description",
"column_to_embed": "description",
"filename": "create_final_relationships_description_embeddings",
},
text_unit_text_embedding:{
text_unit_text_embedding: {
"data": text_units_embeddings,
"column_to_embed":"text",
"column_to_embed": "text",
"filename": "create_final_text_units_text_embeddings",
},
entity_name_embedding:{
entity_name_embedding: {
"data": entities_embeddings,
"column_to_embed":"name",
"column_to_embed": "name",
"filename": "create_final_entities_name_embeddings",
},
entity_description_embedding:{
entity_description_embedding: {
"data": entities_embeddings,
"column_to_embed":"name_description",
"column_to_embed": "name_description",
"filename": "create_final_entities_description_embeddings",
},
community_title_embedding:{
community_title_embedding: {
"data": community_reports_embeddings,
"column_to_embed":"title",
"column_to_embed": "title",
"filename": "create_final_community_reports_title_embeddings",
},
community_summary_embedding:{
community_summary_embedding: {
"data": community_reports_embeddings,
"column_to_embed":"summary",
"column_to_embed": "summary",
"filename": "create_final_community_reports_summary_embeddings",
},
community_full_content_embedding:{
community_full_content_embedding: {
"data": community_reports_embeddings,
"column_to_embed":"full_content",
"column_to_embed": "full_content",
"filename": "create_final_community_reports_full_content_embeddings",
},

}

if base_text_embed:
Expand All @@ -105,7 +105,7 @@ async def generate_text_embeddings(
cache=cache,
storage=storage,
base_text_embed=base_text_embed,
**embedding_param_map[field]
**embedding_param_map[field],
)

return pd.DataFrame()
Expand All @@ -127,9 +127,9 @@ async def _run_and_snapshot_embeddings(
"collection_name": filename,
"store_in_table": True,
}

base_text_embed["strategy"]["vector_store"].update(new_vector_store)

data["embedding"] = await embed_text(
data,
callbacks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def generate_text_embeddings(
cache: PipelineCache,
storage: PipelineStorage,
embedded_fields: set[str],
base_text_embed: dict | None = None,
base_text_embed: dict,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to generate embeddings."""
Expand Down
2 changes: 1 addition & 1 deletion graphrag/vector_stores/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import json
from typing import Any

import lancedb as lancedb
import pyarrow as pa

import lancedb as lancedb
from graphrag.model.types import TextEmbedder

from .base import (
Expand Down
164 changes: 164 additions & 0 deletions tests/verbs/test_generate_text_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from io import BytesIO

import pandas as pd

from graphrag.index.config.embeddings import (
all_embeddings,
)
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.generate_text_embeddings import (
build_steps,
workflow_name,
)

from .util import (
get_config_for_workflow,
get_workflow_output,
load_input_tables,
)


async def test_generate_text_embeddings():
input_tables = load_input_tables(
inputs=[
"workflow:create_final_documents",
"workflow:create_final_relationships",
"workflow:create_final_text_units",
"workflow:create_final_entities",
"workflow:create_final_community_reports",
]
)
context = create_run_context(None, None, None)

config = get_config_for_workflow(workflow_name)

config["text_embed"]["strategy"]["type"] = "mock"
config["embedded_fields"] = all_embeddings
config["text_embed"]["strategy"]["vector_store"] = {
"type": "lancedb",
"db_uri": "./lancedb",
"store_in_table": True,
}

steps = build_steps(config)

await get_workflow_output(
input_tables,
{
"steps": steps,
},
context,
)

parquet_files = context.storage.keys()
assert "create_final_documents_raw_content_embeddings.parquet" in parquet_files
assert "create_final_relationships_description_embeddings.parquet" in parquet_files
assert "create_final_text_units_text_embeddings.parquet" in parquet_files
assert "create_final_entities_name_embeddings.parquet" in parquet_files
assert "create_final_entities_description_embeddings.parquet" in parquet_files
assert "create_final_community_reports_title_embeddings.parquet" in parquet_files
assert "create_final_community_reports_summary_embeddings.parquet" in parquet_files
assert (
"create_final_community_reports_full_content_embeddings.parquet"
in parquet_files
)

create_final_documents_raw_content_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_documents_raw_content_embeddings.parquet", as_bytes=True
)
)
create_final_documents_raw_content_embeddings = pd.read_parquet(
create_final_documents_raw_content_embeddings_buffer
)
assert len(create_final_documents_raw_content_embeddings.columns) == 2
assert "id" in create_final_documents_raw_content_embeddings.columns
assert "embedding" in create_final_documents_raw_content_embeddings.columns

create_final_relationships_description_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_relationships_description_embeddings.parquet", as_bytes=True
)
)
create_final_relationships_description_embeddings = pd.read_parquet(
create_final_relationships_description_embeddings_buffer
)
assert len(create_final_relationships_description_embeddings.columns) == 2
assert "id" in create_final_relationships_description_embeddings.columns
assert "embedding" in create_final_relationships_description_embeddings.columns

create_final_text_units_text_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_text_units_text_embeddings.parquet", as_bytes=True
)
)
create_final_text_units_text_embeddings = pd.read_parquet(
create_final_text_units_text_embeddings_buffer
)
assert len(create_final_text_units_text_embeddings.columns) == 2
assert "id" in create_final_text_units_text_embeddings.columns
assert "embedding" in create_final_text_units_text_embeddings.columns

create_final_entities_name_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_entities_name_embeddings.parquet", as_bytes=True
)
)
create_final_entities_name_embeddings = pd.read_parquet(
create_final_entities_name_embeddings_buffer
)
assert len(create_final_entities_name_embeddings.columns) == 2
assert "id" in create_final_entities_name_embeddings.columns
assert "embedding" in create_final_entities_name_embeddings.columns

create_final_entities_description_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_entities_description_embeddings.parquet", as_bytes=True
)
)
create_final_entities_description_embeddings = pd.read_parquet(
create_final_entities_description_embeddings_buffer
)
assert len(create_final_entities_description_embeddings.columns) == 2
assert "id" in create_final_entities_description_embeddings.columns
assert "embedding" in create_final_entities_description_embeddings.columns

create_final_community_reports_title_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_community_reports_title_embeddings.parquet", as_bytes=True
)
)
create_final_community_reports_title_embeddings = pd.read_parquet(
create_final_community_reports_title_embeddings_buffer
)
assert len(create_final_community_reports_title_embeddings.columns) == 2
assert "id" in create_final_community_reports_title_embeddings.columns
assert "embedding" in create_final_community_reports_title_embeddings.columns

create_final_community_reports_summary_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_community_reports_summary_embeddings.parquet", as_bytes=True
)
)
create_final_community_reports_summary_embeddings = pd.read_parquet(
create_final_community_reports_summary_embeddings_buffer
)
assert len(create_final_community_reports_summary_embeddings.columns) == 2
assert "id" in create_final_community_reports_summary_embeddings.columns
assert "embedding" in create_final_community_reports_summary_embeddings.columns

create_final_community_reports_full_content_embeddings_buffer = BytesIO(
await context.storage.get(
"create_final_community_reports_full_content_embeddings.parquet",
as_bytes=True,
)
)
create_final_community_reports_full_content_embeddings = pd.read_parquet(
create_final_community_reports_full_content_embeddings_buffer
)
assert len(create_final_community_reports_full_content_embeddings.columns) == 2
assert "id" in create_final_community_reports_full_content_embeddings.columns
assert "embedding" in create_final_community_reports_full_content_embeddings.columns

0 comments on commit 218787e

Please sign in to comment.