Skip to content

Commit

Permalink
Add temporal type inference and parsing test coverage, misc. fixes an…
Browse files Browse the repository at this point in the history
…d extension thereof.
  • Loading branch information
clarkzinzow committed Dec 1, 2023
1 parent 007425f commit 6ba9130
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 63 deletions.
26 changes: 23 additions & 3 deletions src/daft-decoding/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ pub(crate) const ALL_NAIVE_TIMESTAMP_FMTS: &[&str] = &[
];
pub(crate) const ALL_TIMESTAMP_FMTS: &[&str] = &[ISO8601, RFC3339_WITH_SPACE];

pub(crate) const ISO8601_DATE: &str = "%Y-%m-%d";
pub(crate) const ISO8601_DATE_SLASHES: &str = "%Y/%m/%d";
pub(crate) const ALL_NAIVE_DATE_FMTS: &[&str] = &[ISO8601_DATE, ISO8601_DATE_SLASHES];

// Ideally this trait should not be needed and both `csv` and `csv_async` crates would share
// the same `ByteRecord` struct. Unfortunately, they do not and thus we must use generics
// over this trait and materialize the generics for each struct.
Expand Down Expand Up @@ -152,6 +156,20 @@ fn deserialize_null<B: ByteRecordGeneric>(rows: &[B], _: usize) -> Box<dyn Array
Box::new(NullArray::new(DataType::Null, rows.len()))
}

#[inline]
pub fn deserialize_naive_date(string: &str, fmt_idx: &mut usize) -> Option<chrono::NaiveDate> {
// TODO(Clark): Parse as all candidate formats in a single pass.
for i in 0..ALL_NAIVE_DATE_FMTS.len() {
let idx = (i + *fmt_idx) % ALL_NAIVE_DATE_FMTS.len();
let fmt = ALL_NAIVE_DATE_FMTS[idx];
if let Ok(dt) = chrono::NaiveDate::parse_from_str(string, fmt) {
*fmt_idx = idx;
return Some(dt);
}
}
None
}

