Skip to content

Commit

Permalink
Add more derive_more::Display implementations to remove boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
Raunak Bhagat committed Sep 5, 2024
1 parent c87ebb3 commit 3be5292
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 155 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-py-serde = {path = "../common/py-serde", default-features = false}
daft-minhash = {path = "../daft-minhash", default-features = false}
daft-sketch = {path = "../daft-sketch", default-features = false}
derive_more = {workspace = true}
fastrand = "2.1.0"
fnv = "1.0.7"
html-escape = {workspace = true}
Expand Down
12 changes: 3 additions & 9 deletions src/daft-core/src/count_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use common_py_serde::impl_bincode_py_state_serialization;
#[cfg(feature = "python")]
use pyo3::{exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter, Result};
use std::str::FromStr;

use derive_more::Display;

use common_error::{DaftError, DaftResult};

/// Supported count modes for Daft's count aggregation.
Expand All @@ -13,7 +14,7 @@ use common_error::{DaftError, DaftResult};
/// | Valid - Count only valid values.
/// | Null - Count only null values.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(Clone, Copy, Debug, Display, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum CountMode {
All = 1,
Expand Down Expand Up @@ -66,10 +67,3 @@ impl FromStr for CountMode {
}
}
}

impl Display for CountMode {
fn fmt(&self, f: &mut Formatter) -> Result {
// Leverage Debug trait implementation, which will already return the enum variant as a string.
write!(f, "{:?}", self)
}
}
116 changes: 65 additions & 51 deletions src/daft-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,61 @@
use std::fmt::{Display, Formatter, Result};
use std::fmt::Write;

use arrow2::datatypes::DataType as ArrowType;
use derive_more::Display;

use crate::datatypes::{field::Field, image_mode::ImageMode, time_unit::TimeUnit};

use common_error::{DaftError, DaftResult};

use serde::{Deserialize, Serialize};

