diff --git a/.semversioner/next-release/patch-20240904161252783119.json b/.semversioner/next-release/patch-20240904161252783119.json new file mode 100644 index 0000000000..1bde7b3e8c --- /dev/null +++ b/.semversioner/next-release/patch-20240904161252783119.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "add querying from azure blob storage" +} diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 3ac76f81c8..872a4d0993 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -9,8 +9,14 @@ import pandas as pd -from graphrag.config import load_config, resolve_path +from graphrag.config import ( + GraphRagConfig, + load_config, + resolve_path, +) +from graphrag.index.create_pipeline_config import create_pipeline_config from graphrag.index.progress import PrintProgressReporter +from graphrag.utils.storage import _create_storage, _load_table_from_storage from . import api @@ -36,17 +42,21 @@ def run_global_search( if data_dir: config.storage.base_dir = str(resolve_path(data_dir, root)) - data_path = Path(config.storage.base_dir).resolve() - - final_nodes: pd.DataFrame = pd.read_parquet( - data_path / "create_final_nodes.parquet" - ) - final_entities: pd.DataFrame = pd.read_parquet( - data_path / "create_final_entities.parquet" - ) - final_community_reports: pd.DataFrame = pd.read_parquet( - data_path / "create_final_community_reports.parquet" + dataframe_dict = _resolve_parquet_files( + root_dir=root_dir, + config=config, + parquet_list=[ + "create_final_nodes.parquet", + "create_final_entities.parquet", + "create_final_community_reports.parquet", + ], + optional_list=[], ) + final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] + final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] + final_community_reports: pd.DataFrame = dataframe_dict[ + "create_final_community_reports" + ] # call the Query API if streaming: @@ -112,23 +122,26 @@ def run_local_search( if data_dir: config.storage.base_dir = str(resolve_path(data_dir, root)) - data_path = Path(config.storage.base_dir).resolve() - - final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") - final_community_reports = pd.read_parquet( - data_path / "create_final_community_reports.parquet" - ) - final_text_units = pd.read_parquet(data_path / "create_final_text_units.parquet") - final_relationships = pd.read_parquet( - data_path / "create_final_relationships.parquet" - ) - final_entities = pd.read_parquet(data_path / "create_final_entities.parquet") - final_covariates_path = data_path / "create_final_covariates.parquet" - final_covariates = ( - pd.read_parquet(final_covariates_path) - if final_covariates_path.exists() - else None + dataframe_dict = _resolve_parquet_files( + root_dir=root_dir, + config=config, + parquet_list=[ + "create_final_nodes.parquet", + "create_final_community_reports.parquet", + "create_final_text_units.parquet", + "create_final_relationships.parquet", + "create_final_entities.parquet", + ], + optional_list=["create_final_covariates.parquet"], ) + final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] + final_community_reports: pd.DataFrame = dataframe_dict[ + "create_final_community_reports" + ] + final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"] + final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"] + final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] + final_covariates: pd.DataFrame | None = dataframe_dict["create_final_covariates"] # call the Query API if streaming: @@ -179,3 +192,35 @@ async def run_streaming_search(): # NOTE: we return the response and context data here purely as a complete demonstration of the API. # External users should use the API directly to get the response and context data. return response, context_data + + +def _resolve_parquet_files( + root_dir: str, + config: GraphRagConfig, + parquet_list: list[str], + optional_list: list[str], +) -> dict[str, pd.DataFrame]: + """Read parquet files to a dataframe dict.""" + dataframe_dict = {} + pipeline_config = create_pipeline_config(config) + storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage) + for parquet_file in parquet_list: + df_key = parquet_file.split(".")[0] + df_value = asyncio.run( + _load_table_from_storage(name=parquet_file, storage=storage_obj) + ) + dataframe_dict[df_key] = df_value + + # for optional parquet files, set the dict entry to None instead of erroring out if it does not exist + for optional_file in optional_list: + file_exists = asyncio.run(storage_obj.has(optional_file)) + df_key = optional_file.split(".")[0] + if file_exists: + df_value = asyncio.run( + _load_table_from_storage(name=optional_file, storage=storage_obj) + ) + dataframe_dict[df_key] = df_value + else: + dataframe_dict[df_key] = None + + return dataframe_dict