Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into tpcds
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Oct 23, 2024
2 parents 4597ecc + c69ee3f commit c2b344c
Show file tree
Hide file tree
Showing 31 changed files with 290 additions and 163 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ jobs:
run: |
uv pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall
rm -rf daft
- name: Install ODBC Driver 18 for SQL Server
run: |
curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add -
sudo add-apt-repository https://packages.microsoft.com/ubuntu/$(lsb_release -rs)/prod
sudo apt-get update
sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18
- name: Spin up services
run: |
pushd ./tests/integration/sql/docker-compose/
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ScanTask,
StorageConfig,
)
from daft.io.aws_config import boto3_client_from_s3_config
from daft.io.object_store_options import io_config_to_storage_options
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema
Expand All @@ -43,6 +44,24 @@ def __init__(
deltalake_sdk_io_config = storage_config.config.io_config
scheme = urlparse(table_uri).scheme
if scheme == "s3" or scheme == "s3a":
# Try to get region from boto3
if deltalake_sdk_io_config.s3.region_name is None:
from botocore.exceptions import BotoCoreError

try:
client = boto3_client_from_s3_config("s3", deltalake_sdk_io_config.s3)
response = client.get_bucket_location(Bucket=urlparse(table_uri).netloc)
except BotoCoreError as e:
logger.warning(
"Failed to get the S3 bucket region using existing storage config, will attempt to get it from the environment instead. Error from boto3: %s",
e,
)
else:
deltalake_sdk_io_config = deltalake_sdk_io_config.replace(
s3=deltalake_sdk_io_config.s3.replace(region_name=response["LocationConstraint"])
)

# Try to get config from the environment
if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]):
try:
s3_config_from_env = S3Config.from_env()
Expand Down
21 changes: 21 additions & 0 deletions daft/io/aws_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import TYPE_CHECKING

from daft.daft import S3Config

if TYPE_CHECKING:
import boto3


def boto3_client_from_s3_config(service: str, s3_config: S3Config) -> "boto3.client":
import boto3

return boto3.client(
service,
region_name=s3_config.region_name,
use_ssl=s3_config.use_ssl,
verify=s3_config.verify_ssl,
endpoint_url=s3_config.endpoint_url,
aws_access_key_id=s3_config.key_id,
aws_secret_access_key=s3_config.access_key,
aws_session_token=s3_config.session_token,
)
15 changes: 2 additions & 13 deletions daft/io/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

from daft.daft import IOConfig
from daft.io.aws_config import boto3_client_from_s3_config