// pub type TimeZone = String;

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum DataType {
// Start ArrowTypes
// ArrowTypes:
/// Null type
Null,

/// `true` and `false`.
Boolean,

/// An [`i8`]
Int8,

/// An [`i16`]
Int16,

/// An [`i32`]
Int32,

/// An [`i64`]
Int64,

/// An [`i128`]
Int128,

/// An [`u8`]
UInt8,

/// An [`u16`]
UInt16,

/// An [`u32`]
UInt32,

/// An [`u64`]
UInt64,
/// An 16-bit float
// Float16,

/// A [`f32`]
Float32,

/// A [`f64`]
Float64,

/// Fixed-precision decimal type.
/// TODO: allow negative scale once Arrow2 allows it: https://github.com/jorgecarleitao/arrow2/issues/1518
#[display("{_0}.{_1}")]
Decimal128(usize, usize),

/// A [`i64`] representing a timestamp measured in [`TimeUnit`] with an optional timezone.
///
/// Time is measured as a Unix epoch, counting the seconds from
Expand All @@ -58,47 +70,92 @@ pub enum DataType {
///
/// When the timezone is not specified, the timestamp is considered to have no timezone
/// and is represented _as is_
#[display("Time[{_0} {}]", _1.as_deref().map_or_else(|| "UTC".to_owned(), |zone| zone.to_owned()))]
Timestamp(TimeUnit, Option<String>),

/// An [`i32`] representing the elapsed time since UNIX epoch (1970-01-01)
/// in days.
Date,

/// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`.
/// Only [`TimeUnit::Microsecond`] and [`TimeUnit::Nanosecond`] are supported on this variant.
#[display("Time[{_0}]")]
Time(TimeUnit),

/// Measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.)
#[display("Duration[{_0}]")]
Duration(TimeUnit),

/// Opaque binary data of variable length whose offsets are represented as [`i64`].
Binary,

/// Opaque binary data of fixed size. Enum parameter specifies the number of bytes per value.
#[display("FixedSizeBinary[{_0}]")]
FixedSizeBinary(usize),

/// A variable-length UTF-8 encoded string whose offsets are represented as [`i64`].
Utf8,

/// A list of some logical data type with a fixed number of elements.
#[display("FixedSizeList[{_0}; {_1}]")]
FixedSizeList(Box<DataType>, usize),

/// A list of some logical data type whose offsets are represented as [`i64`].
#[display("List[{_0}]")]
List(Box<DataType>),

/// A nested [`DataType`] with a given number of [`Field`]s.
#[display("{}", format_struct(_0)?)]
Struct(Vec<Field>),

/// A nested [`DataType`] that is represented as List<entries: Struct<key: K, value: V>>.
#[display("Map[{_0}]")]
Map(Box<DataType>),

/// Extension type.
#[display("{_1}")]
Extension(String, Box<DataType>, Option<String>),
// Stop ArrowTypes

// Non-ArrowTypes:
/// A logical type for embeddings.
#[display("Embedding[{_0}; {_1}]")]
Embedding(Box<DataType>, usize),

/// A logical type for images with variable shapes.
#[display("Image[{}]", _0.map_or_else(|| "MIXED".to_string(), |mode| mode.to_string()))]
Image(Option<ImageMode>),

/// A logical type for images with the same size (height x width).
#[display("Image[{_0}; {_1} x {_2}]")]
FixedShapeImage(ImageMode, u32, u32),

/// A logical type for tensors with variable shapes.
#[display("Tensor[{_0}]")]
Tensor(Box<DataType>),

/// A logical type for tensors with the same shape.
#[display("FixedShapeTensor[{_0}; {_1:?}]")]
FixedShapeTensor(Box<DataType>, Vec<u64>),

#[cfg(feature = "python")]
Python,

Unknown,
}

fn format_struct(fields: &[Field]) -> std::result::Result<String, std::fmt::Error> {
let mut f = String::default();
for (index, field) in fields.iter().enumerate() {
if index != 0 {
write!(&mut f, ", ")?;
}
if !(field.name.is_empty() && field.dtype.is_null()) {
write!(&mut f, "{}", field)?;
}
}
Ok(f)
}

#[derive(Serialize, Deserialize)]
struct DataTypePayload {
datatype: DataType,
Expand Down Expand Up @@ -590,46 +647,3 @@ impl From<&ImageMode> for DataType {
}
}
}

impl Display for DataType {
// `f` is a buffer, and this method must write the formatted string into it
fn fmt(&self, f: &mut Formatter) -> Result {
match self {
DataType::List(nested) => write!(f, "List[{}]", nested),
DataType::FixedSizeList(inner, size) => {
write!(f, "FixedSizeList[{}; {}]", inner, size)
}
DataType::Map(inner, ..) => {
write!(f, "Map[{}]", inner)
}
DataType::Struct(fields) => {
let fields: String = fields
.iter()
.filter_map(|f| {
if f.name.is_empty() && f.dtype == DataType::Null {
None
} else {
Some(format!("{}: {}", f.name, f.dtype))
}
})
.collect::<Vec<String>>()
.join(", ");
write!(f, "Struct[{fields}]")
}
DataType::Embedding(inner, size) => {
write!(f, "Embedding[{}; {}]", inner, size)
}
DataType::Image(mode) => {
write!(
f,
"Image[{}]",
mode.map_or("MIXED".to_string(), |m| m.to_string())
)
}
DataType::FixedShapeImage(mode, height, width) => {
write!(f, "Image[{}; {} x {}]", mode, height, width)
}
_ => write!(f, "{self:?}"),
}
}
}
20 changes: 5 additions & 15 deletions src/daft-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use std::fmt::{Display, Formatter, Result};
use std::hash::Hash;
use std::sync::Arc;

use arrow2::datatypes::Field as ArrowField;

use crate::datatypes::dtype::DataType;
use common_error::{DaftError, DaftResult};
use derive_more::Display;

