Skip to content

Commit

Permalink
Theres a LOT oh my god
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Oct 1, 2024
1 parent 7d8947c commit f30efca
Show file tree
Hide file tree
Showing 6 changed files with 855 additions and 113 deletions.
3 changes: 0 additions & 3 deletions docs/source/api_docs/sql.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
SQL
===

SQL Functions
-------------

.. autofunction:: daft.sql

.. autofunction:: daft.sql_expr
Expand Down
220 changes: 110 additions & 110 deletions src/daft-sql/src/modules/aggs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,116 @@ use crate::{

pub struct SQLModuleAggs;

impl SQLModule for SQLModuleAggs {
fn register(parent: &mut SQLFunctions) {
use AggExpr::*;
// HACK TO USE AggExpr as an enum rather than a
let nil = Arc::new(Expr::Literal(LiteralValue::Null));
parent.add_fn(
"count",
Count(nil.clone(), daft_core::count_mode::CountMode::Valid),
);
parent.add_fn("sum", Sum(nil.clone()));
parent.add_fn("avg", Mean(nil.clone()));
parent.add_fn("mean", Mean(nil.clone()));
parent.add_fn("min", Min(nil.clone()));
parent.add_fn("max", Max(nil.clone()));
}
}

impl SQLFunction for AggExpr {
fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
// COUNT(*) needs a bit of extra handling, so we process that outside of `to_expr`
if let Self::Count(_, _) = self {
handle_count(inputs, planner)
} else {
let inputs = self.args_to_expr_unnamed(inputs, planner)?;
to_expr(self, inputs.as_slice())
}
}

fn docstrings(&self, alias: &str) -> String {
match self {
Self::Count(_, _) => static_docs::COUNT_DOCSTRING.to_string(),
Self::Sum(_) => static_docs::SUM_DOCSTRING.to_string(),
Self::Mean(_) => static_docs::AVG_DOCSTRING.replace("{}", alias),
Self::Min(_) => static_docs::MIN_DOCSTRING.to_string(),
Self::Max(_) => static_docs::MAX_DOCSTRING.to_string(),
e => unimplemented!("Need to implement docstrings for {e}"),

Check warning on line 52 in src/daft-sql/src/modules/aggs.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/aggs.rs#L52

Added line #L52 was not covered by tests
}
}

fn arg_names(&self) -> &'static [&'static str] {
match self {
Self::Count(_, _) | Self::Sum(_) | Self::Mean(_) | Self::Min(_) | Self::Max(_) => {
&["input"]
}
e => unimplemented!("Need to implement arg names for {e}"),

Check warning on line 61 in src/daft-sql/src/modules/aggs.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/aggs.rs#L61

Added line #L61 was not covered by tests
}
}
}

fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
Ok(match inputs {
[FunctionArg::Unnamed(FunctionArgExpr::Wildcard)] => match planner.relation_opt() {
Some(rel) => {
let schema = rel.schema();
col(schema.fields[0].name.clone())
.count(daft_core::count_mode::CountMode::All)
.alias("count")
}
None => unsupported_sql_err!("Wildcard is not supported in this context"),
},
[FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(name))] => {
match planner.relation_opt() {
Some(rel) if name.to_string() == rel.name => {
let schema = rel.schema();
col(schema.fields[0].name.clone())
.count(daft_core::count_mode::CountMode::All)
.alias("count")
}
_ => unsupported_sql_err!("Wildcard is not supported in this context"),
}
}
[expr] => {
// SQL default COUNT ignores nulls
let input = planner.plan_function_arg(expr)?;
input.count(daft_core::count_mode::CountMode::Valid)
}
_ => unsupported_sql_err!("COUNT takes exactly one argument"),
})
}

pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult<ExprRef> {
match expr {
AggExpr::Count(_, _) => unreachable!("count should be handled by by this point"),
AggExpr::Sum(_) => {
ensure!(args.len() == 1, "sum takes exactly one argument");
Ok(args[0].clone().sum())
}
AggExpr::ApproxCountDistinct(_) => unsupported_sql_err!("approx_percentile"),
AggExpr::ApproxPercentile(_) => unsupported_sql_err!("approx_percentile"),
AggExpr::ApproxSketch(_, _) => unsupported_sql_err!("approx_sketch"),
AggExpr::MergeSketch(_, _) => unsupported_sql_err!("merge_sketch"),
AggExpr::Mean(_) => {
ensure!(args.len() == 1, "mean takes exactly one argument");
Ok(args[0].clone().mean())
}
AggExpr::Min(_) => {
ensure!(args.len() == 1, "min takes exactly one argument");
Ok(args[0].clone().min())
}
AggExpr::Max(_) => {
ensure!(args.len() == 1, "max takes exactly one argument");
Ok(args[0].clone().max())
}
AggExpr::AnyValue(_, _) => unsupported_sql_err!("any_value"),
AggExpr::List(_) => unsupported_sql_err!("list"),
AggExpr::Concat(_) => unsupported_sql_err!("concat"),
AggExpr::MapGroups { .. } => unsupported_sql_err!("map_groups"),
}
}

