diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 3affeecc4c..0a7d2de10a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -582,6 +582,12 @@ jobs: run: | uv pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall rm -rf daft + - name: Install ODBC Driver 18 for SQL Server + run: | + curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add - + sudo add-apt-repository https://packages.microsoft.com/ubuntu/$(lsb_release -rs)/prod + sudo apt-get update + sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 - name: Spin up services run: | pushd ./tests/integration/sql/docker-compose/ diff --git a/Cargo.lock b/Cargo.lock index c57309e10f..3fcc294fef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1960,6 +1960,7 @@ dependencies = [ "daft-core", "daft-csv", "daft-dsl", + "daft-functions", "daft-io", "daft-json", "daft-micropartition", diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index a5c7b94b29..f1357c2fa7 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -17,6 +17,7 @@ ScanTask, StorageConfig, ) +from daft.io.aws_config import boto3_client_from_s3_config from daft.io.object_store_options import io_config_to_storage_options from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema @@ -43,6 +44,24 @@ def __init__( deltalake_sdk_io_config = storage_config.config.io_config scheme = urlparse(table_uri).scheme if scheme == "s3" or scheme == "s3a": + # Try to get region from boto3 + if deltalake_sdk_io_config.s3.region_name is None: + from botocore.exceptions import BotoCoreError + + try: + client = boto3_client_from_s3_config("s3", deltalake_sdk_io_config.s3) + response = client.get_bucket_location(Bucket=urlparse(table_uri).netloc) + except BotoCoreError as e: + logger.warning( + "Failed to get the S3 bucket region using existing storage config, will attempt to get it from the environment instead. Error from boto3: %s", + e, + ) + else: + deltalake_sdk_io_config = deltalake_sdk_io_config.replace( + s3=deltalake_sdk_io_config.s3.replace(region_name=response["LocationConstraint"]) + ) + + # Try to get config from the environment if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]): try: s3_config_from_env = S3Config.from_env() diff --git a/daft/io/aws_config.py b/daft/io/aws_config.py new file mode 100644 index 0000000000..7f0e9e3dff --- /dev/null +++ b/daft/io/aws_config.py @@ -0,0 +1,21 @@ +from typing import TYPE_CHECKING + +from daft.daft import S3Config + +if TYPE_CHECKING: + import boto3 + + +def boto3_client_from_s3_config(service: str, s3_config: S3Config) -> "boto3.client": + import boto3 + + return boto3.client( + service, + region_name=s3_config.region_name, + use_ssl=s3_config.use_ssl, + verify=s3_config.verify_ssl, + endpoint_url=s3_config.endpoint_url, + aws_access_key_id=s3_config.key_id, + aws_secret_access_key=s3_config.access_key, + aws_session_token=s3_config.session_token, + ) diff --git a/daft/io/catalog.py b/daft/io/catalog.py index 1183caa8ab..62cb16e672 100644 --- a/daft/io/catalog.py +++ b/daft/io/catalog.py @@ -5,6 +5,7 @@ from typing import Optional from daft.daft import IOConfig +from daft.io.aws_config import boto3_client_from_s3_config class DataCatalogType(Enum): @@ -42,20 +43,8 @@ def table_uri(self, io_config: IOConfig) -> str: """ if self.catalog == DataCatalogType.GLUE: # Use boto3 to get the table from AWS Glue Data Catalog. - import boto3 + glue = boto3_client_from_s3_config("glue", io_config.s3) - s3_config = io_config.s3 - - glue = boto3.client( - "glue", - region_name=s3_config.region_name, - use_ssl=s3_config.use_ssl, - verify=s3_config.verify_ssl, - endpoint_url=s3_config.endpoint_url, - aws_access_key_id=s3_config.key_id, - aws_secret_access_key=s3_config.access_key, - aws_session_token=s3_config.session_token, - ) if self.catalog_id is not None: # Allow cross account access, table.catalog_id should be the target account id glue_table = glue.get_table( diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 4f0f9a35c7..4d3156ae80 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -161,14 +161,18 @@ def _get_num_rows(self) -> int: def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: try: - # Try to get percentiles using percentile_cont + # Try to get percentiles using percentile_disc. + # Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons. percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] + # Use the OVER clause for SQL Server + over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else "" percentile_sql = self.conn.construct_sql_query( self.sql, projection=[ - f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}" for i, percentile in enumerate(percentiles) ], + limit=1, ) pa_table = self.conn.execute_sql_query(percentile_sql) return pa_table, PartitionBoundStrategy.PERCENTILE diff --git a/requirements-dev.txt b/requirements-dev.txt index 069d699e31..ea9a982c65 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -67,6 +67,8 @@ trino[sqlalchemy]==0.328.0; python_version >= '3.8' PyMySQL==1.1.0; python_version >= '3.8' psycopg2-binary==2.9.9; python_version >= '3.8' sqlglot==23.3.0; python_version >= '3.8' +pyodbc==5.1.0; python_version >= '3.8' + # AWS s3fs==2023.12.0; python_version >= '3.8' # on old versions of s3fs's pinned botocore, they neglected to pin urllib3<2 which leads to: diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 873f9013bd..567a2d35d8 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -990,21 +990,9 @@ impl Expr { to_sql_inner(inner, buffer)?; write!(buffer, ") IS NOT NULL") } - Expr::IfElse { - if_true, - if_false, - predicate, - } => { - write!(buffer, "CASE WHEN ")?; - to_sql_inner(predicate, buffer)?; - write!(buffer, " THEN ")?; - to_sql_inner(if_true, buffer)?; - write!(buffer, " ELSE ")?; - to_sql_inner(if_false, buffer)?; - write!(buffer, " END") - } // TODO: Implement SQL translations for these expressions if possible - Expr::Agg(..) + Expr::IfElse { .. } + | Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Between(..) diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index cd061c1c35..932462215d 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -7,6 +7,7 @@ common-tracing = {path = "../common/tracing", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} +daft-functions = {path = "../daft-functions", default-features = false} daft-io = {path = "../daft-io", default-features = false} daft-json = {path = "../daft-json", default-features = false} daft-micropartition = {path = "../daft-micropartition", default-features = false} diff --git a/src/daft-local-execution/src/intermediate_ops/explode.rs b/src/daft-local-execution/src/intermediate_ops/explode.rs new file mode 100644 index 0000000000..774be696a8 --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/explode.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_dsl::ExprRef; +use daft_functions::list::explode; +use tracing::instrument; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; +use crate::pipeline::PipelineResultType; + +pub struct ExplodeOperator { + to_explode: Vec, +} + +impl ExplodeOperator { + pub fn new(to_explode: Vec) -> Self { + Self { + to_explode: to_explode.into_iter().map(explode).collect(), + } + } +} + +impl IntermediateOperator for ExplodeOperator { + #[instrument(skip_all, name = "ExplodeOperator::execute")] + fn execute( + &self, + _idx: usize, + input: &PipelineResultType, + _state: Option<&mut Box>, + ) -> DaftResult { + let out = input.as_data().explode(&self.to_explode)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( + out, + )))) + } + + fn name(&self) -> &'static str { + "ExplodeOperator" + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 098bbcbfbe..7d97464e24 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,6 +1,7 @@ pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod buffer; +pub mod explode; pub mod filter; pub mod inner_hash_join_probe; pub mod intermediate_op; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index b5531fcf84..eccece1a56 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -10,8 +10,8 @@ use daft_core::{ use daft_dsl::{col, join::get_common_join_keys, Expr}; use daft_micropartition::MicroPartition; use daft_physical_plan::{ - EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Pivot, - Project, Sample, Sort, UnGroupedAggregate, Unpivot, + Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, + LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; use daft_plan::{populate_aggregation_stages, JoinType}; use daft_table::ProbeState; @@ -22,12 +22,13 @@ use crate::{ channel::PipelineChannel, intermediate_ops::{ aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, - filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, - intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator, - sample::SampleOperator, unpivot::UnpivotOperator, + explode::ExplodeOperator, filter::FilterOperator, + inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, + pivot::PivotOperator, project::ProjectOperator, sample::SampleOperator, + unpivot::UnpivotOperator, }, sinks::{ - aggregate::AggregateSink, blocking_sink::BlockingSinkNode, + aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, hash_join_build::HashJoinBuildSink, limit::LimitSink, outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink, streaming_sink::StreamingSinkNode, @@ -145,6 +146,13 @@ pub fn physical_plan_to_pipeline( let child_node = physical_plan_to_pipeline(input, psets)?; IntermediateNode::new(Arc::new(filter_op), vec![child_node]).boxed() } + LocalPhysicalPlan::Explode(Explode { + input, to_explode, .. + }) => { + let explode_op = ExplodeOperator::new(to_explode.clone()); + let child_node = physical_plan_to_pipeline(input, psets)?; + IntermediateNode::new(Arc::new(explode_op), vec![child_node]).boxed() + } LocalPhysicalPlan::Limit(Limit { input, num_rows, .. }) => { @@ -152,12 +160,11 @@ pub fn physical_plan_to_pipeline( let child_node = physical_plan_to_pipeline(input, psets)?; StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed() } - LocalPhysicalPlan::Concat(_) => { - todo!("concat") - // let sink = ConcatSink::new(); - // let left_child = physical_plan_to_pipeline(input, psets)?; - // let right_child = physical_plan_to_pipeline(other, psets)?; - // PipelineNode::double_sink(sink, left_child, right_child) + LocalPhysicalPlan::Concat(Concat { input, other, .. }) => { + let left_child = physical_plan_to_pipeline(input, psets)?; + let right_child = physical_plan_to_pipeline(other, psets)?; + let sink = ConcatSink {}; + StreamingSinkNode::new(Arc::new(sink), vec![left_child, right_child]).boxed() } LocalPhysicalPlan::UnGroupedAggregate(UnGroupedAggregate { input, diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index 010bed0aaf..5b98cb84c6 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -1,61 +1,68 @@ -// use std::sync::Arc; - -// use common_error::DaftResult; -// use daft_micropartition::MicroPartition; -// use tracing::instrument; - -// use super::sink::{Sink, SinkResultType}; - -// #[derive(Clone)] -// pub struct ConcatSink { -// result_left: Vec>, -// result_right: Vec>, -// } - -// impl ConcatSink { -// pub fn new() -> Self { -// Self { -// result_left: Vec::new(), -// result_right: Vec::new(), -// } -// } - -// #[instrument(skip_all, name = "ConcatSink::sink")] -// fn sink_left(&mut self, input: &Arc) -> DaftResult { -// self.result_left.push(input.clone()); -// Ok(SinkResultType::NeedMoreInput) -// } - -// #[instrument(skip_all, name = "ConcatSink::sink")] -// fn sink_right(&mut self, input: &Arc) -> DaftResult { -// self.result_right.push(input.clone()); -// Ok(SinkResultType::NeedMoreInput) -// } -// } - -// impl Sink for ConcatSink { -// fn sink(&mut self, index: usize, input: &Arc) -> DaftResult { -// match index { -// 0 => self.sink_left(input), -// 1 => self.sink_right(input), -// _ => panic!("concat only supports 2 inputs, got {index}"), -// } -// } - -// fn in_order(&self) -> bool { -// true -// } - -// fn num_inputs(&self) -> usize { -// 2 -// } - -// #[instrument(skip_all, name = "ConcatSink::finalize")] -// fn finalize(self: Box) -> DaftResult>> { -// Ok(self -// .result_left -// .into_iter() -// .chain(self.result_right.into_iter()) -// .collect()) -// } -// } +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use daft_micropartition::MicroPartition; +use tracing::instrument; + +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct ConcatSinkState { + // The index of the last morsel of data that was received, which should be strictly non-decreasing. + pub curr_idx: usize, +} +impl StreamingSinkState for ConcatSinkState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub struct ConcatSink {} + +impl StreamingSink for ConcatSink { + /// Execute for the ConcatSink operator does not do any computation and simply returns the input data. + /// It only expects that the indices of the input data are strictly non-decreasing. + /// TODO(Colin): If maintain_order is false, technically we could accept any index. Make this optimization later. + #[instrument(skip_all, name = "ConcatSink::sink")] + fn execute( + &self, + index: usize, + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("ConcatSink should have ConcatSinkState"); + + // If the index is the same as the current index or one more than the current index, then we can accept the morsel. + if state.curr_idx == index || state.curr_idx + 1 == index { + state.curr_idx = index; + Ok(StreamingSinkOutput::NeedMoreInput(Some( + input.as_data().clone(), + ))) + } else { + Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index))) + } + } + + fn name(&self) -> &'static str { + "Concat" + } + + fn finalize( + &self, + _states: Vec>, + ) -> DaftResult>> { + Ok(None) + } + + fn make_state(&self) -> Box { + Box::new(ConcatSinkState { curr_idx: 0 }) + } + + /// Since the ConcatSink does not do any computation, it does not need to spawn multiple workers. + fn max_concurrency(&self) -> usize { + 1 + } +} diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 0a7000af8f..6e8a022cdb 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -26,18 +26,30 @@ pub enum StreamingSinkOutput { } pub trait StreamingSink: Send + Sync { + /// Execute the StreamingSink operator on the morsel of input data, + /// received from the child with the given index, + /// with the given state. fn execute( &self, index: usize, input: &PipelineResultType, state: &mut dyn StreamingSinkState, ) -> DaftResult; + + /// Finalize the StreamingSink operator, with the given states from each worker. fn finalize( &self, states: Vec>, ) -> DaftResult>>; + + /// The name of the StreamingSink operator. fn name(&self) -> &'static str; + + /// Create a new worker-local state for this StreamingSink. fn make_state(&self) -> Box; + + /// The maximum number of concurrent workers that can be spawned for this sink. + /// Each worker will has its own StreamingSinkState. fn max_concurrency(&self) -> usize { *NUM_CPUS } @@ -118,6 +130,8 @@ impl StreamingSinkNode { output_receiver } + // Forwards input from the children to the workers in a round-robin fashion. + // Always exhausts the input from one child before moving to the next. async fn forward_input_to_workers( receivers: Vec, worker_senders: Vec>, diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 84dfa5a83e..730d0b3315 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -313,10 +313,9 @@ pub struct ParquetFileReader { impl ParquetFileReader { const DEFAULT_CHUNK_SIZE: usize = 2048; - // Set to a very high number 256MB to guard against unbounded large - // downloads from remote storage, which likely indicates corrupted Parquet data - // See: https://github.com/Eventual-Inc/Daft/issues/1551 - const MAX_HEADER_SIZE: usize = 256 * 1024 * 1024; + // Set to 2GB because that's the maximum size of strings allowable by Parquet (using i32 offsets). + // See issue: https://github.com/Eventual-Inc/Daft/issues/3007 + const MAX_PAGE_SIZE: usize = 2 * 1024 * 1024 * 1024; fn new( uri: String, @@ -473,7 +472,7 @@ impl ParquetFileReader { range_reader, vec![], Arc::new(|_, _| true), - Self::MAX_HEADER_SIZE, + Self::MAX_PAGE_SIZE, ) .with_context( |_| UnableToCreateParquetPageStreamSnafu:: { @@ -638,7 +637,7 @@ impl ParquetFileReader { range_reader, vec![], Arc::new(|_, _| true), - Self::MAX_HEADER_SIZE, + Self::MAX_PAGE_SIZE, ) .with_context(|_| { UnableToCreateParquetPageStreamSnafu:: { @@ -821,7 +820,7 @@ impl ParquetFileReader { range_reader, vec![], Arc::new(|_, _| true), - Self::MAX_HEADER_SIZE, + Self::MAX_PAGE_SIZE, ) .with_context(|_| { UnableToCreateParquetPageStreamSnafu:: { diff --git a/src/daft-physical-plan/src/lib.rs b/src/daft-physical-plan/src/lib.rs index 75aa616394..ba20720855 100644 --- a/src/daft-physical-plan/src/lib.rs +++ b/src/daft-physical-plan/src/lib.rs @@ -3,8 +3,8 @@ mod local_plan; mod translate; pub use local_plan::{ - Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, - LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, Sort, - UnGroupedAggregate, Unpivot, + Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, + LocalPhysicalPlan, LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, + Sort, UnGroupedAggregate, Unpivot, }; pub use translate::translate; diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 4ba861798e..94672c2463 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -15,7 +15,7 @@ pub enum LocalPhysicalPlan { Project(Project), Filter(Filter), Limit(Limit), - // Explode(Explode), + Explode(Explode), Unpivot(Unpivot), Sort(Sort), // Split(Split), @@ -107,6 +107,20 @@ impl LocalPhysicalPlan { .arced() } + pub(crate) fn explode( + input: LocalPhysicalPlanRef, + to_explode: Vec, + schema: SchemaRef, + ) -> LocalPhysicalPlanRef { + Self::Explode(Explode { + input, + to_explode, + schema, + plan_stats: PlanStats {}, + }) + .arced() + } + pub(crate) fn project( input: LocalPhysicalPlanRef, projection: Vec, @@ -272,6 +286,7 @@ impl LocalPhysicalPlan { | Self::Sort(Sort { schema, .. }) | Self::Sample(Sample { schema, .. }) | Self::HashJoin(HashJoin { schema, .. }) + | Self::Explode(Explode { schema, .. }) | Self::Unpivot(Unpivot { schema, .. }) | Self::Concat(Concat { schema, .. }) => schema, Self::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema, @@ -323,6 +338,14 @@ pub struct Limit { pub plan_stats: PlanStats, } +#[derive(Debug)] +pub struct Explode { + pub input: LocalPhysicalPlanRef, + pub to_explode: Vec, + pub schema: SchemaRef, + pub plan_stats: PlanStats, +} + #[derive(Debug)] pub struct Sort { pub input: LocalPhysicalPlanRef, diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index 726b3232d5..7dcb0f552b 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -158,6 +158,14 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op"); translate(&repartition.input) } + LogicalPlan::Explode(explode) => { + let input = translate(&explode.input)?; + Ok(LocalPhysicalPlan::explode( + input, + explode.to_explode.clone(), + explode.exploded_schema.clone(), + )) + } _ => todo!("{} not yet implemented", plan.name()), } } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index e651b6528f..55823e5843 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -11,7 +11,7 @@ use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ - ArrayElemTypeDef, BinaryOperator, CastKind, ExactNumberInfo, ExcludeSelectItem, + ArrayElemTypeDef, BinaryOperator, CastKind, Distinct, ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions, }, @@ -202,6 +202,15 @@ impl SQLPlanner { } } + match &selection.distinct { + Some(Distinct::Distinct) => { + let rel = self.relation_mut(); + rel.inner = rel.inner.distinct()?; + } + Some(Distinct::On(_)) => unsupported_sql_err!("DISTINCT ON"), + None => {} + } + if let Some(order_by) = &query.order_by { if order_by.interpolate.is_some() { unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); @@ -1186,9 +1195,7 @@ fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult if selection.top.is_some() { unsupported_sql_err!("TOP"); } - if selection.distinct.is_some() { - unsupported_sql_err!("DISTINCT"); - } + if selection.into.is_some() { unsupported_sql_err!("INTO"); } diff --git a/src/parquet2/src/read/page/stream.rs b/src/parquet2/src/read/page/stream.rs index 25fb0fe6fc..523fc335b6 100644 --- a/src/parquet2/src/read/page/stream.rs +++ b/src/parquet2/src/read/page/stream.rs @@ -37,7 +37,7 @@ pub async fn get_page_stream_from_column_start<'a, R: AsyncRead + Unpin + Send>( reader: &'a mut R, scratch: Vec, pages_filter: PageFilter, - max_header_size: usize, + max_page_size: usize, ) -> Result> + 'a> { let page_metadata: PageMetaData = column_metadata.into(); Ok(_get_page_stream( @@ -47,7 +47,7 @@ pub async fn get_page_stream_from_column_start<'a, R: AsyncRead + Unpin + Send>( page_metadata.descriptor, scratch, pages_filter, - max_header_size, + max_page_size, )) } @@ -56,7 +56,7 @@ pub fn get_owned_page_stream_from_column_start( reader: R, scratch: Vec, pages_filter: PageFilter, - max_header_size: usize, + max_page_size: usize, ) -> Result>> { let page_metadata: PageMetaData = column_metadata.into(); Ok(_get_owned_page_stream( @@ -66,7 +66,7 @@ pub fn get_owned_page_stream_from_column_start( page_metadata.descriptor, scratch, pages_filter, - max_header_size, + max_page_size, )) } diff --git a/tests/dataframe/test_approx_count_distinct.py b/tests/dataframe/test_approx_count_distinct.py index 78d2a7b181..68d7057ca0 100644 --- a/tests/dataframe/test_approx_count_distinct.py +++ b/tests/dataframe/test_approx_count_distinct.py @@ -2,12 +2,7 @@ import pytest import daft -from daft import col, context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) +from daft import col TESTS = [ [[], 0], diff --git a/tests/dataframe/test_concat.py b/tests/dataframe/test_concat.py index 07e06df59c..f3caf56bb1 100644 --- a/tests/dataframe/test_concat.py +++ b/tests/dataframe/test_concat.py @@ -2,13 +2,6 @@ import pytest -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - def test_simple_concat(make_df): df1 = make_df({"foo": [1, 2, 3]}) diff --git a/tests/dataframe/test_explode.py b/tests/dataframe/test_explode.py index 0e8dbd73d2..26416f9938 100644 --- a/tests/dataframe/test_explode.py +++ b/tests/dataframe/test_explode.py @@ -3,14 +3,8 @@ import pyarrow as pa import pytest -from daft import context from daft.expressions import col -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.mark.parametrize( "data", diff --git a/tests/dataframe/test_sample.py b/tests/dataframe/test_sample.py index 791e2a2211..109b9f332b 100644 --- a/tests/dataframe/test_sample.py +++ b/tests/dataframe/test_sample.py @@ -2,8 +2,6 @@ import pytest -from daft import context - def test_sample_fraction(make_df, valid_data: list[dict[str, float]]) -> None: df = make_df(valid_data) @@ -100,10 +98,6 @@ def test_sample_without_replacement(make_df, valid_data: list[dict[str, float]]) assert pylist[0] != pylist[1] -@pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for concat", -) def test_sample_with_concat(make_df, valid_data: list[dict[str, float]]) -> None: df1 = make_df(valid_data) df2 = make_df(valid_data) diff --git a/tests/dataframe/test_transform.py b/tests/dataframe/test_transform.py index 277c378bad..a698b6e7fd 100644 --- a/tests/dataframe/test_transform.py +++ b/tests/dataframe/test_transform.py @@ -3,12 +3,6 @@ import pytest import daft -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) def add_1(df): diff --git a/tests/dataframe/test_wildcard.py b/tests/dataframe/test_wildcard.py index 3497912be5..e732292c53 100644 --- a/tests/dataframe/test_wildcard.py +++ b/tests/dataframe/test_wildcard.py @@ -1,14 +1,9 @@ import pytest import daft -from daft import col, context +from daft import col from daft.exceptions import DaftCoreException -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - def test_wildcard_select(): df = daft.from_pydict( diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index f5c01dccc6..e202eed471 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -26,6 +26,7 @@ "trino://user@localhost:8080/memory/default", "postgresql://username:password@localhost:5432/postgres", "mysql+pymysql://username:password@localhost:3306/mysql", + "mssql+pyodbc://SA:StrongPassword!@127.0.0.1:1433/master?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes", ] TEST_TABLE_NAME = "example" EMPTY_TEST_TABLE_NAME = "empty_table" diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml index 11c391b0d3..b8eb8c3eba 100644 --- a/tests/integration/sql/docker-compose/docker-compose.yml +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -31,6 +31,18 @@ services: volumes: - mysql_data:/var/lib/mysql + azuresqledge: + image: mcr.microsoft.com/azure-sql-edge + container_name: azuresqledge + environment: + ACCEPT_EULA: "Y" + MSSQL_SA_PASSWORD: "StrongPassword!" + ports: + - 1433:1433 + volumes: + - azuresqledge_data:/var/opt/mssql + volumes: postgres_data: mysql_data: + azuresqledge_data: diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index ff02ebaac4..7983be00c7 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -141,6 +141,10 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: ) @pytest.mark.parametrize("num_partitions", [1, 2]) def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions, pdf) -> None: + # Skip invalid comparisons for bool_col + if column == "bool_col" and operator not in ("=", "!="): + pytest.skip(f"Operator {operator} not valid for bool_col") + df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, @@ -204,13 +208,15 @@ def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) - @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions, pdf) -> None: +def test_sql_read_with_non_pushdowned_predicate(test_db, num_partitions, pdf) -> None: df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions, ) + + # If_else is not supported as a pushdown to read_sql, but it should still work df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)] diff --git a/tests/sql/test_list_exprs.py b/tests/sql/test_list_exprs.py index 9b76735e44..2f0799fb71 100644 --- a/tests/sql/test_list_exprs.py +++ b/tests/sql/test_list_exprs.py @@ -1,8 +1,7 @@ import pyarrow as pa -import pytest import daft -from daft import col, context +from daft import col from daft.daft import CountMode from daft.sql.sql import SQLCatalog @@ -62,8 +61,6 @@ def test_list_counts(): def test_list_explode(): - if context.get_context().daft_execution_config.enable_native_executor is True: - pytest.skip("Native executor fails for these tests") df = daft.from_pydict({"col": [[1, 2, 3], [1, 2], [1, None, 4], []]}) catalog = SQLCatalog({"test": df}) expected = df.explode(col("col")) diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 8b8cce43b5..6bcd716854 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -214,3 +214,10 @@ def test_sql_tbl_alias(): catalog = SQLCatalog({"df": daft.from_pydict({"n": [1, 2, 3]})}) df = daft.sql("SELECT df_alias.n FROM df AS df_alias where df_alias.n = 2", catalog) assert df.collect().to_pydict() == {"n": [2]} + + +def test_sql_distinct(): + df = daft.from_pydict({"n": [1, 1, 2, 2]}) + actual = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict() + expected = df.distinct().collect().to_pydict() + assert actual == expected