From e320fa0738e3f4582b0d61db85a00d7a70b676eb Mon Sep 17 00:00:00 2001 From: Giacomo Stevanato Date: Tue, 27 Aug 2024 02:58:40 +0200 Subject: [PATCH] Fix query transmute from table to archetype iteration unsoundness (#14615) # Objective - Fixes #14348 - Fixes #14528 - Less complex (but also likely less performant) alternative to #14611 ## Solution - Add a `is_dense` field flag to `QueryIter` indicating whether it is dense or not, that is whether it can perform dense iteration or not; - Check this flag any time iteration over a query is performed. --- It would be nice if someone could try benching this change to see if it actually matters. ~Note that this not 100% ready for mergin, since there are a bunch of safety comments on the use of the various `IS_DENSE` for checks that still need to be updated.~ This is ready modulo benchmarks --------- Co-authored-by: Alice Cecile --- crates/bevy_ecs/src/query/builder.rs | 44 +++++++++++++++ crates/bevy_ecs/src/query/iter.rs | 64 +++++++++++++--------- crates/bevy_ecs/src/query/par_iter.rs | 2 +- crates/bevy_ecs/src/query/state.rs | 79 ++++++++++++++++++++++++--- 4 files changed, 154 insertions(+), 35 deletions(-) diff --git a/crates/bevy_ecs/src/query/builder.rs b/crates/bevy_ecs/src/query/builder.rs index 3076b5d54e090..ed4e2cd4ba71d 100644 --- a/crates/bevy_ecs/src/query/builder.rs +++ b/crates/bevy_ecs/src/query/builder.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; +use crate::component::StorageType; use crate::{component::ComponentId, prelude::*}; use super::{FilteredAccess, QueryData, QueryFilter}; @@ -68,6 +69,26 @@ impl<'w, D: QueryData, F: QueryFilter> QueryBuilder<'w, D, F> { } } + pub(super) fn is_dense(&self) -> bool { + // Note: `component_id` comes from the user in safe code, so we cannot trust it to + // exist. If it doesn't exist we pessimistically assume it's sparse. + let is_dense = |component_id| { + self.world() + .components() + .get_info(component_id) + .map_or(false, |info| info.storage_type() == StorageType::Table) + }; + + self.access + .access() + .component_reads_and_writes() + .all(is_dense) + && self.access.access().archetypal().all(is_dense) + && !self.access.access().has_read_all_components() + && self.access.with_filters().all(is_dense) + && self.access.without_filters().all(is_dense) + } + /// Returns a reference to the world passed to [`Self::new`]. pub fn world(&self) -> &World { self.world @@ -396,4 +417,27 @@ mod tests { assert_eq!(1, b.deref::().0); } } + + /// Regression test for issue #14348 + #[test] + fn builder_static_dense_dynamic_sparse() { + #[derive(Component)] + struct Dense; + + #[derive(Component)] + #[component(storage = "SparseSet")] + struct Sparse; + + let mut world = World::new(); + + world.spawn(Dense); + world.spawn((Dense, Sparse)); + + let mut query = QueryBuilder::<&Dense>::new(&mut world) + .with::() + .build(); + + let matched = query.iter(&world).count(); + assert_eq!(matched, 1); + } } diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index f8564025aa15c..5c1912b7bcc18 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -128,7 +128,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> { /// # Safety /// - all `rows` must be in `[0, table.entity_count)`. /// - `table` must match D and F - /// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true. + /// - The query iteration must be dense (i.e. `self.query_state.is_dense` must be true). #[inline] pub(super) unsafe fn fold_over_table_range( &mut self, @@ -183,7 +183,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> { /// # Safety /// - all `indices` must be in `[0, archetype.len())`. /// - `archetype` must match D and F - /// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false. + /// - The query iteration must not be dense (i.e. `self.query_state.is_dense` must be false). #[inline] pub(super) unsafe fn fold_over_archetype_range( &mut self, @@ -252,7 +252,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> { /// - all `indices` must be in `[0, archetype.len())`. /// - `archetype` must match D and F /// - `archetype` must have the same length with it's table. - /// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false. + /// - The query iteration must not be dense (i.e. `self.query_state.is_dense` must be false). #[inline] pub(super) unsafe fn fold_over_dense_archetype_range( &mut self, @@ -1031,20 +1031,27 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for QueryIter<'w, 's, D, F> let Some(item) = self.next() else { break }; accum = func(accum, item); } - for id in self.cursor.storage_id_iter.clone() { - if D::IS_DENSE && F::IS_DENSE { + + if self.cursor.is_dense { + for id in self.cursor.storage_id_iter.clone() { + // SAFETY: `self.cursor.is_dense` is true, so storage ids are guaranteed to be table ids. + let table_id = unsafe { id.table_id }; // SAFETY: Matched table IDs are guaranteed to still exist. - let table = unsafe { self.tables.get(id.table_id).debug_checked_unwrap() }; + let table = unsafe { self.tables.get(table_id).debug_checked_unwrap() }; + accum = // SAFETY: // - The fetched table matches both D and F // - The provided range is equivalent to [0, table.entity_count) - // - The if block ensures that D::IS_DENSE and F::IS_DENSE are both true + // - The if block ensures that the query iteration is dense unsafe { self.fold_over_table_range(accum, &mut func, table, 0..table.entity_count()) }; - } else { - let archetype = - // SAFETY: Matched archetype IDs are guaranteed to still exist. - unsafe { self.archetypes.get(id.archetype_id).debug_checked_unwrap() }; + } + } else { + for id in self.cursor.storage_id_iter.clone() { + // SAFETY: `self.cursor.is_dense` is false, so storage ids are guaranteed to be archetype ids. + let archetype_id = unsafe { id.archetype_id }; + // SAFETY: Matched archetype IDs are guaranteed to still exist. + let archetype = unsafe { self.archetypes.get(archetype_id).debug_checked_unwrap() }; // SAFETY: Matched table IDs are guaranteed to still exist. let table = unsafe { self.tables.get(archetype.table_id()).debug_checked_unwrap() }; @@ -1052,19 +1059,19 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for QueryIter<'w, 's, D, F> // this leverages cache locality to optimize performance. if table.entity_count() == archetype.len() { accum = - // SAFETY: - // - The fetched archetype matches both D and F - // - The provided archetype and its' table have the same length. - // - The provided range is equivalent to [0, archetype.len) - // - The if block ensures that ether D::IS_DENSE or F::IS_DENSE are false - unsafe { self.fold_over_dense_archetype_range(accum, &mut func, archetype,0..archetype.len()) }; + // SAFETY: + // - The fetched archetype matches both D and F + // - The provided archetype and its' table have the same length. + // - The provided range is equivalent to [0, archetype.len) + // - The if block ensures that the query iteration is not dense. + unsafe { self.fold_over_dense_archetype_range(accum, &mut func, archetype, 0..archetype.len()) }; } else { accum = - // SAFETY: - // - The fetched archetype matches both D and F - // - The provided range is equivalent to [0, archetype.len) - // - The if block ensures that ether D::IS_DENSE or F::IS_DENSE are false - unsafe { self.fold_over_archetype_range(accum, &mut func, archetype,0..archetype.len()) }; + // SAFETY: + // - The fetched archetype matches both D and F + // - The provided range is equivalent to [0, archetype.len) + // - The if block ensures that the query iteration is not dense. + unsafe { self.fold_over_archetype_range(accum, &mut func, archetype, 0..archetype.len()) }; } } } @@ -1675,6 +1682,8 @@ impl<'w, 's, D: QueryData, F: QueryFilter, const K: usize> Debug } struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> { + // whether the query iteration is dense or not. Mirrors QueryState's `is_dense` field. + is_dense: bool, storage_id_iter: std::slice::Iter<'s, StorageId>, table_entities: &'w [Entity], archetype_entities: &'w [ArchetypeEntity], @@ -1689,6 +1698,7 @@ struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> { impl Clone for QueryIterationCursor<'_, '_, D, F> { fn clone(&self) -> Self { Self { + is_dense: self.is_dense, storage_id_iter: self.storage_id_iter.clone(), table_entities: self.table_entities, archetype_entities: self.archetype_entities, @@ -1701,8 +1711,6 @@ impl Clone for QueryIterationCursor<'_, '_, D, F> } impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> { - const IS_DENSE: bool = D::IS_DENSE && F::IS_DENSE; - unsafe fn init_empty( world: UnsafeWorldCell<'w>, query_state: &'s QueryState, @@ -1732,6 +1740,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> { table_entities: &[], archetype_entities: &[], storage_id_iter: query_state.matched_storage_ids.iter(), + is_dense: query_state.is_dense, current_len: 0, current_row: 0, } @@ -1739,6 +1748,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> { fn reborrow(&mut self) -> QueryIterationCursor<'_, 's, D, F> { QueryIterationCursor { + is_dense: self.is_dense, fetch: D::shrink_fetch(self.fetch.clone()), filter: F::shrink_fetch(self.filter.clone()), table_entities: self.table_entities, @@ -1754,7 +1764,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> { unsafe fn peek_last(&mut self) -> Option> { if self.current_row > 0 { let index = self.current_row - 1; - if Self::IS_DENSE { + if self.is_dense { let entity = self.table_entities.get_unchecked(index); Some(D::fetch( &mut self.fetch, @@ -1780,7 +1790,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> { /// will be **the exact count of remaining values**. fn max_remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize { let ids = self.storage_id_iter.clone(); - let remaining_matched: usize = if Self::IS_DENSE { + let remaining_matched: usize = if self.is_dense { // SAFETY: The if check ensures that storage_id_iter stores TableIds unsafe { ids.map(|id| tables[id.table_id].entity_count()).sum() } } else { @@ -1803,7 +1813,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> { archetypes: &'w Archetypes, query_state: &'s QueryState, ) -> Option> { - if Self::IS_DENSE { + if self.is_dense { loop { // we are on the beginning of the query, or finished processing a table, so skip to the next if self.current_row == self.current_len { diff --git a/crates/bevy_ecs/src/query/par_iter.rs b/crates/bevy_ecs/src/query/par_iter.rs index 7889228ba7af6..b3ea93fbbc514 100644 --- a/crates/bevy_ecs/src/query/par_iter.rs +++ b/crates/bevy_ecs/src/query/par_iter.rs @@ -126,7 +126,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { fn get_batch_size(&self, thread_count: usize) -> usize { let max_items = || { let id_iter = self.state.matched_storage_ids.iter(); - if D::IS_DENSE && F::IS_DENSE { + if self.state.is_dense { // SAFETY: We only access table metadata. let tables = unsafe { &self.world.world_metadata().storages().tables }; id_iter diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 70cf722daff3e..b94cb4aa90454 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -24,7 +24,10 @@ use super::{ /// An ID for either a table or an archetype. Used for Query iteration. /// /// Query iteration is exclusively dense (over tables) or archetypal (over archetypes) based on whether -/// both `D::IS_DENSE` and `F::IS_DENSE` are true or not. +/// the query filters are dense or not. This is represented by the [`QueryState::is_dense`] field. +/// +/// Note that `D::IS_DENSE` and `F::IS_DENSE` have no relationship with `QueryState::is_dense` and +/// any combination of their values can happen. /// /// This is a union instead of an enum as the usage is determined at compile time, as all [`StorageId`]s for /// a [`QueryState`] will be all [`TableId`]s or all [`ArchetypeId`]s, and not a mixture of both. This @@ -68,6 +71,9 @@ pub struct QueryState { pub(crate) component_access: FilteredAccess, // NOTE: we maintain both a bitset and a vec because iterating the vec is faster pub(super) matched_storage_ids: Vec, + // Represents whether this query iteration is dense or not. When this is true + // `matched_storage_ids` stores `TableId`s, otherwise it stores `ArchetypeId`s. + pub(super) is_dense: bool, pub(crate) fetch_state: D::State, pub(crate) filter_state: F::State, #[cfg(feature = "trace")] @@ -194,10 +200,15 @@ impl QueryState { // properly considered in a global "cross-query" context (both within systems and across systems). component_access.extend(&filter_component_access); + // For queries without dynamic filters the dense-ness of the query is equal to the dense-ness + // of its static type parameters. + let is_dense = D::IS_DENSE && F::IS_DENSE; + Self { world_id: world.id(), archetype_generation: ArchetypeGeneration::initial(), matched_storage_ids: Vec::new(), + is_dense, fetch_state, filter_state, component_access, @@ -222,6 +233,8 @@ impl QueryState { world_id: builder.world().id(), archetype_generation: ArchetypeGeneration::initial(), matched_storage_ids: Vec::new(), + // For dynamic queries the dense-ness is given by the query builder. + is_dense: builder.is_dense(), fetch_state, filter_state, component_access: builder.access().clone(), @@ -450,7 +463,7 @@ impl QueryState { let archetype_index = archetype.id().index(); if !self.matched_archetypes.contains(archetype_index) { self.matched_archetypes.grow_and_insert(archetype_index); - if !D::IS_DENSE || !F::IS_DENSE { + if !self.is_dense { self.matched_storage_ids.push(StorageId { archetype_id: archetype.id(), }); @@ -459,7 +472,7 @@ impl QueryState { let table_index = archetype.table_id().as_usize(); if !self.matched_tables.contains(table_index) { self.matched_tables.grow_and_insert(table_index); - if D::IS_DENSE && F::IS_DENSE { + if self.is_dense { self.matched_storage_ids.push(StorageId { table_id: archetype.table_id(), }); @@ -560,6 +573,7 @@ impl QueryState { world_id: self.world_id, archetype_generation: self.archetype_generation, matched_storage_ids: self.matched_storage_ids.clone(), + is_dense: self.is_dense, fetch_state, filter_state, component_access: self.component_access.clone(), @@ -653,12 +667,15 @@ impl QueryState { warn!("You have tried to join queries with different archetype_generations. This could lead to unpredictable results."); } + // the join is dense of both the queries were dense. + let is_dense = self.is_dense && other.is_dense; + // take the intersection of the matched ids let mut matched_tables = self.matched_tables.clone(); let mut matched_archetypes = self.matched_archetypes.clone(); matched_tables.intersect_with(&other.matched_tables); matched_archetypes.intersect_with(&other.matched_archetypes); - let matched_storage_ids = if NewD::IS_DENSE && NewF::IS_DENSE { + let matched_storage_ids = if is_dense { matched_tables .ones() .map(|id| StorageId { @@ -678,6 +695,7 @@ impl QueryState { world_id: self.world_id, archetype_generation: self.archetype_generation, matched_storage_ids, + is_dense, fetch_state: new_fetch_state, filter_state: new_filter_state, component_access: joined_component_access, @@ -1487,7 +1505,7 @@ impl QueryState { let mut iter = self.iter_unchecked_manual(world, last_run, this_run); let mut accum = init_accum(); for storage_id in queue { - if D::IS_DENSE && F::IS_DENSE { + if self.is_dense { let id = storage_id.table_id; let table = &world.storages().tables.get(id).debug_checked_unwrap(); accum = iter.fold_over_table_range( @@ -1521,7 +1539,7 @@ impl QueryState { #[cfg(feature = "trace")] let _span = self.par_iter_span.enter(); let accum = init_accum(); - if D::IS_DENSE && F::IS_DENSE { + if self.is_dense { let id = storage_id.table_id; let table = world.storages().tables.get(id).debug_checked_unwrap(); self.iter_unchecked_manual(world, last_run, this_run) @@ -1537,7 +1555,7 @@ impl QueryState { }; let storage_entity_count = |storage_id: StorageId| -> usize { - if D::IS_DENSE && F::IS_DENSE { + if self.is_dense { tables[storage_id.table_id].entity_count() } else { archetypes[storage_id.archetype_id].len() @@ -2042,6 +2060,53 @@ mod tests { world.query::<(&A, &B)>().transmute::<&B>(&world2); } + /// Regression test for issue #14528 + #[test] + fn transmute_from_sparse_to_dense() { + #[derive(Component)] + struct Dense; + + #[derive(Component)] + #[component(storage = "SparseSet")] + struct Sparse; + + let mut world = World::new(); + + world.spawn(Dense); + world.spawn((Dense, Sparse)); + + let mut query = world + .query_filtered::<&Dense, With>() + .transmute::<&Dense>(&world); + + let matched = query.iter(&world).count(); + assert_eq!(matched, 1); + } + #[test] + fn transmute_from_dense_to_sparse() { + #[derive(Component)] + struct Dense; + + #[derive(Component)] + #[component(storage = "SparseSet")] + struct Sparse; + + let mut world = World::new(); + + world.spawn(Dense); + world.spawn((Dense, Sparse)); + + let mut query = world + .query::<&Dense>() + .transmute_filtered::<&Dense, With>(&world); + + // Note: `transmute_filtered` is supposed to keep the same matched tables/archetypes, + // so it doesn't actually filter out those entities without `Sparse` and the iteration + // remains dense. + let matched = query.iter(&world).count(); + assert_eq!(matched, 2); + } + #[test] fn join() { let mut world = World::new();