Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #4349 #4350

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 62 additions & 22 deletions diesel/src/expression/array_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,46 @@ impl<T, U> In<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
In { left, values }
}

pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
if self.values.is_empty() {
out.push_sql("1=0");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}

impl<T, U> NotIn<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
NotIn { left, values }
}

pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
if self.values.is_empty() {
out.push_sql("1=1");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" NOT IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}

impl<T, U> Expression for In<T, U>
Expand Down Expand Up @@ -104,16 +138,8 @@ where
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
if self.values.is_empty() {
out.push_sql("1=0");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

Expand All @@ -135,16 +161,8 @@ where
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
if self.values.is_empty() {
out.push_sql("1=1");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" NOT IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

Expand Down Expand Up @@ -201,6 +219,10 @@ pub trait MaybeEmpty {
/// Returns `true` if self represents an empty collection
/// Otherwise `false` is returned.
fn is_empty(&self) -> bool;

/// Returns `true` if the values clause represents
/// bind values and each bind value is a postgres array type
fn is_array(&self) -> bool;
}

impl<ST, F, S, D, W, O, LOf, G, H, LC> AsInExpression<ST>
Expand Down Expand Up @@ -278,10 +300,17 @@ where
type SqlType = ST;
}

impl<ST, I> MaybeEmpty for Many<ST, I> {
impl<ST, I> MaybeEmpty for Many<ST, I>
where
ST: SqlType,
{
fn is_empty(&self) -> bool {
self.values.is_empty()
}

fn is_array(&self) -> bool {
ST::IS_ARRAY
}
}

impl<ST, I, QS> SelectableExpression<QS> for Many<ST, I>
Expand Down Expand Up @@ -321,7 +350,18 @@ where
ST: SingleValue,
I: ToSql<ST, DB>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

impl<ST, I> Many<ST, I> {
pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend + HasSqlType<ST>,
ST: SingleValue,
I: ToSql<ST, DB>,
{
out.unsafe_to_cache_prepared();
let mut first = true;
for value in &self.values {
Expand Down
3 changes: 3 additions & 0 deletions diesel/src/expression/subselect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ impl<T, ST> MaybeEmpty for Subselect<T, ST> {
fn is_empty(&self) -> bool {
false
}
fn is_array(&self) -> bool {
false
}
}

impl<T, ST, QS> SelectableExpression<QS> for Subselect<T, ST>
Expand Down
31 changes: 22 additions & 9 deletions diesel/src/pg/query_builder/query_fragment_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ where
U: QueryFragment<Pg> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" = ANY(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
if self.values.is_array() {
self.walk_ansi_ast(out)?;
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" = ANY(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
Expand All @@ -80,10 +84,14 @@ where
U: QueryFragment<Pg> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" != ALL(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
if self.values.is_array() {
self.walk_ansi_ast(out)?;
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" != ALL(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
Expand All @@ -92,10 +100,15 @@ impl<ST, I> QueryFragment<Pg, PgStyleArrayComparison> for Many<ST, I>
where
ST: SingleValue,
Vec<I>: ToSql<Array<ST>, Pg>,
I: ToSql<ST, Pg>,
Pg: HasSqlType<ST>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
out.push_bind_param::<Array<ST>, Vec<I>>(&self.values)
if ST::IS_ARRAY {
self.walk_ansi_ast(out)
} else {
out.push_bind_param::<Array<ST>, Vec<I>>(&self.values)
Copy link
Member

@Ten0 Ten0 Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this got me thinking (long-term changes thoughts): maybe QueryId should be generic on the backend as well, so that we can properly express that this variant could in fact use static query id. If we added the option for HAS_STATIC_QUERY_ID = false to the QueryId derive maybe it wouldn't be that verbose in general.
But maybe that's not worth. In any case that would be a separate change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like an reasonable idea for the long run.

Although I remember that the benchmarks have shown at some point that the cost of query construction for looking up statements in the prepared statement cache is only relevant for SQLite. It's not relevant for postgresql due to the network roundtrips.

}
}
}

Expand Down
3 changes: 3 additions & 0 deletions diesel/src/sql_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,9 @@ pub trait SqlType: 'static {
///
/// ['is_nullable`]: is_nullable
type IsNull: OneIsNullable<is_nullable::IsNullable> + OneIsNullable<is_nullable::NotNull>;

#[doc(hidden)]
const IS_ARRAY: bool = false;
}

/// Is one value of `IsNull` nullable?
Expand Down
5 changes: 5 additions & 0 deletions diesel_derives/src/sql_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,23 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
let model = Model::from_item(&item, true, false)?;

let struct_name = &item.ident;
let generic_count = item.generics.params.len();
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let sqlite_tokens = sqlite_tokens(&item, &model);
let mysql_tokens = mysql_tokens(&item, &model);
let pg_tokens = pg_tokens(&item, &model);

let is_array = struct_name == "Array" && generic_count == 1;

Ok(wrap_in_dummy_mod(quote! {
impl #impl_generics diesel::sql_types::SqlType
for #struct_name #ty_generics
#where_clause
{
type IsNull = diesel::sql_types::is_nullable::NotNull;

const IS_ARRAY: bool = #is_array;
}

impl #impl_generics diesel::sql_types::SingleValue
Expand Down
30 changes: 30 additions & 0 deletions diesel_tests/tests/filter_operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,36 @@ fn filter_by_in() {
);
}

#[test]
#[cfg(feature = "postgres")]
fn filter_array_by_in() {
use crate::schema::posts::dsl::*;

let connection: &mut PgConnection = &mut connection();
let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]];
let result: Vec<i32> = posts
.filter(tags.eq_any(tag_combinations_to_look_for))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to document this behavior in these places:

/// The postgres backend provided a specialized implementation
/// by using `left = ANY(values)` as optimized variant instead.

/// The postgres backend provided a specialized implementation
/// by using `left = ALL(values)` as optimized variant instead.

/// On PostgreSQL, this method automatically performs a `= ANY()`
/// query.

.select(id)
.load(connection)
.unwrap();
assert_eq!(result, &[] as &[i32]);
}

#[test]
#[cfg(feature = "postgres")]
fn filter_array_by_not_in() {
use crate::schema::posts::dsl::*;

let connection: &mut PgConnection = &mut connection();
let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]];
let result: Vec<i32> = posts
.filter(tags.ne_all(tag_combinations_to_look_for))
.select(id)
.load(connection)
.unwrap();
assert_eq!(result, &[] as &[i32]);
}

fn connection_with_3_users() -> TestConnection {
let mut connection = connection_with_sean_and_tess_in_users_table();
diesel::sql_query("INSERT INTO users (id, name) VALUES (3, 'Jim')")
Expand Down
Loading