Skip to content

Commit

Permalink
Change lancedb placement (#996)
Browse files Browse the repository at this point in the history
* changed placement of lancedb dir to under /artifacts

* ruff checks and semversioner

* added support for static paths

* added support for streaming

* more ruff changes

* ruff format changes

* removed string concat for path formation

* added more ruff checks

* removed os.join usage

* more ruff fixes and removed unneccesary path creations

* replaced cast calls with str()

---------

Co-authored-by: Kenny Zhang <[email protected]>
  • Loading branch information
KennyZhang1 and Kenny Zhang authored Aug 22, 2024
1 parent 4b9fdc0 commit dd71135
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ lancedb/
.DS_Store
*.log*
.venv
venv/
.conda
.tmp

Expand Down
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240821135138469990.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "change-lancedb-placement"
}
14 changes: 14 additions & 0 deletions graphrag/query/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
"""

from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any

import pandas as pd
from pydantic import validate_call

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.index.progress.types import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
Expand Down Expand Up @@ -147,6 +149,7 @@ async def global_search_streaming(

@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
root_dir: str | None,
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
Expand Down Expand Up @@ -192,6 +195,11 @@ async def local_search(
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

_entities = read_indexer_entities(nodes, entities, community_level)

base_dir = Path(str(root_dir)) / config.storage.base_dir
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
Expand Down Expand Up @@ -219,6 +227,7 @@ async def local_search(

@validate_call(config={"arbitrary_types_allowed": True})
async def local_search_streaming(
root_dir: str | None,
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
Expand Down Expand Up @@ -261,6 +270,11 @@ async def local_search_streaming(
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

_entities = read_indexer_entities(nodes, entities, community_level)

base_dir = Path(str(root_dir)) / config.storage.base_dir
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
Expand Down
9 changes: 6 additions & 3 deletions graphrag/query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import re
import sys
from pathlib import Path
from typing import cast

import pandas as pd

from graphrag.config import (
GraphRagConfig,
create_graphrag_config,
)
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.index.progress import PrintProgressReporter

from . import api
Expand Down Expand Up @@ -137,6 +137,7 @@ async def run_streaming_search():
context_data = None
get_context_data = True
async for stream_chunk in api.local_search_streaming(
root_dir=root_dir,
config=config,
nodes=final_nodes,
entities=final_entities,
Expand All @@ -162,6 +163,7 @@ async def run_streaming_search():
# not streaming
response, context_data = asyncio.run(
api.local_search(
root_dir=root_dir,
config=config,
nodes=final_nodes,
entities=final_entities,
Expand All @@ -185,12 +187,13 @@ def _configure_paths_and_settings(
root_dir: str | None,
config_filepath: str | None,
) -> tuple[str, str | None, GraphRagConfig]:
config = _create_graphrag_config(root_dir, config_filepath)
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
data_dir = _infer_data_dir(cast(str, root_dir))
config = _create_graphrag_config(root_dir, config_filepath)
base_dir = Path(str(root_dir)) / config.storage.base_dir
data_dir = str(resolve_timestamp_path(base_dir))
return data_dir, root_dir, config


Expand Down

0 comments on commit dd71135

Please sign in to comment.