Skip to content

Commit

Permalink
Merge pull request #4295 from chickadee-engineering/mpflanzer-fix-pg-…
Browse files Browse the repository at this point in the history
…empty-range

Fix deserialization of empty PostgreSQL ranges
  • Loading branch information
weiznich authored Oct 17, 2024
2 parents 0d171eb + d285e58 commit b170af7
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 14 deletions.
23 changes: 23 additions & 0 deletions diesel/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,26 @@ where
{
const FIELD_COUNT: usize = <ST as crate::util::TupleSize>::SIZE;
}

/// A helper trait for giving a type a useful default value.
///
/// This is needed for types that can be used as range to represent the empty range as
/// (Bound::Excluded(DEFAULT), Bound::Excluded(DEFAULT)).
/// When possible, implementations of this trait should fall back to using std::default::Default.
#[allow(dead_code)]
pub(crate) trait Defaultable {
/// Returns the "default value" for a type.
fn default_value() -> Self;
}

// We cannot have this impl because rustc
// then complains in third party crates that
// diesel may implement `Default`in the future.
// If we get negative trait impls at some point in time
// it should be possible to make this work.
//// Defaultable for types that has Default
//impl<T: Default> Defaultable for T {
// fn default_value() -> Self {
// T::default()
// }
//}
30 changes: 29 additions & 1 deletion diesel/src/pg/types/date_and_time/chrono.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extern crate chrono;
use self::chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};

use super::{PgDate, PgInterval, PgTime, PgTimestamp};
use crate::deserialize::{self, FromSql};
use crate::deserialize::{self, Defaultable, FromSql};
use crate::pg::{Pg, PgValue};
use crate::serialize::{self, Output, ToSql};
use crate::sql_types::{Date, Interval, Time, Timestamp, Timestamptz};
Expand Down Expand Up @@ -61,6 +61,13 @@ impl ToSql<Timestamptz, Pg> for NaiveDateTime {
}
}

#[cfg(all(feature = "chrono", feature = "postgres_backend"))]
impl Defaultable for NaiveDateTime {
fn default_value() -> Self {
Self::default()
}
}