class DataCatalogType(Enum):
Expand Down Expand Up @@ -42,20 +43,8 @@ def table_uri(self, io_config: IOConfig) -> str:
"""
if self.catalog == DataCatalogType.GLUE:
# Use boto3 to get the table from AWS Glue Data Catalog.
import boto3
glue = boto3_client_from_s3_config("glue", io_config.s3)

s3_config = io_config.s3

glue = boto3.client(
"glue",
region_name=s3_config.region_name,
use_ssl=s3_config.use_ssl,
verify=s3_config.verify_ssl,
endpoint_url=s3_config.endpoint_url,
aws_access_key_id=s3_config.key_id,
aws_secret_access_key=s3_config.access_key,
aws_session_token=s3_config.session_token,
)
if self.catalog_id is not None:
# Allow cross account access, table.catalog_id should be the target account id
glue_table = glue.get_table(
Expand Down
8 changes: 6 additions & 2 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,18 @@ def _get_num_rows(self) -> int:

def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]:
try:
# Try to get percentiles using percentile_cont
# Try to get percentiles using percentile_disc.
# Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons.
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
# Use the OVER clause for SQL Server
over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}"
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
limit=1,
)
pa_table = self.conn.execute_sql_query(percentile_sql)
return pa_table, PartitionBoundStrategy.PERCENTILE
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ trino[sqlalchemy]==0.328.0; python_version >= '3.8'
PyMySQL==1.1.0; python_version >= '3.8'
psycopg2-binary==2.9.9; python_version >= '3.8'
sqlglot==23.3.0; python_version >= '3.8'
pyodbc==5.1.0; python_version >= '3.8'

# AWS
s3fs==2023.12.0; python_version >= '3.8'
# on old versions of s3fs's pinned botocore, they neglected to pin urllib3<2 which leads to:
Expand Down
16 changes: 2 additions & 14 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,21 +990,9 @@ impl Expr {
to_sql_inner(inner, buffer)?;
write!(buffer, ") IS NOT NULL")
}
Expr::IfElse {
if_true,
if_false,
predicate,
} => {
write!(buffer, "CASE WHEN ")?;
to_sql_inner(predicate, buffer)?;
write!(buffer, " THEN ")?;
to_sql_inner(if_true, buffer)?;
write!(buffer, " ELSE ")?;
to_sql_inner(if_false, buffer)?;
write!(buffer, " END")
}
// TODO: Implement SQL translations for these expressions if possible
Expr::Agg(..)
Expr::IfElse { .. }
| Expr::Agg(..)
| Expr::Cast(..)
| Expr::IsIn(..)
| Expr::Between(..)
Expand Down
1 change: 1 addition & 0 deletions src/daft-local-execution/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ common-tracing = {path = "../common/tracing", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-csv = {path = "../daft-csv", default-features = false}
daft-dsl = {path = "../daft-dsl", default-features = false}
daft-functions = {path = "../daft-functions", default-features = false}
daft-io = {path = "../daft-io", default-features = false}
daft-json = {path = "../daft-json", default-features = false}
daft-micropartition = {path = "../daft-micropartition", default-features = false}
Expand Down
42 changes: 42 additions & 0 deletions src/daft-local-execution/src/intermediate_ops/explode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_dsl::ExprRef;
use daft_functions::list::explode;
use tracing::instrument;

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
};
use crate::pipeline::PipelineResultType;

pub struct ExplodeOperator {
to_explode: Vec<ExprRef>,
}

impl ExplodeOperator {
pub fn new(to_explode: Vec<ExprRef>) -> Self {
Self {
to_explode: to_explode.into_iter().map(explode).collect(),
}
}
}

impl IntermediateOperator for ExplodeOperator {
#[instrument(skip_all, name = "ExplodeOperator::execute")]
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().explode(&self.to_explode)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
}

fn name(&self) -> &'static str {
"ExplodeOperator"
}
}
1 change: 1 addition & 0 deletions src/daft-local-execution/src/intermediate_ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod aggregate;
pub mod anti_semi_hash_join_probe;
pub mod buffer;
pub mod explode;
pub mod filter;
pub mod inner_hash_join_probe;
pub mod intermediate_op;
Expand Down
31 changes: 19 additions & 12 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use daft_core::{
use daft_dsl::{col, join::get_common_join_keys, Expr};
use daft_micropartition::MicroPartition;
use daft_physical_plan::{
EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Pivot,
Project, Sample, Sort, UnGroupedAggregate, Unpivot,
Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit,
LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot,
};
use daft_plan::{populate_aggregation_stages, JoinType};
use daft_table::ProbeState;
Expand All @@ -22,12 +22,13 @@ use crate::{
channel::PipelineChannel,
intermediate_ops::{
aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator,
filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator,
intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator,
sample::SampleOperator, unpivot::UnpivotOperator,
explode::ExplodeOperator, filter::FilterOperator,
inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode,
pivot::PivotOperator, project::ProjectOperator, sample::SampleOperator,
unpivot::UnpivotOperator,
},
sinks::{
aggregate::AggregateSink, blocking_sink::BlockingSinkNode,
aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink,
hash_join_build::HashJoinBuildSink, limit::LimitSink,
outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink,
streaming_sink::StreamingSinkNode,
Expand Down Expand Up @@ -145,19 +146,25 @@ pub fn physical_plan_to_pipeline(
let child_node = physical_plan_to_pipeline(input, psets)?;
IntermediateNode::new(Arc::new(filter_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Explode(Explode {
input, to_explode, ..
}) => {
let explode_op = ExplodeOperator::new(to_explode.clone());
let child_node = physical_plan_to_pipeline(input, psets)?;
IntermediateNode::new(Arc::new(explode_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Limit(Limit {
input, num_rows, ..
}) => {
let sink = LimitSink::new(*num_rows as usize);
let child_node = physical_plan_to_pipeline(input, psets)?;
StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed()
}
LocalPhysicalPlan::Concat(_) => {
todo!("concat")
// let sink = ConcatSink::new();
// let left_child = physical_plan_to_pipeline(input, psets)?;
// let right_child = physical_plan_to_pipeline(other, psets)?;
// PipelineNode::double_sink(sink, left_child, right_child)
LocalPhysicalPlan::Concat(Concat { input, other, .. }) => {
let left_child = physical_plan_to_pipeline(input, psets)?;
let right_child = physical_plan_to_pipeline(other, psets)?;
let sink = ConcatSink {};
StreamingSinkNode::new(Arc::new(sink), vec![left_child, right_child]).boxed()
}
LocalPhysicalPlan::UnGroupedAggregate(UnGroupedAggregate {
input,
Expand Down
Loading

0 comments on commit c2b344c

Please sign in to comment.