use serde::{Deserialize, Serialize};

pub type Metadata = std::collections::BTreeMap<String, String>;

#[derive(Clone, Debug, Eq, Deserialize, Serialize)]
#[derive(Clone, Display, Debug, Eq, Deserialize, Serialize)]
#[display("{name}: {dtype}")]
pub struct Field {
pub name: String,
pub dtype: DataType,
Expand All @@ -20,7 +21,8 @@ pub struct Field {

pub type FieldRef = Arc<Field>;

#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)]
#[derive(Clone, Display, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)]
#[display("{id}")]
pub struct FieldID {
pub id: Arc<str>,
}
Expand Down Expand Up @@ -62,12 +64,6 @@ impl FieldID {
}
}

impl Display for FieldID {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "{}", self.id)
}
}

impl Field {
pub fn new<S: Into<String>>(name: S, dtype: DataType) -> Self {
let name: String = name.into();
Expand Down Expand Up @@ -151,9 +147,3 @@ impl From<&ArrowField> for Field {
}
}
}

impl Display for Field {
fn fmt(&self, f: &mut Formatter) -> Result {
write!(f, "{}#{}", self.name, self.dtype)
}
}
11 changes: 2 additions & 9 deletions src/daft-core/src/datatypes/image_format.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::{Display, Formatter, Result};
use derive_more::Display;
use std::str::FromStr;

#[cfg(feature = "python")]
Expand All @@ -9,7 +9,7 @@ use common_error::{DaftError, DaftResult};

/// Supported image formats for Daft's I/O layer.
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(Clone, Copy, Debug, Display, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum ImageFormat {
PNG,
Expand Down Expand Up @@ -92,10 +92,3 @@ impl From<ImageFormat> for image::ImageFormat {
}
}
}

impl Display for ImageFormat {
fn fmt(&self, f: &mut Formatter) -> Result {
// Leverage Debug trait implementation, which will already return the enum variant as a string.
write!(f, "{:?}", self)
}
}
14 changes: 5 additions & 9 deletions src/daft-core/src/datatypes/image_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use num_derive::FromPrimitive;
#[cfg(feature = "python")]
use pyo3::{exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter, Result};
use std::str::FromStr;

use derive_more::Display;

use common_error::{DaftError, DaftResult};

/// Supported image modes for Daft's image type.
Expand All @@ -26,7 +27,9 @@ use common_error::{DaftError, DaftResult};
/// | RGB32F - 32-bit floating RGB
/// | RGBA32F - 32-bit floating RGB + alpha
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, FromPrimitive)]
#[derive(
Clone, Copy, Debug, Display, PartialEq, Eq, Serialize, Deserialize, Hash, FromPrimitive,
)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum ImageMode {
L = 1,
Expand Down Expand Up @@ -194,10 +197,3 @@ impl FromStr for ImageMode {
}
}
}

impl Display for ImageMode {
fn fmt(&self, f: &mut Formatter) -> Result {
// Leverage Debug trait implementation, which will already return the enum variant as a string.
write!(f, "{:?}", self)
}
}
12 changes: 4 additions & 8 deletions src/daft-core/src/datatypes/time_unit.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::fmt::{Display, Formatter};
use derive_more::Display;

use arrow2::datatypes::TimeUnit as ArrowTimeUnit;

use serde::{Deserialize, Serialize};

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
#[derive(
Copy, Clone, Debug, Display, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize,
)]
pub enum TimeUnit {
Nanoseconds,
Microseconds,
Expand All @@ -23,12 +25,6 @@ impl TimeUnit {
}
}
}
impl Display for TimeUnit {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
// Leverage Debug trait implementation, which will already return the enum variant as a string.
write!(f, "{:?}", self)
}
}

impl From<&ArrowTimeUnit> for TimeUnit {
fn from(tu: &ArrowTimeUnit) -> Self {
Expand Down
Loading

0 comments on commit 3be5292

Please sign in to comment.