diff --git a/diesel/src/expression/array_comparison.rs b/diesel/src/expression/array_comparison.rs index bade87beea01..342fe645083c 100644 --- a/diesel/src/expression/array_comparison.rs +++ b/diesel/src/expression/array_comparison.rs @@ -63,12 +63,46 @@ impl In { pub(crate) fn new(left: T, values: U) -> Self { In { left, values } } + + pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> + where + DB: Backend, + T: QueryFragment, + U: QueryFragment + MaybeEmpty, + { + if self.values.is_empty() { + out.push_sql("1=0"); + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" IN ("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } + Ok(()) + } } impl NotIn { pub(crate) fn new(left: T, values: U) -> Self { NotIn { left, values } } + + pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> + where + DB: Backend, + T: QueryFragment, + U: QueryFragment + MaybeEmpty, + { + if self.values.is_empty() { + out.push_sql("1=1"); + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" NOT IN ("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } + Ok(()) + } } impl Expression for In @@ -104,16 +138,8 @@ where T: QueryFragment, U: QueryFragment + MaybeEmpty, { - fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> { - if self.values.is_empty() { - out.push_sql("1=0"); - } else { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" IN ("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); - } - Ok(()) + fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + self.walk_ansi_ast(out) } } @@ -135,16 +161,8 @@ where T: QueryFragment, U: QueryFragment + MaybeEmpty, { - fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> { - if self.values.is_empty() { - out.push_sql("1=1"); - } else { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" NOT IN ("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); - } - Ok(()) + fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + self.walk_ansi_ast(out) } } @@ -201,6 +219,10 @@ pub trait MaybeEmpty { /// Returns `true` if self represents an empty collection /// Otherwise `false` is returned. fn is_empty(&self) -> bool; + + /// Returns `true` if the values clause represents + /// bind values and each bind value is a postgres array type + fn is_array(&self) -> bool; } impl AsInExpression @@ -278,10 +300,17 @@ where type SqlType = ST; } -impl MaybeEmpty for Many { +impl MaybeEmpty for Many +where + ST: SqlType, +{ fn is_empty(&self) -> bool { self.values.is_empty() } + + fn is_array(&self) -> bool { + ST::IS_ARRAY + } } impl SelectableExpression for Many @@ -321,7 +350,18 @@ where ST: SingleValue, I: ToSql, { - fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> { + self.walk_ansi_ast(out) + } +} + +impl Many { + pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> + where + DB: Backend + HasSqlType, + ST: SingleValue, + I: ToSql, + { out.unsafe_to_cache_prepared(); let mut first = true; for value in &self.values { diff --git a/diesel/src/expression/subselect.rs b/diesel/src/expression/subselect.rs index 9bff4424a9f2..872cdc3e0b9b 100644 --- a/diesel/src/expression/subselect.rs +++ b/diesel/src/expression/subselect.rs @@ -31,6 +31,9 @@ impl MaybeEmpty for Subselect { fn is_empty(&self) -> bool { false } + fn is_array(&self) -> bool { + false + } } impl SelectableExpression for Subselect diff --git a/diesel/src/pg/query_builder/query_fragment_impls.rs b/diesel/src/pg/query_builder/query_fragment_impls.rs index e8bc3cd74db9..95b73ad73088 100644 --- a/diesel/src/pg/query_builder/query_fragment_impls.rs +++ b/diesel/src/pg/query_builder/query_fragment_impls.rs @@ -66,10 +66,14 @@ where U: QueryFragment + MaybeEmpty, { fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" = ANY("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); + if self.values.is_array() { + self.walk_ansi_ast(out)?; + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" = ANY("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } Ok(()) } } @@ -80,10 +84,14 @@ where U: QueryFragment + MaybeEmpty, { fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { - self.left.walk_ast(out.reborrow())?; - out.push_sql(" != ALL("); - self.values.walk_ast(out.reborrow())?; - out.push_sql(")"); + if self.values.is_array() { + self.walk_ansi_ast(out)?; + } else { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" != ALL("); + self.values.walk_ast(out.reborrow())?; + out.push_sql(")"); + } Ok(()) } } @@ -92,10 +100,15 @@ impl QueryFragment for Many where ST: SingleValue, Vec: ToSql, Pg>, + I: ToSql, Pg: HasSqlType, { fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { - out.push_bind_param::, Vec>(&self.values) + if ST::IS_ARRAY { + self.walk_ansi_ast(out) + } else { + out.push_bind_param::, Vec>(&self.values) + } } } diff --git a/diesel/src/sql_types/mod.rs b/diesel/src/sql_types/mod.rs index 1aece629c942..84e330adb75f 100644 --- a/diesel/src/sql_types/mod.rs +++ b/diesel/src/sql_types/mod.rs @@ -676,6 +676,9 @@ pub trait SqlType: 'static { /// /// ['is_nullable`]: is_nullable type IsNull: OneIsNullable + OneIsNullable; + + #[doc(hidden)] + const IS_ARRAY: bool = false; } /// Is one value of `IsNull` nullable? diff --git a/diesel_derives/src/sql_type.rs b/diesel_derives/src/sql_type.rs index 2e6f00d65464..799f703ef1a0 100644 --- a/diesel_derives/src/sql_type.rs +++ b/diesel_derives/src/sql_type.rs @@ -11,18 +11,23 @@ pub fn derive(item: DeriveInput) -> Result { let model = Model::from_item(&item, true, false)?; let struct_name = &item.ident; + let generic_count = item.generics.params.len(); let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); let sqlite_tokens = sqlite_tokens(&item, &model); let mysql_tokens = mysql_tokens(&item, &model); let pg_tokens = pg_tokens(&item, &model); + let is_array = struct_name == "Array" && generic_count == 1; + Ok(wrap_in_dummy_mod(quote! { impl #impl_generics diesel::sql_types::SqlType for #struct_name #ty_generics #where_clause { type IsNull = diesel::sql_types::is_nullable::NotNull; + + const IS_ARRAY: bool = #is_array; } impl #impl_generics diesel::sql_types::SingleValue diff --git a/diesel_tests/tests/filter_operators.rs b/diesel_tests/tests/filter_operators.rs index b4d42a34f31e..9f1e106a2bbd 100644 --- a/diesel_tests/tests/filter_operators.rs +++ b/diesel_tests/tests/filter_operators.rs @@ -279,6 +279,21 @@ fn filter_array_by_in() { assert_eq!(result, &[] as &[i32]); } +#[test] +#[cfg(feature = "postgres")] +fn filter_array_by_not_in() { + use crate::schema::posts::dsl::*; + + let connection: &mut PgConnection = &mut connection(); + let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]]; + let result: Vec = posts + .filter(tags.ne_all(tag_combinations_to_look_for)) + .select(id) + .load(connection) + .unwrap(); + assert_eq!(result, &[] as &[i32]); +} + fn connection_with_3_users() -> TestConnection { let mut connection = connection_with_sean_and_tess_in_users_table(); diesel::sql_query("INSERT INTO users (id, name) VALUES (3, 'Jim')")