diff --git a/daft/io/_delta_lake.py b/daft/io/_delta_lake.py index a3c60e8c8e..c58b7d0560 100644 --- a/daft/io/_delta_lake.py +++ b/daft/io/_delta_lake.py @@ -10,6 +10,12 @@ from daft.io.catalog import DataCatalogTable from daft.logical.builder import LogicalPlanBuilder +_UNITY_CATALOG_AVAILABLE = True +try: + from daft.unity_catalog import UnityCatalogTable +except ImportError: + _UNITY_CATALOG_AVAILABLE = False + def read_delta_lake( table: Union[str, DataCatalogTable], @@ -25,7 +31,7 @@ def read_delta_lake( @PublicAPI def read_deltalake( - table: Union[str, DataCatalogTable], + table: Union[str, DataCatalogTable, "UnityCatalogTable"], io_config: Optional["IOConfig"] = None, _multithreaded_io: Optional[bool] = None, ) -> DataFrame: @@ -56,20 +62,27 @@ def read_deltalake( """ from daft.delta_lake.delta_lake_scan import DeltaLakeScanOperator - io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config - # If running on Ray, we want to limit the amount of concurrency and requests being made. # This is because each Ray worker process receives its own pool of thread workers and connections multithreaded_io = not context.get_context().is_ray_runner if _multithreaded_io is None else _multithreaded_io + + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) if isinstance(table, str): table_uri = table elif isinstance(table, DataCatalogTable): table_uri = table.table_uri(io_config) + elif _UNITY_CATALOG_AVAILABLE and isinstance(table, UnityCatalogTable): + table_uri = table.table_uri + + # Override the storage_config with the one provided by Unity catalog + table_io_config = table.io_config + if table_io_config is not None: + storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, table_io_config)) else: raise ValueError( - f"table argument must be a table URI string or a DataCatalogTable instance, but got: {type(table)}, {table}" + f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}" ) delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config) diff --git a/daft/unity_catalog/__init__.py b/daft/unity_catalog/__init__.py new file mode 100644 index 0000000000..ce559ec33c --- /dev/null +++ b/daft/unity_catalog/__init__.py @@ -0,0 +1,3 @@ +from .unity_catalog import UnityCatalog, UnityCatalogTable + +__all__ = ["UnityCatalog", "UnityCatalogTable"] diff --git a/daft/unity_catalog/unity_catalog.py b/daft/unity_catalog/unity_catalog.py new file mode 100644 index 0000000000..417facfba7 --- /dev/null +++ b/daft/unity_catalog/unity_catalog.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import dataclasses +from typing import Callable + +import unitycatalog + +from daft.io import IOConfig, S3Config + + +@dataclasses.dataclass(frozen=True) +class UnityCatalogTable: + table_uri: str + io_config: IOConfig | None + + +class UnityCatalog: + """Client to access the Unity Catalog + + Unity Catalog is an open-sourced data catalog that can be self-hosted, or hosted by Databricks. + + Example of reading a dataframe from a table in Unity Catalog hosted by Databricks: + + >>> cat = UnityCatalog("https://.cloud.databricks.com", token="my-token") + >>> table = cat.load_table("my_catalog.my_schema.my_table") + >>> df = daft.read_delta_lake(table) + """ + + def __init__(self, endpoint: str, token: str | None = None): + self._client = unitycatalog.Unitycatalog( + base_url=endpoint.rstrip("/") + "/api/2.1/unity-catalog/", + default_headers={"Authorization": f"Bearer {token}"}, + ) + + def _paginate_to_completion( + self, + client_func_call: Callable[[unitycatalog.Unitycatalog, str | None], tuple[list[str] | None, str | None]], + ) -> list[str]: + results = [] + + # Make first request + new_results, next_page_token = client_func_call(self._client, None) + if new_results is not None: + results.extend(new_results) + + # Exhaust pages + while next_page_token is not None and next_page_token != "": + new_results, next_page_token = client_func_call(self._client, next_page_token) + if new_results is not None: + results.extend(new_results) + + return results + + def list_catalogs(self) -> list[str]: + def _paginated_list_catalogs(client: unitycatalog.Unitycatalog, page_token: str | None): + response = client.catalogs.list(page_token=page_token) + next_page_token = response.next_page_token + if response.catalogs is None: + return None, next_page_token + return [c.name for c in response.catalogs], next_page_token + + return self._paginate_to_completion(_paginated_list_catalogs) + + def list_schemas(self, catalog_name: str) -> list[str]: + def _paginated_list_schemas(client: unitycatalog.Unitycatalog, page_token: str | None): + response = client.schemas.list(catalog_name=catalog_name, page_token=page_token) + next_page_token = response.next_page_token + if response.schemas is None: + return None, next_page_token + return [s.full_name for s in response.schemas], next_page_token + + return self._paginate_to_completion(_paginated_list_schemas) + + def list_tables(self, schema_name: str): + if schema_name.count(".") != 1: + raise ValueError( + f"Expected fully-qualified schema name with format `catalog_name`.`schema_name`, but received: {schema_name}" + ) + + catalog_name, schema_name = schema_name.split(".") + + def _paginated_list_tables(client: unitycatalog.Unitycatalog, page_token: str | None): + response = client.tables.list(catalog_name=catalog_name, schema_name=schema_name, page_token=page_token) + next_page_token = response.next_page_token + if response.tables is None: + return None, next_page_token + return [f"{t.catalog_name}.{t.schema_name}.{t.name}" for t in response.tables], next_page_token + + return self._paginate_to_completion(_paginated_list_tables) + + def load_table(self, table_name: str) -> UnityCatalogTable: + # Load the table ID + table_info = self._client.tables.retrieve(table_name) + table_id = table_info.table_id + storage_location = table_info.storage_location + + # Grab credentials from Unity catalog and place it into the Table + temp_table_credentials = self._client.temporary_table_credentials.create(operation="READ", table_id=table_id) + aws_temp_credentials = temp_table_credentials.aws_temp_credentials + io_config = ( + IOConfig( + s3=S3Config( + key_id=aws_temp_credentials.access_key_id, + access_key=aws_temp_credentials.secret_access_key, + session_token=aws_temp_credentials.session_token, + ) + ) + if aws_temp_credentials is not None + else None + ) + + return UnityCatalogTable( + table_uri=storage_location, + io_config=io_config, + ) diff --git a/docs/source/user_guide/integrations.rst b/docs/source/user_guide/integrations.rst index a548b2f518..de296676f5 100644 --- a/docs/source/user_guide/integrations.rst +++ b/docs/source/user_guide/integrations.rst @@ -4,6 +4,7 @@ Integrations .. toctree:: integrations/ray + integrations/unity-catalog integrations/iceberg integrations/delta_lake integrations/hudi diff --git a/docs/source/user_guide/integrations/delta_lake.rst b/docs/source/user_guide/integrations/delta_lake.rst index 2a2f7a0ce4..777d7c4d51 100644 --- a/docs/source/user_guide/integrations/delta_lake.rst +++ b/docs/source/user_guide/integrations/delta_lake.rst @@ -74,6 +74,7 @@ Write to Delta Lake You can use `write_deltalake` to write a Daft DataFrame to a Delta table: .. code:: python + df.write_deltalake("tmp/daft-table", mode="overwrite") diff --git a/docs/source/user_guide/integrations/unity-catalog.rst b/docs/source/user_guide/integrations/unity-catalog.rst new file mode 100644 index 0000000000..05e9fa20ea --- /dev/null +++ b/docs/source/user_guide/integrations/unity-catalog.rst @@ -0,0 +1,62 @@ +Unity Catalog +============= + +`Unity Catalog `_ is an open-sourced catalog developed by Databricks. +Users of Unity Catalog are able to work with data assets such as tables (Parquet, CSV, Iceberg, Delta), volumes +(storing raw files), functions and models. + +.. WARNING:: + + These APIs are in beta and may be subject to change as the Unity Catalog continues to be developed. + +Connecting to the Unity Catalog +******************************* + +Daft includes an abstraction for the Unity Catalog. + +.. code:: python + + from daft.unity_catalog import UnityCatalog + + unity = UnityCatalog( + endpoint="https://.cloud.databricks.com", + # Authentication can be retrieved from your provider of Unity Catalog + token="my-token", + ) + + # See all available catalogs + print(unity.list_catalogs()) + + # See available schemas in a given catalog + print(unity.list_schemas("my_catalog_name")) + + # See available tables in a given schema + print(unity.list_tables("my_catalog_name.my_schema_name")) + +Loading a Daft Dataframe from a Delta Lake table in Unity Catalog +***************************************************************** + +.. code:: python + + unity_table = unity.load_table("my_catalog_name.my_schema_name.my_table_name") + + df = daft.read_delta_lake(unity_table) + df.show() + +Any subsequent filter operations on the Daft ``df`` DataFrame object will be correctly optimized to take advantage of DeltaLake features + +.. code:: python + + # Filter which takes advantage of partition pruning capabilities of Delta Lake + df = df.where(df["partition_key"] < 1000) + df.show() + +See also :doc:`delta_lake` for more information about how to work with the Delta Lake tables provided by the Unity Catalog. + +Roadmap +******* + +1. Volumes integration for reading objects from volumes (e.g. images and documents) +2. Unity Iceberg integration for reading tables using the Iceberg interface instead of the Delta Lake interface + +Please make issues on the Daft repository if you have any use-cases that Daft does not currently cover! diff --git a/pyproject.toml b/pyproject.toml index 78f921cb7a..6ea9f5aa1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ readme = "README.rst" requires-python = ">=3.8" [project.optional-dependencies] -all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, sql]"] +all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, sql, unity]"] aws = ["boto3"] azure = [] deltalake = ["deltalake"] @@ -40,6 +40,7 @@ ray = [ "packaging" ] sql = ["connectorx", "sqlalchemy", "sqlglot"] +unity = ["unitycatalog"] viz = [] [project.urls] diff --git a/requirements-dev.txt b/requirements-dev.txt index 9aca0f1b51..29b0ef07d3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -52,6 +52,7 @@ deltalake==0.15.3; platform_system != "Windows" and python_version >= '3.8' # Databricks databricks-sdk==0.12.0 +unitycatalog==0.1.1 #SQL sqlalchemy==2.0.25; python_version >= '3.8'