#[cfg(all(feature = "chrono", feature = "postgres_backend"))]
impl FromSql<Timestamptz, Pg> for DateTime<Utc> {
fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> {
Expand All @@ -69,6 +76,13 @@ impl FromSql<Timestamptz, Pg> for DateTime<Utc> {
}
}

#[cfg(all(feature = "chrono", feature = "postgres_backend"))]
impl Defaultable for DateTime<Utc> {
fn default_value() -> Self {
Self::default()
}
}

#[cfg(all(feature = "chrono", feature = "postgres_backend"))]
impl FromSql<Timestamptz, Pg> for DateTime<Local> {
fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> {
Expand All @@ -77,6 +91,13 @@ impl FromSql<Timestamptz, Pg> for DateTime<Local> {
}
}

#[cfg(all(feature = "chrono", feature = "postgres_backend"))]
impl Defaultable for DateTime<Local> {
fn default_value() -> Self {
Self::default()
}
}

#[cfg(all(feature = "chrono", feature = "postgres_backend"))]
impl<TZ: TimeZone> ToSql<Timestamptz, Pg> for DateTime<TZ> {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
Expand Down Expand Up @@ -139,6 +160,13 @@ impl FromSql<Date, Pg> for NaiveDate {
}
}

#[cfg(all(feature = "chrono", feature = "postgres_backend"))]
impl Defaultable for NaiveDate {
fn default_value() -> Self {
Self::default()
}
}

const DAYS_PER_MONTH: i32 = 30;
const SECONDS_PER_DAY: i64 = 60 * 60 * 24;
const MICROSECONDS_PER_SECOND: i64 = 1_000_000;
Expand Down
23 changes: 22 additions & 1 deletion diesel/src/pg/types/date_and_time/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use self::time::{
};

use super::{PgDate, PgTime, PgTimestamp};
use crate::deserialize::{self, FromSql};
use crate::deserialize::{self, Defaultable, FromSql};
use crate::pg::{Pg, PgValue};
use crate::serialize::{self, Output, ToSql};
use crate::sql_types::{Date, Time, Timestamp, Timestamptz};
Expand Down Expand Up @@ -57,6 +57,13 @@ impl ToSql<Timestamptz, Pg> for PrimitiveDateTime {
}
}

#[cfg(all(feature = "time", feature = "postgres_backend"))]
impl Defaultable for PrimitiveDateTime {
fn default_value() -> Self {
PG_EPOCH
}
}

#[cfg(all(feature = "time", feature = "postgres_backend"))]
impl FromSql<Timestamptz, Pg> for OffsetDateTime {
fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> {
Expand All @@ -74,6 +81,13 @@ impl ToSql<Timestamptz, Pg> for OffsetDateTime {
}
}

#[cfg(all(feature = "time", feature = "postgres_backend"))]
impl Defaultable for OffsetDateTime {
fn default_value() -> Self {
datetime!(2000-01-01 0:00:00 UTC)
}
}

#[cfg(all(feature = "time", feature = "postgres_backend"))]
impl ToSql<Time, Pg> for NaiveTime {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
Expand Down Expand Up @@ -119,6 +133,13 @@ impl FromSql<Date, Pg> for NaiveDate {
}
}

#[cfg(all(feature = "time", feature = "postgres_backend"))]
impl Defaultable for NaiveDate {
fn default_value() -> Self {
PG_EPOCH_DATE
}
}

#[cfg(test)]
mod tests {
extern crate dotenvy;
Expand Down
12 changes: 10 additions & 2 deletions diesel/src/pg/types/floats/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::deserialize::{self, FromSql, FromSqlRow};
use crate::deserialize::{self, Defaultable, FromSql, FromSqlRow};
use crate::expression::AsExpression;
use crate::pg::{Pg, PgValue};
use crate::serialize::{self, IsNull, Output, ToSql};
Expand All @@ -9,7 +9,7 @@ use std::error::Error;
#[cfg(feature = "quickcheck")]
mod quickcheck_impls;

#[derive(Debug, Clone, PartialEq, Eq, AsExpression, FromSqlRow)]
#[derive(Debug, Default, Clone, PartialEq, Eq, AsExpression, FromSqlRow)]
#[diesel(sql_type = sql_types::Numeric)]
/// Represents a NUMERIC value, closely mirroring the PG wire protocol
/// representation
Expand All @@ -33,6 +33,7 @@ pub enum PgNumeric {
digits: Vec<i16>,
},
/// Not a number
#[default]
NaN,
}

Expand Down Expand Up @@ -113,6 +114,13 @@ impl ToSql<sql_types::Numeric, Pg> for PgNumeric {
}
}

#[cfg(feature = "postgres_backend")]
impl Defaultable for PgNumeric {
fn default_value() -> Self {
Self::default()
}
}

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::Float, Pg> for f32 {
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
Expand Down
16 changes: 15 additions & 1 deletion diesel/src/pg/types/integers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::deserialize::{self, FromSql};
use crate::deserialize::{self, Defaultable, FromSql};
use crate::pg::{Pg, PgValue};
use crate::serialize::{self, IsNull, Output, ToSql};
use crate::sql_types;
Expand Down Expand Up @@ -126,6 +126,20 @@ impl ToSql<sql_types::BigInt, Pg> for i64 {
}
}

#[cfg(feature = "postgres_backend")]
impl Defaultable for i32 {
fn default_value() -> Self {
Self::default()
}
}

#[cfg(feature = "postgres_backend")]
impl Defaultable for i64 {
fn default_value() -> Self {
Self::default()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 2 additions & 2 deletions diesel/src/pg/types/multirange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use std::io::Write;
use std::ops::Bound;

use crate::deserialize::{self, FromSql};
use crate::deserialize::{self, Defaultable, FromSql};
use crate::expression::bound::Bound as SqlBound;
use crate::expression::AsExpression;
use crate::pg::{Pg, PgTypeMetadata, PgValue};
Expand Down Expand Up @@ -68,7 +68,7 @@ multirange_as_expressions!(std::ops::RangeTo<T>);
#[cfg(feature = "postgres_backend")]
impl<T, ST> FromSql<Multirange<ST>, Pg> for Vec<(Bound<T>, Bound<T>)>
where
T: FromSql<ST, Pg>,
T: FromSql<ST, Pg> + Defaultable,
{
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
Expand Down
9 changes: 8 additions & 1 deletion diesel/src/pg/types/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod bigdecimal {
use self::num_integer::Integer;
use self::num_traits::{Signed, ToPrimitive, Zero};

use crate::deserialize::{self, FromSql};
use crate::deserialize::{self, Defaultable, FromSql};
use crate::pg::data_types::PgNumeric;
use crate::pg::{Pg, PgValue};
use crate::serialize::{self, Output, ToSql};
Expand Down Expand Up @@ -169,6 +169,13 @@ mod bigdecimal {
}
}

#[cfg(all(feature = "postgres_backend", feature = "numeric"))]
impl Defaultable for BigDecimal {
fn default_value() -> Self {
Self::default()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
14 changes: 9 additions & 5 deletions diesel/src/pg/types/ranges.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::Bound;
use std::error::Error;
use std::io::Write;

use crate::deserialize::{self, FromSql, Queryable};
use crate::deserialize::{self, Defaultable, FromSql, Queryable};
use crate::expression::bound::Bound as SqlBound;
use crate::expression::AsExpression;
use crate::pg::{Pg, PgTypeMetadata, PgValue};
Expand Down Expand Up @@ -74,15 +74,17 @@ range_as_expression!(&'a std::ops::RangeTo<T>; Nullable<Range<ST>>);
#[cfg(feature = "postgres_backend")]
impl<T, ST> FromSql<Range<ST>, Pg> for (Bound<T>, Bound<T>)
where
T: FromSql<ST, Pg>,
T: FromSql<ST, Pg> + Defaultable,
{
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?);
let mut lower_bound = Bound::Unbounded;
let mut upper_bound = Bound::Unbounded;

if !flags.contains(RangeFlags::LB_INF) {
if flags.contains(RangeFlags::EMPTY) {
lower_bound = Bound::Excluded(T::default_value());
} else if !flags.contains(RangeFlags::LB_INF) {
let elem_size = bytes.read_i32::<NetworkEndian>()?;
let (elem_bytes, new_bytes) = bytes.split_at(elem_size.try_into()?);
bytes = new_bytes;
Expand All @@ -95,7 +97,9 @@ where
};
}

if !flags.contains(RangeFlags::UB_INF) {
if flags.contains(RangeFlags::EMPTY) {
upper_bound = Bound::Excluded(T::default_value());
} else if !flags.contains(RangeFlags::UB_INF) {
let _size = bytes.read_i32::<NetworkEndian>()?;
let value = T::from_sql(PgValue::new_internal(bytes, &value))?;

Expand All @@ -113,7 +117,7 @@ where
#[cfg(feature = "postgres_backend")]
impl<T, ST> Queryable<Range<ST>, Pg> for (Bound<T>, Bound<T>)
where
T: FromSql<ST, Pg>,
T: FromSql<ST, Pg> + Defaultable,
{
type Row = Self;

Expand Down
37 changes: 36 additions & 1 deletion diesel_tests/tests/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1422,10 +1422,45 @@ fn test_range_from_sql() {
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "SELECT '(1,1]'::int4range";
let query = "SELECT '[2,1)'::int4range";
assert!(sql::<Range<Int4>>(query)
.load::<(Bound<i32>, Bound<i32>)>(connection)
.is_err());

let query = "'empty'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "'(1,1)'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "'(1,1]'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "'[1,1)'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "'[1,1]'::int4range";
let expected_value = (Bound::Included(1), Bound::Excluded(2));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);
}

#[cfg(feature = "postgres")]
Expand Down

0 comments on commit b170af7

Please sign in to comment.