From dd711359955e97b2338b4c372a263764753c317d Mon Sep 17 00:00:00 2001 From: KennyZhang1 <90438893+KennyZhang1@users.noreply.github.com> Date: Thu, 22 Aug 2024 13:39:55 -0400 Subject: [PATCH] Change lancedb placement (#996) * changed placement of lancedb dir to under /artifacts * ruff checks and semversioner * added support for static paths * added support for streaming * more ruff changes * ruff format changes * removed string concat for path formation * added more ruff checks * removed os.join usage * more ruff fixes and removed unneccesary path creations * replaced cast calls with str() --------- Co-authored-by: Kenny Zhang --- .gitignore | 1 + .../next-release/patch-20240821135138469990.json | 4 ++++ graphrag/query/api.py | 14 ++++++++++++++ graphrag/query/cli.py | 9 ++++++--- 4 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 .semversioner/next-release/patch-20240821135138469990.json diff --git a/.gitignore b/.gitignore index bff8e24810..0919e27c7c 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ lancedb/ .DS_Store *.log* .venv +venv/ .conda .tmp diff --git a/.semversioner/next-release/patch-20240821135138469990.json b/.semversioner/next-release/patch-20240821135138469990.json new file mode 100644 index 0000000000..f6494e37f8 --- /dev/null +++ b/.semversioner/next-release/patch-20240821135138469990.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "change-lancedb-placement" +} diff --git a/graphrag/query/api.py b/graphrag/query/api.py index 788de8c0ca..c374260218 100644 --- a/graphrag/query/api.py +++ b/graphrag/query/api.py @@ -18,12 +18,14 @@ """ from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any import pandas as pd from pydantic import validate_call from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.resolve_timestamp_path import resolve_timestamp_path from graphrag.index.progress.types import PrintProgressReporter from graphrag.model.entity import Entity from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 @@ -147,6 +149,7 @@ async def global_search_streaming( @validate_call(config={"arbitrary_types_allowed": True}) async def local_search( + root_dir: str | None, config: GraphRagConfig, nodes: pd.DataFrame, entities: pd.DataFrame, @@ -192,6 +195,11 @@ async def local_search( vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) _entities = read_indexer_entities(nodes, entities, community_level) + + base_dir = Path(str(root_dir)) / config.storage.base_dir + resolved_base_dir = resolve_timestamp_path(base_dir) + lancedb_dir = resolved_base_dir / "lancedb" + vector_store_args.update({"db_uri": str(lancedb_dir)}) description_embedding_store = _get_embedding_description_store( entities=_entities, vector_store_type=vector_store_type, @@ -219,6 +227,7 @@ async def local_search( @validate_call(config={"arbitrary_types_allowed": True}) async def local_search_streaming( + root_dir: str | None, config: GraphRagConfig, nodes: pd.DataFrame, entities: pd.DataFrame, @@ -261,6 +270,11 @@ async def local_search_streaming( vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) _entities = read_indexer_entities(nodes, entities, community_level) + + base_dir = Path(str(root_dir)) / config.storage.base_dir + resolved_base_dir = resolve_timestamp_path(base_dir) + lancedb_dir = resolved_base_dir / "lancedb" + vector_store_args.update({"db_uri": str(lancedb_dir)}) description_embedding_store = _get_embedding_description_store( entities=_entities, vector_store_type=vector_store_type, diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index de625e912a..ed7dc4566a 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -7,7 +7,6 @@ import re import sys from pathlib import Path -from typing import cast import pandas as pd @@ -15,6 +14,7 @@ GraphRagConfig, create_graphrag_config, ) +from graphrag.config.resolve_timestamp_path import resolve_timestamp_path from graphrag.index.progress import PrintProgressReporter from . import api @@ -137,6 +137,7 @@ async def run_streaming_search(): context_data = None get_context_data = True async for stream_chunk in api.local_search_streaming( + root_dir=root_dir, config=config, nodes=final_nodes, entities=final_entities, @@ -162,6 +163,7 @@ async def run_streaming_search(): # not streaming response, context_data = asyncio.run( api.local_search( + root_dir=root_dir, config=config, nodes=final_nodes, entities=final_entities, @@ -185,12 +187,13 @@ def _configure_paths_and_settings( root_dir: str | None, config_filepath: str | None, ) -> tuple[str, str | None, GraphRagConfig]: + config = _create_graphrag_config(root_dir, config_filepath) if data_dir is None and root_dir is None: msg = "Either data_dir or root_dir must be provided." raise ValueError(msg) if data_dir is None: - data_dir = _infer_data_dir(cast(str, root_dir)) - config = _create_graphrag_config(root_dir, config_filepath) + base_dir = Path(str(root_dir)) / config.storage.base_dir + data_dir = str(resolve_timestamp_path(base_dir)) return data_dir, root_dir, config