Skip to content

Commit

Permalink
Merge branch main into colin/streaming-catalog-writes
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 1, 2024
2 parents 2972d8f + b2dabf6 commit 138da54
Show file tree
Hide file tree
Showing 19 changed files with 758 additions and 58 deletions.
7 changes: 6 additions & 1 deletion daft/iceberg/iceberg_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Tuple

from daft import Expression, col
from daft.datatype import DataType
from daft.io.common import _get_schema_from_dict
from daft.table import MicroPartition
from daft.table.partitioning import PartitionedTable, partition_strings_to_path

Expand Down Expand Up @@ -211,7 +213,10 @@ def visitor(self, partition_record: "IcebergRecord") -> "IcebergWriteVisitors.Fi
return self.FileVisitor(self, partition_record)

def to_metadata(self) -> MicroPartition:
return MicroPartition.from_pydict({"data_file": self.data_files})
col_name = "data_file"
if len(self.data_files) == 0:
return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()}))
return MicroPartition.from_pydict({col_name: self.data_files})


def partitioned_table_to_iceberg_iter(
Expand Down
22 changes: 15 additions & 7 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
PythonStorageConfig,
StorageConfig,
)
from daft.datatype import DataType
from daft.dependencies import pa, pacsv, pads, pajson, pq
from daft.expressions import ExpressionsProjection, col
from daft.filesystem import (
_resolve_paths_and_filesystem,
canonicalize_protocol,
get_protocol_from_path,
)
from daft.io.common import _get_schema_from_dict
from daft.logical.schema import Schema
from daft.runners.partitioning import (
TableParseCSVOptions,
Expand Down Expand Up @@ -426,16 +428,22 @@ def __call__(self, written_file):
self.parent.paths.append(written_file.path)
self.parent.partition_indices.append(self.idx)

def __init__(self, partition_values: MicroPartition | None, path_key: str = "path"):
def __init__(self, partition_values: MicroPartition | None, schema: Schema):
self.paths: list[str] = []
self.partition_indices: list[int] = []
self.partition_values = partition_values
self.path_key = path_key
self.path_key = schema.column_names()[
0
] # I kept this from our original code, but idk why it's the first column name -kevin
self.schema = schema

def visitor(self, partition_idx: int) -> TabularWriteVisitors.FileVisitor:
return self.FileVisitor(self, partition_idx)

def to_metadata(self) -> MicroPartition:
if len(self.paths) == 0:
return MicroPartition.empty(self.schema)

metadata: dict[str, Any] = {self.path_key: self.paths}

if self.partition_values:
Expand Down Expand Up @@ -488,10 +496,7 @@ def write_tabular(

partitioned = PartitionedTable(table, partition_cols)

# I kept this from our original code, but idk why it's the first column name -kevin
path_key = schema.column_names()[0]

visitors = TabularWriteVisitors(partitioned.partition_values(), path_key)
visitors = TabularWriteVisitors(partitioned.partition_values(), schema)

for i, (part_table, part_path) in enumerate(partitioned_table_to_hive_iter(partitioned, resolved_path)):
size_bytes = part_table.nbytes
Expand Down Expand Up @@ -685,7 +690,10 @@ def visitor(self, partition_values: dict[str, str | None]) -> DeltaLakeWriteVisi
return self.FileVisitor(self, partition_values)

def to_metadata(self) -> MicroPartition:
return MicroPartition.from_pydict({"add_action": self.add_actions})
col_name = "add_action"
if len(self.add_actions) == 0:
return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()}))
return MicroPartition.from_pydict({col_name: self.add_actions})


