Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New workflow to generate embeddings in a single workflow #1296

Merged
merged 56 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
24c830b
New workflow to generate embeddings in a single workflow
gaudyb Oct 18, 2024
0d54218
New workflow to generate embeddings in a single workflow
gaudyb Oct 18, 2024
411b679
version change
gaudyb Oct 18, 2024
00690cd
clean tests without any embeddings references
gaudyb Oct 18, 2024
922a7cf
clean tests without any embeddings references
gaudyb Oct 18, 2024
0a13615
remove code
gaudyb Oct 18, 2024
cbe94c1
Merge branch 'main' into new_workflow
AlonsoGuevara Oct 21, 2024
977d025
Merge branch 'main' into new_workflow
AlonsoGuevara Oct 21, 2024
9dd9fc8
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 24, 2024
efd926b
feedback implemented
gaudyb Oct 24, 2024
a889c3e
merge conflict fixed
gaudyb Oct 24, 2024
8c07704
changes in logic
gaudyb Oct 25, 2024
a646e2c
feedback implemented
gaudyb Oct 25, 2024
6cc1481
store in table bug fixed
gaudyb Oct 25, 2024
218787e
smoke test for generate_text_embeddings workflow
gaudyb Oct 26, 2024
fda481e
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 26, 2024
a958332
smoke test fix
gaudyb Oct 26, 2024
106dcd9
add generate_text_embeddings to the list of transient workflows
gaudyb Oct 28, 2024
70c693b
smoke tests
gaudyb Oct 29, 2024
d422b83
fix
gaudyb Oct 29, 2024
4f3925c
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 29, 2024
3f3890d
Merge branch 'main' into new_workflow
jgbradley1 Oct 29, 2024
4a18ce7
ruff formatting updates
jgbradley1 Oct 29, 2024
05d205a
fix
gaudyb Oct 29, 2024
7a3925c
smoke test fixed
gaudyb Oct 29, 2024
7ba505d
smoke test fixed
gaudyb Oct 29, 2024
e789930
fix lancedb import
gaudyb Oct 29, 2024
e57faa6
smoke test fix
gaudyb Oct 29, 2024
f61ebe4
ignore sorting
gaudyb Oct 29, 2024
a9c0a73
smoke test fixed
gaudyb Oct 29, 2024
5c0359b
smoke test fixed
gaudyb Oct 29, 2024
054b6a2
check smoke test
gaudyb Oct 29, 2024
38db420
smoke test fixed
gaudyb Oct 29, 2024
3a4a04a
change config for vector store
gaudyb Oct 30, 2024
3829cf0
format fix
gaudyb Oct 30, 2024
3a2e7da
vector store changes
gaudyb Oct 30, 2024
76925e2
revert debug profile back to empty filepath
jgbradley1 Oct 30, 2024
7d600ed
merge conflict solved
gaudyb Oct 30, 2024
399393d
merge conflict solved
gaudyb Oct 30, 2024
b2343f7
format fixed
gaudyb Oct 30, 2024
b65cfd7
merge conflict solved
gaudyb Oct 30, 2024
7db6667
format fixed
gaudyb Oct 30, 2024
7c86c6e
fix return dataframe
gaudyb Oct 30, 2024
2c10a7f
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 31, 2024
3eca0b6
snapshot fix
gaudyb Oct 31, 2024
46a27c2
format fix
gaudyb Oct 31, 2024
0931df8
embeddings param implemented
gaudyb Oct 31, 2024
d4e6d1e
validation fixes
gaudyb Oct 31, 2024
dd11006
fix map
gaudyb Oct 31, 2024
dff8839
fix map
gaudyb Oct 31, 2024
4a2211a
fix properties
gaudyb Oct 31, 2024
9fed605
config updates
gaudyb Oct 31, 2024
cab502b
smoke test fixed
gaudyb Oct 31, 2024
8a26964
settings change
gaudyb Oct 31, 2024
05ab9cb
Update collection config and rework back-compat
natoverse Nov 1, 2024
bd1c99a
Repalce . with - for embedding store
natoverse Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20241018204541069382.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "embeddings moved to a different workflow"
}
1 change: 1 addition & 0 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class TextEmbeddingTarget(str, Enum):

