Skip to content

Commit

Permalink
[FEAT] Add Unity Catalog support (#2377)
Browse files Browse the repository at this point in the history
1. Adds a new `daft.io.UnityCatalog` class
2. Adds some basic methods on that class, only implements
`list_schemas`, `list_tables` and `load_tables` right now
3. Adds integrations with `daft.read_delta_lake` to make this work
4. Ensure that the `io_config` is correctly propagated by requesting
credentials from unity catalog

<img width="484" alt="image"
src="https://github.com/Eventual-Inc/Daft/assets/17691182/a6c2b670-7d61-4c39-b068-a4c2f207d54c">

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Jun 14, 2024
1 parent 041a73a commit 395ebe8
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 5 deletions.
21 changes: 17 additions & 4 deletions daft/io/_delta_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from daft.io.catalog import DataCatalogTable
from daft.logical.builder import LogicalPlanBuilder

_UNITY_CATALOG_AVAILABLE = True
try:
from daft.unity_catalog import UnityCatalogTable
except ImportError:
_UNITY_CATALOG_AVAILABLE = False


def read_delta_lake(
table: Union[str, DataCatalogTable],
Expand All @@ -25,7 +31,7 @@ def read_delta_lake(

@PublicAPI
def read_deltalake(
table: Union[str, DataCatalogTable],
table: Union[str, DataCatalogTable, "UnityCatalogTable"],
io_config: Optional["IOConfig"] = None,
_multithreaded_io: Optional[bool] = None,
) -> DataFrame:
Expand Down Expand Up @@ -56,20 +62,27 @@ def read_deltalake(
"""
from daft.delta_lake.delta_lake_scan import DeltaLakeScanOperator

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

# If running on Ray, we want to limit the amount of concurrency and requests being made.
# This is because each Ray worker process receives its own pool of thread workers and connections
multithreaded_io = not context.get_context().is_ray_runner if _multithreaded_io is None else _multithreaded_io

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))

if isinstance(table, str):
table_uri = table
elif isinstance(table, DataCatalogTable):
table_uri = table.table_uri(io_config)
elif _UNITY_CATALOG_AVAILABLE and isinstance(table, UnityCatalogTable):
table_uri = table.table_uri

# Override the storage_config with the one provided by Unity catalog
table_io_config = table.io_config
if table_io_config is not None:
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, table_io_config))
else:
raise ValueError(
f"table argument must be a table URI string or a DataCatalogTable instance, but got: {type(table)}, {table}"
f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}"
)
delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config)

Expand Down
3 changes: 3 additions & 0 deletions daft/unity_catalog/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .unity_catalog import UnityCatalog, UnityCatalogTable

__all__ = ["UnityCatalog", "UnityCatalogTable"]
115 changes: 115 additions & 0 deletions daft/unity_catalog/unity_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

import dataclasses
from typing import Callable

import unitycatalog

from daft.io import IOConfig, S3Config


@dataclasses.dataclass(frozen=True)
class UnityCatalogTable:
table_uri: str
io_config: IOConfig | None


class UnityCatalog:
"""Client to access the Unity Catalog
Unity Catalog is an open-sourced data catalog that can be self-hosted, or hosted by Databricks.
Example of reading a dataframe from a table in Unity Catalog hosted by Databricks:
>>> cat = UnityCatalog("https://<databricks_workspace_id>.cloud.databricks.com", token="my-token")
>>> table = cat.load_table("my_catalog.my_schema.my_table")
>>> df = daft.read_delta_lake(table)
"""

def __init__(self, endpoint: str, token: str | None = None):
self._client = unitycatalog.Unitycatalog(
base_url=endpoint.rstrip("/") + "/api/2.1/unity-catalog/",
default_headers={"Authorization": f"Bearer {token}"},
)

def _paginate_to_completion(
self,
client_func_call: Callable[[unitycatalog.Unitycatalog, str | None], tuple[list[str] | None, str | None]],
) -> list[str]:
results = []

# Make first request
new_results, next_page_token = client_func_call(self._client, None)
if new_results is not None:
results.extend(new_results)

# Exhaust pages
while next_page_token is not None and next_page_token != "":
new_results, next_page_token = client_func_call(self._client, next_page_token)
if new_results is not None:
results.extend(new_results)

return results

def list_catalogs(self) -> list[str]:
def _paginated_list_catalogs(client: unitycatalog.Unitycatalog, page_token: str | None):
response = client.catalogs.list(page_token=page_token)
next_page_token = response.next_page_token
if response.catalogs is None:
return None, next_page_token
return [c.name for c in response.catalogs], next_page_token

return self._paginate_to_completion(_paginated_list_catalogs)

def list_schemas(self, catalog_name: str) -> list[str]:
def _paginated_list_schemas(client: unitycatalog.Unitycatalog, page_token: str | None):
response = client.schemas.list(catalog_name=catalog_name, page_token=page_token)
next_page_token = response.next_page_token
if response.schemas is None:
return None, next_page_token
return [s.full_name for s in response.schemas], next_page_token

return self._paginate_to_completion(_paginated_list_schemas)

def list_tables(self, schema_name: str):
if schema_name.count(".") != 1:
raise ValueError(
f"Expected fully-qualified schema name with format `catalog_name`.`schema_name`, but received: {schema_name}"
)

catalog_name, schema_name = schema_name.split(".")

def _paginated_list_tables(client: unitycatalog.Unitycatalog, page_token: str | None):
response = client.tables.list(catalog_name=catalog_name, schema_name=schema_name, page_token=page_token)
next_page_token = response.next_page_token
if response.tables is None:
return None, next_page_token
return [f"{t.catalog_name}.{t.schema_name}.{t.name}" for t in response.tables], next_page_token

return self._paginate_to_completion(_paginated_list_tables)

def load_table(self, table_name: str) -> UnityCatalogTable:
# Load the table ID
table_info = self._client.tables.retrieve(table_name)
table_id = table_info.table_id
storage_location = table_info.storage_location

# Grab credentials from Unity catalog and place it into the Table
temp_table_credentials = self._client.temporary_table_credentials.create(operation="READ", table_id=table_id)
aws_temp_credentials = temp_table_credentials.aws_temp_credentials
io_config = (
IOConfig(
s3=S3Config(
key_id=aws_temp_credentials.access_key_id,
access_key=aws_temp_credentials.secret_access_key,
session_token=aws_temp_credentials.session_token,
)
)
if aws_temp_credentials is not None
else None
)

return UnityCatalogTable(
table_uri=storage_location,
io_config=io_config,
)
1 change: 1 addition & 0 deletions docs/source/user_guide/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Integrations
.. toctree::

integrations/ray
integrations/unity-catalog
integrations/iceberg
integrations/delta_lake
integrations/hudi
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide/integrations/delta_lake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Write to Delta Lake
You can use `write_deltalake` to write a Daft DataFrame to a Delta table:

.. code:: python
df.write_deltalake("tmp/daft-table", mode="overwrite")
Expand Down
62 changes: 62 additions & 0 deletions docs/source/user_guide/integrations/unity-catalog.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Unity Catalog
=============

`Unity Catalog <https://github.com/unitycatalog/unitycatalog//>`_ is an open-sourced catalog developed by Databricks.
Users of Unity Catalog are able to work with data assets such as tables (Parquet, CSV, Iceberg, Delta), volumes
(storing raw files), functions and models.

.. WARNING::

These APIs are in beta and may be subject to change as the Unity Catalog continues to be developed.

Connecting to the Unity Catalog
*******************************

Daft includes an abstraction for the Unity Catalog.

.. code:: python
from daft.unity_catalog import UnityCatalog
unity = UnityCatalog(
endpoint="https://<databricks_workspace_id>.cloud.databricks.com",
# Authentication can be retrieved from your provider of Unity Catalog
token="my-token",
)
# See all available catalogs
print(unity.list_catalogs())
# See available schemas in a given catalog
print(unity.list_schemas("my_catalog_name"))
# See available tables in a given schema
print(unity.list_tables("my_catalog_name.my_schema_name"))
Loading a Daft Dataframe from a Delta Lake table in Unity Catalog
*****************************************************************

.. code:: python
unity_table = unity.load_table("my_catalog_name.my_schema_name.my_table_name")
df = daft.read_delta_lake(unity_table)
df.show()
Any subsequent filter operations on the Daft ``df`` DataFrame object will be correctly optimized to take advantage of DeltaLake features

.. code:: python
# Filter which takes advantage of partition pruning capabilities of Delta Lake
df = df.where(df["partition_key"] < 1000)
df.show()
See also :doc:`delta_lake` for more information about how to work with the Delta Lake tables provided by the Unity Catalog.

Roadmap
*******

1. Volumes integration for reading objects from volumes (e.g. images and documents)
2. Unity Iceberg integration for reading tables using the Iceberg interface instead of the Delta Lake interface

Please make issues on the Daft repository if you have any use-cases that Daft does not currently cover!
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ readme = "README.rst"
requires-python = ">=3.8"

[project.optional-dependencies]
all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, sql]"]
all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, sql, unity]"]
aws = ["boto3"]
azure = []
deltalake = ["deltalake"]
Expand All @@ -40,6 +40,7 @@ ray = [
"packaging"
]
sql = ["connectorx", "sqlalchemy", "sqlglot"]
unity = ["unitycatalog"]
viz = []

[project.urls]
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ deltalake==0.15.3; platform_system != "Windows" and python_version >= '3.8'

# Databricks
databricks-sdk==0.12.0
unitycatalog==0.1.1

#SQL
sqlalchemy==2.0.25; python_version >= '3.8'
Expand Down

0 comments on commit 395ebe8

Please sign in to comment.