Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New workflow to generate embeddings in a single workflow #1296

Merged
merged 56 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
24c830b
New workflow to generate embeddings in a single workflow
gaudyb Oct 18, 2024
0d54218
New workflow to generate embeddings in a single workflow
gaudyb Oct 18, 2024
411b679
version change
gaudyb Oct 18, 2024
00690cd
clean tests without any embeddings references
gaudyb Oct 18, 2024
922a7cf
clean tests without any embeddings references
gaudyb Oct 18, 2024
0a13615
remove code
gaudyb Oct 18, 2024
cbe94c1
Merge branch 'main' into new_workflow
AlonsoGuevara Oct 21, 2024
977d025
Merge branch 'main' into new_workflow
AlonsoGuevara Oct 21, 2024
9dd9fc8
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 24, 2024
efd926b
feedback implemented
gaudyb Oct 24, 2024
a889c3e
merge conflict fixed
gaudyb Oct 24, 2024
8c07704
changes in logic
gaudyb Oct 25, 2024
a646e2c
feedback implemented
gaudyb Oct 25, 2024
6cc1481
store in table bug fixed
gaudyb Oct 25, 2024
218787e
smoke test for generate_text_embeddings workflow
gaudyb Oct 26, 2024
fda481e
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 26, 2024
a958332
smoke test fix
gaudyb Oct 26, 2024
106dcd9
add generate_text_embeddings to the list of transient workflows
gaudyb Oct 28, 2024
70c693b
smoke tests
gaudyb Oct 29, 2024
d422b83
fix
gaudyb Oct 29, 2024
4f3925c
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 29, 2024
3f3890d
Merge branch 'main' into new_workflow
jgbradley1 Oct 29, 2024
4a18ce7
ruff formatting updates
jgbradley1 Oct 29, 2024
05d205a
fix
gaudyb Oct 29, 2024
7a3925c
smoke test fixed
gaudyb Oct 29, 2024
7ba505d
smoke test fixed
gaudyb Oct 29, 2024
e789930
fix lancedb import
gaudyb Oct 29, 2024
e57faa6
smoke test fix
gaudyb Oct 29, 2024
f61ebe4
ignore sorting
gaudyb Oct 29, 2024
a9c0a73
smoke test fixed
gaudyb Oct 29, 2024
5c0359b
smoke test fixed
gaudyb Oct 29, 2024
054b6a2
check smoke test
gaudyb Oct 29, 2024
38db420
smoke test fixed
gaudyb Oct 29, 2024
3a4a04a
change config for vector store
gaudyb Oct 30, 2024
3829cf0
format fix
gaudyb Oct 30, 2024
3a2e7da
vector store changes
gaudyb Oct 30, 2024
76925e2
revert debug profile back to empty filepath
jgbradley1 Oct 30, 2024
7d600ed
merge conflict solved
gaudyb Oct 30, 2024
399393d
merge conflict solved
gaudyb Oct 30, 2024
b2343f7
format fixed
gaudyb Oct 30, 2024
b65cfd7
merge conflict solved
gaudyb Oct 30, 2024
7db6667
format fixed
gaudyb Oct 30, 2024
7c86c6e
fix return dataframe
gaudyb Oct 30, 2024
2c10a7f
Merge remote-tracking branch 'origin/main' into new_workflow
gaudyb Oct 31, 2024
3eca0b6
snapshot fix
gaudyb Oct 31, 2024
46a27c2
format fix
gaudyb Oct 31, 2024
0931df8
embeddings param implemented
gaudyb Oct 31, 2024
d4e6d1e
validation fixes
gaudyb Oct 31, 2024
dd11006
fix map
gaudyb Oct 31, 2024
dff8839
fix map
gaudyb Oct 31, 2024
4a2211a
fix properties
gaudyb Oct 31, 2024
9fed605
config updates
gaudyb Oct 31, 2024
cab502b
smoke test fixed
gaudyb Oct 31, 2024
8a26964
settings change
gaudyb Oct 31, 2024
05ab9cb
Update collection config and rework back-compat
natoverse Nov 1, 2024
bd1c99a
Repalce . with - for embedding store
natoverse Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20241018204541069382.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "embeddings moved to a different workflow"
}
6 changes: 3 additions & 3 deletions docs/config/json_yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ This is the base LLM configuration section. Other steps may override this config
- `async_mode` (see Async Mode top-level config)
- `batch_size` **int** - The maximum batch size to use.
- `batch_max_tokens` **int** - The maximum batch # of tokens.
- `target` **required|all** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip.
- `target` **required|all|none** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip. Only useful if target=all to customize the list.
- `vector_store` **dict** - The vector store to use. Configured for lancedb by default.
- `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb`
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
- `url` **str** (only for AI Search) - AI Search endpoint
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
- `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True`
- `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings`
- `container_name` **str** - The name of a vector container. This stores all indexes (tables) for a given dataset ingest. Default=`default`
- `strategy` **dict** - Fully override the text-embedding strategy.

## chunks
Expand Down
2 changes: 1 addition & 1 deletion docs/examples_notebooks/local_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity_description_embeddings\",\n",
" collection_name=\"entity.description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
"\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity_description_embeddings\",\n",
" collection_name=\"entity.description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
Expand Down
27 changes: 20 additions & 7 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,7 @@ async def build_index(
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)

# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
config = _patch_vector_config(config)

