diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index ebac895e20..fed18367e1 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -342,7 +342,7 @@ pub enum PadPlacement { Right, } -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)] pub struct Utf8NormalizeOptions { pub remove_punct: bool, pub lowercase: bool, diff --git a/src/daft-functions/src/count_matches.rs b/src/daft-functions/src/count_matches.rs index 89df9274a9..a5b5596681 100644 --- a/src/daft-functions/src/count_matches.rs +++ b/src/daft-functions/src/count_matches.rs @@ -7,9 +7,9 @@ use daft_dsl::{ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -struct CountMatchesFunction { - pub(super) whole_words: bool, - pub(super) case_sensitive: bool, +pub struct CountMatchesFunction { + pub whole_words: bool, + pub case_sensitive: bool, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/decode.rs b/src/daft-functions/src/tokenize/decode.rs index 30a713f993..e486f274e8 100644 --- a/src/daft-functions/src/tokenize/decode.rs +++ b/src/daft-functions/src/tokenize/decode.rs @@ -66,11 +66,11 @@ fn tokenize_decode_series( } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct TokenizeDecodeFunction { - pub(super) tokens_path: String, - pub(super) io_config: Option>, - pub(super) pattern: Option, - pub(super) special_tokens: Option, +pub struct TokenizeDecodeFunction { + pub tokens_path: String, + pub io_config: Option>, + pub pattern: Option, + pub special_tokens: Option, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/encode.rs b/src/daft-functions/src/tokenize/encode.rs index e36f9be4d2..a101cf930f 100644 --- a/src/daft-functions/src/tokenize/encode.rs +++ b/src/daft-functions/src/tokenize/encode.rs @@ -70,12 +70,12 @@ fn tokenize_encode_series( } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct TokenizeEncodeFunction { - pub(super) tokens_path: String, - pub(super) io_config: Option>, - pub(super) pattern: Option, - pub(super) special_tokens: Option, - pub(super) use_special_tokens: bool, +pub struct TokenizeEncodeFunction { + pub tokens_path: String, + pub io_config: Option>, + pub pattern: Option, + pub special_tokens: Option, + pub use_special_tokens: bool, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/mod.rs b/src/daft-functions/src/tokenize/mod.rs index 8eb1aee7e1..564ca79226 100644 --- a/src/daft-functions/src/tokenize/mod.rs +++ b/src/daft-functions/src/tokenize/mod.rs @@ -1,7 +1,7 @@ use daft_dsl::{functions::ScalarFunction, ExprRef}; use daft_io::IOConfig; -use decode::TokenizeDecodeFunction; -use encode::TokenizeEncodeFunction; +pub use decode::TokenizeDecodeFunction; +pub use encode::TokenizeEncodeFunction; mod bpe; mod decode; diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index 8a82fcd80a..ae04cfdb84 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -103,6 +103,54 @@ impl SQLFunctionArguments { pub fn get_named(&self, name: &str) -> Option<&ExprRef> { self.named.get(name) } + + pub fn try_get_named(&self, name: &str) -> Result, PlannerError> { + self.named + .get(name) + .map(|expr| T::from_expr(expr)) + .transpose() + } +} + +pub trait SQLLiteral { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized; +} + +impl SQLLiteral for String { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + let e = expr + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| PlannerError::invalid_operation("Expected a string literal"))?; + Ok(e.to_string()) + } +} + +impl SQLLiteral for i64 { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal")) + } +} + +impl SQLLiteral for bool { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(|lit| lit.as_bool()) + .ok_or_else(|| PlannerError::invalid_operation("Expected a boolean literal")) + } } impl SQLFunctions { @@ -214,7 +262,7 @@ impl SQLPlanner { } positional_args.insert(idx, self.try_unwrap_function_arg_expr(arg)?); } - _ => unsupported_sql_err!("unsupported function argument type"), + other => unsupported_sql_err!("unsupported function argument type: {other}, valid function arguments for this function are: {expected_named:?}."), } } diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 310c256e27..8138edd283 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -293,7 +293,7 @@ mod tests { #[case::starts_with("select starts_with(utf8, 'a') as starts_with from tbl1")] #[case::contains("select contains(utf8, 'a') as contains from tbl1")] #[case::split("select split(utf8, '.') as split from tbl1")] - #[case::replace("select replace(utf8, 'a', 'b') as replace from tbl1")] + #[case::replace("select regexp_replace(utf8, 'a', 'b') as replace from tbl1")] #[case::length("select length(utf8) as length from tbl1")] #[case::lower("select lower(utf8) as lower from tbl1")] #[case::upper("select upper(utf8) as upper from tbl1")] diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index 263a8bd9e7..6e7b6b68a5 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -1,12 +1,22 @@ +use daft_core::array::ops::Utf8NormalizeOptions; use daft_dsl::{ - functions::{self, utf8::Utf8Expr}, + functions::{ + self, + utf8::{normalize, Utf8Expr}, + }, ExprRef, LiteralValue, }; +use daft_functions::{ + count_matches::{utf8_count_matches, CountMatchesFunction}, + tokenize::{tokenize_decode, tokenize_encode, TokenizeDecodeFunction, TokenizeEncodeFunction}, +}; use super::SQLModule; use crate::{ - ensure, error::SQLPlannerResult, functions::SQLFunction, invalid_operation_err, - unsupported_sql_err, + ensure, + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments}, + invalid_operation_err, unsupported_sql_err, }; pub struct SQLModuleUtf8; @@ -17,16 +27,18 @@ impl SQLModule for SQLModuleUtf8 { parent.add_fn("ends_with", EndsWith); parent.add_fn("starts_with", StartsWith); parent.add_fn("contains", Contains); - parent.add_fn("split", Split(true)); + parent.add_fn("split", Split(false)); // TODO add split variants // parent.add("split", f(Split(false))); - parent.add_fn("match", Match); - parent.add_fn("extract", Extract(0)); - parent.add_fn("extract_all", ExtractAll(0)); - parent.add_fn("replace", Replace(true)); + parent.add_fn("regexp_match", Match); + parent.add_fn("regexp_extract", Extract(0)); + parent.add_fn("regexp_extract_all", ExtractAll(0)); + parent.add_fn("regexp_replace", Replace(true)); + parent.add_fn("regexp_split", Split(true)); // TODO add replace variants // parent.add("replace", f(Replace(false))); parent.add_fn("length", Length); + parent.add_fn("length_bytes", LengthBytes); parent.add_fn("lower", Lower); parent.add_fn("upper", Upper); parent.add_fn("lstrip", Lstrip); @@ -39,13 +51,13 @@ impl SQLModule for SQLModuleUtf8 { parent.add_fn("rpad", Rpad); parent.add_fn("lpad", Lpad); parent.add_fn("repeat", Repeat); - parent.add_fn("like", Like); - parent.add_fn("ilike", Ilike); - parent.add_fn("substr", Substr); + parent.add_fn("to_date", ToDate("".to_string())); parent.add_fn("to_datetime", ToDatetime("".to_string(), None)); - // TODO add normalization variants. - // parent.add("normalize", f(Normalize(Default::default()))); + parent.add_fn("count_matches", SQLCountMatches); + parent.add_fn("normalize", SQLNormalize); + parent.add_fn("tokenize_encode", SQLTokenizeEncode); + parent.add_fn("tokenize_decode", SQLTokenizeDecode); } } @@ -78,19 +90,44 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { ensure!(args.len() == 2, "contains takes exactly two arguments"); Ok(contains(args[0].clone(), args[1].clone())) } - Split(_) => { + Split(true) => { + ensure!(args.len() == 2, "split takes exactly two arguments"); + Ok(split(args[0].clone(), args[1].clone(), true)) + } + Split(false) => { ensure!(args.len() == 2, "split takes exactly two arguments"); Ok(split(args[0].clone(), args[1].clone(), false)) } Match => { - unsupported_sql_err!("match") - } - Extract(_) => { - unsupported_sql_err!("extract") - } - ExtractAll(_) => { - unsupported_sql_err!("extract_all") + ensure!(args.len() == 2, "regexp_match takes exactly two arguments"); + Ok(match_(args[0].clone(), args[1].clone())) } + Extract(_) => match args { + [input, pattern] => Ok(extract(input.clone(), pattern.clone(), 0)), + [input, pattern, idx] => { + let idx = idx.as_literal().and_then(|lit| lit.as_i64()).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {:?}", idx)) + })?; + + Ok(extract(input.clone(), pattern.clone(), idx as usize)) + } + _ => { + invalid_operation_err!("regexp_extract takes exactly two or three arguments") + } + }, + ExtractAll(_) => match args { + [input, pattern] => Ok(extract_all(input.clone(), pattern.clone(), 0)), + [input, pattern, idx] => { + let idx = idx.as_literal().and_then(|lit| lit.as_i64()).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {:?}", idx)) + })?; + + Ok(extract_all(input.clone(), pattern.clone(), idx as usize)) + } + _ => { + invalid_operation_err!("regexp_extract_all takes exactly two or three arguments") + } + }, Replace(_) => { ensure!(args.len() == 3, "replace takes exactly three arguments"); Ok(replace( @@ -101,10 +138,10 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { )) } Like => { - unsupported_sql_err!("like") + unreachable!("like should be handled by the parser") } Ilike => { - unsupported_sql_err!("ilike") + unreachable!("ilike should be handled by the parser") } Length => { ensure!(args.len() == 1, "length takes exactly one argument"); @@ -163,8 +200,7 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { Ok(repeat(args[0].clone(), args[1].clone())) } Substr => { - ensure!(args.len() == 3, "substr takes exactly three arguments"); - Ok(substr(args[0].clone(), args[1].clone(), args[2].clone())) + unreachable!("substr should be handled by the parser") } ToDate(_) => { ensure!(args.len() == 2, "to_date takes exactly two arguments"); @@ -195,3 +231,233 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { } } } + +pub struct SQLCountMatches; + +impl TryFrom for CountMatchesFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let whole_words = args.try_get_named("whole_words")?.unwrap_or(false); + let case_sensitive = args.try_get_named("case_sensitive")?.unwrap_or(true); + + Ok(Self { + whole_words, + case_sensitive, + }) + } +} + +impl SQLFunction for SQLCountMatches { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, pattern] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + Ok(utf8_count_matches(input, pattern, false, true)) + } + [input, pattern, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + let args: CountMatchesFunction = + planner.plan_function_args(args, &["whole_words", "case_sensitive"], 0)?; + + Ok(utf8_count_matches( + input, + pattern, + args.whole_words, + args.case_sensitive, + )) + } + _ => Err(PlannerError::invalid_operation( + "Invalid arguments for count_matches: '{inputs:?}'", + )), + } + } +} + +pub struct SQLNormalize; + +impl TryFrom for Utf8NormalizeOptions { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let remove_punct = args.try_get_named("remove_punct")?.unwrap_or(false); + let lowercase = args.try_get_named("lowercase")?.unwrap_or(false); + let nfd_unicode = args.try_get_named("nfd_unicode")?.unwrap_or(false); + let white_space = args.try_get_named("white_space")?.unwrap_or(false); + + Ok(Self { + remove_punct, + lowercase, + nfd_unicode, + white_space, + }) + } +} + +impl SQLFunction for SQLNormalize { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(normalize(input, Utf8NormalizeOptions::default())) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: Utf8NormalizeOptions = planner.plan_function_args( + args, + &["remove_punct", "lowercase", "nfd_unicode", "white_space"], + 0, + )?; + Ok(normalize(input, args)) + } + _ => invalid_operation_err!("Invalid arguments for normalize"), + } + } +} + +pub struct SQLTokenizeEncode; +impl TryFrom for TokenizeEncodeFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + if args.get_named("io_config").is_some() { + return Err(PlannerError::invalid_operation( + "io_config argument is not yet supported for tokenize_encode", + )); + } + + let tokens_path = args.try_get_named("tokens_path")?.ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument is required for tokenize_encode") + })?; + + let pattern = args.try_get_named("pattern")?; + let special_tokens = args.try_get_named("special_tokens")?; + let use_special_tokens = args.try_get_named("use_special_tokens")?.unwrap_or(false); + + Ok(Self { + tokens_path, + pattern, + special_tokens, + use_special_tokens, + io_config: None, + }) + } +} + +impl SQLFunction for SQLTokenizeEncode { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, tokens_path] => { + let input = planner.plan_function_arg(input)?; + let tokens_path = planner.plan_function_arg(tokens_path)?; + let tokens_path = tokens_path + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument must be a string") + })?; + Ok(tokenize_encode(input, tokens_path, None, None, None, false)) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: TokenizeEncodeFunction = planner.plan_function_args( + args, + &[ + "tokens_path", + "pattern", + "special_tokens", + "use_special_tokens", + ], + 1, // tokens_path can be named or positional + )?; + Ok(tokenize_encode( + input, + &args.tokens_path, + None, + args.pattern.as_deref(), + args.special_tokens.as_deref(), + args.use_special_tokens, + )) + } + _ => invalid_operation_err!("Invalid arguments for tokenize_encode"), + } + } +} + +pub struct SQLTokenizeDecode; +impl TryFrom for TokenizeDecodeFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + if args.get_named("io_config").is_some() { + return Err(PlannerError::invalid_operation( + "io_config argument is not yet supported for tokenize_decode", + )); + } + + let tokens_path = args.try_get_named("tokens_path")?.ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument is required for tokenize_encode") + })?; + + let pattern = args.try_get_named("pattern")?; + let special_tokens = args.try_get_named("special_tokens")?; + + Ok(Self { + tokens_path, + pattern, + special_tokens, + io_config: None, + }) + } +} +impl SQLFunction for SQLTokenizeDecode { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, tokens_path] => { + let input = planner.plan_function_arg(input)?; + let tokens_path = planner.plan_function_arg(tokens_path)?; + let tokens_path = tokens_path + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument must be a string") + })?; + Ok(tokenize_decode(input, tokens_path, None, None, None)) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: TokenizeDecodeFunction = planner.plan_function_args( + args, + &["tokens_path", "pattern", "special_tokens"], + 1, // tokens_path can be named or positional + )?; + Ok(tokenize_decode( + input, + &args.tokens_path, + None, + args.pattern.as_deref(), + args.special_tokens.as_deref(), + )) + } + _ => invalid_operation_err!("Invalid arguments for tokenize_decode"), + } + } +} diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 2aefab9f96..b76fc66fc0 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -617,7 +617,26 @@ impl SQLPlanner { SQLExpr::Ceil { expr, .. } => Ok(ceil(self.plan_expr(expr)?)), SQLExpr::Floor { expr, .. } => Ok(floor(self.plan_expr(expr)?)), SQLExpr::Position { .. } => unsupported_sql_err!("POSITION"), - SQLExpr::Substring { .. } => unsupported_sql_err!("SUBSTRING"), + SQLExpr::Substring { + expr, + substring_from, + substring_for, + special: true, // We only support SUBSTRING(expr, start, length) syntax + } => { + let (Some(substring_from), Some(substring_for)) = (substring_from, substring_for) + else { + unsupported_sql_err!("SUBSTRING") + }; + + let expr = self.plan_expr(expr)?; + let start = self.plan_expr(substring_from)?; + let length = self.plan_expr(substring_for)?; + + Ok(daft_dsl::functions::utf8::substr(expr, start, length)) + } + SQLExpr::Substring { special: false, .. } => { + unsupported_sql_err!("`SUBSTRING(expr [FROM start] [FOR len])` syntax") + } SQLExpr::Trim { .. } => unsupported_sql_err!("TRIM"), SQLExpr::Overlay { .. } => unsupported_sql_err!("OVERLAY"), SQLExpr::Collate { .. } => unsupported_sql_err!("COLLATE"), diff --git a/tests/sql/test_utf8_exprs.py b/tests/sql/test_utf8_exprs.py new file mode 100644 index 0000000000..12b53a9ebc --- /dev/null +++ b/tests/sql/test_utf8_exprs.py @@ -0,0 +1,113 @@ +import daft +from daft import col + + +def test_utf8_exprs(): + df = daft.from_pydict( + { + "a": [ + "a", + "df_daft", + "foo", + "bar", + "baz", + "lorΓ©m", + "ipsum", + "dolor", + "sit", + "amet", + "😊", + "🌟", + "πŸŽ‰", + "This is a longer with some words", + "THIS is ", + "", + ], + } + ) + + sql = """ + SELECT + ends_with(a, 'a') as ends_with_a, + starts_with(a, 'a') as starts_with_a, + contains(a, 'a') as contains_a, + split(a, ' ') as split_a, + regexp_match(a, 'ba.') as match_a, + regexp_extract(a, 'ba.') as extract_a, + regexp_extract_all(a, 'ba.') as extract_all_a, + regexp_replace(a, 'ba.', 'foo') as replace_a, + regexp_split(a, '\\s+') as regexp_split_a, + length(a) as length_a, + length_bytes(a) as length_bytes_a, + lower(a) as lower_a, + lstrip(a) as lstrip_a, + rstrip(a) as rstrip_a, + reverse(a) as reverse_a, + capitalize(a) as capitalize_a, + left(a, 4) as left_a, + right(a, 4) as right_a, + find(a, 'a') as find_a, + rpad(a, 10, '<') as rpad_a, + lpad(a, 10, '>') as lpad_a, + repeat(a, 2) as repeat_a, + a like 'a%' as like_a, + a ilike 'a%' as ilike_a, + substring(a, 1, 3) as substring_a, + count_matches(a, 'a') as count_matches_a_0, + count_matches(a, 'a', case_sensitive := true) as count_matches_a_1, + count_matches(a, 'a', case_sensitive := false, whole_words := false) as count_matches_a_2, + count_matches(a, 'a', case_sensitive := true, whole_words := true) as count_matches_a_3, + normalize(a) as normalize_a, + normalize(a, remove_punct:=true) as normalize_remove_punct_a, + normalize(a, remove_punct:=true, lowercase:=true) as normalize_remove_punct_lower_a, + normalize(a, remove_punct:=true, lowercase:=true, white_space:=true) as normalize_remove_punct_lower_ws_a, + tokenize_encode(a, 'r50k_base') as tokenize_encode_a, + tokenize_decode(tokenize_encode(a, 'r50k_base'), 'r50k_base') as tokenize_decode_a + FROM df + """ + actual = daft.sql(sql).collect() + expected = ( + df.select( + col("a").str.endswith("a").alias("ends_with_a"), + col("a").str.startswith("a").alias("starts_with_a"), + col("a").str.contains("a").alias("contains_a"), + col("a").str.split(" ").alias("split_a"), + col("a").str.match("ba.").alias("match_a"), + col("a").str.extract("ba.").alias("extract_a"), + col("a").str.extract_all("ba.").alias("extract_all_a"), + col("a").str.split(r"\s+", regex=True).alias("regexp_split_a"), + col("a").str.replace("ba.", "foo").alias("replace_a"), + col("a").str.length().alias("length_a"), + col("a").str.length_bytes().alias("length_bytes_a"), + col("a").str.lower().alias("lower_a"), + col("a").str.lstrip().alias("lstrip_a"), + col("a").str.rstrip().alias("rstrip_a"), + col("a").str.reverse().alias("reverse_a"), + col("a").str.capitalize().alias("capitalize_a"), + col("a").str.left(4).alias("left_a"), + col("a").str.right(4).alias("right_a"), + col("a").str.find("a").alias("find_a"), + col("a").str.rpad(10, "<").alias("rpad_a"), + col("a").str.lpad(10, ">").alias("lpad_a"), + col("a").str.repeat(2).alias("repeat_a"), + col("a").str.like("a%").alias("like_a"), + col("a").str.ilike("a%").alias("ilike_a"), + col("a").str.substr(1, 3).alias("substring_a"), + col("a").str.count_matches("a").alias("count_matches_a_0"), + col("a").str.count_matches("a", case_sensitive=True).alias("count_matches_a_1"), + col("a").str.count_matches("a", case_sensitive=False, whole_words=False).alias("count_matches_a_2"), + col("a").str.count_matches("a", case_sensitive=True, whole_words=True).alias("count_matches_a_3"), + col("a").str.normalize().alias("normalize_a"), + col("a").str.normalize(remove_punct=True).alias("normalize_remove_punct_a"), + col("a").str.normalize(remove_punct=True, lowercase=True).alias("normalize_remove_punct_lower_a"), + col("a") + .str.normalize(remove_punct=True, lowercase=True, white_space=True) + .alias("normalize_remove_punct_lower_ws_a"), + col("a").str.tokenize_encode("r50k_base").alias("tokenize_encode_a"), + col("a").str.tokenize_encode("r50k_base").str.tokenize_decode("r50k_base").alias("tokenize_decode_a"), + ) + .collect() + .to_pydict() + ) + actual = actual.to_pydict() + assert actual == expected