#[inline]
pub fn deserialize_naive_datetime(
string: &str,
Expand Down Expand Up @@ -237,13 +255,15 @@ pub fn deserialize_column<B: ByteRecordGeneric>(
lexical_core::parse::<f64>(bytes).ok()
}),
Date32 => deserialize_primitive(rows, column, datatype, |bytes| {
let mut last_fmt_idx = 0;
to_utf8(bytes)
.and_then(|x| x.parse::<chrono::NaiveDate>().ok())
.and_then(|x| deserialize_naive_date(x, &mut last_fmt_idx))
.map(|x| x.num_days_from_ce() - temporal_conversions::EPOCH_DAYS_FROM_CE)
}),
Date64 => deserialize_primitive(rows, column, datatype, |bytes| {
let mut last_fmt_idx = 0;
to_utf8(bytes)
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.and_then(|x| deserialize_naive_datetime(x, &mut last_fmt_idx))
.map(|x| x.timestamp_millis())
}),
Time32(time_unit) => deserialize_primitive(rows, column, datatype, |bytes| {
Expand Down Expand Up @@ -313,7 +333,7 @@ pub fn deserialize_column<B: ByteRecordGeneric>(
}

// Return the factor by how small is a time unit compared to seconds
fn get_factor_from_timeunit(time_unit: TimeUnit) -> u32 {
pub fn get_factor_from_timeunit(time_unit: TimeUnit) -> u32 {
match time_unit {
TimeUnit::Second => 1,
TimeUnit::Millisecond => 1_000,
Expand Down
28 changes: 20 additions & 8 deletions src/daft-decoding/src/inference.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow2::datatypes::{DataType, TimeUnit};
use chrono::Timelike;

use crate::deserialize::{ALL_NAIVE_TIMESTAMP_FMTS, ALL_TIMESTAMP_FMTS};
use crate::deserialize::{ALL_NAIVE_DATE_FMTS, ALL_NAIVE_TIMESTAMP_FMTS, ALL_TIMESTAMP_FMTS};

/// Infers [`DataType`] from `bytes`
/// # Implementation
Expand Down Expand Up @@ -35,12 +35,15 @@ pub fn infer(bytes: &[u8]) -> arrow2::datatypes::DataType {
pub fn infer_string(string: &str) -> DataType {
if is_date(string) {
DataType::Date32
} else if is_time(string) {
DataType::Time32(TimeUnit::Millisecond)
} else if let Some(time_unit) = is_naive_datetime(string) {
DataType::Timestamp(time_unit, None)
} else if let Some(time_unit) = is_time(string) {
DataType::Time32(time_unit)
} else if let Some((time_unit, offset)) = is_datetime(string) {
// NOTE: We try to parse as a non-naive datatime (with timezone information) first,
// since is_datetime() will return false if timezone information is not present in the string,
// while is_naive_datetime() will ignore timezone information in the string.
DataType::Timestamp(time_unit, Some(offset))
} else if let Some(time_unit) = is_naive_datetime(string) {
DataType::Timestamp(time_unit, None)
} else {
DataType::Utf8
}
Expand All @@ -63,11 +66,20 @@ fn is_integer(bytes: &[u8]) -> bool {
}

fn is_date(string: &str) -> bool {
string.parse::<chrono::NaiveDate>().is_ok()
for fmt in ALL_NAIVE_DATE_FMTS {
if chrono::NaiveDate::parse_from_str(string, fmt).is_ok() {
return true;
}
}
false
}

fn is_time(string: &str) -> bool {
string.parse::<chrono::NaiveTime>().is_ok()
fn is_time(string: &str) -> Option<TimeUnit> {
if let Ok(t) = string.parse::<chrono::NaiveTime>() {
let time_unit = nanoseconds_to_time_unit(t.nanosecond());
return Some(time_unit);
}
None
}

fn is_naive_datetime(string: &str) -> Option<TimeUnit> {
Expand Down
63 changes: 54 additions & 9 deletions src/daft-json/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ use arrow2::error::{Error, Result};
use arrow2::offset::Offsets;
use arrow2::temporal_conversions;
use arrow2::types::{f16, NativeType, Offset};
use daft_decoding::deserialize::{deserialize_datetime, deserialize_naive_datetime};
use chrono::{Datelike, Timelike};
use daft_decoding::deserialize::{
deserialize_datetime, deserialize_naive_date, deserialize_naive_datetime,
get_factor_from_timeunit,
};
use indexmap::IndexMap;
use json_deserializer::{Number, Value};

Expand Down Expand Up @@ -135,20 +139,20 @@ fn deserialize_into<'a, A: Borrow<Value<'a>>>(target: &mut Box<dyn MutableArray>
}
DataType::Int8 => deserialize_primitive_into::<_, i8>(target, rows, deserialize_int_into),
DataType::Int16 => deserialize_primitive_into::<_, i16>(target, rows, deserialize_int_into),
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
DataType::Int32 | DataType::Interval(IntervalUnit::YearMonth) => {
deserialize_primitive_into::<_, i32>(target, rows, deserialize_int_into)
}
DataType::Date32 | DataType::Time32(_) => {
deserialize_primitive_into::<_, i32>(target, rows, deserialize_date_into)
}
DataType::Interval(IntervalUnit::DayTime) => {
unimplemented!("There is no natural representation of DayTime in JSON.")
}
DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Duration(_) => {
DataType::Int64 | DataType::Duration(_) => {
deserialize_primitive_into::<_, i64>(target, rows, deserialize_int_into)
}
DataType::Timestamp(..) => {
deserialize_primitive_into::<_, i64>(target, rows, deserialize_timestamp_into)
DataType::Timestamp(..) | DataType::Date64 | DataType::Time64(_) => {
deserialize_primitive_into::<_, i64>(target, rows, deserialize_datetime_into)
}
DataType::UInt8 => deserialize_primitive_into::<_, u8>(target, rows, deserialize_int_into),
DataType::UInt16 => {
Expand Down Expand Up @@ -369,7 +373,34 @@ where
}
}

fn deserialize_timestamp_into<'a, A: Borrow<Value<'a>>>(
fn deserialize_date_into<'a, A: Borrow<Value<'a>>>(
target: &mut MutablePrimitiveArray<i32>,
rows: &[A],
) {
let dtype = target.data_type().clone();
let mut last_fmt_idx = 0;
let iter = rows.iter().map(|row| match row.borrow() {
Value::Number(v) => Some(deserialize_int_single(*v)),
Value::String(v) => match dtype {
DataType::Time32(tu) => {
let factor = get_factor_from_timeunit(tu);
v.parse::<chrono::NaiveTime>().ok().map(|x| {
(x.hour() * 3_600 * factor
+ x.minute() * 60 * factor
+ x.second() * factor
+ x.nanosecond() / (1_000_000_000 / factor)) as i32
})
}
DataType::Date32 => deserialize_naive_date(v, &mut last_fmt_idx)
.map(|x| x.num_days_from_ce() - temporal_conversions::EPOCH_DAYS_FROM_CE),
_ => unreachable!(),
},
_ => None,
});
target.extend_trusted_len(iter);
}

fn deserialize_datetime_into<'a, A: Borrow<Value<'a>>>(
target: &mut MutablePrimitiveArray<i64>,
rows: &[A],
) {
Expand All @@ -378,6 +409,19 @@ fn deserialize_timestamp_into<'a, A: Borrow<Value<'a>>>(
let iter = rows.iter().map(|row| match row.borrow() {
Value::Number(v) => Some(deserialize_int_single(*v)),
Value::String(v) => match dtype {
DataType::Time64(tu) => {
let factor = get_factor_from_timeunit(tu) as u64;
v.parse::<chrono::NaiveTime>().ok().map(|x| {
(x.hour() as u64 * 3_600 * factor
+ x.minute() as u64 * 60 * factor
+ x.second() as u64 * factor
+ x.nanosecond() as u64 / (1_000_000_000 / factor))
as i64
})
}
DataType::Date64 => {
deserialize_naive_datetime(v, &mut last_fmt_idx).map(|x| x.timestamp_millis())
}
DataType::Timestamp(tu, None) => deserialize_naive_datetime(v, &mut last_fmt_idx)
.and_then(|dt| match tu {
TimeUnit::Second => Some(dt.timestamp()),
Expand All @@ -386,6 +430,7 @@ fn deserialize_timestamp_into<'a, A: Borrow<Value<'a>>>(
TimeUnit::Nanosecond => dt.timestamp_nanos_opt(),
}),
DataType::Timestamp(tu, Some(ref tz)) => {
let tz = if tz == "Z" { "UTC" } else { tz };
let tz = temporal_conversions::parse_offset(tz).unwrap();
deserialize_datetime(v, &tz, &mut last_fmt_idx).and_then(|dt| match tu {
TimeUnit::Second => Some(dt.timestamp()),
Expand Down
118 changes: 81 additions & 37 deletions src/daft-json/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ fn infer_array(values: &[Value]) -> Result<DataType> {

let dt = if !types.is_empty() {
let types = types.into_iter().collect::<Vec<_>>();
coerce_data_type(&types)
coerce_data_type(types)
} else {
DataType::Null
};
Expand Down Expand Up @@ -101,7 +101,7 @@ pub(crate) fn column_types_map_to_fields(
.map(|(name, dtype_set)| {
let dtypes = dtype_set.into_iter().collect::<Vec<_>>();
// Get consolidated dtype for column.
let dtype = coerce_data_type(dtypes.as_slice());
let dtype = coerce_data_type(dtypes);
arrow2::datatypes::Field::new(name, dtype, true)
})
.collect::<Vec<_>>()
Expand All @@ -113,21 +113,21 @@ pub(crate) fn column_types_map_to_fields(
/// * Lists and scalars are coerced to a list of a compatible scalar
/// * Structs contain the union of all fields
/// * All other types are coerced to `Utf8`
pub(crate) fn coerce_data_type<A: Borrow<DataType> + std::fmt::Debug>(datatypes: &[A]) -> DataType {
pub(crate) fn coerce_data_type(datatypes: Vec<DataType>) -> DataType {
// Drop null dtype from the dtype set.
let datatypes = datatypes
.iter()
.into_iter()
.filter(|dt| !matches!((*dt).borrow(), DataType::Null))
.collect::<Vec<_>>();

if datatypes.is_empty() {
return DataType::Null;
}

let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow());
let are_all_equal = datatypes.windows(2).all(|w| w[0] == w[1]);

if are_all_equal {
return datatypes[0].borrow().clone();
return datatypes.into_iter().next().unwrap();
}

let are_all_structs = datatypes
Expand All @@ -136,23 +136,23 @@ pub(crate) fn coerce_data_type<A: Borrow<DataType> + std::fmt::Debug>(datatypes:

if are_all_structs {
// All structs => union of all field dtypes (these may have equal names).
let fields = datatypes.iter().fold(vec![], |mut acc, dt| {
if let DataType::Struct(new_fields) = (*dt).borrow() {
let fields = datatypes.into_iter().fold(vec![], |mut acc, dt| {
if let DataType::Struct(new_fields) = dt {
acc.extend(new_fields);
};
acc
});
// Group fields by unique names.
let fields = fields.iter().fold(
IndexMap::<&String, HashSet<&DataType>>::new(),
let fields = fields.into_iter().fold(
IndexMap::<String, HashSet<DataType>>::new(),
|mut acc, field| {
match acc.entry(&field.name) {
match acc.entry(field.name) {
indexmap::map::Entry::Occupied(mut v) => {
v.get_mut().insert(&field.data_type);
v.get_mut().insert(field.data_type);
}
indexmap::map::Entry::Vacant(v) => {
let mut a = HashSet::new();
a.insert(&field.data_type);
a.insert(field.data_type);
v.insert(a);
}
}
Expand All @@ -164,32 +164,76 @@ pub(crate) fn coerce_data_type<A: Borrow<DataType> + std::fmt::Debug>(datatypes:
.into_iter()
.map(|(name, dts)| {
let dts = dts.into_iter().collect::<Vec<_>>();
Field::new(name, coerce_data_type(&dts), true)
Field::new(name, coerce_data_type(dts), true)
})
.collect();
return DataType::Struct(fields);
} else if datatypes.len() > 2 {
// TODO(Clark): Return an error for uncoercible types.
return DataType::Utf8;
}
let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow());

return match (lhs, rhs) {
(lhs, rhs) if lhs == rhs => lhs.clone(),
(DataType::List(lhs), DataType::List(rhs)) => {
let inner = coerce_data_type(&[lhs.data_type(), rhs.data_type()]);
DataType::List(Box::new(Field::new(ITEM_NAME, inner, true)))
}
(scalar, DataType::List(list)) | (DataType::List(list), scalar) => {
let inner = coerce_data_type(&[scalar, list.data_type()]);
DataType::List(Box::new(Field::new(ITEM_NAME, inner, true)))
}
(DataType::Float64, DataType::Int64) | (DataType::Int64, DataType::Float64) => {
DataType::Float64
}
(DataType::Int64, DataType::Boolean) | (DataType::Boolean, DataType::Int64) => {
DataType::Int64
}
(_, _) => DataType::Utf8,
};
datatypes
.into_iter()
.reduce(|lhs, rhs| {
match (lhs, rhs) {
(lhs, rhs) if lhs == rhs => lhs,
(DataType::Utf8, _) | (_, DataType::Utf8) => DataType::Utf8,
(DataType::List(lhs), DataType::List(rhs)) => {
let inner =
coerce_data_type(vec![lhs.data_type().clone(), rhs.data_type().clone()]);
DataType::List(Box::new(Field::new(ITEM_NAME, inner, true)))
}
(scalar, DataType::List(list)) | (DataType::List(list), scalar) => {
let inner = coerce_data_type(vec![scalar, list.data_type().clone()]);
DataType::List(Box::new(Field::new(ITEM_NAME, inner, true)))
}
(DataType::Float64, DataType::Int64) | (DataType::Int64, DataType::Float64) => {
DataType::Float64
}
(DataType::Int64, DataType::Boolean) | (DataType::Boolean, DataType::Int64) => {
DataType::Int64
}
(DataType::Time32(left_tu), DataType::Time32(right_tu)) => {
// Set unified time unit to the highest granularity time unit.
let unified_tu = if left_tu == right_tu
|| time_unit_to_ordinal(&left_tu) > time_unit_to_ordinal(&right_tu)
{
left_tu
} else {
right_tu
};
DataType::Time32(unified_tu)
}
(
DataType::Timestamp(left_tu, left_tz),
DataType::Timestamp(right_tu, right_tz),
) => {
// Set unified time unit to the highest granularity time unit.
let unified_tu = if left_tu == right_tu
|| time_unit_to_ordinal(&left_tu) > time_unit_to_ordinal(&right_tu)
{
left_tu
} else {
right_tu
};
// Set unified time zone to UTC.
let unified_tz = if left_tz == right_tz {
left_tz.clone()
} else {
Some("Z".to_string())
};
DataType::Timestamp(unified_tu, unified_tz)
}
(_, _) => DataType::Utf8,
}
})
.unwrap()
}

fn time_unit_to_ordinal(tu: &arrow2::datatypes::TimeUnit) -> usize {
use arrow2::datatypes::TimeUnit;

match tu {
TimeUnit::Second => 0,
TimeUnit::Millisecond => 1,
TimeUnit::Microsecond => 2,
TimeUnit::Nanosecond => 3,
}
}
Loading

0 comments on commit 6ba9130

Please sign in to comment.