diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 309b6ec26f..0b2dcf8af7 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -130,7 +130,6 @@ jobs: publish: name: Publish wheels to PYPI and Anaconda - if: ${{ (github.ref == 'refs/heads/main') }} runs-on: ubuntu-latest needs: - build-and-test @@ -166,7 +165,7 @@ jobs: run: conda install -q -y anaconda-client "urllib3<2.0" - name: Upload wheels to anaconda nightly - if: ${{ success() && (env.IS_SCHEDULE_DISPATCH == 'true' || env.IS_PUSH == 'true') }} + if: ${{ success() && (((env.IS_SCHEDULE_DISPATCH == 'true') && (github.ref == 'refs/heads/main')) || env.IS_PUSH == 'true') }} shell: bash -el {0} env: DAFT_STAGING_UPLOAD_TOKEN: ${{ secrets.DAFT_STAGING_UPLOAD_TOKEN }} diff --git a/Cargo.lock b/Cargo.lock index f56c399227..ed2640e265 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1088,12 +1088,14 @@ dependencies = [ "openssl-sys", "pyo3", "pyo3-log", + "regex", "reqwest", "serde", "serde_json", "snafu", "tempfile", "tokio", + "tokio-stream", "url", ] @@ -2066,9 +2068,9 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "memoffset" @@ -2834,9 +2836,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.1" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -2846,9 +2848,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.2" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83d3daa6976cffb758ec878f108ba0e062a45b2d6ca3a2cca965338855476caf" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", @@ -2857,9 +2859,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.3" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab07dc67230e4a4718e70fd5c20055a4334b121f1f9db8fe63ef39ce9b8c846" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "reqwest" diff --git a/Cargo.toml b/Cargo.toml index 3a9f5a9ba3..19ea900845 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ rand = "^0.8" serde_json = "1.0.104" snafu = "0.7.4" tokio = {version = "1.32.0", features = ["net", "time", "bytes", "process", "signal", "macros", "rt", "rt-multi-thread"]} +tokio-stream = {version = "0.1.14", features = ["fs"]} [workspace.dependencies.arrow2] git = "https://github.com/Eventual-Inc/arrow2" diff --git a/daft/daft.pyi b/daft/daft.pyi index c4d9eee604..025b0b3db5 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -555,6 +555,7 @@ class PyExpr: def utf8_endswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_startswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_contains(self, pattern: PyExpr) -> PyExpr: ... + def utf8_split(self, pattern: PyExpr) -> PyExpr: ... def utf8_length(self) -> PyExpr: ... def image_decode(self) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... @@ -617,6 +618,7 @@ class PySeries: def utf8_endswith(self, pattern: PySeries) -> PySeries: ... def utf8_startswith(self, pattern: PySeries) -> PySeries: ... def utf8_contains(self, pattern: PySeries) -> PySeries: ... + def utf8_split(self, pattern: PySeries) -> PySeries: ... def utf8_length(self) -> PySeries: ... def is_nan(self) -> PySeries: ... def dt_date(self) -> PySeries: ... @@ -673,7 +675,9 @@ class PhysicalPlanScheduler: A work scheduler for physical query plans. """ - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: ... + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: ... class LogicalPlanBuilder: """ diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 2ffa9057a3..76afe65f8d 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -2,7 +2,6 @@ from typing import Iterator, TypeVar, cast -from daft.context import get_context from daft.daft import ( FileFormat, FileFormatConfig, @@ -29,10 +28,11 @@ def tabular_scan( file_format_config: FileFormatConfig, storage_config: StorageConfig, limit: int, + is_ray_runner: bool, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: # TODO(Clark): Fix this Ray runner hack. part = Table._from_pytable(file_info_table) - if get_context().is_ray_runner: + if is_ray_runner: import ray parts = [ray.put(part)] diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2eaf5eb290..d73f8b4709 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -572,7 +572,7 @@ def endswith(self, suffix: str | Expression) -> Expression: suffix_expr = Expression._to_expression(suffix) return Expression._from_pyexpr(self._expr.utf8_endswith(suffix_expr._expr)) - def startswith(self, prefix: str) -> Expression: + def startswith(self, prefix: str | Expression) -> Expression: """Checks whether each string starts with the given pattern in a string column Example: @@ -587,6 +587,22 @@ def startswith(self, prefix: str) -> Expression: prefix_expr = Expression._to_expression(prefix) return Expression._from_pyexpr(self._expr.utf8_startswith(prefix_expr._expr)) + def split(self, pattern: str | Expression) -> Expression: + """Splits each string on the given pattern, into one or more strings. + + Example: + >>> col("x").str.split(",") + >>> col("x").str.split(col("pattern")) + + Args: + pattern: The pattern on which each string should be split, or a column to pick such patterns from. + + Returns: + Expression: A List[Utf8] expression containing the string splits for each string in the column. + """ + pattern_expr = Expression._to_expression(pattern) + return Expression._from_pyexpr(self._expr.utf8_split(pattern_expr._expr)) + def concat(self, other: str) -> Expression: """Concatenates two string expressions together diff --git a/daft/planner/planner.py b/daft/planner/planner.py index 5ee5a66346..1120f88e83 100644 --- a/daft/planner/planner.py +++ b/daft/planner/planner.py @@ -12,5 +12,7 @@ class PhysicalPlanScheduler(ABC): """ @abstractmethod - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: pass diff --git a/daft/planner/py_planner.py b/daft/planner/py_planner.py index bf5321a7bb..a9ab2b90f8 100644 --- a/daft/planner/py_planner.py +++ b/daft/planner/py_planner.py @@ -9,5 +9,7 @@ class PyPhysicalPlanScheduler(PhysicalPlanScheduler): def __init__(self, plan: logical_plan.LogicalPlan): self._plan = plan - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: return physical_plan.materialize(physical_plan_factory._get_physical_plan(self._plan, psets)) diff --git a/daft/planner/rust_planner.py b/daft/planner/rust_planner.py index cd1c00fd83..9ea74edcd5 100644 --- a/daft/planner/rust_planner.py +++ b/daft/planner/rust_planner.py @@ -9,5 +9,7 @@ class RustPhysicalPlanScheduler(PhysicalPlanScheduler): def __init__(self, scheduler: _PhysicalPlanScheduler): self._scheduler = scheduler - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: - return physical_plan.materialize(self._scheduler.to_partition_tasks(psets)) + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: + return physical_plan.materialize(self._scheduler.to_partition_tasks(psets, is_ray_runner)) diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 1ad9dcd7e5..f291727b13 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -148,7 +148,7 @@ def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[Table]: if entry.value is not None } # Get executable tasks from planner. - tasks = plan_scheduler.to_partition_tasks(psets) + tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=False) with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"): partitions_gen = self._physical_plan_to_partitions(tasks) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index d596974ff9..a7dd27f355 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -432,7 +432,7 @@ def _run_plan( from loguru import logger # Get executable tasks from plan scheduler. - tasks = plan_scheduler.to_partition_tasks(psets) + tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) # Note: For autoscaling clusters, we will probably want to query cores dynamically. # Keep in mind this call takes about 0.3ms. diff --git a/daft/series.py b/daft/series.py index d81196c391..2430dbf31d 100644 --- a/daft/series.py +++ b/daft/series.py @@ -534,6 +534,12 @@ def contains(self, pattern: Series) -> Series: assert self._series is not None and pattern._series is not None return Series._from_pyseries(self._series.utf8_contains(pattern._series)) + def split(self, pattern: Series) -> Series: + if not isinstance(pattern, Series): + raise ValueError(f"expected another Series but got {type(pattern)}") + assert self._series is not None and pattern._series is not None + return Series._from_pyseries(self._series.utf8_split(pattern._series)) + def concat(self, other: Series) -> Series: if not isinstance(other, Series): raise ValueError(f"expected another Series but got {type(other)}") diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 604b139e48..c61a491685 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -20,6 +20,7 @@ pub enum DaftError { path: String, source: GenericError, }, + InternalError(String), External(GenericError), } @@ -31,7 +32,8 @@ impl std::error::Error for DaftError { | DaftError::TypeError(_) | DaftError::ComputeError(_) | DaftError::ArrowError(_) - | DaftError::ValueError(_) => None, + | DaftError::ValueError(_) + | DaftError::InternalError(_) => None, DaftError::IoError(io_error) => Some(io_error), DaftError::FileNotFound { source, .. } | DaftError::External(source) => Some(&**source), #[cfg(feature = "python")] @@ -96,6 +98,7 @@ impl Display for DaftError { Self::ComputeError(s) => write!(f, "DaftError::ComputeError {s}"), Self::ArrowError(s) => write!(f, "DaftError::ArrowError {s}"), Self::ValueError(s) => write!(f, "DaftError::ValueError {s}"), + Self::InternalError(s) => write!(f, "DaftError::InternalError {s}"), #[cfg(feature = "python")] Self::PyO3Error(e) => write!(f, "DaftError::PyO3Error {e}"), Self::IoError(e) => write!(f, "DaftError::IoError {e}"), diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index f39c904709..f2f86c5298 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -129,7 +129,8 @@ impl FullNull for ListArray { Self::new( Field::new(name, dtype.clone()), empty_flat_child, - OffsetsBuffer::try_from(repeat(0).take(length).collect::>()).unwrap(), + OffsetsBuffer::try_from(repeat(0).take(length + 1).collect::>()) + .unwrap(), Some(validity), ) } diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index 262559776f..7f9987078b 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -1,10 +1,66 @@ -use crate::datatypes::{BooleanArray, UInt64Array, Utf8Array}; +use crate::{ + array::ListArray, + datatypes::{BooleanArray, Field, UInt64Array, Utf8Array}, + DataType, Series, +}; use arrow2; use common_error::{DaftError, DaftResult}; use super::{as_arrow::AsArrow, full::FullNull}; +fn split_array_on_patterns<'a, T, U>( + arr_iter: T, + pattern_iter: U, + buffer_len: usize, + name: &str, +) -> DaftResult +where + T: arrow2::trusted_len::TrustedLen + Iterator>, + U: Iterator>, +{ + // This will overallocate by pattern_len * N_i, where N_i is the number of pattern occurences in the ith string in arr_iter. + let mut splits = arrow2::array::MutableUtf8Array::with_capacity(buffer_len); + // arr_iter implements TrustedLen, so we can always use size_hint().1 as the exact length of the iterator. The only + // time this would fail is if the length of the iterator exceeds usize::MAX, which should never happen for an i64 + // offset array, since the array length can't exceed i64::MAX on 64-bit machines. + let arr_len = arr_iter.size_hint().1.unwrap(); + let mut offsets = arrow2::offset::Offsets::new(); + let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(arr_len); + for (val, pat) in arr_iter.zip(pattern_iter) { + let mut num_splits = 0i64; + match (val, pat) { + (Some(val), Some(pat)) => { + for split in val.split(pat) { + splits.push(Some(split)); + num_splits += 1; + } + validity.push(true); + } + (_, _) => { + validity.push(false); + } + } + offsets.try_push(num_splits)?; + } + // Shrink splits capacity to current length, since we will have overallocated if any of the patterns actually occurred in the strings. + splits.shrink_to_fit(); + let splits: arrow2::array::Utf8Array = splits.into(); + let offsets: arrow2::offset::OffsetsBuffer = offsets.into(); + let validity: Option = match validity.unset_bits() { + 0 => None, + _ => Some(validity.into()), + }; + let flat_child = + Series::try_from(("splits", Box::new(splits) as Box))?; + Ok(ListArray::new( + Field::new(name, DataType::List(Box::new(DataType::Utf8))), + flat_child, + offsets, + validity, + )) +} + impl Utf8Array { pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult { self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.ends_with(pat)) @@ -18,6 +74,65 @@ impl Utf8Array { self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.contains(pat)) } + pub fn split(&self, pattern: &Utf8Array) -> DaftResult { + let self_arrow = self.as_arrow(); + let pattern_arrow = pattern.as_arrow(); + // Handle all-null cases. + if self_arrow + .validity() + .map_or(false, |v| v.unset_bits() == v.len()) + || pattern_arrow + .validity() + .map_or(false, |v| v.unset_bits() == v.len()) + { + return Ok(ListArray::full_null( + self.name(), + &DataType::List(Box::new(DataType::Utf8)), + std::cmp::max(self.len(), pattern.len()), + )); + // Handle empty cases. + } else if self.is_empty() || pattern.is_empty() { + return Ok(ListArray::empty( + self.name(), + &DataType::List(Box::new(DataType::Utf8)), + )); + } + let buffer_len = self_arrow.values().len(); + match (self.len(), pattern.len()) { + // Matching len case: + (self_len, pattern_len) if self_len == pattern_len => split_array_on_patterns( + self_arrow.into_iter(), + pattern_arrow.into_iter(), + buffer_len, + self.name(), + ), + // Broadcast pattern case: + (self_len, 1) => { + let pattern_scalar_value = pattern.get(0).unwrap(); + split_array_on_patterns( + self_arrow.into_iter(), + std::iter::repeat(Some(pattern_scalar_value)).take(self_len), + buffer_len, + self.name(), + ) + } + // Broadcast self case: + (1, pattern_len) => { + let self_scalar_value = self.get(0).unwrap(); + split_array_on_patterns( + std::iter::repeat(Some(self_scalar_value)).take(pattern_len), + pattern_arrow.into_iter(), + buffer_len * pattern_len, + self.name(), + ) + } + // Mismatched len case: + (self_len, pattern_len) => Err(DaftError::ComputeError(format!( + "lhs and rhs have different length arrays: {self_len} vs {pattern_len}" + ))), + } + } + pub fn length(&self) -> DaftResult { let self_arrow = self.as_arrow(); let arrow_result = self_arrow diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index a388585940..f3c8cb3fa0 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -247,6 +247,10 @@ impl PySeries { Ok(self.series.utf8_contains(&pattern.series)?.into()) } + pub fn utf8_split(&self, pattern: &Self) -> PyResult { + Ok(self.series.utf8_split(&pattern.series)?.into()) + } + pub fn utf8_length(&self) -> PyResult { Ok(self.series.utf8_length()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index bea305d9fc..fb2539b64e 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -32,6 +32,15 @@ impl Series { } } + pub fn utf8_split(&self, pattern: &Series) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.split(pattern.utf8()?)?.into_series()), + dt => Err(DaftError::TypeError(format!( + "Split not implemented for type {dt}" + ))), + } + } + pub fn utf8_length(&self) -> DaftResult { match self.data_type() { DataType::Utf8 => Ok(self.utf8()?.length()?.into_series()), diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index cd23c0883b..5c8901147e 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -1,12 +1,14 @@ mod contains; mod endswith; mod length; +mod split; mod startswith; use contains::ContainsEvaluator; use endswith::EndswithEvaluator; use length::LengthEvaluator; use serde::{Deserialize, Serialize}; +use split::SplitEvaluator; use startswith::StartswithEvaluator; use crate::Expr; @@ -18,6 +20,7 @@ pub enum Utf8Expr { EndsWith, StartsWith, Contains, + Split, Length, } @@ -29,6 +32,7 @@ impl Utf8Expr { EndsWith => &EndswithEvaluator {}, StartsWith => &StartswithEvaluator {}, Contains => &ContainsEvaluator {}, + Split => &SplitEvaluator {}, Length => &LengthEvaluator {}, } } @@ -55,6 +59,13 @@ pub fn contains(data: &Expr, pattern: &Expr) -> Expr { } } +pub fn split(data: &Expr, pattern: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Split), + inputs: vec![data.clone(), pattern.clone()], + } +} + pub fn length(data: &Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Utf8(Utf8Expr::Length), diff --git a/src/daft-dsl/src/functions/utf8/split.rs b/src/daft-dsl/src/functions/utf8/split.rs new file mode 100644 index 0000000000..8d2c238b70 --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/split.rs @@ -0,0 +1,50 @@ +use crate::Expr; +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct SplitEvaluator {} + +impl FunctionEvaluator for SplitEvaluator { + fn fn_name(&self) -> &'static str { + "split" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to split to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_split(pattern), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index c64afc4271..cb61044339 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -291,6 +291,11 @@ impl PyExpr { Ok(contains(&self.expr, &pattern.expr).into()) } + pub fn utf8_split(&self, pattern: &Self) -> PyResult { + use crate::functions::utf8::split; + Ok(split(&self.expr, &pattern.expr).into()) + } + pub fn utf8_length(&self) -> PyResult { use crate::functions::utf8::length; Ok(length(&self.expr).into()) diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index e18cd06c00..4662f3e33d 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -24,10 +24,12 @@ log = {workspace = true} openssl-sys = {version = "0.9.93", features = ["vendored"]} pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true, optional = true} +regex = {version = "1.9.5"} serde = {workspace = true} serde_json = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} +tokio-stream = {workspace = true} url = "2.4.0" [dependencies.reqwest] diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index a98728d91b..5c2f3b252f 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -3,13 +3,22 @@ use std::{num::ParseIntError, ops::Range, string::FromUtf8Error, sync::Arc}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; +use lazy_static::lazy_static; +use regex::Regex; use reqwest::header::{CONTENT_LENGTH, RANGE}; use snafu::{IntoError, ResultExt, Snafu}; +use url::Position; -use crate::object_io::LSResult; +use crate::object_io::{FileMetadata, FileType, LSResult}; use super::object_io::{GetResult, ObjectSource}; +lazy_static! { + // Taken from: https://stackoverflow.com/a/15926317/3821154 + static ref HTML_A_TAG_HREF_RE: Regex = + Regex::new(r#"<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P[^"']+)"#).unwrap(); +} + #[derive(Debug, Snafu)] enum Error { #[snafu(display("Unable to connect to {}: {}", path, source))] @@ -45,7 +54,15 @@ enum Error { #[snafu(display( "Unable to parse data as Utf8 while reading header for file: {path}. {source}" ))] - UnableToParseUtf8 { path: String, source: FromUtf8Error }, + UnableToParseUtf8Header { path: String, source: FromUtf8Error }, + + #[snafu(display( + "Unable to parse data as Utf8 while reading body for file: {path}. {source}" + ))] + UnableToParseUtf8Body { + path: String, + source: reqwest::Error, + }, #[snafu(display( "Unable to parse data as Integer while reading header for file: {path}. {source}" @@ -53,6 +70,64 @@ enum Error { UnableToParseInteger { path: String, source: ParseIntError }, } +/// Finds and retrieves FileMetadata from HTML text +/// +/// This function will look for `` tags and return all the links that it finds as +/// absolute URLs +fn _get_file_metadata_from_html(path: &str, text: &str) -> super::Result> { + let path_url = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; + let metas = HTML_A_TAG_HREF_RE + .captures_iter(text) + .map(|captures| { + // Parse the matched URL into an absolute URL + let matched_url = captures.name("url").unwrap().as_str(); + let absolute_path = if let Ok(parsed_matched_url) = url::Url::parse(matched_url) { + // matched_url is already an absolute path + parsed_matched_url + } else if matched_url.starts_with('/') { + // matched_url is a path relative to the origin of `path` + let base = url::Url::parse(&path_url[..Position::BeforePath]).unwrap(); + base.join(matched_url) + .with_context(|_| InvalidUrlSnafu { path: matched_url })? + } else { + // matched_url is a path relative to `path` and needs to be joined + path_url + .join(matched_url) + .with_context(|_| InvalidUrlSnafu { path: matched_url })? + }; + + // Ignore any links that are not descendants of `path` to avoid cycles + let relative = path_url.make_relative(&absolute_path); + match relative { + None => { + return Ok(None); + } + Some(relative_path) + if relative_path.is_empty() || relative_path.starts_with("..") => + { + return Ok(None); + } + _ => (), + }; + + let filetype = if matched_url.ends_with('/') { + FileType::Directory + } else { + FileType::File + }; + Ok(Some(FileMetadata { + filepath: absolute_path.to_string(), + // NOTE: This is consistent with fsspec behavior, but we may choose to HEAD the files to grab Content-Length + // for populating `size` if necessary + size: None, + filetype, + })) + }) + .collect::>>()?; + + Ok(metas.into_iter().flatten().collect()) +} + pub(crate) struct HttpSource { client: reqwest::Client, } @@ -135,8 +210,9 @@ impl ObjectSource for HttpSource { let headers = response.headers(); match headers.get(CONTENT_LENGTH) { Some(v) => { - let size_bytes = String::from_utf8(v.as_bytes().to_vec()) - .with_context(|_| UnableToParseUtf8Snafu:: { path: uri.into() })?; + let size_bytes = String::from_utf8(v.as_bytes().to_vec()).with_context(|_| { + UnableToParseUtf8HeaderSnafu:: { path: uri.into() } + })?; Ok(size_bytes .parse() @@ -148,11 +224,52 @@ impl ObjectSource for HttpSource { async fn ls( &self, - _path: &str, + path: &str, _delimiter: Option<&str>, _continuation_token: Option<&str>, ) -> super::Result { - unimplemented!("http ls"); + let request = self.client.get(path); + let response = request + .send() + .await + .context(UnableToConnectSnafu:: { path: path.into() })? + .error_for_status() + .with_context(|_| UnableToOpenFileSnafu { path })?; + + // Reconstruct the actual path of the request, which may have been redirected via a 301 + // This is important because downstream URL joining logic relies on proper trailing-slashes/index.html + let path = response.url().to_string(); + let path = if path.ends_with('/') { + format!("{}/", path.trim_end_matches('/')) + } else { + path + }; + + match response.headers().get("content-type") { + // If the content-type is text/html, we treat the data on this path as a traversable "directory" + Some(header_value) if header_value.to_str().map_or(false, |v| v == "text/html") => { + let text = response + .text() + .await + .with_context(|_| UnableToParseUtf8BodySnafu { + path: path.to_string(), + })?; + let file_metadatas = _get_file_metadata_from_html(path.as_str(), text.as_str())?; + Ok(LSResult { + files: file_metadatas, + continuation_token: None, + }) + } + // All other forms of content-type is treated as a raw file + _ => Ok(LSResult { + files: vec![FileMetadata { + filepath: path.to_string(), + filetype: FileType::File, + size: response.content_length(), + }], + continuation_token: None, + }), + } } } diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index ccbbedc622..6cf62e9634 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -2,12 +2,16 @@ use std::io::SeekFrom; use std::ops::Range; use std::path::PathBuf; -use crate::object_io::LSResult; +use crate::object_io::{self, FileMetadata, LSResult}; use super::object_io::{GetResult, ObjectSource}; use super::Result; use async_trait::async_trait; use bytes::Bytes; +use common_error::DaftError; +use futures::stream::BoxStream; +use futures::StreamExt; +use futures::TryStreamExt; use snafu::{ResultExt, Snafu}; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncSeekExt}; @@ -33,6 +37,21 @@ enum Error { source: std::io::Error, }, + #[snafu(display("Unable to fetch file metadata for file {}: {}", path, source))] + UnableToFetchFileMetadata { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Unable to get entries for directory {}: {}", path, source))] + UnableToFetchDirectoryEntries { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Unexpected symlink when processing directory {}: {}", path, source))] + UnexpectedSymlink { path: String, source: DaftError }, + #[snafu(display("Unable to parse URL \"{}\"", url.to_string_lossy()))] InvalidUrl { url: PathBuf, source: ParseError }, @@ -44,7 +63,9 @@ impl From for super::Error { fn from(error: Error) -> Self { use Error::*; match error { - UnableToOpenFile { path, source } => { + UnableToOpenFile { path, source } + | UnableToFetchFileMetadata { path, source } + | UnableToFetchDirectoryEntries { path, source } => { use std::io::ErrorKind::*; match source.kind() { NotFound => super::Error::NotFound { @@ -84,49 +105,104 @@ pub struct LocalFile { #[async_trait] impl ObjectSource for LocalSource { async fn get(&self, uri: &str, range: Option>) -> super::Result { - const TO_STRIP: &str = "file://"; - if let Some(p) = uri.strip_prefix(TO_STRIP) { - let path = std::path::Path::new(p); + const LOCAL_PROTOCOL: &str = "file://"; + if let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) { Ok(GetResult::File(LocalFile { - path: path.to_path_buf(), + path: uri.into(), range, })) } else { - return Err(Error::InvalidFilePath { - path: uri.to_string(), - } - .into()); + Err(Error::InvalidFilePath { path: uri.into() }.into()) } } async fn get_size(&self, uri: &str) -> super::Result { - const TO_STRIP: &str = "file://"; - if let Some(p) = uri.strip_prefix(TO_STRIP) { - let path = std::path::Path::new(p); - let file = tokio::fs::File::open(path) - .await - .context(UnableToOpenFileSnafu { - path: path.to_string_lossy(), - })?; - let metadata = file.metadata().await.context(UnableToOpenFileSnafu { - path: path.to_string_lossy(), - })?; - return Ok(metadata.len() as usize); - } else { - return Err(Error::InvalidFilePath { + const LOCAL_PROTOCOL: &str = "file://"; + let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { + return Err(Error::InvalidFilePath { path: uri.into() }.into()); + }; + let meta = tokio::fs::metadata(uri) + .await + .context(UnableToFetchFileMetadataSnafu { path: uri.to_string(), - } - .into()); - } + })?; + Ok(meta.len() as usize) } async fn ls( &self, - _path: &str, + path: &str, _delimiter: Option<&str>, _continuation_token: Option<&str>, ) -> super::Result { - unimplemented!("local ls"); + let s = self.iter_dir(path, None, None).await?; + let files = s.try_collect::>().await?; + Ok(LSResult { + files, + continuation_token: None, + }) + } + + async fn iter_dir( + &self, + uri: &str, + _delimiter: Option<&str>, + _limit: Option, + ) -> super::Result>> { + const LOCAL_PROTOCOL: &str = "file://"; + let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { + return Err(Error::InvalidFilePath { path: uri.into() }.into()); + }; + let meta = + tokio::fs::metadata(uri) + .await + .with_context(|_| UnableToFetchFileMetadataSnafu { + path: uri.to_string(), + })?; + if meta.file_type().is_file() { + // Provided uri points to a file, so only return that file. + return Ok(futures::stream::iter([Ok(FileMetadata { + filepath: format!("{}{}", LOCAL_PROTOCOL, uri), + size: Some(meta.len()), + filetype: object_io::FileType::File, + })]) + .boxed()); + } + let dir_entries = tokio::fs::read_dir(uri).await.with_context(|_| { + UnableToFetchDirectoryEntriesSnafu { + path: uri.to_string(), + } + })?; + let dir_stream = tokio_stream::wrappers::ReadDirStream::new(dir_entries); + let uri = Arc::new(uri.to_string()); + let file_meta_stream = dir_stream.then(move |entry| { + let uri = uri.clone(); + async move { + let entry = entry.with_context(|_| UnableToFetchDirectoryEntriesSnafu { + path: uri.to_string(), + })?; + let meta = tokio::fs::metadata(entry.path()).await.with_context(|_| { + UnableToFetchFileMetadataSnafu { + path: entry.path().to_string_lossy().to_string(), + } + })?; + Ok(FileMetadata { + filepath: format!( + "{}{}{}", + LOCAL_PROTOCOL, + entry.path().to_string_lossy(), + if meta.is_dir() { "/" } else { "" } + ), + size: Some(meta.len()), + filetype: meta.file_type().try_into().with_context(|_| { + UnexpectedSymlinkSnafu { + path: entry.path().to_string_lossy().to_string(), + } + })?, + }) + } + }); + Ok(file_meta_stream.boxed()) } } @@ -171,16 +247,15 @@ pub(crate) async fn collect_file(local_file: LocalFile) -> Result { #[cfg(test)] mod tests { - use std::io::Write; - use crate::object_io::ObjectSource; + use crate::object_io::{FileMetadata, FileType, ObjectSource}; use crate::Result; use crate::{HttpSource, LocalSource}; - #[tokio::test] - async fn test_full_get_from_local() -> Result<()> { - let mut file1 = tempfile::NamedTempFile::new().unwrap(); + async fn write_remote_parquet_to_local_file( + f: &mut tempfile::NamedTempFile, + ) -> Result { let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; @@ -190,15 +265,22 @@ mod tests { let all_bytes = bytes.as_ref(); let checksum = format!("{:x}", md5::compute(all_bytes)); assert_eq!(checksum, parquet_expected_md5); - file1.write_all(all_bytes).unwrap(); - file1.flush().unwrap(); + f.write_all(all_bytes).unwrap(); + f.flush().unwrap(); + Ok(bytes) + } + + #[tokio::test] + async fn test_local_full_get() -> Result<()> { + let mut file1 = tempfile::NamedTempFile::new().unwrap(); + let bytes = write_remote_parquet_to_local_file(&mut file1).await?; let parquet_file_path = format!("file://{}", file1.path().to_str().unwrap()); let client = LocalSource::get_client().await?; let try_all_bytes = client.get(&parquet_file_path, None).await?.bytes().await?; - assert_eq!(try_all_bytes.len(), all_bytes.len()); - assert_eq!(try_all_bytes.as_ref(), all_bytes); + assert_eq!(try_all_bytes.len(), bytes.len()); + assert_eq!(try_all_bytes, bytes); let first_bytes = client .get_range(&parquet_file_path, 0..10) @@ -206,7 +288,7 @@ mod tests { .bytes() .await?; assert_eq!(first_bytes.len(), 10); - assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); + assert_eq!(first_bytes.as_ref(), &bytes[..10]); let first_bytes = client .get_range(&parquet_file_path, 10..100) @@ -214,21 +296,58 @@ mod tests { .bytes() .await?; assert_eq!(first_bytes.len(), 90); - assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]); + assert_eq!(first_bytes.as_ref(), &bytes[10..100]); let last_bytes = client - .get_range( - &parquet_file_path, - (all_bytes.len() - 10)..(all_bytes.len() + 10), - ) + .get_range(&parquet_file_path, (bytes.len() - 10)..(bytes.len() + 10)) .await? .bytes() .await?; assert_eq!(last_bytes.len(), 10); - assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); + assert_eq!(last_bytes.as_ref(), &bytes[(bytes.len() - 10)..]); let size_from_get_size = client.get_size(parquet_file_path.as_str()).await?; - assert_eq!(size_from_get_size, all_bytes.len()); + assert_eq!(size_from_get_size, bytes.len()); + + Ok(()) + } + + #[tokio::test] + async fn test_local_full_ls() -> Result<()> { + let dir = tempfile::tempdir().unwrap(); + let mut file1 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file1).await?; + let mut file2 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file2).await?; + let mut file3 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file3).await?; + let dir_path = format!("file://{}", dir.path().to_string_lossy()); + let client = LocalSource::get_client().await?; + + let ls_result = client.ls(dir_path.as_ref(), None, None).await?; + let mut files = ls_result.files.clone(); + // Ensure stable sort ordering of file paths before comparing with expected payload. + files.sort_by(|a, b| a.filepath.cmp(&b.filepath)); + let mut expected = vec![ + FileMetadata { + filepath: format!("file://{}", file1.path().to_string_lossy()), + size: Some(file1.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + FileMetadata { + filepath: format!("file://{}", file2.path().to_string_lossy()), + size: Some(file2.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + FileMetadata { + filepath: format!("file://{}", file3.path().to_string_lossy()), + size: Some(file3.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + ]; + expected.sort_by(|a, b| a.filepath.cmp(&b.filepath)); + assert_eq!(files, expected); + assert_eq!(ls_result.continuation_token, None); Ok(()) } diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 98bc23de31..9613d387d1 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; +use common_error::DaftError; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; use tokio::sync::mpsc::Sender; @@ -52,12 +53,32 @@ impl GetResult { } } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub enum FileType { File, Directory, } -#[derive(Debug)] + +impl TryFrom for FileType { + type Error = DaftError; + + fn try_from(value: std::fs::FileType) -> Result { + if value.is_dir() { + Ok(Self::Directory) + } else if value.is_file() { + Ok(Self::File) + } else if value.is_symlink() { + Err(DaftError::InternalError(format!("Symlinks should never be encountered when constructing FileMetadata, but got: {:?}", value))) + } else { + unreachable!( + "Can only be a directory, file, or symlink, but got: {:?}", + value + ) + } + } +} + +#[derive(Debug, Clone, PartialEq)] pub struct FileMetadata { pub filepath: String, pub size: Option, diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 2b4614e110..8dba518e46 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -541,10 +541,12 @@ impl S3LikeSource { } } + #[allow(clippy::too_many_arguments)] #[async_recursion] async fn _list_impl( &self, _permit: SemaphorePermit<'async_recursion>, + scheme: &str, bucket: &str, key: &str, delimiter: String, @@ -587,7 +589,7 @@ impl S3LikeSource { } else { request.send().await }; - let uri = &format!("s3://{bucket}/{key}"); + let uri = &format!("{scheme}://{bucket}/{key}"); match response { Ok(v) => { let dirs = v.common_prefixes(); @@ -604,7 +606,10 @@ impl S3LikeSource { if let Some(dirs) = dirs { for d in dirs { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", d.prefix().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + d.prefix().unwrap_or_default() + ), size: None, filetype: FileType::Directory, }; @@ -614,7 +619,10 @@ impl S3LikeSource { if let Some(files) = files { for f in files { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", f.key().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + f.key().unwrap_or_default() + ), size: Some(f.size() as u64), filetype: FileType::File, }; @@ -646,6 +654,7 @@ impl S3LikeSource { log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting List in that region with new client", new_region, region); self._list_impl( _permit, + scheme, bucket, key, delimiter, @@ -694,6 +703,7 @@ impl ObjectSource for S3LikeSource { continuation_token: Option<&str>, ) -> super::Result { let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; + let scheme = parsed.scheme(); let delimiter = delimiter.unwrap_or("/"); let bucket = match parsed.host_str() { @@ -723,6 +733,7 @@ impl ObjectSource for S3LikeSource { self._list_impl( permit, + scheme, bucket, &key, delimiter.into(), @@ -742,6 +753,7 @@ impl ObjectSource for S3LikeSource { let mut lsr = self ._list_impl( permit, + scheme, bucket, key, delimiter.into(), @@ -749,7 +761,7 @@ impl ObjectSource for S3LikeSource { &self.default_region, ) .await?; - let target_path = format!("s3://{bucket}/{key}"); + let target_path = format!("{scheme}://{bucket}/{key}"); lsr.files.retain(|f| f.filepath == target_path); if lsr.files.is_empty() { diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 9c87facaa9..9766c530ac 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -64,8 +64,12 @@ pub struct PhysicalPlanScheduler { #[pymethods] impl PhysicalPlanScheduler { /// Converts the contained physical plan into an iterator of executable partition tasks. - pub fn to_partition_tasks(&self, psets: HashMap>) -> PyResult { - Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets)) + pub fn to_partition_tasks( + &self, + psets: HashMap>, + is_ray_runner: bool, + ) -> PyResult { + Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets, is_ray_runner)) } } @@ -98,6 +102,7 @@ impl PartitionIterator { } #[cfg(feature = "python")] +#[allow(clippy::too_many_arguments)] fn tabular_scan( py: Python<'_>, source_schema: &SchemaRef, @@ -106,6 +111,7 @@ fn tabular_scan( file_format_config: &Arc, storage_config: &Arc, limit: &Option, + is_ray_runner: bool, ) -> PyResult { let columns_to_read = projection_schema .fields @@ -123,6 +129,7 @@ fn tabular_scan( PyFileFormatConfig::from(file_format_config.clone()), PyStorageConfig::from(storage_config.clone()), *limit, + is_ray_runner, ))?; Ok(py_iter.into()) } @@ -162,6 +169,7 @@ impl PhysicalPlan { &self, py: Python<'_>, psets: &HashMap>, + is_ray_runner: bool, ) -> PyResult { match self { PhysicalPlan::InMemoryScan(InMemoryScan { @@ -198,6 +206,7 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::TabularScanCsv(TabularScanCsv { projection_schema, @@ -219,6 +228,7 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::TabularScanJson(TabularScanJson { projection_schema, @@ -240,13 +250,14 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::Project(Project { input, projection, resource_request, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let projection_pyexprs: Vec = projection .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -258,7 +269,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Filter(Filter { input, predicate }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?; let py_predicate = expressions_mod @@ -287,7 +298,7 @@ impl PhysicalPlan { limit, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_physical_plan = py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; let local_limit_iter = py_physical_plan @@ -299,7 +310,7 @@ impl PhysicalPlan { Ok(global_limit_iter.into()) } PhysicalPlan::Explode(Explode { input, to_explode }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let explode_pyexprs: Vec = to_explode .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -316,7 +327,7 @@ impl PhysicalPlan { descending, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let sort_by_pyexprs: Vec = sort_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -337,7 +348,7 @@ impl PhysicalPlan { input_num_partitions, output_num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "split"))? @@ -345,7 +356,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Flatten(Flatten { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "flatten_plan"))? @@ -356,7 +367,7 @@ impl PhysicalPlan { input, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "fanout_random"))? @@ -368,7 +379,7 @@ impl PhysicalPlan { num_partitions, partition_by, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let partition_by_pyexprs: Vec = partition_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -383,7 +394,7 @@ impl PhysicalPlan { "FanoutByRange not implemented, since only use case (sorting) doesn't need it yet." ), PhysicalPlan::ReduceMerge(ReduceMerge { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "reduce_merge"))? @@ -396,7 +407,7 @@ impl PhysicalPlan { input, .. }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let aggs_as_pyexprs: Vec = aggregations .iter() .map(|agg_expr| PyExpr::from(Expr::Agg(agg_expr.clone()))) @@ -416,7 +427,7 @@ impl PhysicalPlan { num_from, num_to, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "coalesce"))? @@ -424,8 +435,8 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Concat(Concat { other, input }) => { - let upstream_input_iter = input.to_partition_tasks(py, psets)?; - let upstream_other_iter = other.to_partition_tasks(py, psets)?; + let upstream_input_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_other_iter = other.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "concat"))? @@ -440,8 +451,8 @@ impl PhysicalPlan { join_type, .. }) => { - let upstream_left_iter = left.to_partition_tasks(py, psets)?; - let upstream_right_iter = right.to_partition_tasks(py, psets)?; + let upstream_left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; let left_on_pyexprs: Vec = left_on .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -474,7 +485,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, @@ -493,7 +504,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, @@ -512,7 +523,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 5a38d0ed02..31b2f5e0fb 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -171,7 +171,7 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { let input_physical = plan(input)?; Ok(PhysicalPlan::Coalesce(Coalesce::new( input_physical.into(), - logical_plan.partition_spec().num_partitions, + input.partition_spec().num_partitions, *num_to, ))) } diff --git a/tests/dataframe/test_repartition.py b/tests/dataframe/test_repartition.py index 92c96e5721..88d6281088 100644 --- a/tests/dataframe/test_repartition.py +++ b/tests/dataframe/test_repartition.py @@ -7,3 +7,9 @@ def test_into_partitions_some_empty() -> None: data = {"foo": [1, 2, 3]} df = daft.from_pydict(data).into_partitions(32).collect() assert df.to_pydict() == data + + +def test_into_partitions_coalesce() -> None: + data = {"foo": list(range(100))} + df = daft.from_pydict(data).into_partitions(20).into_partitions(1).collect() + assert df.to_pydict() == data diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index 99d85ac460..99ab473190 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -15,6 +15,7 @@ pytest.param(lambda data, pat: data.str.contains(pat), id="contains"), pytest.param(lambda data, pat: data.str.startswith(pat), id="startswith"), pytest.param(lambda data, pat: data.str.endswith(pat), id="endswith"), + pytest.param(lambda data, pat: data.str.endswith(pat), id="split"), pytest.param(lambda data, pat: data.str.concat(pat), id="concat"), ], ) diff --git a/tests/integration/docker-compose/nginx-serve-static-files.conf b/tests/integration/docker-compose/nginx-serve-static-files.conf index 0c097a5096..9673ecd43b 100644 --- a/tests/integration/docker-compose/nginx-serve-static-files.conf +++ b/tests/integration/docker-compose/nginx-serve-static-files.conf @@ -11,7 +11,7 @@ http { listen [::]:8080; resolver 127.0.0.11; - autoindex off; + autoindex on; server_name _; server_tokens off; diff --git a/tests/integration/io/conftest.py b/tests/integration/io/conftest.py index 13076bca96..e950cafda1 100644 --- a/tests/integration/io/conftest.py +++ b/tests/integration/io/conftest.py @@ -129,21 +129,32 @@ def mount_data_nginx(nginx_config: tuple[str, pathlib.Path], folder: pathlib.Pat """ server_url, static_assets_tmpdir = nginx_config - # Copy data - for root, dirs, files in os.walk(folder, topdown=False): - for file in files: - shutil.copy2(os.path.join(root, file), os.path.join(static_assets_tmpdir, file)) - for dir in dirs: - shutil.copytree(os.path.join(root, dir), os.path.join(static_assets_tmpdir, dir)) - - yield [f"{server_url}/{p.relative_to(folder)}" for p in folder.glob("**/*") if p.is_file()] - - # Delete data - for root, dirs, files in os.walk(static_assets_tmpdir, topdown=False): - for file in files: - os.remove(os.path.join(root, file)) - for dir in dirs: - os.rmdir(os.path.join(root, dir)) + # Cleanup any old stuff in mount folder + for item in os.listdir(static_assets_tmpdir): + path = static_assets_tmpdir / item + if path.is_dir(): + shutil.rmtree(path) + else: + os.remove(path) + + # Copy data to mount folder + for item in os.listdir(folder): + src = folder / item + dest = static_assets_tmpdir / item + if src.is_dir(): + shutil.copytree(str(src), str(dest)) + else: + shutil.copy2(src, dest) + + try: + yield [f"{server_url}/{p.relative_to(folder)}" for p in folder.glob("**/*") if p.is_file()] + finally: + for item in os.listdir(static_assets_tmpdir): + path = static_assets_tmpdir / item + if path.is_dir(): + shutil.rmtree(static_assets_tmpdir / item) + else: + os.remove(static_assets_tmpdir / item) ### diff --git a/tests/integration/io/test_list_files_gcs.py b/tests/integration/io/test_list_files_gcs.py index 640d9d5be4..1053e25fa1 100644 --- a/tests/integration/io/test_list_files_gcs.py +++ b/tests/integration/io/test_list_files_gcs.py @@ -3,9 +3,11 @@ import gcsfs import pytest -from daft.daft import io_list +from daft.daft import GCSConfig, IOConfig, io_list BUCKET = "daft-public-data-gs" +DEFAULT_GCS_CONFIG = GCSConfig(project_id=None, anonymous=None) +ANON_GCS_CONFIG = GCSConfig(project_id=None, anonymous=True) def gcsfs_recursive_list(fs, path) -> list: @@ -49,28 +51,32 @@ def compare_gcs_result(daft_ls_result: list, fsspec_result: list): ], ) @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_flat_directory_listing(path, recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_flat_directory_listing(path, recursive, gcs_config): fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_single_file_listing(recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_single_file_listing(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/file.txt" fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() -def test_gs_notfound(): +@pytest.mark.parametrize("recursive", [False, True]) +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_notfound(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/MISSING" fs = gcsfs.GCSFileSystem() with pytest.raises(FileNotFoundError): fs.ls(path, detail=True) with pytest.raises(FileNotFoundError, match=path): - io_list(path) + io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) diff --git a/tests/integration/io/test_list_files_http.py b/tests/integration/io/test_list_files_http.py new file mode 100644 index 0000000000..2471133c3f --- /dev/null +++ b/tests/integration/io/test_list_files_http.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from fsspec.implementations.http import HTTPFileSystem + +from daft.daft import io_list +from tests.integration.io.conftest import mount_data_nginx + + +def compare_http_result(daft_ls_result: list, fsspec_result: list): + daft_files = [(f["path"], f["type"].lower(), f["size"]) for f in daft_ls_result] + httpfs_files = [(f["name"], f["type"], f["size"]) for f in fsspec_result] + assert len(daft_files) == len(httpfs_files) + assert sorted(daft_files) == sorted(httpfs_files) + + +@pytest.fixture(scope="module") +def nginx_http_url(nginx_config, tmpdir_factory): + tmpdir = tmpdir_factory.mktemp("test-list-http") + data_path = Path(tmpdir) + (Path(data_path) / "file.txt").touch() + (Path(data_path) / "test_ls").mkdir() + (Path(data_path) / "test_ls" / "file.txt").touch() + (Path(data_path) / "test_ls" / "paginated-10-files").mkdir() + for i in range(10): + (Path(data_path) / "test_ls" / "paginated-10-files" / f"file.{i}.txt").touch() + + with mount_data_nginx(nginx_config, data_path): + yield nginx_config[0] + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "path", + [ + f"", + f"/", + f"/test_ls", + f"/test_ls/", + f"/test_ls//", + f"/test_ls/paginated-10-files/", + ], +) +def test_http_flat_directory_listing(path, nginx_http_url): + http_path = f"{nginx_http_url}{path}" + fs = HTTPFileSystem() + fsspec_result = fs.ls(http_path, detail=True) + daft_ls_result = io_list(http_path) + compare_http_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_gs_single_file_listing(nginx_http_url): + path = f"{nginx_http_url}/test_ls/file.txt" + daft_ls_result = io_list(path) + + # NOTE: FSSpec will return size 0 list for this case, but we want to return 1 element to be + # consistent with behavior of our other file listing utilities + # fs = HTTPFileSystem() + # fsspec_result = fs.ls(path, detail=True) + + assert len(daft_ls_result) == 1 + assert daft_ls_result[0] == {"path": path, "size": 0, "type": "File"} + + +@pytest.mark.integration() +def test_http_notfound(nginx_http_url): + path = f"{nginx_http_url}/test_ls/MISSING" + fs = HTTPFileSystem() + with pytest.raises(FileNotFoundError, match=path): + fs.ls(path, detail=True) + + with pytest.raises(FileNotFoundError, match=path): + io_list(path) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "path", + [ + f"", + f"/", + ], +) +def test_http_flat_directory_listing_recursive(path, nginx_http_url): + http_path = f"{nginx_http_url}/{path}" + fs = HTTPFileSystem() + fsspec_result = list(fs.glob(http_path.rstrip("/") + "/**", detail=True).values()) + daft_ls_result = io_list(http_path, recursive=True) + compare_http_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_http_listing_absolute_urls(nginx_config, tmpdir): + nginx_http_url, _ = nginx_config + + tmpdir = Path(tmpdir) + test_manifest_file = tmpdir / "index.html" + test_manifest_file.write_text( + f""" + this is an absolute path to a file + this is an absolute path to a dir + """ + ) + + with mount_data_nginx(nginx_config, tmpdir): + http_path = f"{nginx_http_url}/index.html" + daft_ls_result = io_list(http_path, recursive=False) + + # NOTE: Cannot use fsspec here because they do not correctly find the links + # fsspec_result = fs.ls(http_path, detail=True) + # compare_http_result(daft_ls_result, fsspec_result) + + assert daft_ls_result == [ + {"type": "File", "path": f"{nginx_http_url}/other.html", "size": None}, + {"type": "Directory", "path": f"{nginx_http_url}/dir/", "size": None}, + ] + + +@pytest.mark.integration() +def test_http_listing_absolute_base_urls(nginx_config, tmpdir): + nginx_http_url, _ = nginx_config + + tmpdir = Path(tmpdir) + test_manifest_file = tmpdir / "index.html" + test_manifest_file.write_text( + f""" + this is an absolute base path to a file + this is an absolute base path to a dir + """ + ) + + with mount_data_nginx(nginx_config, tmpdir): + http_path = f"{nginx_http_url}/index.html" + daft_ls_result = io_list(http_path, recursive=False) + + # NOTE: Cannot use fsspec here because they do not correctly find the links + # fsspec_result = fs.ls(http_path, detail=True) + # compare_http_result(daft_ls_result, fsspec_result) + + assert daft_ls_result == [ + {"type": "File", "path": f"{nginx_http_url}/other.html", "size": None}, + {"type": "Directory", "path": f"{nginx_http_url}/dir/", "size": None}, + ] diff --git a/tests/integration/io/test_list_files_local.py b/tests/integration/io/test_list_files_local.py new file mode 100644 index 0000000000..dfd016038b --- /dev/null +++ b/tests/integration/io/test_list_files_local.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import pytest +from fsspec.implementations.local import LocalFileSystem + +from daft.daft import io_list + + +def local_recursive_list(fs, path) -> list: + all_results = [] + curr_level_result = fs.ls(path, detail=True) + for item in curr_level_result: + if item["type"] == "directory": + new_path = item["name"] + all_results.extend(local_recursive_list(fs, new_path)) + item["name"] += "/" + all_results.append(item) + else: + all_results.append(item) + return all_results + + +def compare_local_result(daft_ls_result: list, fs_result: list): + daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] + fs_files = [(f'file://{f["name"]}', f["type"]) for f in fs_result] + assert sorted(daft_files) == sorted(fs_files) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_flat_directory_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b", "c"] + for name in files: + p = d / name + p.touch() + d = str(d) + if include_protocol: + d = "file://" + d + daft_ls_result = io_list(d) + fs = LocalFileSystem() + fs_result = fs.ls(d, detail=True) + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_recursive_directory_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + d = str(d) + if include_protocol: + d = "file://" + d + daft_ls_result = io_list(d, recursive=True) + fs = LocalFileSystem() + fs_result = local_recursive_list(fs, d) + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +@pytest.mark.parametrize( + "recursive", + [False, True], +) +def test_single_file_directory_listing(tmp_path, include_protocol, recursive): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + p = f"{d}/c/cc/ccc" + if include_protocol: + p = "file://" + p + daft_ls_result = io_list(p, recursive=recursive) + fs_result = [{"name": f"{d}/c/cc/ccc", "type": "file"}] + assert len(daft_ls_result) == 1 + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_missing_file_path(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + p = f"{d}/c/cc/ddd" + if include_protocol: + p = "file://" + p + with pytest.raises(FileNotFoundError, match=f"File: {d}/c/cc/ddd not found"): + daft_ls_result = io_list(p, recursive=True) diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index dafd294423..1019769501 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -102,6 +102,104 @@ def test_series_utf8_compare_invalid_inputs(funcname, bad_series) -> None: getattr(s.str, funcname)(bad_series) +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Single-character pattern. + (["a,b,c", "d,e", "f", "g,h"], [","], [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]]), + # Multi-character pattern. + (["abbcbbd", "bb", "bbe", "fbb"], ["bb"], [["a", "c", "d"], ["", ""], ["", "e"], ["f", ""]]), + # Empty pattern (character-splitting). + (["foo", "bar"], [""], [["", "f", "o", "o", ""], ["", "b", "a", "r", ""]]), + ], +) +def test_series_utf8_split_broadcast_pattern(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + (["a,b,c", "a:b:c", "a;b;c", "a.b.c"], [",", ":", ";", "."], [["a", "b", "c"]] * 4), + (["aabbccdd"] * 4, ["aa", "bb", "cc", "dd"], [["", "bbccdd"], ["aa", "ccdd"], ["aabb", "dd"], ["aabbcc", ""]]), + ], +) +def test_series_utf8_split_multi_pattern(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + (["aabbccdd"], ["aa", "bb", "cc", "dd"], [["", "bbccdd"], ["aa", "ccdd"], ["aabb", "dd"], ["aabbcc", ""]]), + ], +) +def test_series_utf8_split_broadcast_arr(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Mixed-in nulls. + (["a,b,c", None, "a;b;c", "a.b.c"], [",", ":", None, "."], [["a", "b", "c"], None, None, ["a", "b", "c"]]), + # All null data. + ([None] * 4, [","] * 4, [None] * 4), + # All null patterns. + (["foo"] * 4, [None] * 4, [None] * 4), + # Broadcasted null data. + ([None], [","] * 4, [None] * 4), + # Broadcasted null pattern. + (["foo"] * 4, [None], [None] * 4), + ], +) +def test_series_utf8_split_nulls(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data, type=pa.string())) + patterns = Series.from_arrow(pa.array(patterns, type=pa.string())) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Empty data. + ([[], [","] * 4, []]), + # Empty patterns. + ([["foo"] * 4, [], []]), + ], +) +def test_series_utf8_split_empty_arrs(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data, type=pa.string())) + patterns = Series.from_arrow(pa.array(patterns, type=pa.string())) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + "patterns", + [ + # Wrong number of elements, not broadcastable + Series.from_arrow(pa.array([",", "."], type=pa.string())), + # Bad input type + object(), + ], +) +def test_series_utf8_split_invalid_inputs(patterns) -> None: + s = Series.from_arrow(pa.array(["a,b,c", "d, e", "f"])) + with pytest.raises(ValueError): + s.str.split(patterns) + + def test_series_utf8_length() -> None: s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"])) result = s.str.length() diff --git a/tests/table/utf8/test_split.py b/tests/table/utf8/test_split.py new file mode 100644 index 0000000000..0da7735b1d --- /dev/null +++ b/tests/table/utf8/test_split.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col, lit +from daft.table import Table + + +@pytest.mark.parametrize( + ["expr", "data", "expected"], + [ + (col("col").str.split(","), ["a,b,c", "d,e", "f", "g,h"], [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]]), + ( + col("col").str.split(lit(",")), + ["a,b,c", "d,e", "f", "g,h"], + [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]], + ), + ( + col("col").str.split(col("emptystrings") + lit(",")), + ["a,b,c", "d,e", "f", "g,h"], + [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]], + ), + ], +) +def test_series_utf8_split_broadcast_pattern(expr, data, expected) -> None: + table = Table.from_pydict({"col": data, "emptystrings": ["", "", "", ""]}) + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"col": expected}