From bb506a4aaaab76ba0c86cb48d0be587ee0af9bd7 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Sun, 17 Nov 2024 13:23:53 -0800 Subject: [PATCH] [FEAT] Daft Catalog API (#3036) Adds a `DaftCatalog` API to help cement Daft's catalog data access patterns. Here is the intended UX: ```python import daft ### # Registering external catalogs # # We recognize `PyIceberg.Catalog`s and `UnityCatalog` for now # TODO: This should be configurable via a Daft catalog config file (.toml or .yaml) ### from pyiceberg.catalog import load_catalog catalog = load_catalog(...) daft.catalog.register_python_catalog(catalog) ### # Adding named tables ### df = daft.register_table(df, "foo") ### # Reading tables ### df1 = daft.read_table("foo") # first priority is named tables df2 = daft.read_table("x.y.z") # next priority is the registered default catalog df2 = daft.read_table("default.x.y.z") # equivalent to the previous call df3 = daft.read_table("my_other_catalog.x.y.z") # Supports named catalogs other than default one ``` Other APIs which will be nice as follow-ons: - [ ] Integrate this with the SQLCatalog API that our SQL stuff uses - [ ] Detection of catalog from a YAML `~/.daft.yaml` config file - [ ] Allow for configuring table access (e.g. `daft.read_table("iceberg_table", options=daft.catalog.IcebergReadOptions(...))`) - [ ] Implementations for other catalogs that isn't a Python catalog, and can support other table types (e.g. Hive and Delta): - [ ] `daft.catalog.register_aws_glue()` - [ ] `daft.catalog.register_hive_metastore()` - [ ] `df.write_table("table_name", mode="overwrite|append", create_if_missing=True)` - [ ] `df.upsert("table_name", match_columns={...}, update_columns={...}, insert_columns={...})` - [ ] DDL: allow for easy creation of tables, erroring out if the selected backend does not support a given table format - [ ] `daft.catalog.create_table_parquet(...)` - [ ] `daft.catalog.create_table_iceberg(...)` - [ ] `daft.catalog.create_table_deltalake(...)` - [ ] `daft.catalog.list_tables(...)` --------- Co-authored-by: Jay Chia --- Cargo.lock | 24 ++ Cargo.toml | 10 + daft/__init__.py | 3 + daft/catalog/__init__.py | 151 +++++++++++++ daft/catalog/pyiceberg.py | 32 +++ daft/catalog/python_catalog.py | 24 ++ daft/catalog/unity.py | 49 +++++ daft/daft/catalog.pyi | 11 + daft/io/catalog.py | 9 + src/common/error/src/error.rs | 2 + src/daft-catalog/Cargo.toml | 15 ++ src/daft-catalog/python-catalog/Cargo.toml | 14 ++ src/daft-catalog/python-catalog/src/lib.rs | 2 + src/daft-catalog/python-catalog/src/python.rs | 168 ++++++++++++++ src/daft-catalog/src/data_catalog.rs | 13 ++ src/daft-catalog/src/data_catalog_table.rs | 11 + src/daft-catalog/src/errors.rs | 57 +++++ src/daft-catalog/src/lib.rs | 208 ++++++++++++++++++ src/daft-catalog/src/python.rs | 97 ++++++++ src/lib.rs | 4 + tests/integration/iceberg/conftest.py | 62 +++--- tests/integration/iceberg/test_cloud_load.py | 7 +- .../iceberg/test_partition_pruning.py | 28 ++- .../test_pyiceberg_written_table_load.py | 32 +-- tests/integration/iceberg/test_table_load.py | 67 +++--- 25 files changed, 1021 insertions(+), 79 deletions(-) create mode 100644 daft/catalog/__init__.py create mode 100644 daft/catalog/pyiceberg.py create mode 100644 daft/catalog/python_catalog.py create mode 100644 daft/catalog/unity.py create mode 100644 daft/daft/catalog.pyi create mode 100644 src/daft-catalog/Cargo.toml create mode 100644 src/daft-catalog/python-catalog/Cargo.toml create mode 100644 src/daft-catalog/python-catalog/src/lib.rs create mode 100644 src/daft-catalog/python-catalog/src/python.rs create mode 100644 src/daft-catalog/src/data_catalog.rs create mode 100644 src/daft-catalog/src/data_catalog_table.rs create mode 100644 src/daft-catalog/src/errors.rs create mode 100644 src/daft-catalog/src/lib.rs create mode 100644 src/daft-catalog/src/python.rs diff --git a/Cargo.lock b/Cargo.lock index bf9dfaeb47..2d1afcbe3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1860,6 +1860,8 @@ dependencies = [ "common-system-info", "common-tracing", "common-version", + "daft-catalog", + "daft-catalog-python-catalog", "daft-compression", "daft-connect", "daft-core", @@ -1894,6 +1896,28 @@ dependencies = [ "tikv-jemallocator", ] +[[package]] +name = "daft-catalog" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "daft-core", + "daft-logical-plan", + "lazy_static", + "pyo3", + "snafu", +] + +[[package]] +name = "daft-catalog-python-catalog" +version = "0.3.0-dev0" +dependencies = [ + "daft-catalog", + "daft-logical-plan", + "pyo3", + "snafu", +] + [[package]] name = "daft-compression" version = "0.3.0-dev0" diff --git a/Cargo.toml b/Cargo.toml index e050ab5368..cf16c58ca9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,8 @@ common-scan-info = {path = "src/common/scan-info", default-features = false} common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} +daft-catalog = {path = "src/daft-catalog", default-features = false} +daft-catalog-python-catalog = {path = "src/daft-catalog/python-catalog", optional = true} daft-compression = {path = "src/daft-compression", default-features = false} daft-connect = {path = "src/daft-connect", optional = true} daft-core = {path = "src/daft-core", default-features = false} @@ -52,6 +54,8 @@ python = [ "common-display/python", "common-resource-request/python", "common-system-info/python", + "daft-catalog/python", + "daft-catalog-python-catalog/python", "daft-connect/python", "daft-core/python", "daft-csv/python", @@ -75,6 +79,11 @@ python = [ "daft-writers/python", "daft-table/python", "dep:daft-connect", + "common-daft-config/python", + "common-system-info/python", + "common-display/python", + "common-resource-request/python", + "dep:daft-catalog-python-catalog", "dep:pyo3", "dep:pyo3-log" ] @@ -140,6 +149,7 @@ members = [ "src/common/scan-info", "src/common/system-info", "src/common/treenode", + "src/daft-catalog", "src/daft-core", "src/daft-csv", "src/daft-dsl", diff --git a/daft/__init__.py b/daft/__init__.py index 0e145115c9..add31fc9e2 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -58,6 +58,7 @@ def refresh_logger() -> None: # Daft top-level imports ### +from daft.catalog import read_table, register_table from daft.context import set_execution_config, set_planning_config, execution_config_ctx, planning_config_ctx from daft.convert import ( from_arrow, @@ -129,6 +130,8 @@ def refresh_logger() -> None: "set_execution_config", "planning_config_ctx", "execution_config_ctx", + "read_table", + "register_table", "sql", "sql_expr", "to_struct", diff --git a/daft/catalog/__init__.py b/daft/catalog/__init__.py new file mode 100644 index 0000000000..438fd369d0 --- /dev/null +++ b/daft/catalog/__init__.py @@ -0,0 +1,151 @@ +"""The `daft.catalog` module contains functionality for Data Catalogs. + +A Data Catalog can be understood as a system/service for users to discover, access and query their data. +Most commonly, users' data is represented as a "table". Some more modern Data Catalogs such as Unity Catalog +also expose other types of data including files, ML models, registered functions and more. + +Examples of Data Catalogs include AWS Glue, Hive Metastore, Apache Iceberg REST and Unity Catalog. + +Daft manages Data Catalogs by registering them in an internal meta-catalog, called the "DaftMetaCatalog". This +is simple a collection of data catalogs, which Daft will attempt to detect from a users' current environment. + +**Data Catalog** + +Daft recognizes a default catalog which it will attempt to use when no specific catalog name is provided. + +```python +# This will hit the default catalog +daft.read_table("my_db.my_namespace.my_table") +``` + +**Named Tables** + +Daft allows also the registration of named tables, which have no catalog associated with them. + +Note that named tables take precedence over the default catalog's table names when resolving names. + +```python +df = daft.from_pydict({"foo": [1, 2, 3]}) + +daft.catalog.register_named_table( + "my_table", + df, +) + +# Your table is now accessible from Daft-SQL, or Daft's `read_table` +df1 = daft.read_table("my_table") +df2 = daft.sql("SELECT * FROM my_table") +``` +""" + +from __future__ import annotations + +from daft.daft import catalog as native_catalog +from daft.logical.builder import LogicalPlanBuilder + +from daft.dataframe import DataFrame + +_PYICEBERG_AVAILABLE = False +try: + from pyiceberg.catalog import Catalog as PyIcebergCatalog + + _PYICEBERG_AVAILABLE = True +except ImportError: + pass + +_UNITY_AVAILABLE = False +try: + from daft.unity_catalog import UnityCatalog + + _UNITY_AVAILABLE = True +except ImportError: + pass + +__all__ = [ + "read_table", + "register_python_catalog", + "unregister_catalog", + "register_table", +] + +# Forward imports from the native catalog which don't require Python wrappers +unregister_catalog = native_catalog.unregister_catalog + + +def read_table(name: str) -> DataFrame: + """Finds a table with the specified name and reads it as a DataFrame + + The provided name can be any of the following, and Daft will return them with the following order of priority: + + 1. Name of a registered dataframe/SQL view (manually registered using `daft.register_table`): `"my_registered_table"` + 2. Name of a table within the default catalog (without inputting the catalog name) for example: `"my.table.name"` + 3. Name of a fully-qualified table path with the catalog name for example: `"my_catalog.my.table.name"` + + Args: + name: The identifier for the table to read + + Returns: + A DataFrame containing the data from the specified table. + """ + native_logical_plan_builder = native_catalog.read_table(name) + return DataFrame(LogicalPlanBuilder(native_logical_plan_builder)) + + +def register_table(name: str, dataframe: DataFrame) -> str: + """Register a DataFrame as a named table. + + This function registers a DataFrame as a named table, making it accessible + via Daft-SQL or Daft's `read_table` function. + + Args: + name (str): The name to register the table under. + dataframe (daft.DataFrame): The DataFrame to register as a table. + + Returns: + str: The name of the registered table. + + Example: + >>> df = daft.from_pydict({"foo": [1, 2, 3]}) + >>> daft.catalog.register_table("my_table", df) + >>> daft.read_table("my_table") + """ + return native_catalog.register_table(name, dataframe._builder._builder) + + +def register_python_catalog(catalog: PyIcebergCatalog | UnityCatalog, name: str | None = None) -> str: + """Registers a Python catalog with Daft + + Currently supports: + + * [PyIceberg Catalogs](https://py.iceberg.apache.org/api/) + * [Unity Catalog](https://www.getdaft.io/projects/docs/en/latest/user_guide/integrations/unity-catalog.html) + + Args: + catalog (PyIcebergCatalog | UnityCatalog): The Python catalog to register. + name (str | None, optional): The name to register the catalog under. If None, this catalog is registered as the default catalog. + + Returns: + str: The name of the registered catalog. + + Raises: + ValueError: If an unsupported catalog type is provided. + + Example: + >>> from pyiceberg.catalog import load_catalog + >>> catalog = load_catalog("my_catalog") + >>> daft.catalog.register_python_catalog(catalog, "my_daft_catalog") + + """ + python_catalog: PyIcebergCatalog + if _PYICEBERG_AVAILABLE and isinstance(catalog, PyIcebergCatalog): + from daft.catalog.pyiceberg import PyIcebergCatalogAdaptor + + python_catalog = PyIcebergCatalogAdaptor(catalog) + elif _UNITY_AVAILABLE and isinstance(catalog, UnityCatalog): + from daft.catalog.unity import UnityCatalogAdaptor + + python_catalog = UnityCatalogAdaptor(catalog) + else: + raise ValueError(f"Unsupported catalog type: {type(catalog)}") + + return native_catalog.register_python_catalog(python_catalog, name) diff --git a/daft/catalog/pyiceberg.py b/daft/catalog/pyiceberg.py new file mode 100644 index 0000000000..bde293a488 --- /dev/null +++ b/daft/catalog/pyiceberg.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyiceberg.catalog import Catalog as PyIcebergCatalog + from pyiceberg.table import Table as PyIcebergTable + + from daft.dataframe import DataFrame + +from daft.catalog.python_catalog import PythonCatalog, PythonCatalogTable + + +class PyIcebergCatalogAdaptor(PythonCatalog): + def __init__(self, pyiceberg_catalog: PyIcebergCatalog): + self._catalog = pyiceberg_catalog + + def list_tables(self, prefix: str) -> list[str]: + return [".".join(tup) for tup in self._catalog.list_tables(prefix)] + + def load_table(self, name: str) -> PyIcebergTableAdaptor: + return PyIcebergTableAdaptor(self._catalog.load_table(name)) + + +class PyIcebergTableAdaptor(PythonCatalogTable): + def __init__(self, pyiceberg_table: PyIcebergTable): + self._table = pyiceberg_table + + def to_dataframe(self) -> DataFrame: + import daft + + return daft.read_iceberg(self._table) diff --git a/daft/catalog/python_catalog.py b/daft/catalog/python_catalog.py new file mode 100644 index 0000000000..dc911f3766 --- /dev/null +++ b/daft/catalog/python_catalog.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from daft.dataframe import DataFrame + + +class PythonCatalog: + """Wrapper class for various Python implementations of Data Catalogs""" + + @abstractmethod + def list_tables(self, prefix: str) -> list[str]: ... + + @abstractmethod + def load_table(self, name: str) -> PythonCatalogTable: ... + + +class PythonCatalogTable: + """Wrapper class for various Python implementations of Data Catalog Tables""" + + @abstractmethod + def to_dataframe(self) -> DataFrame: ... diff --git a/daft/catalog/unity.py b/daft/catalog/unity.py new file mode 100644 index 0000000000..aed0b0233e --- /dev/null +++ b/daft/catalog/unity.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from daft.dataframe import DataFrame + from daft.unity_catalog import UnityCatalog, UnityCatalogTable + +from daft.catalog.python_catalog import PythonCatalog, PythonCatalogTable + + +class UnityCatalogAdaptor(PythonCatalog): + def __init__(self, unity_catalog: UnityCatalog): + self._catalog = unity_catalog + + def list_tables(self, prefix: str) -> list[str]: + num_namespaces = prefix.count(".") + if prefix == "": + return [ + tbl + for cat in self._catalog.list_catalogs() + for schema in self._catalog.list_schemas(cat) + for tbl in self._catalog.list_tables(schema) + ] + elif num_namespaces == 0: + catalog_name = prefix + return [ + tbl for schema in self._catalog.list_schemas(catalog_name) for tbl in self._catalog.list_tables(schema) + ] + elif num_namespaces == 1: + schema_name = prefix + return [tbl for tbl in self._catalog.list_tables(schema_name)] + else: + raise ValueError( + f"Unrecognized catalog name or schema name, expected a '.'-separated namespace but received: {prefix}" + ) + + def load_table(self, name: str) -> UnityTableAdaptor: + return UnityTableAdaptor(self._catalog.load_table(name)) + + +class UnityTableAdaptor(PythonCatalogTable): + def __init__(self, unity_table: UnityCatalogTable): + self._table = unity_table + + def to_dataframe(self) -> DataFrame: + import daft + + return daft.read_deltalake(self._table) diff --git a/daft/daft/catalog.pyi b/daft/daft/catalog.pyi new file mode 100644 index 0000000000..55c5cb50c8 --- /dev/null +++ b/daft/daft/catalog.pyi @@ -0,0 +1,11 @@ +from typing import TYPE_CHECKING + +from daft.daft import LogicalPlanBuilder as PyLogicalPlanBuilder + +if TYPE_CHECKING: + from daft.catalog.python_catalog import PythonCatalog + +def read_table(name: str) -> PyLogicalPlanBuilder: ... +def register_table(name: str, plan_builder: PyLogicalPlanBuilder) -> str: ... +def register_python_catalog(catalog: PythonCatalog, catalog_name: str | None) -> str: ... +def unregister_catalog(catalog_name: str | None) -> bool: ... diff --git a/daft/io/catalog.py b/daft/io/catalog.py index 62cb16e672..859e69c6c0 100644 --- a/daft/io/catalog.py +++ b/daft/io/catalog.py @@ -34,6 +34,15 @@ class DataCatalogTable: table_name: str catalog_id: Optional[str] = None + def __post_init__(self): + import warnings + + warnings.warn( + "This API will soon be deprecated. Users should use the new functionality in daft.catalog.", + DeprecationWarning, + stacklevel=2, + ) + def table_uri(self, io_config: IOConfig) -> str: """ Get the URI of the table in the data catalog. diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 6553abc388..e0f4835852 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -50,6 +50,8 @@ pub enum DaftError { FromUtf8Error(#[from] std::string::FromUtf8Error), #[error("Not Yet Implemented: {0}")] NotImplemented(String), + #[error("DaftError::CatalogError {0}")] + CatalogError(String), } impl DaftError { diff --git a/src/daft-catalog/Cargo.toml b/src/daft-catalog/Cargo.toml new file mode 100644 index 0000000000..2e2870e134 --- /dev/null +++ b/src/daft-catalog/Cargo.toml @@ -0,0 +1,15 @@ +[dependencies] +common-error = {path = "../common/error", default-features = false} +daft-core = {path = "../daft-core", default-features = false} +daft-logical-plan = {path = "../daft-logical-plan", default-features = false} +lazy_static = {workspace = true} +pyo3 = {workspace = true, optional = true} +snafu.workspace = true + +[features] +python = ["dep:pyo3", "common-error/python", "daft-logical-plan/python", "daft-core/python"] + +[package] +name = "daft-catalog" +edition.workspace = true +version.workspace = true diff --git a/src/daft-catalog/python-catalog/Cargo.toml b/src/daft-catalog/python-catalog/Cargo.toml new file mode 100644 index 0000000000..1893bbd4ee --- /dev/null +++ b/src/daft-catalog/python-catalog/Cargo.toml @@ -0,0 +1,14 @@ +[dependencies] +daft-catalog = {path = "..", default-features = false} +daft-logical-plan = {path = "../../daft-logical-plan", default-features = false} +pyo3 = {workspace = true, optional = true} +snafu = {workspace = true} + +[features] +python = ["daft-catalog/python", "daft-logical-plan/python", "dep:pyo3"] + +[package] +description = "Python implementations of Daft DataCatalogTable (backed by a PythonCatalog abstract class)" +name = "daft-catalog-python-catalog" +edition.workspace = true +version.workspace = true diff --git a/src/daft-catalog/python-catalog/src/lib.rs b/src/daft-catalog/python-catalog/src/lib.rs new file mode 100644 index 0000000000..a098fa319b --- /dev/null +++ b/src/daft-catalog/python-catalog/src/lib.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "python")] +pub mod python; diff --git a/src/daft-catalog/python-catalog/src/python.rs b/src/daft-catalog/python-catalog/src/python.rs new file mode 100644 index 0000000000..f9de6fffcc --- /dev/null +++ b/src/daft-catalog/python-catalog/src/python.rs @@ -0,0 +1,168 @@ +use std::sync::Arc; + +use daft_catalog::{ + errors::{Error as DaftCatalogError, Result}, + DataCatalog, DataCatalogTable, +}; +use daft_logical_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder}; +use pyo3::prelude::*; +use snafu::{ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Error while listing tables: {}", source))] + ListTables { source: pyo3::PyErr }, + + #[snafu(display("Error while getting table {}: {}", table_name, source))] + GetTable { + source: pyo3::PyErr, + table_name: String, + }, + + #[snafu(display("Error while reading table {} into Daft: {}", table_name, source))] + ReadTable { + source: pyo3::PyErr, + table_name: String, + }, +} + +impl From for DaftCatalogError { + fn from(value: Error) -> Self { + match value { + Error::ListTables { source } => DaftCatalogError::PythonError { + source, + context: "listing tables".to_string(), + }, + Error::GetTable { source, table_name } => DaftCatalogError::PythonError { + source, + context: format!("getting table `{}`", table_name), + }, + Error::ReadTable { source, table_name } => DaftCatalogError::PythonError { + source, + context: format!("reading table `{}`", table_name), + }, + } + } +} + +/// Wrapper around a `daft.catalog.python_catalog.PythonCatalogTable` +pub struct PythonTable { + table_name: String, + table_pyobj: PyObject, +} + +impl PythonTable { + pub fn new(table_name: String, table_pyobj: PyObject) -> Self { + Self { + table_name, + table_pyobj, + } + } +} + +impl DataCatalogTable for PythonTable { + fn to_logical_plan_builder(&self) -> daft_catalog::errors::Result { + Python::with_gil(|py| { + let dataframe = self + .table_pyobj + .bind(py) + .getattr("to_dataframe") + .and_then(|to_dataframe_method| to_dataframe_method.call0()) + .with_context(|_| ReadTableSnafu { + table_name: self.table_name.clone(), + })?; + + let py_logical_plan_builder = dataframe + .getattr("_builder") + .unwrap() + .getattr("_builder") + .unwrap(); + let py_logical_plan_builder = py_logical_plan_builder + .downcast::() + .unwrap(); + Ok(py_logical_plan_builder.borrow().builder.clone()) + }) + } +} + +/// Wrapper around a `daft.catalog.python_catalog.PythonCatalog` +pub struct PythonCatalog { + python_catalog_pyobj: PyObject, +} + +impl PythonCatalog { + pub fn new(python_catalog_pyobj: PyObject) -> Self { + Self { + python_catalog_pyobj, + } + } +} + +impl DataCatalog for PythonCatalog { + fn list_tables(&self, prefix: &str) -> Result> { + Python::with_gil(|py| { + let python_catalog = self.python_catalog_pyobj.bind(py); + + Ok(python_catalog + .getattr("list_tables") + .and_then(|list_tables_method| list_tables_method.call1((prefix,))) + .and_then(|tables| tables.extract::>()) + .with_context(|_| ListTablesSnafu)?) + }) + } + + fn get_table(&self, name: &str) -> Result>> { + Python::with_gil(|py| { + let python_catalog = self.python_catalog_pyobj.bind(py); + let list_tables_method = + python_catalog + .getattr("load_table") + .with_context(|_| GetTableSnafu { + table_name: name.to_string(), + })?; + let table = list_tables_method + .call1((name,)) + .with_context(|_| GetTableSnafu { + table_name: name.to_string(), + })?; + let python_table = PythonTable::new(name.to_string(), table.unbind()); + Ok(Some(Box::new(python_table) as Box)) + }) + } +} + +/// Registers an PythonCatalog instance +/// +/// This function registers a Python-based catalog with the Daft catalog system. +/// +/// Args: +/// python_catalog_obj (PyObject): The Python catalog object to register. +/// catalog_name (Optional[str]): The name to give the catalog. If None, a default name will be used. +/// +/// Returns: +/// str: The name of the registered catalog (always "default" in the current implementation). +/// +/// Raises: +/// PyErr: If there's an error during the registration process. +/// +/// Example: +/// >>> import daft +/// >>> from my_python_catalog import MyPythonCatalog +/// >>> python_catalog = MyPythonCatalog() +/// >>> daft.register_python_catalog(python_catalog, "my_catalog") +/// 'default' +#[pyfunction] +#[pyo3(name = "register_python_catalog")] +pub fn py_register_python_catalog( + python_catalog_obj: PyObject, + catalog_name: Option<&str>, +) -> PyResult { + let catalog = PythonCatalog::new(python_catalog_obj); + daft_catalog::global_catalog::register_catalog(Arc::new(catalog), catalog_name); + Ok("default".to_string()) +} + +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_wrapped(wrap_pyfunction!(py_register_python_catalog))?; + Ok(()) +} diff --git a/src/daft-catalog/src/data_catalog.rs b/src/daft-catalog/src/data_catalog.rs new file mode 100644 index 0000000000..0cf6d45708 --- /dev/null +++ b/src/daft-catalog/src/data_catalog.rs @@ -0,0 +1,13 @@ +use crate::{data_catalog_table::DataCatalogTable, errors::Result}; + +/// DataCatalog is a catalog of data sources +/// +/// It allows registering and retrieving data sources, as well as querying their schemas. +/// The catalog is used by the query planner to resolve table references in queries. +pub trait DataCatalog: Sync + Send { + /// Lists the fully-qualified names of tables in the catalog with the specified prefix + fn list_tables(&self, prefix: &str) -> Result>; + + /// Retrieves a [`DataCatalogTable`] from this [`DataCatalog`] if it exists + fn get_table(&self, name: &str) -> Result>>; +} diff --git a/src/daft-catalog/src/data_catalog_table.rs b/src/daft-catalog/src/data_catalog_table.rs new file mode 100644 index 0000000000..ba9dcb00f3 --- /dev/null +++ b/src/daft-catalog/src/data_catalog_table.rs @@ -0,0 +1,11 @@ +use daft_logical_plan::LogicalPlanBuilder; + +use crate::errors; + +/// A Table in a Data Catalog +/// +/// This is a trait because there are many different implementations of this, for example +/// Iceberg, DeltaLake, Hive and more. +pub trait DataCatalogTable { + fn to_logical_plan_builder(&self) -> errors::Result; +} diff --git a/src/daft-catalog/src/errors.rs b/src/daft-catalog/src/errors.rs new file mode 100644 index 0000000000..2c6c1de6fe --- /dev/null +++ b/src/daft-catalog/src/errors.rs @@ -0,0 +1,57 @@ +use snafu::Snafu; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display( + "Failed to find specified table identifier {} in the requested catalog {}", + catalog_name, + table_id + ))] + TableNotFound { + catalog_name: String, + table_id: String, + }, + + #[snafu(display("Catalog not found: {}", name))] + CatalogNotFound { name: String }, + + #[snafu(display( + "Invalid table name (expected only alphanumeric characters and '_'): {}", + name + ))] + InvalidTableName { name: String }, + + #[cfg(feature = "python")] + #[snafu(display("Python error during {}: {}", context, source))] + PythonError { + source: pyo3::PyErr, + context: String, + }, +} + +pub type Result = std::result::Result; + +impl From for common_error::DaftError { + fn from(err: Error) -> Self { + match &err { + Error::TableNotFound { .. } + | Error::CatalogNotFound { .. } + | Error::InvalidTableName { .. } => { + common_error::DaftError::CatalogError(err.to_string()) + } + #[cfg(feature = "python")] + Error::PythonError { .. } => common_error::DaftError::CatalogError(err.to_string()), + } + } +} + +#[cfg(feature = "python")] +use pyo3::PyErr; + +#[cfg(feature = "python")] +impl From for PyErr { + fn from(value: Error) -> Self { + let daft_error: common_error::DaftError = value.into(); + daft_error.into() + } +} diff --git a/src/daft-catalog/src/lib.rs b/src/daft-catalog/src/lib.rs new file mode 100644 index 0000000000..8492e6ae37 --- /dev/null +++ b/src/daft-catalog/src/lib.rs @@ -0,0 +1,208 @@ +mod data_catalog; +mod data_catalog_table; +pub mod errors; + +// Export public-facing traits +use std::{collections::HashMap, default, sync::Arc}; + +use daft_logical_plan::LogicalPlanBuilder; +pub use data_catalog::DataCatalog; +pub use data_catalog_table::DataCatalogTable; + +#[cfg(feature = "python")] +pub mod python; + +use errors::{Error, Result}; + +pub mod global_catalog { + use std::sync::{Arc, RwLock}; + + use lazy_static::lazy_static; + + use crate::{DaftMetaCatalog, DataCatalog}; + + lazy_static! { + pub(crate) static ref GLOBAL_DAFT_META_CATALOG: RwLock = + RwLock::new(DaftMetaCatalog::new_from_env()); + } + + /// Register a DataCatalog with the global DaftMetaCatalog + pub fn register_catalog(catalog: Arc, name: Option<&str>) { + GLOBAL_DAFT_META_CATALOG + .write() + .unwrap() + .register_catalog(catalog, name); + } + + /// Unregisters a catalog with the global DaftMetaCatalog + pub fn unregister_catalog(name: Option<&str>) -> bool { + GLOBAL_DAFT_META_CATALOG + .write() + .unwrap() + .unregister_catalog(name) + } +} + +/// Name of the default catalog +static DEFAULT_CATALOG_NAME: &str = "default"; + +/// The [`DaftMetaCatalog`] is a catalog of [`DataCatalog`] implementations +/// +/// Users of Daft can register various [`DataCatalog`] with Daft, enabling +/// discovery of tables across various [`DataCatalog`] implementations. +pub struct DaftMetaCatalog { + /// Map of catalog names to the DataCatalog impls. + /// + /// NOTE: The default catalog is always named "default" + data_catalogs: HashMap>, + + /// LogicalPlans that were "named" and registered with Daft + named_tables: HashMap, +} + +impl DaftMetaCatalog { + /// Create a `DaftMetaCatalog` from the current environment + pub fn new_from_env() -> Self { + // TODO: Parse a YAML file to produce the catalog + DaftMetaCatalog { + data_catalogs: default::Default::default(), + named_tables: default::Default::default(), + } + } + + /// Register a new [`DataCatalog`] with the `DaftMetaCatalog`. + /// + /// # Arguments + /// + /// * `catalog` - The [`DataCatalog`] to register. + pub fn register_catalog(&mut self, catalog: Arc, name: Option<&str>) { + let name = name.unwrap_or(DEFAULT_CATALOG_NAME); + self.data_catalogs.insert(name.to_string(), catalog); + } + + /// Unregister a [`DataCatalog`] from the `DaftMetaCatalog`. + /// + /// # Arguments + /// + /// * `name` - The name of the catalog to unregister. If None, the default catalog will be unregistered. + /// + /// # Returns + /// + /// Returns `true` if a catalog was successfully unregistered, `false` otherwise. + pub fn unregister_catalog(&mut self, name: Option<&str>) -> bool { + let name = name.unwrap_or(DEFAULT_CATALOG_NAME); + self.data_catalogs.remove(name).is_some() + } + + /// Registers a LogicalPlan with a name in the DaftMetaCatalog + pub fn register_named_table(&mut self, name: &str, view: LogicalPlanBuilder) -> Result<()> { + if !name.chars().all(|c| c.is_alphanumeric() || c == '_') { + return Err(Error::InvalidTableName { + name: name.to_string(), + }); + } + self.named_tables.insert(name.to_string(), view); + Ok(()) + } + + /// Provides high-level functionality for reading a table of data against a [`DaftMetaCatalog`] + /// + /// Resolves the provided table_identifier against the catalog: + /// + /// 1. If there is an exact match for the provided `table_identifier` in the catalog's registered named tables, immediately return the exact match + /// 2. If the [`DaftMetaCatalog`] has a default catalog, we will attempt to resolve the `table_identifier` against the default catalog + /// 3. If the `table_identifier` is hierarchical (delimited by "."), use the first component as the Data Catalog name and resolve the rest of the components against + /// the selected Data Catalog + pub fn read_table(&self, table_identifier: &str) -> errors::Result { + // If the name is an exact match with a registered view, return it. + if let Some(view) = self.named_tables.get(table_identifier) { + return Ok(view.clone()); + } + + let mut searched_catalog_name = "default"; + let mut searched_table_name = table_identifier; + + // Check the default catalog for a match + if let Some(default_data_catalog) = self.data_catalogs.get(DEFAULT_CATALOG_NAME) { + if let Some(tbl) = default_data_catalog.get_table(table_identifier)? { + return tbl.as_ref().to_logical_plan_builder(); + } + } + + // Try to parse the catalog name from the provided table identifier by taking the first segment, split by '.' + if let Some((catalog_name, table_name)) = table_identifier.split_once('.') { + if let Some(data_catalog) = self.data_catalogs.get(catalog_name) { + searched_catalog_name = catalog_name; + searched_table_name = table_name; + if let Some(tbl) = data_catalog.get_table(table_name)? { + return tbl.as_ref().to_logical_plan_builder(); + } + } + } + + // Return the error containing the last catalog/table pairing that we attempted to search on + Err(Error::TableNotFound { + catalog_name: searched_catalog_name.to_string(), + table_id: searched_table_name.to_string(), + }) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use daft_core::prelude::*; + use daft_logical_plan::{ + ops::Source, source_info::PlaceHolderInfo, ClusteringSpec, LogicalPlan, LogicalPlanRef, + SourceInfo, + }; + + use super::*; + + fn mock_plan() -> LogicalPlanRef { + let schema = Arc::new( + Schema::new(vec![ + Field::new("text", DataType::Utf8), + Field::new("id", DataType::Int32), + ]) + .unwrap(), + ); + LogicalPlan::Source(Source { + output_schema: schema.clone(), + source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + source_schema: schema, + clustering_spec: Arc::new(ClusteringSpec::unknown()), + source_id: 0, + })), + }) + .arced() + } + + #[test] + fn test_register_and_unregister_named_table() { + let mut catalog = DaftMetaCatalog::new_from_env(); + let plan = LogicalPlanBuilder::new(mock_plan(), None); + + // Register a table + assert!(catalog + .register_named_table("test_table", plan.clone()) + .is_ok()); + + // Try to register a table with invalid name + assert!(catalog + .register_named_table("invalid name", plan.clone()) + .is_err()); + } + + #[test] + fn test_read_registered_table() { + let mut catalog = DaftMetaCatalog::new_from_env(); + let plan = LogicalPlanBuilder::new(mock_plan(), None); + + catalog.register_named_table("test_table", plan).unwrap(); + + assert!(catalog.read_table("test_table").is_ok()); + assert!(catalog.read_table("non_existent_table").is_err()); + } +} diff --git a/src/daft-catalog/src/python.rs b/src/daft-catalog/src/python.rs new file mode 100644 index 0000000000..cc8848078f --- /dev/null +++ b/src/daft-catalog/src/python.rs @@ -0,0 +1,97 @@ +use daft_logical_plan::PyLogicalPlanBuilder; +use pyo3::prelude::*; + +use crate::global_catalog; + +/// Read a table from the specified `DaftMetaCatalog`. +/// +/// This function reads a table from a `DaftMetaCatalog` and returns a PyLogicalPlanBuilder +/// object representing the plan required to read the table. +/// +/// The provided `table_identifier` can be: +/// +/// 1. Name of a registered dataframe/SQL view (manually registered using `DaftMetaCatalog.register_view`) +/// 2. Name of a table within the default catalog (without inputting the catalog name) for example: `"my.table.name"` +/// 3. Name of a fully-qualified table path with the catalog name for example: `"my_catalog.my.table.name"` +/// +/// Args: +/// table_identifier (str): The identifier of the table to read. +/// +/// Returns: +/// PyLogicalPlanBuilder: A PyLogicalPlanBuilder object representing the table's data. +/// +/// Raises: +/// DaftError: If the table cannot be read or the specified table identifier is not found. +/// +/// Example: +/// >>> import daft +/// >>> df = daft.read_table("foo") +#[pyfunction] +#[pyo3(name = "read_table")] +fn py_read_table(table_identifier: &str) -> PyResult { + let logical_plan_builder = global_catalog::GLOBAL_DAFT_META_CATALOG + .read() + .unwrap() + .read_table(table_identifier)?; + Ok(PyLogicalPlanBuilder::new(logical_plan_builder)) +} + +/// Register a table with the global catalog. +/// +/// This function registers a table with the global `DaftMetaCatalog` using the provided +/// table identifier and logical plan. +/// +/// Args: +/// table_identifier (str): The identifier to use for the registered table. +/// logical_plan (PyLogicalPlanBuilder): The logical plan representing the table's data. +/// +/// Returns: +/// str: The table identifier used for registration. +/// +/// Example: +/// >>> import daft +/// >>> df = daft.read_csv("data.csv") +/// >>> daft.register_table("my_table", df) +#[pyfunction] +#[pyo3(name = "register_table")] +fn py_register_table( + table_identifier: &str, + logical_plan: &PyLogicalPlanBuilder, +) -> PyResult { + global_catalog::GLOBAL_DAFT_META_CATALOG + .write() + .unwrap() + .register_named_table(table_identifier, logical_plan.builder.clone())?; + Ok(table_identifier.to_string()) +} + +/// Unregisters a catalog from the Daft catalog system +/// +/// This function removes a previously registered catalog from the Daft catalog system. +/// +/// Args: +/// catalog_name (Optional[str]): The name of the catalog to unregister. If None, the default catalog will be unregistered. +/// +/// Returns: +/// bool: True if a catalog was successfully unregistered, False otherwise. +/// +/// Example: +/// >>> import daft +/// >>> daft.unregister_catalog("my_catalog") +/// True +#[pyfunction] +#[pyo3(name = "unregister_catalog")] +pub fn py_unregister_catalog(catalog_name: Option<&str>) -> bool { + crate::global_catalog::unregister_catalog(catalog_name) +} + +pub fn register_modules<'py>(parent: &Bound<'py, PyModule>) -> PyResult> { + let module = PyModule::new_bound(parent.py(), "catalog")?; + + module.add_wrapped(wrap_pyfunction!(py_read_table))?; + module.add_wrapped(wrap_pyfunction!(py_register_table))?; + module.add_wrapped(wrap_pyfunction!(py_unregister_catalog))?; + + parent.add_submodule(&module)?; + Ok(module) +} diff --git a/src/lib.rs b/src/lib.rs index 2b3efc2b4c..a44a2d632c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -121,6 +121,10 @@ pub mod pylib { daft_functions_json::register_modules(m)?; daft_connect::register_modules(m)?; + // Register catalog module + let catalog_module = daft_catalog::python::register_modules(m)?; + daft_catalog_python_catalog::python::register_modules(&catalog_module)?; + m.add_wrapped(wrap_pyfunction!(version))?; m.add_wrapped(wrap_pyfunction!(build_type))?; m.add_wrapped(wrap_pyfunction!(refresh_logger))?; diff --git a/tests/integration/iceberg/conftest.py b/tests/integration/iceberg/conftest.py index 79a59e8ade..6550571260 100644 --- a/tests/integration/iceberg/conftest.py +++ b/tests/integration/iceberg/conftest.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Generator, TypeVar +from typing import Generator, Iterator, TypeVar import pyarrow as pa import pytest +import daft +import daft.catalog + pyiceberg = pytest.importorskip("pyiceberg") PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) @@ -44,18 +47,8 @@ ] -@tenacity.retry( - stop=tenacity.stop_after_delay(60), - retry=tenacity.retry_if_exception_type(pyiceberg.exceptions.NoSuchTableError), - wait=tenacity.wait_fixed(5), - reraise=True, -) -def _load_table(catalog, name) -> Table: - return catalog.load_table(f"default.{name}") - - @pytest.fixture(scope="session") -def local_iceberg_catalog() -> Catalog: +def local_iceberg_catalog() -> Iterator[tuple[str, Catalog]]: cat = load_catalog( "local", **{ @@ -66,23 +59,20 @@ def local_iceberg_catalog() -> Catalog: "s3.secret-access-key": "password", }, ) + # ensure all tables are available for name in local_tables_names: _load_table(cat, name) - return cat - - -@pytest.fixture(scope="session", params=local_tables_names) -def local_iceberg_tables(request, local_iceberg_catalog) -> Table: - NAMESPACE = "default" - table_name = request.param - return local_iceberg_catalog.load_table(f"{NAMESPACE}.{table_name}") + catalog_name = "_local_iceberg_catalog" + daft.catalog.register_python_catalog(cat, name=catalog_name) + yield catalog_name, cat + daft.catalog.unregister_catalog(catalog_name=catalog_name) @pytest.fixture(scope="session") -def cloud_iceberg_catalog() -> Catalog: - return load_catalog( +def azure_iceberg_catalog() -> Iterator[tuple[str, Catalog]]: + cat = load_catalog( "default", **{ "uri": "sqlite:///tests/assets/pyiceberg_catalog.db", @@ -90,7 +80,29 @@ def cloud_iceberg_catalog() -> Catalog: }, ) + catalog_name = "_azure_iceberg_catalog" + daft.catalog.register_python_catalog(cat, name=catalog_name) + yield catalog_name, cat + daft.catalog.unregister_catalog(catalog_name=catalog_name) + + +@tenacity.retry( + stop=tenacity.stop_after_delay(60), + retry=tenacity.retry_if_exception_type(pyiceberg.exceptions.NoSuchTableError), + wait=tenacity.wait_fixed(5), + reraise=True, +) +def _load_table(catalog, name) -> Table: + return catalog.load_table(f"default.{name}") + + +@pytest.fixture(scope="function", params=local_tables_names) +def local_iceberg_tables(request, local_iceberg_catalog) -> Iterator[str]: + NAMESPACE = "default" + table_name = request.param + yield f"{NAMESPACE}.{table_name}" + -@pytest.fixture(scope="session", params=cloud_tables_names) -def cloud_iceberg_table(request, cloud_iceberg_catalog) -> Table: - return cloud_iceberg_catalog.load_table(request.param) +@pytest.fixture(scope="function", params=cloud_tables_names) +def azure_iceberg_table(request, azure_iceberg_catalog) -> Iterator[str]: + yield request.param diff --git a/tests/integration/iceberg/test_cloud_load.py b/tests/integration/iceberg/test_cloud_load.py index 210bb36b7d..32be0b85ef 100644 --- a/tests/integration/iceberg/test_cloud_load.py +++ b/tests/integration/iceberg/test_cloud_load.py @@ -7,8 +7,9 @@ @pytest.mark.integration -def test_daft_iceberg_cloud_table_load(cloud_iceberg_table): - df = daft.read_iceberg(cloud_iceberg_table) +def test_daft_iceberg_cloud_table_load(azure_iceberg_table, azure_iceberg_catalog): + catalog_name, pyiceberg_catalog = azure_iceberg_catalog + df = daft.read_table(f"{catalog_name}.{azure_iceberg_table}") daft_pandas = df.to_pandas() - iceberg_pandas = cloud_iceberg_table.scan().to_arrow().to_pandas() + iceberg_pandas = pyiceberg_catalog.load_table(azure_iceberg_table).scan().to_arrow().to_pandas() assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) diff --git a/tests/integration/iceberg/test_partition_pruning.py b/tests/integration/iceberg/test_partition_pruning.py index 70737111d6..05d4059f39 100644 --- a/tests/integration/iceberg/test_partition_pruning.py +++ b/tests/integration/iceberg/test_partition_pruning.py @@ -16,8 +16,9 @@ @pytest.mark.integration() def test_daft_iceberg_table_predicate_pushdown_days(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_partitioned_by_days") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_partitioned_by_days") + df = daft.read_table(f"{catalog_name}.default.test_partitioned_by_days") df = df.where(df["ts"] < date(2023, 3, 6)) df.collect() daft_pandas = df.to_pandas() @@ -70,8 +71,9 @@ def udf_func(obj): ), ) def test_daft_iceberg_table_predicate_pushdown_on_date_column(predicate, table, limit, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table}") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(f"default.{table}") + df = daft.read_table(f"{catalog_name}.default.{table}") df = df.where(predicate(df["dt"])) if limit: df = df.limit(limit) @@ -110,8 +112,9 @@ def test_daft_iceberg_table_predicate_pushdown_on_date_column(predicate, table, ), ) def test_daft_iceberg_table_predicate_pushdown_on_timestamp_column(predicate, table, limit, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table}") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(f"default.{table}") + df = daft.read_table(f"{catalog_name}.default.{table}") df = df.where(predicate(df["ts"])) if limit: df = df.limit(limit) @@ -151,8 +154,9 @@ def test_daft_iceberg_table_predicate_pushdown_on_timestamp_column(predicate, ta ), ) def test_daft_iceberg_table_predicate_pushdown_on_letter(predicate, table, limit, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table}") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(f"default.{table}") + df = daft.read_table(f"{catalog_name}.default.{table}") df = df.where(predicate(df["letter"])) if limit: df = df.limit(limit) @@ -191,8 +195,9 @@ def test_daft_iceberg_table_predicate_pushdown_on_letter(predicate, table, limit ), ) def test_daft_iceberg_table_predicate_pushdown_on_number(predicate, table, limit, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table}") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(f"default.{table}") + df = daft.read_table(f"{catalog_name}.default.{table}") df = df.where(predicate(df["number"])) if limit: df = df.limit(limit) @@ -208,7 +213,8 @@ def test_daft_iceberg_table_predicate_pushdown_on_number(predicate, table, limit @pytest.mark.integration() def test_daft_iceberg_table_predicate_pushdown_empty_scan(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_partitioned_by_months") + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_partitioned_by_months") df = daft.read_iceberg(tab) df = df.where(df["dt"] > date(2030, 1, 1)) df.collect() diff --git a/tests/integration/iceberg/test_pyiceberg_written_table_load.py b/tests/integration/iceberg/test_pyiceberg_written_table_load.py index e7d1f982c3..48bd261a75 100644 --- a/tests/integration/iceberg/test_pyiceberg_written_table_load.py +++ b/tests/integration/iceberg/test_pyiceberg_written_table_load.py @@ -13,53 +13,57 @@ @contextlib.contextmanager -def table_written_by_pyiceberg(local_iceberg_catalog): +def table_written_by_pyiceberg(local_pyiceberg_catalog): schema = pa.schema([("col", pa.int64()), ("mapCol", pa.map_(pa.int32(), pa.string()))]) data = {"col": [1, 2, 3], "mapCol": [[(1, "foo"), (2, "bar")], [(3, "baz")], [(4, "foobar")]]} arrow_table = pa.Table.from_pydict(data, schema=schema) + table_name = "pyiceberg.map_table" try: - table = local_iceberg_catalog.create_table("pyiceberg.map_table", schema=schema) + table = local_pyiceberg_catalog.create_table(table_name, schema=schema) table.append(arrow_table) - yield table + yield table_name except Exception as e: raise e finally: - local_iceberg_catalog.drop_table("pyiceberg.map_table") + local_pyiceberg_catalog.drop_table(table_name) @contextlib.contextmanager -def table_written_by_daft(local_iceberg_catalog): +def table_written_by_daft(local_pyiceberg_catalog): schema = pa.schema([("col", pa.int64()), ("mapCol", pa.map_(pa.int32(), pa.string()))]) data = {"col": [1, 2, 3], "mapCol": [[(1, "foo"), (2, "bar")], [(3, "baz")], [(4, "foobar")]]} arrow_table = pa.Table.from_pydict(data, schema=schema) + table_name = "pyiceberg.map_table" try: - table = local_iceberg_catalog.create_table("pyiceberg.map_table", schema=schema) + table = local_pyiceberg_catalog.create_table(table_name, schema=schema) df = daft.from_arrow(arrow_table) df.write_iceberg(table, mode="overwrite") table.refresh() - yield table + yield table_name except Exception as e: raise e finally: - local_iceberg_catalog.drop_table("pyiceberg.map_table") + local_pyiceberg_catalog.drop_table(table_name) @pytest.mark.integration() def test_pyiceberg_written_catalog(local_iceberg_catalog): - with table_written_by_pyiceberg(local_iceberg_catalog) as catalog_table: - df = daft.read_iceberg(catalog_table) + catalog_name, local_pyiceberg_catalog = local_iceberg_catalog + with table_written_by_pyiceberg(local_pyiceberg_catalog) as catalog_table_name: + df = daft.read_table(f"{catalog_name}.{catalog_table_name}") daft_pandas = df.to_pandas() - iceberg_pandas = catalog_table.scan().to_arrow().to_pandas() + iceberg_pandas = local_pyiceberg_catalog.load_table(catalog_table_name).scan().to_arrow().to_pandas() assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) @pytest.mark.integration() @pytest.mark.skip def test_daft_written_catalog(local_iceberg_catalog): - with table_written_by_daft(local_iceberg_catalog) as catalog_table: - df = daft.read_iceberg(catalog_table) + catalog_name, local_pyiceberg_catalog = local_iceberg_catalog + with table_written_by_daft(local_pyiceberg_catalog) as catalog_table_name: + df = daft.read_table(f"{catalog_name}.{catalog_table_name}") daft_pandas = df.to_pandas() - iceberg_pandas = catalog_table.scan().to_arrow().to_pandas() + iceberg_pandas = local_pyiceberg_catalog.load_table(catalog_table_name).scan().to_arrow().to_pandas() assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) diff --git a/tests/integration/iceberg/test_table_load.py b/tests/integration/iceberg/test_table_load.py index 5cc1f2bf7d..36ddb47f8b 100644 --- a/tests/integration/iceberg/test_table_load.py +++ b/tests/integration/iceberg/test_table_load.py @@ -15,9 +15,11 @@ @pytest.mark.integration() -def test_daft_iceberg_table_open(local_iceberg_tables): - df = daft.read_iceberg(local_iceberg_tables) - iceberg_schema = local_iceberg_tables.schema() +def test_daft_iceberg_table_open(local_iceberg_tables, local_iceberg_catalog): + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(local_iceberg_tables) + df = daft.read_table(f"{catalog_name}.{local_iceberg_tables}") + iceberg_schema = tab.schema() as_pyarrow_schema = schema_to_pyarrow(iceberg_schema) as_daft_schema = Schema.from_pyarrow_schema(as_pyarrow_schema) assert df.schema() == as_daft_schema @@ -51,16 +53,17 @@ def test_daft_iceberg_table_open(local_iceberg_tables): @pytest.mark.integration() @pytest.mark.parametrize("table_name", WORKING_SHOW_COLLECT) def test_daft_iceberg_table_show(table_name, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table_name}") - df = daft.read_iceberg(tab) + catalog_name, _ = local_iceberg_catalog + df = daft.read_table(f"{catalog_name}.default.{table_name}") df.show() @pytest.mark.integration() @pytest.mark.parametrize("table_name", WORKING_SHOW_COLLECT) def test_daft_iceberg_table_collect_correct(table_name, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table_name}") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(f"default.{table_name}") + df = daft.read_table(f"{catalog_name}.default.{table_name}") df.collect() daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() @@ -69,8 +72,9 @@ def test_daft_iceberg_table_collect_correct(table_name, local_iceberg_catalog): @pytest.mark.integration() def test_daft_iceberg_table_renamed_filtered_collect_correct(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_table_rename") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_table_rename") + df = daft.read_table(f"{catalog_name}.default.test_table_rename") df = df.where(df["idx_renamed"] <= 1) daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() @@ -80,8 +84,9 @@ def test_daft_iceberg_table_renamed_filtered_collect_correct(local_iceberg_catal @pytest.mark.integration() def test_daft_iceberg_table_renamed_column_pushdown_collect_correct(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_table_rename") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_table_rename") + df = daft.read_table(f"{catalog_name}.default.test_table_rename") df = df.select("idx_renamed") daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() @@ -91,8 +96,9 @@ def test_daft_iceberg_table_renamed_column_pushdown_collect_correct(local_iceber @pytest.mark.integration() def test_daft_iceberg_table_read_partition_column_identity(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_partitioned_by_identity") + df = daft.read_table(f"{catalog_name}.default.test_partitioned_by_identity") df = df.select("ts", "number") daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() @@ -102,8 +108,9 @@ def test_daft_iceberg_table_read_partition_column_identity(local_iceberg_catalog @pytest.mark.integration() def test_daft_iceberg_table_read_partition_column_identity_filter(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_partitioned_by_identity") + df = daft.read_table(f"{catalog_name}.default.test_partitioned_by_identity") df = df.where(df["number"] > 0) df = df.select("ts") daft_pandas = df.to_pandas() @@ -118,8 +125,9 @@ def test_daft_iceberg_table_read_partition_column_identity_filter(local_iceberg_ ) @pytest.mark.integration() def test_daft_iceberg_table_read_partition_column_identity_filter_on_partkey(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_partitioned_by_identity") + df = daft.read_table(f"{catalog_name}.default.test_partitioned_by_identity") df = df.select("ts") df = df.where(df["ts"] > datetime.date(2022, 3, 1)) daft_pandas = df.to_pandas() @@ -134,8 +142,9 @@ def test_daft_iceberg_table_read_partition_column_identity_filter_on_partkey(loc ) @pytest.mark.integration() def test_daft_iceberg_table_read_partition_column_identity_only(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_partitioned_by_identity") + df = daft.read_table(f"{catalog_name}.default.test_partitioned_by_identity") df = df.select("ts") daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() @@ -145,8 +154,9 @@ def test_daft_iceberg_table_read_partition_column_identity_only(local_iceberg_ca @pytest.mark.integration() def test_daft_iceberg_table_read_partition_column_transformed(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_partitioned_by_bucket") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_partitioned_by_bucket") + df = daft.read_table(f"{catalog_name}.default.test_partitioned_by_bucket") df = df.select("number") daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() @@ -156,11 +166,14 @@ def test_daft_iceberg_table_read_partition_column_transformed(local_iceberg_cata @pytest.mark.integration() def test_daft_iceberg_table_read_table_snapshot(local_iceberg_catalog): - tab = local_iceberg_catalog.load_table("default.test_snapshotting") + _catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table("default.test_snapshotting") + snapshots = tab.history() assert len(snapshots) == 2 for snapshot in snapshots: + # TODO: Provide read_table API for reading iceberg with snapshot ID daft_pandas = daft.read_iceberg(tab, snapshot_id=snapshot.snapshot_id).to_pandas() iceberg_pandas = tab.scan(snapshot_id=snapshot.snapshot_id).to_pandas() assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) @@ -169,8 +182,9 @@ def test_daft_iceberg_table_read_table_snapshot(local_iceberg_catalog): @pytest.mark.integration() @pytest.mark.parametrize("table_name", ["test_positional_mor_deletes", "test_positional_mor_double_deletes"]) def test_daft_iceberg_table_mor_limit_collect_correct(table_name, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table_name}") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(f"default.{table_name}") + df = daft.read_table(f"{catalog_name}.default.{table_name}") df = df.limit(10) df.collect() daft_pandas = df.to_pandas() @@ -185,8 +199,9 @@ def test_daft_iceberg_table_mor_limit_collect_correct(table_name, local_iceberg_ @pytest.mark.integration() @pytest.mark.parametrize("table_name", ["test_positional_mor_deletes", "test_positional_mor_double_deletes"]) def test_daft_iceberg_table_mor_predicate_collect_correct(table_name, local_iceberg_catalog): - tab = local_iceberg_catalog.load_table(f"default.{table_name}") - df = daft.read_iceberg(tab) + catalog_name, pyiceberg_catalog = local_iceberg_catalog + tab = pyiceberg_catalog.load_table(f"default.{table_name}") + df = daft.read_table(f"{catalog_name}.default.{table_name}") df = df.where(df["number"] > 5) df.collect() daft_pandas = df.to_pandas()