diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 56201514ef..896c7f3e8b 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -4,10 +4,7 @@ import pathlib import sys from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Generic, TypeVar - -if TYPE_CHECKING: - import fsspec +from typing import Generic, TypeVar if sys.version_info < (3, 8): from typing_extensions import Protocol @@ -22,6 +19,7 @@ JsonSourceConfig, ParquetSourceConfig, ResourceRequest, + StorageConfig, ) from daft.expressions import Expression, ExpressionsProjection, col from daft.logical.map_partition_ops import MapPartitionOp @@ -316,7 +314,7 @@ class ReadFile(SingleOutputInstruction): # Max number of rows to read. limit_rows: int | None schema: Schema - fs: fsspec.AbstractFileSystem | None + storage_config: StorageConfig columns_to_read: list[str] | None file_format_config: FileFormatConfig @@ -363,18 +361,18 @@ def _handle_tabular_files_scan( ) file_format = self.file_format_config.file_format() - config = self.file_format_config.config + format_config = self.file_format_config.config if file_format == FileFormat.Csv: - assert isinstance(config, CsvSourceConfig) + assert isinstance(format_config, CsvSourceConfig) table = Table.concat( [ table_io.read_csv( file=fp, schema=self.schema, - fs=self.fs, + storage_config=self.storage_config, csv_options=TableParseCSVOptions( - delimiter=config.delimiter, - header_index=0 if config.has_headers else None, + delimiter=format_config.delimiter, + header_index=0 if format_config.has_headers else None, ), read_options=read_options, ) @@ -382,29 +380,27 @@ def _handle_tabular_files_scan( ] ) elif file_format == FileFormat.Json: - assert isinstance(config, JsonSourceConfig) + assert isinstance(format_config, JsonSourceConfig) table = Table.concat( [ table_io.read_json( file=fp, schema=self.schema, - fs=self.fs, + storage_config=self.storage_config, read_options=read_options, ) for fp in filepaths ] ) elif file_format == FileFormat.Parquet: - assert isinstance(config, ParquetSourceConfig) + assert isinstance(format_config, ParquetSourceConfig) table = Table.concat( [ table_io.read_parquet( file=fp, schema=self.schema, - fs=self.fs, + storage_config=self.storage_config, read_options=read_options, - io_config=config.io_config, - use_native_downloader=config.use_native_downloader, ) for fp in filepaths ] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 041118e61a..1ff6de696c 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -16,14 +16,17 @@ import math import pathlib from collections import deque -from typing import TYPE_CHECKING, Generator, Iterator, TypeVar, Union - -if TYPE_CHECKING: - import fsspec +from typing import Generator, Iterator, TypeVar, Union from loguru import logger -from daft.daft import FileFormat, FileFormatConfig, JoinType, ResourceRequest +from daft.daft import ( + FileFormat, + FileFormatConfig, + JoinType, + ResourceRequest, + StorageConfig, +) from daft.execution import execution_step from daft.execution.execution_step import ( Instruction, @@ -67,7 +70,7 @@ def file_read( # Max number of rows to read. limit_rows: int | None, schema: Schema, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, columns_to_read: list[str] | None, file_format_config: FileFormatConfig, ) -> InProgressPhysicalPlan[PartitionT]: @@ -99,7 +102,7 @@ def file_read( file_rows=file_rows[i], limit_rows=limit_rows, schema=schema, - fs=fs, + storage_config=storage_config, columns_to_read=columns_to_read, file_format_config=file_format_config, ), diff --git a/daft/execution/physical_plan_factory.py b/daft/execution/physical_plan_factory.py index 3189381c97..c90cd148ea 100644 --- a/daft/execution/physical_plan_factory.py +++ b/daft/execution/physical_plan_factory.py @@ -39,7 +39,7 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> child_plan=child_plan, limit_rows=node._limit_rows, schema=node._schema, - fs=node._fs, + storage_config=node._storage_config, columns_to_read=node._column_names, file_format_config=node._file_format_config, ) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 587679bc62..2ffa9057a3 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -11,6 +11,7 @@ PySchema, PyTable, ResourceRequest, + StorageConfig, ) from daft.execution import execution_step, physical_plan from daft.expressions import Expression, ExpressionsProjection @@ -26,6 +27,7 @@ def tabular_scan( columns_to_read: list[str], file_info_table: PyTable, file_format_config: FileFormatConfig, + storage_config: StorageConfig, limit: int, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: # TODO(Clark): Fix this Ray runner hack. @@ -43,7 +45,7 @@ def tabular_scan( child_plan=file_info_iter, limit_rows=limit, schema=Schema._from_pyschema(schema), - fs=None, + storage_config=storage_config, columns_to_read=columns_to_read, file_format_config=file_format_config, ) diff --git a/daft/filesystem.py b/daft/filesystem.py index 13df24d715..eb56e02a0a 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -25,7 +25,7 @@ _resolve_filesystem_and_path, ) -from daft.daft import FileFormat, FileFormatConfig, FileInfos, ParquetSourceConfig +from daft.daft import FileFormat, FileInfos, NativeStorageConfig, StorageConfig from daft.table import Table _CACHED_FSES: dict[str, FileSystem] = {} @@ -294,8 +294,9 @@ def _path_is_glob(path: str) -> bool: def glob_path_with_stats( path: str, - file_format_config: FileFormatConfig | None, + file_format: FileFormat | None, fs: fsspec.AbstractFileSystem, + storage_config: StorageConfig | None, ) -> FileInfos: """Glob a path, returning a list ListingInfo.""" protocol = get_protocol_from_path(path) @@ -326,10 +327,9 @@ def glob_path_with_stats( raise FileNotFoundError(f"File or directory not found: {path}") # Set number of rows if available. - if file_format_config is not None and file_format_config.file_format() == FileFormat.Parquet: - config = file_format_config.config - assert isinstance(config, ParquetSourceConfig) - if config.use_native_downloader: + if file_format is not None and file_format == FileFormat.Parquet: + config = storage_config.config if storage_config is not None else None + if config is not None and isinstance(config, NativeStorageConfig): parquet_statistics = Table.read_parquet_statistics( list(filepaths_to_infos.keys()), config.io_config ).to_pydict() diff --git a/daft/io/_csv.py b/daft/io/_csv.py index b99553a9b1..3aff76211d 100644 --- a/daft/io/_csv.py +++ b/daft/io/_csv.py @@ -5,7 +5,12 @@ import fsspec from daft.api_annotations import PublicAPI -from daft.daft import CsvSourceConfig, FileFormatConfig +from daft.daft import ( + CsvSourceConfig, + FileFormatConfig, + PythonStorageConfig, + StorageConfig, +) from daft.dataframe import DataFrame from daft.datatype import DataType from daft.io.common import _get_tabular_files_scan @@ -52,5 +57,6 @@ def read_csv( csv_config = CsvSourceConfig(delimiter=delimiter, has_headers=has_headers) file_format_config = FileFormatConfig.from_csv_config(csv_config) - builder = _get_tabular_files_scan(path, schema_hints, file_format_config, fs) + storage_config = StorageConfig.python(PythonStorageConfig(fs)) + builder = _get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) return DataFrame(builder) diff --git a/daft/io/_json.py b/daft/io/_json.py index f29aaa67bf..fc37ec4b82 100644 --- a/daft/io/_json.py +++ b/daft/io/_json.py @@ -5,7 +5,12 @@ import fsspec from daft.api_annotations import PublicAPI -from daft.daft import FileFormatConfig, JsonSourceConfig +from daft.daft import ( + FileFormatConfig, + JsonSourceConfig, + PythonStorageConfig, + StorageConfig, +) from daft.dataframe import DataFrame from daft.datatype import DataType from daft.io.common import _get_tabular_files_scan @@ -40,5 +45,6 @@ def read_json( json_config = JsonSourceConfig() file_format_config = FileFormatConfig.from_json_config(json_config) - builder = _get_tabular_files_scan(path, schema_hints, file_format_config, fs) + storage_config = StorageConfig.python(PythonStorageConfig(fs)) + builder = _get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) return DataFrame(builder) diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index 621e54f202..1e6558dd9b 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -5,7 +5,13 @@ import fsspec from daft.api_annotations import PublicAPI -from daft.daft import FileFormatConfig, ParquetSourceConfig +from daft.daft import ( + FileFormatConfig, + NativeStorageConfig, + ParquetSourceConfig, + PythonStorageConfig, + StorageConfig, +) from daft.dataframe import DataFrame from daft.datatype import DataType from daft.io.common import _get_tabular_files_scan @@ -47,8 +53,11 @@ def read_parquet( if isinstance(path, list) and len(path) == 0: raise ValueError(f"Cannot read DataFrame from from empty list of Parquet filepaths") - parquet_config = ParquetSourceConfig(use_native_downloader=use_native_downloader, io_config=io_config) - file_format_config = FileFormatConfig.from_parquet_config(parquet_config) + file_format_config = FileFormatConfig.from_parquet_config(ParquetSourceConfig()) + if use_native_downloader: + storage_config = StorageConfig.native(NativeStorageConfig(io_config)) + else: + storage_config = StorageConfig.python(PythonStorageConfig(fs)) - builder = _get_tabular_files_scan(path, schema_hints, file_format_config, fs) + builder = _get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config, fs=fs) return DataFrame(builder) diff --git a/daft/io/common.py b/daft/io/common.py index 03dd25471d..0027b75128 100644 --- a/daft/io/common.py +++ b/daft/io/common.py @@ -1,13 +1,16 @@ from __future__ import annotations -import fsspec +from typing import TYPE_CHECKING from daft.context import get_context -from daft.daft import FileFormatConfig, LogicalPlanBuilder +from daft.daft import FileFormatConfig, LogicalPlanBuilder, StorageConfig from daft.datatype import DataType from daft.logical.builder import LogicalPlanBuilder from daft.logical.schema import Schema +if TYPE_CHECKING: + import fsspec + def _get_schema_from_hints(hints: dict[str, DataType]) -> Schema: if isinstance(hints, dict): @@ -20,20 +23,21 @@ def _get_tabular_files_scan( path: str | list[str], schema_hints: dict[str, DataType] | None, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, + fs: fsspec.AbstractFileSystem | None = None, ) -> LogicalPlanBuilder: """Returns a TabularFilesScan LogicalPlan for a given glob filepath.""" paths = path if isinstance(path, list) else [str(path)] schema_hint = _get_schema_from_hints(schema_hints) if schema_hints is not None else None # Glob the path using the Runner runner_io = get_context().runner().runner_io() - file_infos = runner_io.glob_paths_details(paths, file_format_config, fs) + file_infos = runner_io.glob_paths_details(paths, file_format_config, fs=fs, storage_config=storage_config) # Infer schema if no hints provided inferred_or_provided_schema = ( schema_hint if schema_hint is not None - else runner_io.get_schema_from_first_filepath(file_infos, file_format_config, fs) + else runner_io.get_schema_from_first_filepath(file_infos, file_format_config, storage_config) ) # Construct plan builder_cls = get_context().logical_plan_builder_class() @@ -41,6 +45,6 @@ def _get_tabular_files_scan( file_infos=file_infos, schema=inferred_or_provided_schema, file_format_config=file_format_config, - fs=fs, + storage_config=storage_config, ) return builder diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 11b2740936..c1f56cb09e 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -4,8 +4,6 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING -import fsspec - from daft.daft import ( FileFormat, FileFormatConfig, @@ -14,6 +12,7 @@ PartitionScheme, PartitionSpec, ResourceRequest, + StorageConfig, ) from daft.expressions.expressions import Expression from daft.logical.schema import Schema @@ -84,7 +83,7 @@ def from_tabular_scan( file_infos: FileInfos, schema: Schema, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, ) -> LogicalPlanBuilder: pass diff --git a/daft/logical/logical_plan.py b/daft/logical/logical_plan.py index 1e5c94c19e..6a085f9336 100644 --- a/daft/logical/logical_plan.py +++ b/daft/logical/logical_plan.py @@ -7,8 +7,6 @@ from pprint import pformat from typing import TYPE_CHECKING, Any, Generic, TypeVar -import fsspec - from daft.context import get_context from daft.daft import ( FileFormat, @@ -18,6 +16,7 @@ PartitionScheme, PartitionSpec, ResourceRequest, + StorageConfig, ) from daft.datatype import DataType from daft.errors import ExpressionTypeError @@ -116,7 +115,7 @@ def from_tabular_scan( file_infos: FileInfos, schema: Schema, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, ) -> PyLogicalPlanBuilder: file_infos_table = Table._from_pytable(file_infos.to_table()) partition = LocalPartitionSet({0: file_infos_table}) @@ -132,7 +131,7 @@ def from_tabular_scan( predicate=None, columns=None, file_format_config=file_format_config, - fs=fs, + storage_config=storage_config, filepaths_child=filepath_plan, # WARNING: This is currently hardcoded to be the same number of partitions as rows!! This is because we emit # one partition per filepath. This will change in the future and our logic here should change accordingly. @@ -438,7 +437,7 @@ def __init__( *, schema: Schema, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, predicate: ExpressionsProjection | None = None, columns: list[str] | None = None, filepaths_child: LogicalPlan, @@ -465,7 +464,7 @@ def __init__( self._column_names = columns self._columns = self._schema self._file_format_config = file_format_config - self._fs = fs + self._storage_config = storage_config self._limit_rows = limit_rows self._register_child(filepaths_child) @@ -503,7 +502,7 @@ def rebuild(self) -> LogicalPlan: return TabularFilesScan( schema=self.schema(), file_format_config=self._file_format_config, - fs=self._fs, + storage_config=self._storage_config, predicate=self._predicate if self._predicate is not None else None, columns=self._column_names, filepaths_child=child, @@ -514,7 +513,7 @@ def copy_with_new_children(self, new_children: list[LogicalPlan]) -> LogicalPlan return TabularFilesScan( schema=self.schema(), file_format_config=self._file_format_config, - fs=self._fs, + storage_config=self._storage_config, predicate=self._predicate, columns=self._column_names, filepaths_child=new_children[0], diff --git a/daft/logical/optimizer.py b/daft/logical/optimizer.py index d25b3a2389..938178e9cd 100644 --- a/daft/logical/optimizer.py +++ b/daft/logical/optimizer.py @@ -333,7 +333,7 @@ def _push_down_local_limit_into_scan(self, parent: LocalLimit, child: TabularFil predicate=child._predicate, columns=child._column_names, file_format_config=child._file_format_config, - fs=child._fs, + storage_config=child._storage_config, filepaths_child=child._filepaths_child, limit_rows=new_limit_rows, ) @@ -356,7 +356,7 @@ def _push_down_projections_into_scan(self, parent: Projection, child: TabularFil predicate=child._predicate, columns=ordered_required_columns, file_format_config=child._file_format_config, - fs=child._fs, + storage_config=child._storage_config, filepaths_child=child._filepaths_child, ) if any(not e._is_column() for e in parent._projection): diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index 1193dc2f9f..83860aeb27 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -3,12 +3,10 @@ import pathlib from typing import TYPE_CHECKING -import fsspec - from daft import col from daft.daft import CountMode, FileFormat, FileFormatConfig, FileInfos, JoinType from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder -from daft.daft import PartitionScheme, PartitionSpec, ResourceRequest +from daft.daft import PartitionScheme, PartitionSpec, ResourceRequest, StorageConfig from daft.expressions.expressions import Expression from daft.logical.builder import LogicalPlanBuilder from daft.logical.schema import Schema @@ -63,11 +61,9 @@ def from_tabular_scan( file_infos: FileInfos, schema: Schema, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, ) -> RustLogicalPlanBuilder: - if fs is not None: - raise ValueError("fsspec filesystems not supported for Rust query planner.") - builder = _LogicalPlanBuilder.table_scan(file_infos, schema._schema, file_format_config) + builder = _LogicalPlanBuilder.table_scan(file_infos, schema._schema, file_format_config, storage_config) return cls(builder) def project( diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index e41878a6b3..1ad9dcd7e5 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -3,13 +3,18 @@ import multiprocessing from concurrent import futures from dataclasses import dataclass -from typing import Iterable, Iterator +from typing import TYPE_CHECKING, Iterable, Iterator -import fsspec import psutil from loguru import logger -from daft.daft import FileFormatConfig, FileInfos, ResourceRequest +from daft.daft import ( + FileFormatConfig, + FileInfos, + PythonStorageConfig, + ResourceRequest, + StorageConfig, +) from daft.execution import physical_plan from daft.execution.execution_step import Instruction, MaterializedResult, PartitionTask from daft.filesystem import get_filesystem_from_path, glob_path_with_stats @@ -27,6 +32,9 @@ from daft.runners.runner import Runner from daft.table import Table +if TYPE_CHECKING: + import fsspec + @dataclass class LocalPartitionSet(PartitionSet[Table]): @@ -73,13 +81,19 @@ def glob_paths_details( source_paths: list[str], file_format_config: FileFormatConfig | None = None, fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, ) -> FileInfos: + if fs is None and storage_config is not None: + config = storage_config.config + if isinstance(config, PythonStorageConfig): + fs = config.fs file_infos = FileInfos() + file_format = file_format_config.file_format() if file_format_config is not None else None for source_path in source_paths: if fs is None: fs = get_filesystem_from_path(source_path) - path_file_infos = glob_path_with_stats(source_path, file_format_config, fs) + path_file_infos = glob_path_with_stats(source_path, file_format, fs, storage_config) if len(path_file_infos) == 0: raise FileNotFoundError(f"No files found at {source_path}") @@ -92,12 +106,12 @@ def get_schema_from_first_filepath( self, file_infos: FileInfos, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, ) -> Schema: if len(file_infos) == 0: raise ValueError("No files to get schema from") # Naively retrieve the first filepath in the PartitionSet - return runner_io.sample_schema(file_infos[0].file_path, file_format_config, fs) + return runner_io.sample_schema(file_infos[0].file_path, file_format_config, storage_config) class PyRunner(Runner[Table]): diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index b3ff3b16d0..a980071959 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -8,7 +8,6 @@ from queue import Queue from typing import TYPE_CHECKING, Any, Iterable, Iterator -import fsspec import pyarrow as pa from loguru import logger @@ -23,7 +22,13 @@ ) raise -from daft.daft import FileFormatConfig, FileInfos, ResourceRequest +from daft.daft import ( + FileFormatConfig, + FileInfos, + PythonStorageConfig, + ResourceRequest, + StorageConfig, +) from daft.datatype import DataType from daft.execution.execution_step import ( FanoutInstruction, @@ -49,6 +54,7 @@ if TYPE_CHECKING: import dask + import fsspec import pandas as pd from ray.data.block import Block as RayDatasetBlock from ray.data.dataset import Dataset as RayDataset @@ -69,13 +75,18 @@ def _glob_path_into_file_infos( paths: list[str], file_format_config: FileFormatConfig | None, fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig | None, ) -> Table: + if fs is None and storage_config is not None: + config = storage_config.config + if isinstance(config, PythonStorageConfig): + fs = config.fs file_infos = FileInfos() for path in paths: if fs is None: fs = get_filesystem_from_path(path) - path_file_infos = glob_path_with_stats(path, file_format_config, fs) + path_file_infos = glob_path_with_stats(path, file_format_config, fs, storage_config) if len(path_file_infos) == 0: raise FileNotFoundError(f"No files found at {path}") file_infos.extend(path_file_infos) @@ -123,11 +134,11 @@ def remote_len_partition(p: Table) -> int: def sample_schema_from_filepath( first_file_path: str, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig | None, ) -> Schema: """Ray remote function to run schema sampling on top of a Table containing a single filepath""" # Currently just samples the Schema from the first file - return runner_io.sample_schema(first_file_path, file_format_config, fs) + return runner_io.sample_schema(first_file_path, file_format_config, storage_config) @dataclass @@ -208,15 +219,18 @@ def glob_paths_details( source_paths: list[str], file_format_config: FileFormatConfig | None = None, fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, ) -> FileInfos: # Synchronously fetch the file infos, for now. - return ray.get(_glob_path_into_file_infos.remote(source_paths, file_format_config, fs=fs)) + return ray.get( + _glob_path_into_file_infos.remote(source_paths, file_format_config, fs=fs, storage_config=storage_config) + ) def get_schema_from_first_filepath( self, file_infos: FileInfos, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig | None, ) -> Schema: if len(file_infos) == 0: raise ValueError("No files to get schema from") @@ -226,7 +240,7 @@ def get_schema_from_first_filepath( sample_schema_from_filepath.remote( first_path, file_format_config, - fs, + storage_config, ) ) diff --git a/daft/runners/runner_io.py b/daft/runners/runner_io.py index 45be394dbb..eff6c57caa 100644 --- a/daft/runners/runner_io.py +++ b/daft/runners/runner_io.py @@ -1,9 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import TypeVar - -import fsspec +from typing import TYPE_CHECKING, TypeVar from daft.daft import ( CsvSourceConfig, @@ -12,12 +10,15 @@ FileInfos, JsonSourceConfig, ParquetSourceConfig, + StorageConfig, ) -from daft.filesystem import get_filesystem_from_path from daft.logical.schema import Schema from daft.runners.partitioning import TableParseCSVOptions from daft.table import schema_inference +if TYPE_CHECKING: + import fsspec + PartitionT = TypeVar("PartitionT") @@ -33,6 +34,7 @@ def glob_paths_details( source_path: list[str], file_format_config: FileFormatConfig | None = None, fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, ) -> FileInfos: """Globs the specified filepath to construct a FileInfos object containing file and dir metadata. @@ -49,7 +51,7 @@ def get_schema_from_first_filepath( self, file_info: FileInfos, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, ) -> Schema: raise NotImplementedError() @@ -57,19 +59,16 @@ def get_schema_from_first_filepath( def sample_schema( filepath: str, file_format_config: FileFormatConfig, - fs: fsspec.AbstractFileSystem | None, + storage_config: StorageConfig, ) -> Schema: """Helper method that samples a schema from the specified source""" - if fs is None: - fs = get_filesystem_from_path(filepath) - file_format = file_format_config.file_format() config = file_format_config.config if file_format == FileFormat.Csv: assert isinstance(config, CsvSourceConfig) return schema_inference.from_csv( file=filepath, - fs=fs, + storage_config=storage_config, csv_options=TableParseCSVOptions( delimiter=config.delimiter, header_index=0 if config.has_headers else None, @@ -79,15 +78,13 @@ def sample_schema( assert isinstance(config, JsonSourceConfig) return schema_inference.from_json( file=filepath, - fs=fs, + storage_config=storage_config, ) elif file_format == FileFormat.Parquet: assert isinstance(config, ParquetSourceConfig) return schema_inference.from_parquet( file=filepath, - fs=fs, - io_config=config.io_config, - use_native_downloader=config.use_native_downloader, + storage_config=storage_config, ) else: raise NotImplementedError(f"Schema inference for {file_format} not implemented") diff --git a/daft/table/schema_inference.py b/daft/table/schema_inference.py index 13bbea3ce1..c9b06e651b 100644 --- a/daft/table/schema_inference.py +++ b/daft/table/schema_inference.py @@ -1,12 +1,12 @@ from __future__ import annotations import pathlib -from typing import TYPE_CHECKING import pyarrow.csv as pacsv import pyarrow.json as pajson import pyarrow.parquet as papq +from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.datatype import DataType from daft.filesystem import _resolve_paths_and_filesystem from daft.logical.schema import Schema @@ -14,15 +14,10 @@ from daft.table import Table from daft.table.table_io import FileInput, _open_stream -if TYPE_CHECKING: - import fsspec - - from daft.io import IOConfig - def from_csv( file: FileInput, - fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, csv_options: TableParseCSVOptions = TableParseCSVOptions(), ) -> Schema: """Infers a Schema from a CSV file @@ -35,10 +30,15 @@ def from_csv( Returns: Schema: Inferred Schema from the CSV """ - # Have PyArrow generate the column names if user specifies that there are no headers pyarrow_autogenerate_column_names = csv_options.header_index is None + if storage_config is not None: + config = storage_config.config + assert isinstance(config, PythonStorageConfig) + fs = config.fs + else: + fs = None with _open_stream(file, fs) as f: table = pacsv.read_csv( f, @@ -55,7 +55,7 @@ def from_csv( def from_json( file: FileInput, - fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, ) -> Schema: """Reads a Schema from a JSON file @@ -66,6 +66,12 @@ def from_json( Returns: Schema: Inferred Schema from the JSON """ + if storage_config is not None: + config = storage_config.config + assert isinstance(config, PythonStorageConfig) + fs = config.fs + else: + fs = None with _open_stream(file, fs) as f: table = pajson.read_json(f) @@ -74,14 +80,23 @@ def from_json( def from_parquet( file: FileInput, - fs: fsspec.AbstractFileSystem | None = None, - io_config: IOConfig | None = None, - use_native_downloader: bool = False, + storage_config: StorageConfig | None = None, ) -> Schema: """Infers a Schema from a Parquet file""" - if use_native_downloader: - assert isinstance(file, (str, pathlib.Path)) - return Schema.from_parquet(str(file), io_config=io_config) + if storage_config is not None: + config = storage_config.config + if isinstance(config, NativeStorageConfig): + assert isinstance( + file, (str, pathlib.Path) + ), "Native downloader only works on string inputs to read_parquet" + assert isinstance(file, (str, pathlib.Path)) + io_config = config.io_config + return Schema.from_parquet(str(file), io_config=io_config) + + assert isinstance(config, PythonStorageConfig) + fs = config.fs + else: + fs = None if not isinstance(file, (str, pathlib.Path)): # BytesIO path. diff --git a/daft/table/table_io.py b/daft/table/table_io.py index b68b8ba23f..fe94dc88a4 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -3,7 +3,7 @@ import contextlib import pathlib from collections.abc import Generator -from typing import IO, TYPE_CHECKING, Union +from typing import IO, Union from uuid import uuid4 import fsspec @@ -14,6 +14,7 @@ from pyarrow import parquet as papq from pyarrow.fs import FileSystem +from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.expressions import ExpressionsProjection from daft.filesystem import _resolve_paths_and_filesystem from daft.logical.schema import Schema @@ -24,9 +25,6 @@ ) from daft.table import Table -if TYPE_CHECKING: - from daft.io import IOConfig - FileInput = Union[pathlib.Path, str, IO[bytes]] @@ -69,7 +67,7 @@ def _cast_table_to_schema(table: Table, read_options: TableReadOptions, schema: def read_json( file: FileInput, schema: Schema, - fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, read_options: TableReadOptions = TableReadOptions(), ) -> Table: """Reads a Table from a JSON file @@ -83,6 +81,12 @@ def read_json( Returns: Table: Parsed Table from JSON """ + if storage_config is not None: + config = storage_config.config + assert isinstance(config, PythonStorageConfig) + fs = config.fs + else: + fs = None with _open_stream(file, fs) as f: table = pajson.read_json(f) @@ -99,11 +103,9 @@ def read_json( def read_parquet( file: FileInput, schema: Schema, - fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, read_options: TableReadOptions = TableReadOptions(), parquet_options: TableParseParquetOptions = TableParseParquetOptions(), - io_config: IOConfig | None = None, - use_native_downloader: bool = False, ) -> Table: """Reads a Table from a Parquet file @@ -116,16 +118,25 @@ def read_parquet( Returns: Table: Parsed Table from Parquet """ - if use_native_downloader: - assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_parquet" - tbl = Table.read_parquet( - str(file), - columns=read_options.column_names, - num_rows=read_options.num_rows, - io_config=io_config, - coerce_int96_timestamp_unit=parquet_options.coerce_int96_timestamp_unit, - ) - return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) + if storage_config is not None: + config = storage_config.config + if isinstance(config, NativeStorageConfig): + assert isinstance( + file, (str, pathlib.Path) + ), "Native downloader only works on string inputs to read_parquet" + tbl = Table.read_parquet( + str(file), + columns=read_options.column_names, + num_rows=read_options.num_rows, + io_config=config.io_config, + coerce_int96_timestamp_unit=parquet_options.coerce_int96_timestamp_unit, + ) + return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) + + assert isinstance(config, PythonStorageConfig) + fs = config.fs + else: + fs = None f: IO if not isinstance(file, (str, pathlib.Path)): @@ -167,7 +178,7 @@ def read_parquet( def read_csv( file: FileInput, schema: Schema, - fs: fsspec.AbstractFileSystem | None = None, + storage_config: StorageConfig | None = None, csv_options: TableParseCSVOptions = TableParseCSVOptions(), read_options: TableReadOptions = TableReadOptions(), ) -> Table: @@ -184,6 +195,12 @@ def read_csv( Returns: Table: Parsed Table from CSV """ + if storage_config is not None: + config = storage_config.config + assert isinstance(config, PythonStorageConfig) + fs = config.fs + else: + fs = None with _open_stream(file, fs) as f: table = pacsv.read_csv( diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 591abb2a19..6815ea9674 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -8,7 +8,7 @@ use crate::{ sink_info::{OutputFileInfo, SinkInfo}, source_info::{ ExternalInfo as ExternalSourceInfo, FileFormatConfig, FileInfos as InputFileInfos, - SourceInfo, + PyStorageConfig, SourceInfo, StorageConfig, }, FileFormat, JoinType, PartitionScheme, PartitionSpec, PhysicalPlanScheduler, ResourceRequest, }; @@ -64,14 +64,16 @@ impl LogicalPlanBuilder { file_infos: InputFileInfos, schema: Arc, file_format_config: Arc, + storage_config: Arc, ) -> DaftResult { - Self::table_scan_with_limit(file_infos, schema, file_format_config, None) + Self::table_scan_with_limit(file_infos, schema, file_format_config, storage_config, None) } pub fn table_scan_with_limit( file_infos: InputFileInfos, schema: Arc, file_format_config: Arc, + storage_config: Arc, limit: Option, ) -> DaftResult { let num_partitions = file_infos.len(); @@ -79,6 +81,7 @@ impl LogicalPlanBuilder { schema.clone(), file_infos.into(), file_format_config, + storage_config, )); let partition_spec = PartitionSpec::new_internal(PartitionScheme::Unknown, num_partitions, None); @@ -269,11 +272,15 @@ impl PyLogicalPlanBuilder { file_infos: InputFileInfos, schema: PySchema, file_format_config: PyFileFormatConfig, + storage_config: PyStorageConfig, ) -> PyResult { - Ok( - LogicalPlanBuilder::table_scan(file_infos, schema.into(), file_format_config.into())? - .into(), - ) + Ok(LogicalPlanBuilder::table_scan( + file_infos, + schema.into(), + file_format_config.into(), + storage_config.into(), + )? + .into()) } pub fn project( diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 0e81647a72..cf42fd592d 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -23,12 +23,12 @@ pub use partitioning::{PartitionScheme, PartitionSpec}; pub use physical_plan::PhysicalPlanScheduler; pub use resource_request::ResourceRequest; pub use source_info::{ - CsvSourceConfig, FileFormat, FileInfo, FileInfos, JsonSourceConfig, ParquetSourceConfig, - PyFileFormatConfig, + CsvSourceConfig, FileFormat, FileInfo, FileInfos, JsonSourceConfig, NativeStorageConfig, + ParquetSourceConfig, PyFileFormatConfig, PyStorageConfig, }; #[cfg(feature = "python")] -use pyo3::prelude::*; +use {pyo3::prelude::*, source_info::PythonStorageConfig}; #[cfg(feature = "python")] pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { @@ -45,6 +45,9 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; + parent.add_class::()?; + parent.add_class::()?; Ok(()) } diff --git a/src/daft-plan/src/ops/source.rs b/src/daft-plan/src/ops/source.rs index f5b8454d4a..684cc8c42e 100644 --- a/src/daft-plan/src/ops/source.rs +++ b/src/daft-plan/src/ops/source.rs @@ -69,6 +69,7 @@ impl Source { source_schema, file_infos, file_format_config, + storage_config, }) => { res.push(format!("Source: {:?}", file_format_config.var_name())); for fp in file_infos.file_paths.iter() { @@ -76,6 +77,7 @@ impl Source { } res.push(format!("File schema = {}", source_schema.short_string())); res.push(format!("Format-specific config = {:?}", file_format_config)); + res.push(format!("Storage config = {:?}", storage_config)); } #[cfg(feature = "python")] SourceInfo::InMemoryInfo(_) => {} diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 620b3518c6..2c87e143f5 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -3,7 +3,8 @@ use { crate::{ sink_info::OutputFileInfo, source_info::{ - ExternalInfo, FileFormat, FileFormatConfig, FileInfos, InMemoryInfo, PyFileFormatConfig, + ExternalInfo, FileFormat, FileFormatConfig, FileInfos, InMemoryInfo, + PyFileFormatConfig, PyStorageConfig, StorageConfig, }, }, daft_core::python::schema::PySchema, @@ -124,6 +125,7 @@ fn tabular_scan( projection_schema: &SchemaRef, file_infos: &Arc, file_format_config: &Arc, + storage_config: &Arc, limit: &Option, ) -> PyResult { let columns_to_read = projection_schema @@ -140,6 +142,7 @@ fn tabular_scan( columns_to_read, file_infos.to_table()?, PyFileFormatConfig::from(file_format_config.clone()), + PyStorageConfig::from(storage_config.clone()), *limit, ))?; Ok(py_iter.into()) @@ -203,6 +206,7 @@ impl PhysicalPlan { source_schema, file_infos, file_format_config, + storage_config, .. }, limit, @@ -213,6 +217,7 @@ impl PhysicalPlan { projection_schema, file_infos, file_format_config, + storage_config, limit, ), PhysicalPlan::TabularScanCsv(TabularScanCsv { @@ -222,6 +227,7 @@ impl PhysicalPlan { source_schema, file_infos, file_format_config, + storage_config, .. }, limit, @@ -232,6 +238,7 @@ impl PhysicalPlan { projection_schema, file_infos, file_format_config, + storage_config, limit, ), PhysicalPlan::TabularScanJson(TabularScanJson { @@ -241,6 +248,7 @@ impl PhysicalPlan { source_schema, file_infos, file_format_config, + storage_config, .. }, limit, @@ -251,6 +259,7 @@ impl PhysicalPlan { projection_schema, file_infos, file_format_config, + storage_config, limit, ), PhysicalPlan::Project(Project { diff --git a/src/daft-plan/src/source_info.rs b/src/daft-plan/src/source_info.rs index 4e75150389..4dc248a802 100644 --- a/src/daft-plan/src/source_info.rs +++ b/src/daft-plan/src/source_info.rs @@ -1,7 +1,4 @@ -use std::{ - hash::{Hash, Hasher}, - sync::Arc, -}; +use std::{hash::Hash, sync::Arc}; use arrow2::array::Array; use common_error::DaftResult; @@ -15,7 +12,7 @@ use daft_table::Table; #[cfg(feature = "python")] use { - common_io_config::python::IOConfig as PyIOConfig, + common_io_config::python, daft_table::python::PyTable, pyo3::{ exceptions::{PyKeyError, PyValueError}, @@ -25,11 +22,15 @@ use { types::{PyBytes, PyTuple}, IntoPy, PyObject, PyResult, Python, ToPyObject, }, + serde::{ + de::{Error as DeError, Visitor}, + ser::Error as SerError, + Deserializer, Serializer, + }, + std::{fmt, hash::Hasher}, }; -use serde::de::{Error as DeError, Visitor}; -use serde::{ser::Error as SerError, Deserialize, Deserializer, Serialize, Serializer}; -use std::fmt; +use serde::{Deserialize, Serialize}; #[derive(Debug, PartialEq, Eq, Hash)] pub enum SourceInfo { @@ -99,6 +100,55 @@ where d.deserialize_bytes(PyObjectVisitor) } +#[derive(Serialize)] +#[serde(transparent)] +#[cfg(feature = "python")] +struct PyObjSerdeWrapper<'a>(#[serde(serialize_with = "serialize_py_object")] &'a PyObject); + +#[cfg(feature = "python")] +fn serialize_py_object_optional(obj: &Option, s: S) -> Result +where + S: Serializer, +{ + match obj { + Some(obj) => s.serialize_some(&PyObjSerdeWrapper(obj)), + None => s.serialize_none(), + } +} + +struct OptPyObjectVisitor; + +#[cfg(feature = "python")] +impl<'de> Visitor<'de> for OptPyObjectVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a byte array containing the pickled partition bytes") + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize_py_object(deserializer).map(Some) + } + + fn visit_none(self) -> Result + where + E: DeError, + { + Ok(None) + } +} + +#[cfg(feature = "python")] +fn deserialize_py_object_optional<'de, D>(d: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + d.deserialize_option(OptPyObjectVisitor) +} + #[cfg(feature = "python")] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct InMemoryInfo { @@ -157,6 +207,7 @@ pub struct ExternalInfo { pub source_schema: SchemaRef, pub file_infos: Arc, pub file_format_config: Arc, + pub storage_config: Arc, } impl ExternalInfo { @@ -164,11 +215,159 @@ impl ExternalInfo { source_schema: SchemaRef, file_infos: Arc, file_format_config: Arc, + storage_config: Arc, ) -> Self { Self { source_schema, file_infos, file_format_config, + storage_config, + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(transparent)] +#[cfg_attr( + feature = "python", + pyclass(module = "daft.daft", name = "StorageConfig") +)] +pub struct PyStorageConfig(Arc); + +#[cfg(feature = "python")] +#[pymethods] +impl PyStorageConfig { + #[new] + #[pyo3(signature = (*args))] + pub fn new(args: &PyTuple) -> PyResult { + match args.len() { + // Create dummy inner StorageConfig, to be overridden by __setstate__. + 0 => Ok(Arc::new(StorageConfig::Native( + NativeStorageConfig::new_internal(None).into(), + )) + .into()), + _ => Err(PyValueError::new_err(format!( + "expected no arguments to make new PyStorageConfig, got : {}", + args.len() + ))), + } + } + #[staticmethod] + fn native(config: NativeStorageConfig) -> Self { + Self(Arc::new(StorageConfig::Native(config.into()))) + } + + #[staticmethod] + fn python(config: PythonStorageConfig) -> Self { + Self(Arc::new(StorageConfig::Python(config))) + } + + #[getter] + fn get_config(&self, py: Python) -> PyObject { + use StorageConfig::*; + + match self.0.as_ref() { + Native(config) => config.as_ref().clone().into_py(py), + Python(config) => config.clone().into_py(py), + } + } +} + +impl_bincode_py_state_serialization!(PyStorageConfig); + +impl From for Arc { + fn from(value: PyStorageConfig) -> Self { + value.0 + } +} + +impl From> for PyStorageConfig { + fn from(value: Arc) -> Self { + Self(value) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub enum StorageConfig { + Native(Arc), + #[cfg(feature = "python")] + Python(PythonStorageConfig), +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] +pub struct NativeStorageConfig { + pub io_config: Option, +} + +impl NativeStorageConfig { + pub fn new_internal(io_config: Option) -> Self { + Self { io_config } + } +} + +#[cfg(feature = "python")] +#[pymethods] +impl NativeStorageConfig { + #[new] + pub fn new(io_config: Option) -> Self { + Self::new_internal(io_config.map(|c| c.config)) + } + + #[getter] + pub fn io_config(&self) -> Option { + self.io_config.clone().map(|c| c.into()) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft", get_all))] +pub struct PythonStorageConfig { + #[serde( + serialize_with = "serialize_py_object_optional", + deserialize_with = "deserialize_py_object_optional", + default + )] + pub fs: Option, +} + +#[cfg(feature = "python")] +#[pymethods] +impl PythonStorageConfig { + #[new] + pub fn new(fs: Option) -> Self { + Self { fs } + } +} + +#[cfg(feature = "python")] +impl PartialEq for PythonStorageConfig { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| match (&self.fs, &other.fs) { + (Some(self_fs), Some(other_fs)) => self_fs.as_ref(py).eq(other_fs.as_ref(py)).unwrap(), + (None, None) => true, + _ => false, + }) + } +} + +#[cfg(feature = "python")] +impl Eq for PythonStorageConfig {} + +#[cfg(feature = "python")] +impl Hash for PythonStorageConfig { + fn hash(&self, state: &mut H) { + let py_obj_hash = self + .fs + .as_ref() + .map(|fs| Python::with_gil(|py| fs.as_ref(py).hash())) + .transpose(); + match py_obj_hash { + // If Python object is None OR is hashable, hash the Option of the Python-side hash. + Ok(py_obj_hash) => py_obj_hash.hash(state), + // Fall back to hashing the pickled Python object. + Err(_) => Some(serde_json::to_vec(self).unwrap()).hash(state), } } } @@ -402,30 +601,14 @@ impl FileFormatConfig { #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] -pub struct ParquetSourceConfig { - pub use_native_downloader: bool, - pub io_config: Box>, -} +pub struct ParquetSourceConfig; #[cfg(feature = "python")] #[pymethods] impl ParquetSourceConfig { #[new] - fn new(use_native_downloader: bool, io_config: Option) -> Self { - Self { - use_native_downloader, - io_config: io_config.map(|c| c.config).into(), - } - } - - #[getter] - pub fn get_use_native_downloader(&self) -> PyResult { - Ok(self.use_native_downloader) - } - - #[getter] - fn get_io_config(&self) -> PyResult> { - Ok(self.io_config.clone().map(|c| c.into())) + fn new() -> Self { + Self {} } } diff --git a/src/daft-plan/src/test/mod.rs b/src/daft-plan/src/test/mod.rs index 428a0cb029..c08c12008c 100644 --- a/src/daft-plan/src/test/mod.rs +++ b/src/daft-plan/src/test/mod.rs @@ -4,8 +4,8 @@ use daft_core::{datatypes::Field, schema::Schema}; use crate::{ builder::LogicalPlanBuilder, - source_info::{FileFormatConfig, FileInfos}, - JsonSourceConfig, + source_info::{FileFormatConfig, FileInfos, StorageConfig}, + JsonSourceConfig, NativeStorageConfig, }; /// Create a dummy scan node containing the provided fields in its schema. @@ -15,6 +15,7 @@ pub fn dummy_scan_node(fields: Vec) -> LogicalPlanBuilder { FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]).into(), schema, FileFormatConfig::Json(JsonSourceConfig {}).into(), + StorageConfig::Native(NativeStorageConfig::new_internal(None).into()).into(), ) .unwrap() } @@ -26,6 +27,7 @@ pub fn dummy_scan_node_with_limit(fields: Vec, limit: Option) -> L FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]).into(), schema, FileFormatConfig::Json(JsonSourceConfig {}).into(), + StorageConfig::Native(NativeStorageConfig::new_internal(None).into()).into(), limit, ) .unwrap() diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 06b752197d..ab2207bb2f 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -384,8 +384,8 @@ def test_create_dataframe_multiple_csvs(valid_data: list[dict[str, float]]) -> N @pytest.mark.skipif( - (get_context().runner_config.name not in {"py"}) or get_context().use_rust_planner, - reason="requires PyRunner and old query planner to be in use", + get_context().runner_config.name not in {"py"}, + reason="requires PyRunner to be in use", ) def test_create_dataframe_csv_custom_fs(valid_data: list[dict[str, float]]) -> None: with tempfile.NamedTemporaryFile("w") as f: @@ -557,8 +557,8 @@ def test_create_dataframe_multiple_jsons(valid_data: list[dict[str, float]]) -> @pytest.mark.skipif( - (get_context().runner_config.name not in {"py"}) or get_context().use_rust_planner, - reason="requires PyRunner and old query planner to be in use", + get_context().runner_config.name not in {"py"}, + reason="requires PyRunner to be in use", ) def test_create_dataframe_json_custom_fs(valid_data: list[dict[str, float]]) -> None: with tempfile.NamedTemporaryFile("w") as f: @@ -693,8 +693,8 @@ def test_create_dataframe_multiple_parquets(valid_data: list[dict[str, float]], @pytest.mark.skipif( - (get_context().runner_config.name not in {"py"}) or get_context().use_rust_planner, - reason="requires PyRunner and old query planner to be in use", + get_context().runner_config.name not in {"py"}, + reason="requires PyRunner to be in use", ) def test_create_dataframe_parquet_custom_fs(valid_data: list[dict[str, float]]) -> None: with tempfile.NamedTemporaryFile("w") as f: diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index 0a6d001fac..d552b9d476 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -197,9 +197,6 @@ def read_parquet_with_pyarrow(path) -> pa.Table: @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) @pytest.mark.parametrize( "multithreaded_io", [False, True], @@ -213,9 +210,6 @@ def test_parquet_read_table(parquet_file, public_storage_io_config, multithreade @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) @pytest.mark.parametrize( "multithreaded_io", [False, True], @@ -233,9 +227,6 @@ def test_parquet_read_table_bulk(parquet_file, public_storage_io_config, multith @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) def test_parquet_read_df(parquet_file, public_storage_io_config): _, url = parquet_file # This is a hack until we remove `fsspec.info`, `fsspec.glob` and `fsspec.glob` from `daft.read_parquet`. diff --git a/tests/integration/io/test_url_download_public_aws_s3.py b/tests/integration/io/test_url_download_public_aws_s3.py index 85817a0ed7..306f4d89a6 100644 --- a/tests/integration/io/test_url_download_public_aws_s3.py +++ b/tests/integration/io/test_url_download_public_aws_s3.py @@ -7,9 +7,6 @@ @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) def test_url_download_aws_s3_public_bucket_custom_s3fs(small_images_s3_paths): fs = s3fs.S3FileSystem(anon=True) data = {"urls": small_images_s3_paths} @@ -23,9 +20,6 @@ def test_url_download_aws_s3_public_bucket_custom_s3fs(small_images_s3_paths): @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) def test_url_download_aws_s3_public_bucket_custom_s3fs_wrong_region(small_images_s3_paths): fs = s3fs.S3FileSystem(anon=True) data = {"urls": small_images_s3_paths} diff --git a/tests/integration/io/test_url_download_s3_minio.py b/tests/integration/io/test_url_download_s3_minio.py index a6446d49c1..48e17002c8 100644 --- a/tests/integration/io/test_url_download_s3_minio.py +++ b/tests/integration/io/test_url_download_s3_minio.py @@ -7,9 +7,6 @@ @pytest.mark.integration() -@pytest.mark.skipif( - daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" -) def test_url_download_minio_custom_s3fs(minio_io_config, minio_image_data_fixture, image_data): urls = minio_image_data_fixture fs = s3fs.S3FileSystem( diff --git a/tests/optimizer/test_pushdown_clauses_into_scan.py b/tests/optimizer/test_pushdown_clauses_into_scan.py index 2105a080ca..2baf1d583f 100644 --- a/tests/optimizer/test_pushdown_clauses_into_scan.py +++ b/tests/optimizer/test_pushdown_clauses_into_scan.py @@ -27,7 +27,7 @@ def test_push_projection_scan_all_cols(valid_data_json_path: str, optimizer): predicate=df_unoptimized_scan._get_current_builder()._plan._predicate, columns=["sepal_length"], file_format_config=df_unoptimized_scan._get_current_builder()._plan._file_format_config, - fs=df_unoptimized_scan._get_current_builder()._plan._fs, + storage_config=df_unoptimized_scan._get_current_builder()._plan._storage_config, filepaths_child=df_unoptimized_scan._get_current_builder()._plan._filepaths_child, ).to_builder() ) @@ -47,7 +47,7 @@ def test_push_projection_scan_all_cols_alias(valid_data_json_path: str, optimize predicate=df_unoptimized_scan._get_current_builder()._plan._predicate, columns=["sepal_length"], file_format_config=df_unoptimized_scan._get_current_builder()._plan._file_format_config, - fs=df_unoptimized_scan._get_current_builder()._plan._fs, + storage_config=df_unoptimized_scan._get_current_builder()._plan._storage_config, filepaths_child=df_unoptimized_scan._get_current_builder()._plan._filepaths_child, ).to_builder() ) @@ -68,7 +68,7 @@ def test_push_projection_scan_some_cols_aliases(valid_data_json_path: str, optim predicate=df_unoptimized_scan._get_current_builder()._plan._predicate, columns=["sepal_length", "sepal_width"], file_format_config=df_unoptimized_scan._get_current_builder()._plan._file_format_config, - fs=df_unoptimized_scan._get_current_builder()._plan._fs, + storage_config=df_unoptimized_scan._get_current_builder()._plan._storage_config, filepaths_child=df_unoptimized_scan._get_current_builder()._plan._filepaths_child, ).to_builder() ) diff --git a/tests/table/table_io/test_parquet.py b/tests/table/table_io/test_parquet.py index ec85766983..0806e62f4c 100644 --- a/tests/table/table_io/test_parquet.py +++ b/tests/table/table_io/test_parquet.py @@ -10,6 +10,7 @@ import pytest import daft +from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.datatype import DataType, TimeUnit from daft.logical.schema import Schema from daft.runners.partitioning import TableParseParquetOptions, TableReadOptions @@ -19,6 +20,13 @@ PYARROW_GE_13_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (13, 0, 0) +def storage_config_from_use_native_downloader(use_native_downloader: bool) -> StorageConfig: + if use_native_downloader: + return StorageConfig.native(NativeStorageConfig(None)) + else: + return StorageConfig.python(PythonStorageConfig(None)) + + def test_read_input(tmpdir): tmpdir = pathlib.Path(tmpdir) data = pa.Table.from_pydict({"foo": [1, 2, 3]}) @@ -60,11 +68,11 @@ def _parquet_write_helper(data: pa.Table, row_group_size: int = None, papq_write ) @pytest.mark.parametrize("use_native_downloader", [True, False]) def test_parquet_infer_schema(data, expected_dtype, use_native_downloader): - # HACK: Pyarrow 13 changed their schema parsing behavior so we receive DataType.list(..) instead of DataType.list(..) # However, our native downloader still parses DataType.list(..) regardless of PyArrow version if PYARROW_GE_13_0_0 and not use_native_downloader and expected_dtype == DataType.list(DataType.int64()): expected_dtype = DataType.list(DataType.int64()) + storage_config = storage_config_from_use_native_downloader(use_native_downloader) with _parquet_write_helper( pa.Table.from_pydict( @@ -74,7 +82,7 @@ def test_parquet_infer_schema(data, expected_dtype, use_native_downloader): } ) ) as f: - schema = schema_inference.from_parquet(f, use_native_downloader=use_native_downloader) + schema = schema_inference.from_parquet(f, storage_config=storage_config) assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)]) @@ -114,7 +122,8 @@ def test_parquet_read_data(data, expected_data_series, use_native_downloader): "data": expected_data_series, } ) - table = table_io.read_parquet(f, schema, use_native_downloader=use_native_downloader) + storage_config = storage_config_from_use_native_downloader(use_native_downloader) + table = table_io.read_parquet(f, schema, storage_config=storage_config) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -137,8 +146,9 @@ def test_parquet_read_data_limit_rows(row_group_size, use_native_downloader): "data": [1, 2], } ) + storage_config = storage_config_from_use_native_downloader(use_native_downloader) table = table_io.read_parquet( - f, schema, read_options=TableReadOptions(num_rows=2), use_native_downloader=use_native_downloader + f, schema, read_options=TableReadOptions(num_rows=2), storage_config=storage_config ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -159,8 +169,9 @@ def test_parquet_read_data_select_columns(use_native_downloader): "data": [1, 2, None], } ) + storage_config = storage_config_from_use_native_downloader(use_native_downloader) table = table_io.read_parquet( - f, schema, read_options=TableReadOptions(column_names=["data"]), use_native_downloader=use_native_downloader + f, schema, read_options=TableReadOptions(column_names=["data"]), storage_config=storage_config ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -199,11 +210,12 @@ def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps, use_nati ) as f: schema = Schema._from_field_name_and_types(schema) expected = Table.from_pydict(data) + storage_config = storage_config_from_use_native_downloader(use_native_downloader) table = table_io.read_parquet( f, schema, read_options=TableReadOptions(column_names=schema.column_names()), - use_native_downloader=use_native_downloader, + storage_config=storage_config, ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -236,12 +248,13 @@ def test_parquet_read_int96_timestamps_overflow(coerce_to, use_native_downloader ) as f: schema = Schema._from_field_name_and_types(schema) expected = Table.from_pydict(data) + storage_config = storage_config_from_use_native_downloader(use_native_downloader) table = table_io.read_parquet( f, schema, read_options=TableReadOptions(column_names=schema.column_names()), parquet_options=TableParseParquetOptions(coerce_int96_timestamp_unit=coerce_to), - use_native_downloader=use_native_downloader, + storage_config=storage_config, ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}"