Skip to content

Commit

Permalink
Fix #4349
Browse files Browse the repository at this point in the history
This commit fixes an issue that allowed to trigger a runtime error by
passing an array of arrays to `.eq_any()`. We rewrite queries containing
`IN` expressions to `= ANY()` on postgresql as that more
efficient (allows caching + allows binding all values at once). That's
not possible for arrays of arrays as we do not support nested arrays
yet.

The fix introduces an associated constant for the `SqlType` trait that
tracks if a SQL type is a `Array<T>` or not. This information is then
used to conditionally generate the "right" SQL. By defaulting to `false`
while adding this constant we do not break existing code. We use the
derive to set the right value based on the type name.
  • Loading branch information
weiznich committed Nov 19, 2024
1 parent d6c93fd commit 84f6918
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 31 deletions.
84 changes: 62 additions & 22 deletions diesel/src/expression/array_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,46 @@ impl<T, U> In<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
In { left, values }
}

pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
if self.values.is_empty() {
out.push_sql("1=0");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}

impl<T, U> NotIn<T, U> {
pub(crate) fn new(left: T, values: U) -> Self {
NotIn { left, values }
}

pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend,
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
if self.values.is_empty() {
out.push_sql("1=1");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" NOT IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}

impl<T, U> Expression for In<T, U>
Expand Down Expand Up @@ -104,16 +138,8 @@ where
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
if self.values.is_empty() {
out.push_sql("1=0");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

Expand All @@ -135,16 +161,8 @@ where
T: QueryFragment<DB>,
U: QueryFragment<DB> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
if self.values.is_empty() {
out.push_sql("1=1");
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" NOT IN (");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

Expand Down Expand Up @@ -201,6 +219,10 @@ pub trait MaybeEmpty {
/// Returns `true` if self represents an empty collection
/// Otherwise `false` is returned.
fn is_empty(&self) -> bool;

/// Returns `true` if the values clause represents
/// bind values and each bind value is a postgres array type
fn is_array(&self) -> bool;
}

impl<ST, F, S, D, W, O, LOf, G, H, LC> AsInExpression<ST>
Expand Down Expand Up @@ -278,10 +300,17 @@ where
type SqlType = ST;
}

impl<ST, I> MaybeEmpty for Many<ST, I> {
impl<ST, I> MaybeEmpty for Many<ST, I>
where
ST: SqlType,
{
fn is_empty(&self) -> bool {
self.values.is_empty()
}

fn is_array(&self) -> bool {
ST::IS_ARRAY
}
}

impl<ST, I, QS> SelectableExpression<QS> for Many<ST, I>
Expand Down Expand Up @@ -321,7 +350,18 @@ where
ST: SingleValue,
I: ToSql<ST, DB>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
self.walk_ansi_ast(out)
}
}

impl<ST, I> Many<ST, I> {
pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
where
DB: Backend + HasSqlType<ST>,
ST: SingleValue,
I: ToSql<ST, DB>,
{
out.unsafe_to_cache_prepared();
let mut first = true;
for value in &self.values {
Expand Down
3 changes: 3 additions & 0 deletions diesel/src/expression/subselect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ impl<T, ST> MaybeEmpty for Subselect<T, ST> {
fn is_empty(&self) -> bool {
false
}
fn is_array(&self) -> bool {
false
}
}

impl<T, ST, QS> SelectableExpression<QS> for Subselect<T, ST>
Expand Down
31 changes: 22 additions & 9 deletions diesel/src/pg/query_builder/query_fragment_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ where
U: QueryFragment<Pg> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" = ANY(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
if self.values.is_array() {
self.walk_ansi_ast(out)?;
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" = ANY(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
Expand All @@ -80,10 +84,14 @@ where
U: QueryFragment<Pg> + MaybeEmpty,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" != ALL(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
if self.values.is_array() {
self.walk_ansi_ast(out)?;
} else {
self.left.walk_ast(out.reborrow())?;
out.push_sql(" != ALL(");
self.values.walk_ast(out.reborrow())?;
out.push_sql(")");
}
Ok(())
}
}
Expand All @@ -92,10 +100,15 @@ impl<ST, I> QueryFragment<Pg, PgStyleArrayComparison> for Many<ST, I>
where
ST: SingleValue,
Vec<I>: ToSql<Array<ST>, Pg>,
I: ToSql<ST, Pg>,
Pg: HasSqlType<ST>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
out.push_bind_param::<Array<ST>, Vec<I>>(&self.values)
if ST::IS_ARRAY {
self.walk_ansi_ast(out)
} else {
out.push_bind_param::<Array<ST>, Vec<I>>(&self.values)
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions diesel/src/sql_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,9 @@ pub trait SqlType: 'static {
///
/// ['is_nullable`]: is_nullable
type IsNull: OneIsNullable<is_nullable::IsNullable> + OneIsNullable<is_nullable::NotNull>;

#[doc(hidden)]
const IS_ARRAY: bool = false;
}

/// Is one value of `IsNull` nullable?
Expand Down
5 changes: 5 additions & 0 deletions diesel_derives/src/sql_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,23 @@ pub fn derive(item: DeriveInput) -> Result<TokenStream> {
let model = Model::from_item(&item, true, false)?;

let struct_name = &item.ident;
let generic_count = item.generics.params.len();
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let sqlite_tokens = sqlite_tokens(&item, &model);
let mysql_tokens = mysql_tokens(&item, &model);
let pg_tokens = pg_tokens(&item, &model);

let is_array = struct_name == "Array" && generic_count == 1;

Ok(wrap_in_dummy_mod(quote! {
impl #impl_generics diesel::sql_types::SqlType
for #struct_name #ty_generics
#where_clause
{
type IsNull = diesel::sql_types::is_nullable::NotNull;

const IS_ARRAY: bool = #is_array;
}

impl #impl_generics diesel::sql_types::SingleValue
Expand Down
15 changes: 15 additions & 0 deletions diesel_tests/tests/filter_operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,21 @@ fn filter_array_by_in() {
assert_eq!(result, &[] as &[i32]);
}

#[test]
#[cfg(feature = "postgres")]
fn filter_array_by_not_in() {
use crate::schema::posts::dsl::*;

let connection: &mut PgConnection = &mut connection();
let tag_combinations_to_look_for: &[&[&str]] = &[&["foo"], &["foo", "bar"], &["baz"]];
let result: Vec<i32> = posts
.filter(tags.ne_all(tag_combinations_to_look_for))
.select(id)
.load(connection)
.unwrap();
assert_eq!(result, &[] as &[i32]);
}

fn connection_with_3_users() -> TestConnection {
let mut connection = connection_with_sean_and_tess_in_users_table();
diesel::sql_query("INSERT INTO users (id, name) VALUES (3, 'Jim')")
Expand Down

0 comments on commit 84f6918

Please sign in to comment.