Skip to content

Commit

Permalink
Fix unsoundness in FilteredEntity{Ref,Mut} various get methods (#…
Browse files Browse the repository at this point in the history
…13554)

# Objective

- `FilteredEntity{Ref,Mut}` various `get` methods never checked that the
given component was present on the entity, only the access allowed
reading/writing them, which is always the case when it is constructed
from a `EntityRef`/`EntityMut`/`EntityWorldMut` (and I guess can also
happen with queries containing `Option<T>` that get transmuted).
- In those cases the various `get` methods were calling
`debug_checked_unwrap` on `None`s, which is UB when debug assertions are
not enabled;
- The goal is thus to fix this soundness issue.

## Solution

- Don't call `debug_checked_unwrap` on those `None` and instead
`flatten` them.

## Testing

- This PR includes regression tests for each combination of
`FilteredEntityRef`/`FilteredEntityMut` and component
present/not-present. The two tests for the not-present cases fail on
`main` but success with this PR changes.
  • Loading branch information
SkiFire13 authored May 28, 2024
1 parent bc102d4 commit d98d6d8
Showing 1 changed file with 87 additions and 23 deletions.
110 changes: 87 additions & 23 deletions crates/bevy_ecs/src/world/entity_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
change_detection::MutUntyped,
component::{Component, ComponentId, ComponentTicks, Components, StorageType},
entity::{Entities, Entity, EntityLocation},
query::{Access, DebugCheckedUnwrap},
query::Access,
removal_detection::RemovedComponentEvents,
storage::Storages,
world::{Mut, World},
Expand Down Expand Up @@ -1824,8 +1824,9 @@ impl<'w> FilteredEntityRef<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_read(id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get().debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get() })
.flatten()
}

/// Gets access to the component of type `T` for the current entity,
Expand All @@ -1837,8 +1838,9 @@ impl<'w> FilteredEntityRef<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_read(id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get_ref().debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_ref() })
.flatten()
}

/// Retrieves the change ticks for the given component. This can be useful for implementing change
Expand All @@ -1848,8 +1850,9 @@ impl<'w> FilteredEntityRef<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_read(id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get_change_ticks::<T>().debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_change_ticks::<T>() })
.flatten()
}

/// Retrieves the change ticks for the given [`ComponentId`]. This can be useful for implementing change
Expand All @@ -1860,12 +1863,11 @@ impl<'w> FilteredEntityRef<'w> {
/// compile time.**
#[inline]
pub fn get_change_ticks_by_id(&self, component_id: ComponentId) -> Option<ComponentTicks> {
// SAFETY: We have read access so we must have the component
self.access.has_read(component_id).then(|| unsafe {
self.entity
.get_change_ticks_by_id(component_id)
.debug_checked_unwrap()
})
self.access
.has_read(component_id)
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_change_ticks_by_id(component_id) })
.flatten()
}

/// Gets the component of the given [`ComponentId`] from the entity.
Expand All @@ -1880,8 +1882,9 @@ impl<'w> FilteredEntityRef<'w> {
pub fn get_by_id(&self, component_id: ComponentId) -> Option<Ptr<'w>> {
self.access
.has_read(component_id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get_by_id(component_id).debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_by_id(component_id) })
.flatten()
}
}

Expand Down Expand Up @@ -2094,8 +2097,9 @@ impl<'w> FilteredEntityMut<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_write(id)
// SAFETY: We have write access so we must have the component
.then(|| unsafe { self.entity.get_mut().debug_checked_unwrap() })
// SAFETY: We have write access
.then(|| unsafe { self.entity.get_mut() })
.flatten()
}

/// Retrieves the change ticks for the given component. This can be useful for implementing change
Expand Down Expand Up @@ -2139,12 +2143,11 @@ impl<'w> FilteredEntityMut<'w> {
/// which is only valid while the [`FilteredEntityMut`] is alive.
#[inline]
pub fn get_mut_by_id(&mut self, component_id: ComponentId) -> Option<MutUntyped<'_>> {
// SAFETY: We have write access so we must have the component
self.access.has_write(component_id).then(|| unsafe {
self.entity
.get_mut_by_id(component_id)
.debug_checked_unwrap()
})
self.access
.has_write(component_id)
// SAFETY: We have write access
.then(|| unsafe { self.entity.get_mut_by_id(component_id) })
.flatten()
}
}

Expand Down Expand Up @@ -2416,6 +2419,7 @@ mod tests {
use bevy_ptr::OwningPtr;
use std::panic::AssertUnwindSafe;

use crate::world::{FilteredEntityMut, FilteredEntityRef};
use crate::{self as bevy_ecs, component::ComponentId, prelude::*, system::assert_is_system};

#[test]
Expand Down Expand Up @@ -2918,4 +2922,64 @@ mod tests {

assert_is_system(incompatible_system);
}

#[test]
fn filtered_entity_ref_normal() {
let mut world = World::new();
let a_id = world.init_component::<A>();

let e: FilteredEntityRef = world.spawn(A).into();

assert!(e.get::<A>().is_some());
assert!(e.get_ref::<A>().is_some());
assert!(e.get_change_ticks::<A>().is_some());
assert!(e.get_by_id(a_id).is_some());
assert!(e.get_change_ticks_by_id(a_id).is_some());
}

#[test]
fn filtered_entity_ref_missing() {
let mut world = World::new();
let a_id = world.init_component::<A>();

let e: FilteredEntityRef = world.spawn(()).into();

assert!(e.get::<A>().is_none());
assert!(e.get_ref::<A>().is_none());
assert!(e.get_change_ticks::<A>().is_none());
assert!(e.get_by_id(a_id).is_none());
assert!(e.get_change_ticks_by_id(a_id).is_none());
}

#[test]
fn filtered_entity_mut_normal() {
let mut world = World::new();
let a_id = world.init_component::<A>();

let mut e: FilteredEntityMut = world.spawn(A).into();

assert!(e.get::<A>().is_some());
assert!(e.get_ref::<A>().is_some());
assert!(e.get_mut::<A>().is_some());
assert!(e.get_change_ticks::<A>().is_some());
assert!(e.get_by_id(a_id).is_some());
assert!(e.get_mut_by_id(a_id).is_some());
assert!(e.get_change_ticks_by_id(a_id).is_some());
}

#[test]
fn filtered_entity_mut_missing() {
let mut world = World::new();
let a_id = world.init_component::<A>();

let mut e: FilteredEntityMut = world.spawn(()).into();

assert!(e.get::<A>().is_none());
assert!(e.get_ref::<A>().is_none());
assert!(e.get_mut::<A>().is_none());
assert!(e.get_change_ticks::<A>().is_none());
assert!(e.get_by_id(a_id).is_none());
assert!(e.get_mut_by_id(a_id).is_none());
assert!(e.get_change_ticks_by_id(a_id).is_none());
}
}

0 comments on commit d98d6d8

Please sign in to comment.