diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index 274330d821..eb96364d13 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -2,10 +2,63 @@ from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.partitioning import PartitionField as IcebergPartitionField +from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec +from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import Table -from daft.io.scan import ScanOperator -from daft.logical.schema import Schema +from daft.datatype import DataType +from daft.expressions.expressions import col +from daft.io.scan import PartitionField, ScanOperator +from daft.logical.schema import Field, Schema + + +def _iceberg_partition_field_to_daft_partition_field( + iceberg_schema: IcebergSchema, pfield: IcebergPartitionField +) -> PartitionField: + name = pfield.name + source_id = pfield.source_id + source_field = iceberg_schema.find_field(source_id) + source_name = source_field.name + daft_field = Field.create( + source_name, DataType.from_arrow_type(schema_to_pyarrow(iceberg_schema.find_type(source_name))) + ) + transform = pfield.transform + iceberg_result_type = transform.result_type(source_field.field_type) + arrow_result_type = schema_to_pyarrow(iceberg_result_type) + daft_result_type = DataType.from_arrow_type(arrow_result_type) + result_field = Field.create(name, daft_result_type) + + from pyiceberg.transforms import ( + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + YearTransform, + ) + + expr = None + if isinstance(transform, IdentityTransform): + expr = col(source_name) + if source_name != name: + expr = expr.alias(name) + elif isinstance(transform, YearTransform): + expr = col(source_name).dt.year().alias(name) + elif isinstance(transform, MonthTransform): + expr = col(source_name).dt.month().alias(name) + elif isinstance(transform, DayTransform): + expr = col(source_name).dt.day().alias(name) + elif isinstance(transform, HourTransform): + raise NotImplementedError("HourTransform not implemented, Please make an issue!") + else: + raise NotImplementedError(f"{transform} not implemented, Please make an issue!") + + assert expr is not None + return PartitionField(result_field, daft_field, transform=expr) + + +def iceberg_partition_spec_to_fields(iceberg_schema: IcebergSchema, spec: IcebergPartitionSpec) -> list[PartitionField]: + return [_iceberg_partition_field_to_daft_partition_field(iceberg_schema, field) for field in spec.fields] class IcebergScanOperator(ScanOperator): @@ -14,10 +67,14 @@ def __init__(self, iceberg_table: Table) -> None: self._table = iceberg_table arrow_schema = schema_to_pyarrow(iceberg_table.schema()) self._schema = Schema.from_pyarrow_schema(arrow_schema) + self._partition_keys = iceberg_partition_spec_to_fields(self._table.schema(), self._table.spec()) def schema(self) -> Schema: return self._schema + def partitioning_keys(self) -> list[PartitionField]: + return self._partition_keys + def catalog() -> Catalog: return load_catalog( diff --git a/daft/io/scan.py b/daft/io/scan.py index 93708e2048..afaaf7e08f 100644 --- a/daft/io/scan.py +++ b/daft/io/scan.py @@ -3,7 +3,8 @@ import abc from dataclasses import dataclass -from daft.logical.schema import Schema +from daft.expressions.expressions import Expression +from daft.logical.schema import Field, Schema @dataclass(frozen=True) @@ -13,14 +14,21 @@ class ScanTask: limit: int | None +@dataclass(frozen=True) +class PartitionField: + field: Field + source_field: Field + transform: Expression + + class ScanOperator(abc.ABC): @abc.abstractmethod def schema(self) -> Schema: raise NotImplementedError() - # @abc.abstractmethod - # def partitioning_keys(self) -> list[Field]: - # raise NotImplementedError() + @abc.abstractmethod + def partitioning_keys(self) -> list[PartitionField]: + raise NotImplementedError() # @abc.abstractmethod # def num_partitions(self) -> int: diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 7f33be83b9..89fad30b31 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -33,6 +33,11 @@ def _from_pyfield(field: _PyField) -> Field: f._field = field return f + @staticmethod + def create(name: str, dtype: DataType) -> Field: + pyfield = _PyField.create(name, dtype._dtype) + return Field._from_pyfield(pyfield) + @property def name(self): return self._field.name() diff --git a/src/daft-core/src/python/field.rs b/src/daft-core/src/python/field.rs index daae3e63b2..6529edd863 100644 --- a/src/daft-core/src/python/field.rs +++ b/src/daft-core/src/python/field.rs @@ -13,6 +13,11 @@ pub struct PyField { #[pymethods] impl PyField { + #[staticmethod] + pub fn create(name: &str, data_type: PyDataType) -> PyResult { + Ok(datatypes::Field::new(name, data_type.dtype).into()) + } + pub fn name(&self) -> PyResult { Ok(self.field.name.clone()) }