def write_deltalake(
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ pub enum PadPlacement {
Right,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub struct Utf8NormalizeOptions {
pub remove_punct: bool,
pub lowercase: bool,
Expand Down
6 changes: 3 additions & 3 deletions src/daft-functions/src/count_matches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use daft_dsl::{
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
struct CountMatchesFunction {
pub(super) whole_words: bool,
pub(super) case_sensitive: bool,
pub struct CountMatchesFunction {
pub whole_words: bool,
pub case_sensitive: bool,
}

#[typetag::serde]
Expand Down
8 changes: 4 additions & 4 deletions src/daft-functions/src/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use daft_dsl::{
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(super) struct MinHashFunction {
num_hashes: usize,
ngram_size: usize,
seed: u32,
pub struct MinHashFunction {
pub num_hashes: usize,
pub ngram_size: usize,
pub seed: u32,
}

#[typetag::serde]
Expand Down
10 changes: 5 additions & 5 deletions src/daft-functions/src/tokenize/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ fn tokenize_decode_series(
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(super) struct TokenizeDecodeFunction {
pub(super) tokens_path: String,
pub(super) io_config: Option<Arc<IOConfig>>,
pub(super) pattern: Option<String>,
pub(super) special_tokens: Option<String>,
pub struct TokenizeDecodeFunction {
pub tokens_path: String,
pub io_config: Option<Arc<IOConfig>>,
pub pattern: Option<String>,
pub special_tokens: Option<String>,
}

#[typetag::serde]
Expand Down
12 changes: 6 additions & 6 deletions src/daft-functions/src/tokenize/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ fn tokenize_encode_series(
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(super) struct TokenizeEncodeFunction {
pub(super) tokens_path: String,
pub(super) io_config: Option<Arc<IOConfig>>,
pub(super) pattern: Option<String>,
pub(super) special_tokens: Option<String>,
pub(super) use_special_tokens: bool,
pub struct TokenizeEncodeFunction {
pub tokens_path: String,
pub io_config: Option<Arc<IOConfig>>,
pub pattern: Option<String>,
pub special_tokens: Option<String>,
pub use_special_tokens: bool,
}

#[typetag::serde]
Expand Down
4 changes: 2 additions & 2 deletions src/daft-functions/src/tokenize/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use daft_dsl::{functions::ScalarFunction, ExprRef};
use daft_io::IOConfig;
use decode::TokenizeDecodeFunction;
use encode::TokenizeEncodeFunction;
pub use decode::TokenizeDecodeFunction;
pub use encode::TokenizeEncodeFunction;

mod bpe;
mod decode;
Expand Down
57 changes: 55 additions & 2 deletions src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashMap, sync::Arc};

use daft_dsl::ExprRef;
use hashing::SQLModuleHashing;
use once_cell::sync::Lazy;
use sqlparser::ast::{
Function, FunctionArg, FunctionArgExpr, FunctionArgOperator, FunctionArguments,
Expand All @@ -18,6 +19,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy<SQLFunctions> = Lazy::new(|| {
let mut functions = SQLFunctions::new();
functions.register::<SQLModuleAggs>();
functions.register::<SQLModuleFloat>();
functions.register::<SQLModuleHashing>();
functions.register::<SQLModuleImage>();
functions.register::<SQLModuleJson>();
functions.register::<SQLModuleList>();
Expand Down Expand Up @@ -103,6 +105,54 @@ impl SQLFunctionArguments {
pub fn get_named(&self, name: &str) -> Option<&ExprRef> {
self.named.get(name)
}

pub fn try_get_named<T: SQLLiteral>(&self, name: &str) -> Result<Option<T>, PlannerError> {
self.named
.get(name)
.map(|expr| T::from_expr(expr))
.transpose()
}
}

pub trait SQLLiteral {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized;
}

impl SQLLiteral for String {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized,
{
let e = expr
.as_literal()
.and_then(|lit| lit.as_str())
.ok_or_else(|| PlannerError::invalid_operation("Expected a string literal"))?;
Ok(e.to_string())
}
}

impl SQLLiteral for i64 {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized,
{
expr.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal"))
}
}

impl SQLLiteral for bool {
fn from_expr(expr: &ExprRef) -> Result<Self, PlannerError>
where
Self: Sized,
{
expr.as_literal()
.and_then(|lit| lit.as_bool())
.ok_or_else(|| PlannerError::invalid_operation("Expected a boolean literal"))
}
}

impl SQLFunctions {
Expand Down Expand Up @@ -214,7 +264,7 @@ impl SQLPlanner {
}
positional_args.insert(idx, self.try_unwrap_function_arg_expr(arg)?);
}
_ => unsupported_sql_err!("unsupported function argument type"),
other => unsupported_sql_err!("unsupported function argument type: {other}, valid function arguments for this function are: {expected_named:?}."),
}
}

Expand All @@ -235,7 +285,10 @@ impl SQLPlanner {
}
}

fn try_unwrap_function_arg_expr(&self, expr: &FunctionArgExpr) -> SQLPlannerResult<ExprRef> {
pub(crate) fn try_unwrap_function_arg_expr(
&self,
expr: &FunctionArgExpr,
) -> SQLPlannerResult<ExprRef> {
match expr {
FunctionArgExpr::Expr(expr) => self.plan_expr(expr),
_ => unsupported_sql_err!("Wildcard function args not yet supported"),
Expand Down
2 changes: 1 addition & 1 deletion src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ mod tests {
#[case::starts_with("select starts_with(utf8, 'a') as starts_with from tbl1")]
#[case::contains("select contains(utf8, 'a') as contains from tbl1")]
#[case::split("select split(utf8, '.') as split from tbl1")]
#[case::replace("select replace(utf8, 'a', 'b') as replace from tbl1")]
#[case::replace("select regexp_replace(utf8, 'a', 'b') as replace from tbl1")]
#[case::length("select length(utf8) as length from tbl1")]
#[case::lower("select lower(utf8) as lower from tbl1")]
#[case::upper("select upper(utf8) as upper from tbl1")]
Expand Down
111 changes: 111 additions & 0 deletions src/daft-sql/src/modules/hashing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use daft_dsl::ExprRef;
use daft_functions::{
hash::hash,
minhash::{minhash, MinHashFunction},
};
use sqlparser::ast::FunctionArg;

use super::SQLModule;
use crate::{
error::{PlannerError, SQLPlannerResult},
functions::{SQLFunction, SQLFunctionArguments, SQLFunctions},
unsupported_sql_err,
};

pub struct SQLModuleHashing;

impl SQLModule for SQLModuleHashing {
fn register(parent: &mut SQLFunctions) {
parent.add_fn("hash", SQLHash);
parent.add_fn("minhash", SQLMinhash);
}
}

pub struct SQLHash;

impl SQLFunction for SQLHash {
fn to_expr(
&self,
inputs: &[FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> SQLPlannerResult<ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(hash(input, None))
}
[input, seed] => {
let input = planner.plan_function_arg(input)?;
match seed {
FunctionArg::Named { name, arg, .. } if name.value == "seed" => {
let seed = planner.try_unwrap_function_arg_expr(arg)?;
Ok(hash(input, Some(seed)))
}
arg @ FunctionArg::Unnamed(_) => {
let seed = planner.plan_function_arg(arg)?;
Ok(hash(input, Some(seed)))
}
_ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"),
}
}
_ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"),
}
}
}

pub struct SQLMinhash;

impl TryFrom<SQLFunctionArguments> for MinHashFunction {
type Error = PlannerError;

fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> {
let num_hashes = args
.get_named("num_hashes")
.ok_or_else(|| PlannerError::invalid_operation("num_hashes is required"))?
.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))?
as usize;

let ngram_size = args
.get_named("ngram_size")
.ok_or_else(|| PlannerError::invalid_operation("ngram_size is required"))?
.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("ngram_size must be an integer"))?
as usize;
let seed = args
.get_named("seed")
.map(|arg| {
arg.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))
})
.transpose()?
.unwrap_or(1) as u32;
Ok(Self {
num_hashes,
ngram_size,
seed,
})
}
}

impl SQLFunction for SQLMinhash {
fn to_expr(
&self,
inputs: &[FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> SQLPlannerResult<ExprRef> {
match inputs {
[input, args @ ..] => {
let input = planner.plan_function_arg(input)?;
let args: MinHashFunction =
planner.plan_function_args(args, &["num_hashes", "ngram_size", "seed"], 0)?;

Ok(minhash(input, args.num_hashes, args.ngram_size, args.seed))
}
_ => unsupported_sql_err!("Invalid arguments for minhash: '{inputs:?}'"),
}
}
}
1 change: 1 addition & 0 deletions src/daft-sql/src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::functions::SQLFunctions;

pub mod aggs;
pub mod float;
pub mod hashing;
pub mod image;
pub mod json;
pub mod list;
Expand Down
Loading

0 comments on commit 138da54

Please sign in to comment.