pipeline_config = create_pipeline_config(config)
pipeline_cache = (
Expand All @@ -90,3 +84,22 @@ async def build_index(
progress_reporter.success(output.workflow)
progress_reporter.info(str(output.result))
return outputs


def _patch_vector_config(config: GraphRagConfig):
"""Back-compat patch to ensure a default vector store configuration."""
if not config.embeddings.vector_store:
config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": "output/lancedb",
"container_name": "default",
"overwrite": True,
}
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
return config
136 changes: 55 additions & 81 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,56 +182,22 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
#################################### BEGIN PATCH ####################################
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
backwards_compatible = False
if not config.embeddings.vector_store:
backwards_compatible = True
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore

config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"collection_name": "entity_description_embeddings",
"overwrite": True,
}
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name=config.embeddings.vector_store["collection_name"],
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
#################################### END PATCH ####################################
config = _patch_vector_store(config, nodes, entities, community_level)

# TODO: update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore

reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)

description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
Expand Down Expand Up @@ -289,56 +255,22 @@ async def local_search_streaming(
------
TODO: Document any exceptions to expect.
"""
#################################### BEGIN PATCH ####################################
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
backwards_compatible = False
if not config.embeddings.vector_store:
backwards_compatible = True
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore

config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"collection_name": "entity_description_embeddings",
"overwrite": True,
}
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name=config.embeddings.vector_store["collection_name"],
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
#################################### END PATCH ####################################
config = _patch_vector_store(config, nodes, entities, community_level)

# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore

reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)

description_embedding_store = _get_embedding_description_store(
conf_args=vector_store_args, # type: ignore
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
Expand Down Expand Up @@ -368,13 +300,55 @@ async def local_search_streaming(
yield stream_chunk


def _patch_vector_store(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_level: int,
) -> GraphRagConfig:
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
if not config.embeddings.vector_store:
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore

config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"container_name": "default",
"overwrite": True,
}
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name="default-entity-description",
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
_entities = read_indexer_entities(nodes, entities, community_level)
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
return config


def _get_embedding_description_store(
config_args: dict,
):
"""Get the embedding description store."""
vector_store_type = config_args["type"]
collection_name = f"{config_args['container_name']}-entity-description"
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
vector_store_type=vector_store_type,
kwargs={**config_args, "collection_name": collection_name},
)
description_embedding_store.connect(**config_args)
return description_embedding_store
Expand Down
5 changes: 4 additions & 1 deletion graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def run_local_search(
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)

# TODO remove optional create_final_entities_description_embeddings.parquet to delete backwards compatibility
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
Expand All @@ -125,7 +126,9 @@ def run_local_search(
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"],
optional_list=[
"create_final_covariates.parquet",
],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[
Expand Down
1 change: 1 addition & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def hydrate_parallelization_params(
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
top_level_nodes=reader.bool("top_level_nodes")
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
)
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
umap_model = UmapConfig(
Expand Down
3 changes: 2 additions & 1 deletion graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
SNAPSHOTS_EMBEDDINGS = False
STORAGE_BASE_DIR = "output"
STORAGE_TYPE = StorageType.file
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
Expand All @@ -91,7 +92,7 @@
VECTOR_STORE = f"""
type: {VectorStoreType.LanceDB.value}
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
collection_name: entity_description_embeddings
collection_name: default
overwrite: true\
"""

Expand Down
1 change: 1 addition & 0 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class TextEmbeddingTarget(str, Enum):

all = "all"
required = "required"
none = "none"
AlonsoGuevara marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
"""Get a string representation."""
Expand Down
4 changes: 4 additions & 0 deletions graphrag/config/models/snapshots_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ class SnapshotsConfig(BaseModel):
description="A flag indicating whether to take snapshots of top-level nodes.",
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
)
embeddings: bool = Field(
description="A flag indicating whether to take snapshots of embeddings.",
default=defs.SNAPSHOTS_EMBEDDINGS,
)
22 changes: 22 additions & 0 deletions graphrag/index/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from .embeddings import (
all_embeddings,
community_full_content_embedding,
community_summary_embedding,
community_title_embedding,
document_raw_content_embedding,
entity_description_embedding,
entity_name_embedding,
relationship_description_embedding,
required_embeddings,
text_unit_text_embedding,
)
from .input import (
PipelineCSVInputConfig,
PipelineInputConfig,
Expand Down Expand Up @@ -66,4 +78,14 @@
"PipelineWorkflowConfig",
"PipelineWorkflowReference",
"PipelineWorkflowStep",
"all_embeddings",
"community_full_content_embedding",
"community_summary_embedding",
"community_title_embedding",
"document_raw_content_embedding",
"entity_description_embedding",
"entity_name_embedding",
"relationship_description_embedding",
"required_embeddings",
"text_unit_text_embedding",
]
25 changes: 25 additions & 0 deletions graphrag/index/config/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing embeddings values."""

entity_name_embedding = "entity.name"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
document_raw_content_embedding = "document.raw_content"
community_title_embedding = "community.title"
community_summary_embedding = "community.summary"
community_full_content_embedding = "community.full_content"
text_unit_text_embedding = "text_unit.text"

all_embeddings: set[str] = {
entity_name_embedding,
entity_description_embedding,
relationship_description_embedding,
document_raw_content_embedding,
community_title_embedding,
community_summary_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {entity_description_embedding}
Loading
Loading