Skip to content

Commit

Permalink
[PERF] Optimize string normalization (#2474)
Browse files Browse the repository at this point in the history
Slight optimization of string normalization. Speeds up that step around
2x (mostly by removing regex)
  • Loading branch information
Vince7778 authored Jul 8, 2024
1 parent f142166 commit e5ff7d7
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use common_error::{DaftError, DaftResult};
use itertools::Itertools;
use num_traits::NumCast;
use serde::{Deserialize, Serialize};
use unicode_normalization::UnicodeNormalization;
use unicode_normalization::{is_nfd_quick, IsNormalized, UnicodeNormalization};

use super::{as_arrow::AsArrow, full::FullNull};

Expand Down Expand Up @@ -1339,41 +1339,48 @@ impl Utf8Array {
}

pub fn normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult<Utf8Array> {
let whitespace_regex = regex::Regex::new(r"\s+").unwrap();

let arrow_result = self
.as_arrow()
.iter()
.map(|maybe_s| {
Ok(Utf8Array::from_iter(
self.name(),
self.as_arrow().iter().map(|maybe_s| {
if let Some(s) = maybe_s {
let mut s = s.to_string();
let mut s = if opts.white_space {
s.trim().to_string()
} else {
s.to_string()
};

if opts.remove_punct {
s = s.chars().filter(|c| !c.is_ascii_punctuation()).collect();
}
let mut prev_white = true;
s = s
.chars()
.filter_map(|c| {
if !(opts.remove_punct && c.is_ascii_punctuation()
|| opts.white_space && c.is_whitespace())
{
prev_white = false;
Some(c)
} else if prev_white || (opts.remove_punct && c.is_ascii_punctuation())
{
None
} else {
prev_white = true;
Some(' ')
}
})
.collect();

if opts.lowercase {
s = s.to_lowercase();
}

if opts.white_space {
s = whitespace_regex
.replace_all(s.as_str().trim(), " ")
.to_string();
}

if opts.nfd_unicode {
if opts.nfd_unicode && is_nfd_quick(s.chars()) != IsNormalized::Yes {
s = s.nfd().collect();
}

Ok(Some(s))
Some(s)
} else {
Ok(None)
None
}
})
.collect::<DaftResult<arrow2::array::Utf8Array<i64>>>()?;

Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}),
))
}

fn unary_broadcasted_op<ScalarKernel>(&self, operation: ScalarKernel) -> DaftResult<Utf8Array>
Expand Down

0 comments on commit e5ff7d7

Please sign in to comment.