Skip to content

Commit

Permalink
Load query from blob (#1095)
Browse files Browse the repository at this point in the history
* Moved query loading from file to helper function

* added loading parquets from blob to function

* resolved adlfs async error

* debugging cleanup and small fixes

* added connection string support

* semversioner and ruff fixes

* completed testing for merge with main

* more ruff changes

* fixed unbound vars warning

* rewrote function to use storage utils

* removed unused vars

---------

Co-authored-by: Kenny Zhang <[email protected]>
  • Loading branch information
KennyZhang1 and Kenny Zhang authored Sep 5, 2024
1 parent 044516f commit 27c5468
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 27 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240904161252783119.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "add querying from azure blob storage"
}
99 changes: 72 additions & 27 deletions graphrag/query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@

import pandas as pd

from graphrag.config import load_config, resolve_path
from graphrag.config import (
GraphRagConfig,
load_config,
resolve_path,
)
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.progress import PrintProgressReporter
from graphrag.utils.storage import _create_storage, _load_table_from_storage

from . import api

Expand All @@ -36,17 +42,21 @@ def run_global_search(
if data_dir:
config.storage.base_dir = str(resolve_path(data_dir, root))

data_path = Path(config.storage.base_dir).resolve()

final_nodes: pd.DataFrame = pd.read_parquet(
data_path / "create_final_nodes.parquet"
)
final_entities: pd.DataFrame = pd.read_parquet(
data_path / "create_final_entities.parquet"
)
final_community_reports: pd.DataFrame = pd.read_parquet(
data_path / "create_final_community_reports.parquet"
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_community_reports.parquet",
],
optional_list=[],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]

# call the Query API
if streaming:
Expand Down Expand Up @@ -112,23 +122,26 @@ def run_local_search(
if data_dir:
config.storage.base_dir = str(resolve_path(data_dir, root))

data_path = Path(config.storage.base_dir).resolve()

final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
final_community_reports = pd.read_parquet(
data_path / "create_final_community_reports.parquet"
)
final_text_units = pd.read_parquet(data_path / "create_final_text_units.parquet")
final_relationships = pd.read_parquet(
data_path / "create_final_relationships.parquet"
)
final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
final_covariates_path = data_path / "create_final_covariates.parquet"
final_covariates = (
pd.read_parquet(final_covariates_path)
if final_covariates_path.exists()
else None
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_covariates: pd.DataFrame | None = dataframe_dict["create_final_covariates"]

# call the Query API
if streaming:
Expand Down Expand Up @@ -179,3 +192,35 @@ async def run_streaming_search():
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data


def _resolve_parquet_files(
root_dir: str,
config: GraphRagConfig,
parquet_list: list[str],
optional_list: list[str],
) -> dict[str, pd.DataFrame]:
"""Read parquet files to a dataframe dict."""
dataframe_dict = {}
pipeline_config = create_pipeline_config(config)
storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage)
for parquet_file in parquet_list:
df_key = parquet_file.split(".")[0]
df_value = asyncio.run(
_load_table_from_storage(name=parquet_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value

# for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
for optional_file in optional_list:
file_exists = asyncio.run(storage_obj.has(optional_file))
df_key = optional_file.split(".")[0]
if file_exists:
df_value = asyncio.run(
_load_table_from_storage(name=optional_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
else:
dataframe_dict[df_key] = None

return dataframe_dict

0 comments on commit 27c5468

Please sign in to comment.