all = "all"
required = "required"
none = "none"
AlonsoGuevara marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
"""Get a string representation."""
Expand Down
22 changes: 22 additions & 0 deletions graphrag/index/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from .embeddings import (
all_embeddings,
community_full_content_embedding,
community_summary_embedding,
community_title_embedding,
document_raw_content_embedding,
entity_description_embedding,
entity_name_embedding,
relationship_description_embedding,
required_embeddings,
text_unit_text_embedding,
)
from .input import (
PipelineCSVInputConfig,
PipelineInputConfig,
Expand Down Expand Up @@ -66,4 +78,14 @@
"PipelineWorkflowConfig",
"PipelineWorkflowReference",
"PipelineWorkflowStep",
"all_embeddings",
"community_full_content_embedding",
"community_summary_embedding",
"community_title_embedding",
"document_raw_content_embedding",
"entity_description_embedding",
"entity_name_embedding",
"relationship_description_embedding",
"required_embeddings",
"text_unit_text_embedding",
]
25 changes: 25 additions & 0 deletions graphrag/index/config/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing embeddings values."""

entity_name_embedding = "entity.name"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
document_raw_content_embedding = "document.raw_content"
community_title_embedding = "community.title"
community_summary_embedding = "community.summary"
community_full_content_embedding = "community.full_content"
text_unit_text_embedding = "text_unit.text"

all_embeddings: set[str] = {
entity_name_embedding,
entity_description_embedding,
relationship_description_embedding,
document_raw_content_embedding,
community_title_embedding,
community_summary_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {entity_description_embedding}
161 changes: 39 additions & 122 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from graphrag.index.config.embeddings import (
all_embeddings,
required_embeddings,
)
from graphrag.index.config.input import (
PipelineCSVInputConfig,
PipelineInputConfigTypes,
Expand Down Expand Up @@ -55,6 +59,7 @@
create_final_community_reports,
create_final_covariates,
create_final_documents,
create_final_embeddings,
create_final_entities,
create_final_nodes,
create_final_relationships,
Expand All @@ -63,29 +68,6 @@

log = logging.getLogger(__name__)


entity_name_embedding = "entity.name"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
document_raw_content_embedding = "document.raw_content"
community_title_embedding = "community.title"
community_summary_embedding = "community.summary"
community_full_content_embedding = "community.full_content"
text_unit_text_embedding = "text_unit.text"

all_embeddings: set[str] = {
entity_name_embedding,
entity_description_embedding,
relationship_description_embedding,
document_raw_content_embedding,
community_title_embedding,
community_summary_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {entity_description_embedding}


builtin_document_attributes: set[str] = {
"id",
"source",
Expand Down Expand Up @@ -121,11 +103,12 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
storage=_get_storage_config(settings),
cache=_get_cache_config(settings),
workflows=[
*_document_workflows(settings, embedded_fields),
*_text_unit_workflows(settings, covariates_enabled, embedded_fields),
*_graph_workflows(settings, embedded_fields),
*_community_workflows(settings, covariates_enabled, embedded_fields),
*_document_workflows(settings),
*_text_unit_workflows(settings, covariates_enabled),
*_graph_workflows(settings),
*_community_workflows(settings, covariates_enabled),
*(_covariate_workflows(settings) if covariates_enabled else []),
*(_embeddings_workflows(settings, embedded_fields)),
],
)

Expand All @@ -138,9 +121,11 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
def _get_embedded_fields(settings: GraphRagConfig) -> set[str]:
match settings.embeddings.target:
case TextEmbeddingTarget.all:
return all_embeddings - {*settings.embeddings.skip}
return all_embeddings.difference(settings.embeddings.skip)
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embeddings.target}"
raise ValueError(msg)
Expand All @@ -163,11 +148,8 @@ def _log_llm_settings(settings: GraphRagConfig) -> None:


def _document_workflows(
settings: GraphRagConfig, embedded_fields: set[str]
settings: GraphRagConfig,
) -> list[PipelineWorkflowReference]:
skip_document_raw_content_embedding = (
document_raw_content_embedding not in embedded_fields
)
return [
PipelineWorkflowReference(
name=create_final_documents,
Expand All @@ -176,15 +158,6 @@ def _document_workflows(
{*(settings.input.document_attribute_columns)}
- builtin_document_attributes
),
"document_raw_content_embed": _get_embedding_settings(
settings.embeddings,
"document_raw_content",
{
"title_column": "raw_content",
"collection_name": "final_documents_raw_content_embedding",
},
),
"skip_raw_content_embedding": skip_document_raw_content_embedding,
},
),
]
Expand All @@ -193,9 +166,7 @@ def _document_workflows(
def _text_unit_workflows(
settings: GraphRagConfig,
covariates_enabled: bool,
embedded_fields: set[str],
) -> list[PipelineWorkflowReference]:
skip_text_unit_embedding = text_unit_text_embedding not in embedded_fields
return [
PipelineWorkflowReference(
name=create_base_text_units,
Expand All @@ -211,13 +182,7 @@ def _text_unit_workflows(
PipelineWorkflowReference(
name=create_final_text_units,
config={
"text_unit_text_embed": _get_embedding_settings(
settings.embeddings,
"text_unit_text",
{"title_column": "text", "collection_name": "text_units_embedding"},
),
"covariates_enabled": covariates_enabled,
"skip_text_unit_embedding": skip_text_unit_embedding,
},
),
]
Expand Down Expand Up @@ -247,16 +212,7 @@ def _get_embedding_settings(
}


def _graph_workflows(
settings: GraphRagConfig, embedded_fields: set[str]
) -> list[PipelineWorkflowReference]:
skip_entity_name_embedding = entity_name_embedding not in embedded_fields
skip_entity_description_embedding = (
entity_description_embedding not in embedded_fields
)
skip_relationship_description_embedding = (
relationship_description_embedding not in embedded_fields
)
def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=create_base_entity_graph,
Expand Down Expand Up @@ -286,40 +242,11 @@ def _graph_workflows(
),
PipelineWorkflowReference(
name=create_final_entities,
config={
"entity_name_embed": _get_embedding_settings(
settings.embeddings,
"entity_name",
{
"title_column": "name",
"collection_name": "entity_name_embeddings",
},
),
"entity_name_description_embed": _get_embedding_settings(
settings.embeddings,
"entity_name_description",
{
"title_column": "description",
"collection_name": "entity_description_embeddings",
},
),
"skip_name_embedding": skip_entity_name_embedding,
"skip_description_embedding": skip_entity_description_embedding,
},
config={},
),
PipelineWorkflowReference(
name=create_final_relationships,
config={
"relationship_description_embed": _get_embedding_settings(
settings.embeddings,
"relationship_description",
{
"title_column": "description",
"collection_name": "relationships_description_embeddings",
},
),
"skip_description_embedding": skip_relationship_description_embedding,
},
config={},
),
PipelineWorkflowReference(
name=create_final_nodes,
Expand All @@ -332,52 +259,21 @@ def _graph_workflows(


def _community_workflows(
settings: GraphRagConfig, covariates_enabled: bool, embedded_fields: set[str]
settings: GraphRagConfig, covariates_enabled: bool
) -> list[PipelineWorkflowReference]:
skip_community_title_embedding = community_title_embedding not in embedded_fields
skip_community_summary_embedding = (
community_summary_embedding not in embedded_fields
)
skip_community_full_content_embedding = (
community_full_content_embedding not in embedded_fields
)
return [
PipelineWorkflowReference(name=create_final_communities),
PipelineWorkflowReference(
name=create_final_community_reports,
config={
"covariates_enabled": covariates_enabled,
"skip_title_embedding": skip_community_title_embedding,
"skip_summary_embedding": skip_community_summary_embedding,
"skip_full_content_embedding": skip_community_full_content_embedding,
"create_community_reports": {
**settings.community_reports.parallelization.model_dump(),
"async_mode": settings.community_reports.async_mode,
"strategy": settings.community_reports.resolved_strategy(
settings.root_dir
),
},
"community_report_full_content_embed": _get_embedding_settings(
settings.embeddings,
"community_report_full_content",
{
"title_column": "full_content",
"collection_name": "final_community_reports_full_content_embedding",
},
),
"community_report_summary_embed": _get_embedding_settings(
settings.embeddings,
"community_report_summary",
{
"title_column": "summary",
"collection_name": "final_community_reports_summary_embedding",
},
),
"community_report_title_embed": _get_embedding_settings(
settings.embeddings,
"community_report_title",
{"title_column": "title"},
),
},
),
]
Expand All @@ -401,6 +297,27 @@ def _covariate_workflows(
]


def _embeddings_workflows(
settings: GraphRagConfig, embedded_fields: set[str]
) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=create_final_embeddings,
config={
"text_embed": _get_embedding_settings(
settings.embeddings,
"column_to_embed",
{
"title_column": "column_to_embed",
"collection_name": "embedding",
},
),
"embedded_fields": embedded_fields,
},
),
]


def _get_pipeline_input_config(
settings: GraphRagConfig,
) -> PipelineInputConfigTypes:
Expand Down
Loading
Loading