mod static_docs {
pub(crate) const COUNT_DOCSTRING: &str =
"Counts the number of non-null elements in the input expression.
Expand Down Expand Up @@ -211,113 +321,3 @@ Example:
╰───────╯
(Showing first 1 of 1 rows)";
}

impl SQLModule for SQLModuleAggs {
fn register(parent: &mut SQLFunctions) {
use AggExpr::*;
// HACK TO USE AggExpr as an enum rather than a
let nil = Arc::new(Expr::Literal(LiteralValue::Null));
parent.add_fn(
"count",
Count(nil.clone(), daft_core::count_mode::CountMode::Valid),
);
parent.add_fn("sum", Sum(nil.clone()));
parent.add_fn("avg", Mean(nil.clone()));
parent.add_fn("mean", Mean(nil.clone()));
parent.add_fn("min", Min(nil.clone()));
parent.add_fn("max", Max(nil.clone()));
}
}

impl SQLFunction for AggExpr {
fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
// COUNT(*) needs a bit of extra handling, so we process that outside of `to_expr`
if let Self::Count(_, _) = self {
handle_count(inputs, planner)
} else {
let inputs = self.args_to_expr_unnamed(inputs, planner)?;
to_expr(self, inputs.as_slice())
}
}

fn docstrings(&self, alias: &str) -> String {
match self {
Self::Count(_, _) => static_docs::COUNT_DOCSTRING.to_string(),
Self::Sum(_) => static_docs::SUM_DOCSTRING.to_string(),
Self::Mean(_) => static_docs::AVG_DOCSTRING.replace("{}", alias),
Self::Min(_) => static_docs::MIN_DOCSTRING.to_string(),
Self::Max(_) => static_docs::MAX_DOCSTRING.to_string(),
e => unimplemented!("Need to implement docstrings for {e}"),
}
}

fn arg_names(&self) -> &'static [&'static str] {
match self {
Self::Count(_, _) | Self::Sum(_) | Self::Mean(_) | Self::Min(_) | Self::Max(_) => {
&["input"]
}
e => unimplemented!("Need to implement arg names for {e}"),
}
}
}

fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
Ok(match inputs {
[FunctionArg::Unnamed(FunctionArgExpr::Wildcard)] => match planner.relation_opt() {
Some(rel) => {
let schema = rel.schema();
col(schema.fields[0].name.clone())
.count(daft_core::count_mode::CountMode::All)
.alias("count")
}
None => unsupported_sql_err!("Wildcard is not supported in this context"),
},
[FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(name))] => {
match planner.relation_opt() {
Some(rel) if name.to_string() == rel.name => {
let schema = rel.schema();
col(schema.fields[0].name.clone())
.count(daft_core::count_mode::CountMode::All)
.alias("count")
}
_ => unsupported_sql_err!("Wildcard is not supported in this context"),
}
}
[expr] => {
// SQL default COUNT ignores nulls
let input = planner.plan_function_arg(expr)?;
input.count(daft_core::count_mode::CountMode::Valid)
}
_ => unsupported_sql_err!("COUNT takes exactly one argument"),
})
}

pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult<ExprRef> {
match expr {
AggExpr::Count(_, _) => unreachable!("count should be handled by by this point"),
AggExpr::Sum(_) => {
ensure!(args.len() == 1, "sum takes exactly one argument");
Ok(args[0].clone().sum())
}
AggExpr::ApproxCountDistinct(_) => unsupported_sql_err!("approx_percentile"),
AggExpr::ApproxPercentile(_) => unsupported_sql_err!("approx_percentile"),
AggExpr::ApproxSketch(_, _) => unsupported_sql_err!("approx_sketch"),
AggExpr::MergeSketch(_, _) => unsupported_sql_err!("merge_sketch"),
AggExpr::Mean(_) => {
ensure!(args.len() == 1, "mean takes exactly one argument");
Ok(args[0].clone().mean())
}
AggExpr::Min(_) => {
ensure!(args.len() == 1, "min takes exactly one argument");
Ok(args[0].clone().min())
}
AggExpr::Max(_) => {
ensure!(args.len() == 1, "max takes exactly one argument");
Ok(args[0].clone().max())
}
AggExpr::AnyValue(_, _) => unsupported_sql_err!("any_value"),
AggExpr::List(_) => unsupported_sql_err!("list"),
AggExpr::Concat(_) => unsupported_sql_err!("concat"),
AggExpr::MapGroups { .. } => unsupported_sql_err!("map_groups"),
}
}
Loading

0 comments on commit f30efca

Please sign in to comment.