Skip to content

Commit

Permalink
PR feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Dec 6, 2023
1 parent 1a43fb5 commit 4a2c336
Show file tree
Hide file tree
Showing 19 changed files with 254 additions and 241 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[dependencies]
common-daft-config = {path = "src/common/daft-config", default-features = false}
daft-compression = {path = "src/daft-compression", default-features = false}
daft-core = {path = "src/daft-core", default-features = false}
daft-csv = {path = "src/daft-csv", default-features = false}
daft-dsl = {path = "src/daft-dsl", default-features = false}
Expand Down
6 changes: 4 additions & 2 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def read_parquet(
if isinstance(config, NativeStorageConfig):
assert isinstance(
file, (str, pathlib.Path)
), "Native downloader only works on string inputs to read_parquet"
), "Native downloader only works on string or Path inputs to read_parquet"
tbl = MicroPartition.read_parquet(
str(file),
columns=read_options.column_names,
Expand Down Expand Up @@ -245,7 +245,9 @@ def read_csv(
if storage_config is not None:
config = storage_config.config
if isinstance(config, NativeStorageConfig):
assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_csv"
assert isinstance(
file, (str, pathlib.Path)
), "Native downloader only works on string or Path inputs to read_csv"
has_header = csv_options.header_index is not None
csv_convert_options = CsvConvertOptions(
limit=read_options.num_rows,
Expand Down
9 changes: 9 additions & 0 deletions src/daft-compression/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[dependencies]
async-compression = {workspace = true}
tokio = {workspace = true}
url = {workspace = true}

[package]
edition = {workspace = true}
name = "daft-compression"
version = {workspace = true}
File renamed without changes.
4 changes: 4 additions & 0 deletions src/daft-compression/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
//! Utilities for async decompression of data.
pub mod compression;

pub use compression::CompressionCodec;
1 change: 1 addition & 0 deletions src/daft-csv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ chrono = {workspace = true}
chrono-tz = {workspace = true}
common-error = {path = "../common/error", default-features = false}
csv-async = "1.2.6"
daft-compression = {path = "../daft-compression", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-decoding = {path = "../daft-decoding"}
daft-io = {path = "../daft-io", default-features = false}
Expand Down
3 changes: 2 additions & 1 deletion src/daft-csv/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use tokio::{
use tokio_util::io::StreamReader;

use crate::{schema::merge_schema, CsvParseOptions};
use daft_decoding::{compression::CompressionCodec, inference::infer};
use daft_compression::CompressionCodec;
use daft_decoding::inference::infer;

const DEFAULT_COLUMN_PREFIX: &str = "column_";

Expand Down
109 changes: 52 additions & 57 deletions src/daft-csv/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use tokio_util::io::StreamReader;

use crate::ArrowSnafu;
use crate::{metadata::read_csv_schema_single, CsvConvertOptions, CsvParseOptions, CsvReadOptions};
use daft_decoding::{compression::CompressionCodec, deserialize::deserialize_column};
use daft_compression::CompressionCodec;
use daft_decoding::deserialize::deserialize_column;

trait ByteRecordChunkStream = Stream<Item = super::Result<Vec<ByteRecord>>>;
trait ColumnArrayChunkStream = Stream<
Expand Down Expand Up @@ -81,64 +82,58 @@ pub fn read_csv_bulk(
) -> DaftResult<Vec<Table>> {
let runtime_handle = get_runtime(multithreaded_io)?;
let _rt_guard = runtime_handle.enter();
let tables = runtime_handle
.block_on(async move {
// Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks.
let task_stream = futures::stream::iter(uris.iter().enumerate().map(|(i, uri)| {
let (uri, convert_options, parse_options, read_options, io_client, io_stats) = (
uri.to_string(),
convert_options.clone(),
parse_options.clone(),
read_options.clone(),
io_client.clone(),
io_stats.clone(),
);
tokio::task::spawn(async move {
let table = read_csv_single_into_table(
uri.as_str(),
convert_options,
parse_options,
read_options,
io_client,
io_stats,
max_chunks_in_flight,
)
.await?;
Ok((i, table))
})
}));
let mut remaining_rows = convert_options
.as_ref()
.and_then(|opts| opts.limit.map(|limit| limit as i64));
task_stream
// Each task is annotated with its position in the output, so we can use unordered buffering to help mitigate stragglers
// and sort the task results at the end.
.buffer_unordered(num_parallel_tasks)
// Terminate the stream if we have already reached the row limit. With the upstream buffering, we will still read up to
// num_parallel_tasks redundant files.
.try_take_while(|result| {
match (result, remaining_rows) {
// Limit has been met, early-teriminate.
(_, Some(rows_left)) if rows_left <= 0 => futures::future::ready(Ok(false)),
// Limit has not yet been met, update remaining limit slack and continue.
(Ok((_, table)), Some(rows_left)) => {
remaining_rows = Some(rows_left - table.len() as i64);
futures::future::ready(Ok(true))
}
// (1) No limit, never early-terminate.
// (2) Encountered error, propagate error to try_collect to allow it to short-circuit.
(_, None) | (Err(_), _) => futures::future::ready(Ok(true)),
}
})
.try_collect::<Vec<_>>()
let tables = runtime_handle.block_on(async move {
// Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks.
let task_stream = futures::stream::iter(uris.iter().map(|uri| {
let (uri, convert_options, parse_options, read_options, io_client, io_stats) = (
uri.to_string(),
convert_options.clone(),
parse_options.clone(),
read_options.clone(),
io_client.clone(),
io_stats.clone(),
);
tokio::task::spawn(async move {
read_csv_single_into_table(
uri.as_str(),
convert_options,
parse_options,
read_options,
io_client,
io_stats,
max_chunks_in_flight,
)
.await
})
.context(super::JoinSnafu {})?;
})
.context(super::JoinSnafu {})
}));
let mut remaining_rows = convert_options
.as_ref()
.and_then(|opts| opts.limit.map(|limit| limit as i64));
task_stream
// Limit the number of file reads we have in flight at any given time.
.buffered(num_parallel_tasks)
// Terminate the stream if we have already reached the row limit. With the upstream buffering, we will still read up to
// num_parallel_tasks redundant files.
.try_take_while(|result| {
match (result, remaining_rows) {
// Limit has been met, early-teriminate.
(_, Some(rows_left)) if rows_left <= 0 => futures::future::ready(Ok(false)),
// Limit has not yet been met, update remaining limit slack and continue.
(Ok(table), Some(rows_left)) => {
remaining_rows = Some(rows_left - table.len() as i64);
futures::future::ready(Ok(true))
}
// (1) No limit, never early-terminate.
// (2) Encountered error, propagate error to try_collect to allow it to short-circuit.
(_, None) | (Err(_), _) => futures::future::ready(Ok(true)),
}
})
.try_collect::<Vec<_>>()
.await
})?;

// Sort the task results by task index, yielding tables whose order matches the input URI order.
let mut collected = tables.into_iter().collect::<DaftResult<Vec<_>>>()?;
collected.sort_by_key(|(idx, _)| *idx);
Ok(collected.into_iter().map(|(_, v)| v).collect())
tables.into_iter().collect::<DaftResult<Vec<_>>>()
}

async fn read_csv_single_into_table(
Expand Down
7 changes: 4 additions & 3 deletions src/daft-decoding/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use csv_async::ByteRecord;
pub(crate) const ISO8601: &str = "%+";
pub(crate) const ISO8601_NO_TIME_ZONE: &str = "%Y-%m-%dT%H:%M:%S%.f";
pub(crate) const ISO8601_NO_TIME_ZONE_NO_FRACTIONAL: &str = "%Y-%m-%dT%H:%M:%S";
pub(crate) const ISO8601_DATE: &str = "%Y-%m-%d";
pub(crate) const ISO8601_DATE_SLASHES: &str = "%Y/%m/%d";
pub(crate) const RFC3339_WITH_SPACE: &str = "%Y-%m-%d %H:%M:%S%.f%:z";
pub(crate) const RFC3339_WITH_SPACE_NO_TIME_ZONE: &str = "%Y-%m-%d %H:%M:%S%.f";
pub(crate) const RFC3339_WITH_SPACE_NO_TIME_ZONE_NO_FRACTIONAL: &str = "%Y-%m-%d %H:%M:%S";
Expand All @@ -20,11 +22,10 @@ pub(crate) const ALL_NAIVE_TIMESTAMP_FMTS: &[&str] = &[
ISO8601_NO_TIME_ZONE_NO_FRACTIONAL,
RFC3339_WITH_SPACE_NO_TIME_ZONE,
RFC3339_WITH_SPACE_NO_TIME_ZONE_NO_FRACTIONAL,
ISO8601_DATE,
ISO8601_DATE_SLASHES,
];
pub(crate) const ALL_TIMESTAMP_FMTS: &[&str] = &[ISO8601, RFC3339_WITH_SPACE];

pub(crate) const ISO8601_DATE: &str = "%Y-%m-%d";
pub(crate) const ISO8601_DATE_SLASHES: &str = "%Y/%m/%d";
pub(crate) const ALL_NAIVE_DATE_FMTS: &[&str] = &[ISO8601_DATE, ISO8601_DATE_SLASHES];

// Ideally this trait should not be needed and both `csv` and `csv_async` crates would share
Expand Down
1 change: 0 additions & 1 deletion src/daft-decoding/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
//! Utilities for decoding data from various sources into both array data and metadata (e.g. schema inference)
pub mod compression;
pub mod deserialize;
pub mod inference;
1 change: 1 addition & 0 deletions src/daft-json/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bytes = {workspace = true}
chrono = {workspace = true}
chrono-tz = {workspace = true}
common-error = {path = "../common/error", default-features = false}
daft-compression = {path = "../daft-compression", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-decoding = {path = "../daft-decoding"}
daft-io = {path = "../daft-io", default-features = false}
Expand Down
52 changes: 23 additions & 29 deletions src/daft-json/src/inference.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Borrow, collections::HashSet};

use arrow2::datatypes::{DataType, Field, Metadata, Schema};
use arrow2::datatypes::{DataType, Field, Metadata, Schema, TimeUnit};
use arrow2::error::{Error, Result};
use indexmap::IndexMap;
use json_deserializer::{Number, Value};
Expand Down Expand Up @@ -70,7 +70,6 @@ fn infer_array(values: &[Value]) -> Result<DataType> {
.collect::<Result<HashSet<_>>>()?;

let dt = if !types.is_empty() {
let types = types.into_iter().collect::<Vec<_>>();
coerce_data_type(types)
} else {
DataType::Null
Expand Down Expand Up @@ -99,9 +98,8 @@ pub(crate) fn column_types_map_to_fields(
column_types
.into_iter()
.map(|(name, dtype_set)| {
let dtypes = dtype_set.into_iter().collect::<Vec<_>>();
// Get consolidated dtype for column.
let dtype = coerce_data_type(dtypes);
let dtype = coerce_data_type(dtype_set);
arrow2::datatypes::Field::new(name, dtype, true)
})
.collect::<Vec<_>>()
Expand All @@ -113,20 +111,16 @@ pub(crate) fn column_types_map_to_fields(
/// * Lists and scalars are coerced to a list of a compatible scalar
/// * Structs contain the union of all fields
/// * All other types are coerced to `Utf8`
pub(crate) fn coerce_data_type(datatypes: Vec<DataType>) -> DataType {
pub(crate) fn coerce_data_type(mut datatypes: HashSet<DataType>) -> DataType {
// Drop null dtype from the dtype set.
let datatypes = datatypes
.into_iter()
.filter(|dt| !matches!((*dt).borrow(), DataType::Null))
.collect::<Vec<_>>();
datatypes.remove(&DataType::Null);

if datatypes.is_empty() {
return DataType::Null;
}

let are_all_equal = datatypes.windows(2).all(|w| w[0] == w[1]);

if are_all_equal {
// All equal.
if datatypes.len() == 1 {
return datatypes.into_iter().next().unwrap();
}

Expand Down Expand Up @@ -162,10 +156,7 @@ pub(crate) fn coerce_data_type(datatypes: Vec<DataType>) -> DataType {
// Coerce dtype set for each field.
let fields = fields
.into_iter()
.map(|(name, dts)| {
let dts = dts.into_iter().collect::<Vec<_>>();
Field::new(name, coerce_data_type(dts), true)
})
.map(|(name, dts)| Field::new(name, coerce_data_type(dts), true))
.collect();
return DataType::Struct(fields);
}
Expand All @@ -177,11 +168,11 @@ pub(crate) fn coerce_data_type(datatypes: Vec<DataType>) -> DataType {
(DataType::Utf8, _) | (_, DataType::Utf8) => DataType::Utf8,
(DataType::List(lhs), DataType::List(rhs)) => {
let inner =
coerce_data_type(vec![lhs.data_type().clone(), rhs.data_type().clone()]);
coerce_data_type([lhs.data_type().clone(), rhs.data_type().clone()].into());
DataType::List(Box::new(Field::new(ITEM_NAME, inner, true)))
}
(scalar, DataType::List(list)) | (DataType::List(list), scalar) => {
let inner = coerce_data_type(vec![scalar, list.data_type().clone()]);
let inner = coerce_data_type([scalar, list.data_type().clone()].into());
DataType::List(Box::new(Field::new(ITEM_NAME, inner, true)))
}
(DataType::Float64, DataType::Int64) | (DataType::Int64, DataType::Float64) => {
Expand All @@ -191,9 +182,9 @@ pub(crate) fn coerce_data_type(datatypes: Vec<DataType>) -> DataType {
DataType::Int64
}
(DataType::Time32(left_tu), DataType::Time32(right_tu)) => {
// Set unified time unit to the highest granularity time unit.
// Set unified time unit to the lowest granularity time unit.
let unified_tu = if left_tu == right_tu
|| time_unit_to_ordinal(&left_tu) > time_unit_to_ordinal(&right_tu)
|| time_unit_to_ordinal(&left_tu) < time_unit_to_ordinal(&right_tu)
{
left_tu
} else {
Expand All @@ -205,31 +196,34 @@ pub(crate) fn coerce_data_type(datatypes: Vec<DataType>) -> DataType {
DataType::Timestamp(left_tu, left_tz),
DataType::Timestamp(right_tu, right_tz),
) => {
// Set unified time unit to the highest granularity time unit.
// Set unified time unit to the lowest granularity time unit.
let unified_tu = if left_tu == right_tu
|| time_unit_to_ordinal(&left_tu) > time_unit_to_ordinal(&right_tu)
|| time_unit_to_ordinal(&left_tu) < time_unit_to_ordinal(&right_tu)
{
left_tu
} else {
right_tu
};
// Set unified time zone to UTC.
let unified_tz = if left_tz == right_tz {
left_tz.clone()
} else {
Some("Z".to_string())
let unified_tz = match (&left_tz, &right_tz) {
(None, None) => None,
(None, _) | (_, None) => return DataType::Utf8,
(Some(l), Some(r)) if l == r => left_tz,
(Some(_), Some(_)) => Some("Z".to_string()),
};
DataType::Timestamp(unified_tu, unified_tz)
}
(DataType::Timestamp(_, None), DataType::Date32)
| (DataType::Date32, DataType::Timestamp(_, None)) => {
DataType::Timestamp(TimeUnit::Second, None)
}
(_, _) => DataType::Utf8,
}
})
.unwrap()
}

fn time_unit_to_ordinal(tu: &arrow2::datatypes::TimeUnit) -> usize {
use arrow2::datatypes::TimeUnit;

fn time_unit_to_ordinal(tu: &TimeUnit) -> usize {
match tu {
TimeUnit::Second => 0,
TimeUnit::Millisecond => 1,
Expand Down
Loading

0 comments on commit 4a2c336

Please sign in to comment.