diff --git a/benches/benches/bevy_reflect/function.rs b/benches/benches/bevy_reflect/function.rs index 03673d3a9a2eb..b82466884f7d7 100644 --- a/benches/benches/bevy_reflect/function.rs +++ b/benches/benches/bevy_reflect/function.rs @@ -1,7 +1,7 @@ use bevy_reflect::func::{ArgList, IntoFunction, IntoFunctionMut, TypedFunction}; use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -criterion_group!(benches, typed, into, call, clone); +criterion_group!(benches, typed, into, call, overload, clone); criterion_main!(benches); fn add(a: i32, b: i32) -> i32 { @@ -79,6 +79,307 @@ fn call(c: &mut Criterion) { }); } +fn overload(c: &mut Criterion) { + fn add>(a: T, b: T) -> T { + a + b + } + + fn complex( + _: T0, + _: T1, + _: T2, + _: T3, + _: T4, + _: T5, + _: T6, + _: T7, + _: T8, + _: T9, + ) { + } + + c.benchmark_group("with_overload") + .bench_function("01_simple_overload", |b| { + b.iter_batched( + || add::.into_function(), + |func| func.with_overload(add::), + BatchSize::SmallInput, + ); + }) + .bench_function("01_complex_overload", |b| { + b.iter_batched( + || complex::.into_function(), + |func| { + func.with_overload(complex::) + }, + BatchSize::SmallInput, + ); + }) + .bench_function("03_simple_overload", |b| { + b.iter_batched( + || add::.into_function(), + |func| { + func.with_overload(add::) + .with_overload(add::) + .with_overload(add::) + }, + BatchSize::SmallInput, + ); + }) + .bench_function("03_complex_overload", |b| { + b.iter_batched( + || complex::.into_function(), + |func| { + func.with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + }, + BatchSize::SmallInput, + ); + }) + .bench_function("10_simple_overload", |b| { + b.iter_batched( + || add::.into_function(), + |func| { + func.with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + }, + BatchSize::SmallInput, + ); + }) + .bench_function("10_complex_overload", |b| { + b.iter_batched( + || complex::.into_function(), + |func| { + func.with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + .with_overload(complex::) + }, + BatchSize::SmallInput, + ); + }) + .bench_function("01_nested_simple_overload", |b| { + b.iter_batched( + || add::.into_function(), + |func| func.with_overload(add::), + BatchSize::SmallInput, + ); + }) + .bench_function("03_nested_simple_overload", |b| { + b.iter_batched( + || add::.into_function(), + |func| { + func.with_overload( + add:: + .into_function() + .with_overload(add::.into_function().with_overload(add::)), + ) + }, + BatchSize::SmallInput, + ); + }) + .bench_function("10_nested_simple_overload", |b| { + b.iter_batched( + || add::.into_function(), + |func| { + func.with_overload( + add::.into_function().with_overload( + add::.into_function().with_overload( + add::.into_function().with_overload( + add::.into_function().with_overload( + add::.into_function().with_overload( + add::.into_function().with_overload( + add::.into_function().with_overload( + add:: + .into_function() + .with_overload(add::), + ), + ), + ), + ), + ), + ), + ), + ) + }, + BatchSize::SmallInput, + ); + }); + + c.benchmark_group("call_overload") + .bench_function("01_simple_overload", |b| { + b.iter_batched( + || { + ( + add::.into_function().with_overload(add::), + ArgList::new().push_owned(75_i8).push_owned(25_i8), + ) + }, + |(func, args)| func.call(args), + BatchSize::SmallInput, + ); + }) + .bench_function("01_complex_overload", |b| { + b.iter_batched( + || { + ( + complex:: + .into_function() + .with_overload( + complex::, + ), + ArgList::new() + .push_owned(1_i8) + .push_owned(2_i16) + .push_owned(3_i32) + .push_owned(4_i64) + .push_owned(5_i128) + .push_owned(6_u8) + .push_owned(7_u16) + .push_owned(8_u32) + .push_owned(9_u64) + .push_owned(10_u128), + ) + }, + |(func, args)| func.call(args), + BatchSize::SmallInput, + ); + }) + .bench_function("03_simple_overload", |b| { + b.iter_batched( + || { + ( + add:: + .into_function() + .with_overload(add::) + .with_overload(add::) + .with_overload(add::), + ArgList::new().push_owned(75_i32).push_owned(25_i32), + ) + }, + |(func, args)| func.call(args), + BatchSize::SmallInput, + ); + }) + .bench_function("03_complex_overload", |b| { + b.iter_batched( + || { + ( + complex:: + .into_function() + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ), + ArgList::new() + .push_owned(1_i32) + .push_owned(2_i64) + .push_owned(3_i128) + .push_owned(4_u8) + .push_owned(5_u16) + .push_owned(6_u32) + .push_owned(7_u64) + .push_owned(8_u128) + .push_owned(9_i8) + .push_owned(10_i16), + ) + }, + |(func, args)| func.call(args), + BatchSize::SmallInput, + ); + }) + .bench_function("10_simple_overload", |b| { + b.iter_batched( + || { + ( + add:: + .into_function() + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::) + .with_overload(add::), + ArgList::new().push_owned(75_u8).push_owned(25_u8), + ) + }, + |(func, args)| func.call(args), + BatchSize::SmallInput, + ); + }) + .bench_function("10_complex_overload", |b| { + b.iter_batched( + || { + ( + complex:: + .into_function() + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ) + .with_overload( + complex::, + ), + ArgList::new() + .push_owned(1_u8) + .push_owned(2_u16) + .push_owned(3_u32) + .push_owned(4_u64) + .push_owned(5_u128) + .push_owned(6_i8) + .push_owned(7_i16) + .push_owned(8_i32) + .push_owned(9_i64) + .push_owned(10_i128), + ) + }, + |(func, args)| func.call(args), + BatchSize::SmallInput, + ); + }); +} + fn clone(c: &mut Criterion) { c.benchmark_group("clone").bench_function("function", |b| { let add = add.into_function(); diff --git a/crates/bevy_reflect/src/func/args/arg.rs b/crates/bevy_reflect/src/func/args/arg.rs index d614f073f2cc8..60698a3d7e0a7 100644 --- a/crates/bevy_reflect/src/func/args/arg.rs +++ b/crates/bevy_reflect/src/func/args/arg.rs @@ -183,6 +183,14 @@ impl<'a> Arg<'a> { } } } + + /// Returns `true` if the argument is of type `T`. + pub fn is(&self) -> bool { + self.value + .try_as_reflect() + .map(::is::) + .unwrap_or_default() + } } /// Represents an argument that can be passed to a [`DynamicFunction`] or [`DynamicFunctionMut`]. diff --git a/crates/bevy_reflect/src/func/args/count.rs b/crates/bevy_reflect/src/func/args/count.rs new file mode 100644 index 0000000000000..d5f410f88dfaf --- /dev/null +++ b/crates/bevy_reflect/src/func/args/count.rs @@ -0,0 +1,311 @@ +use crate::func::args::ArgCountOutOfBoundsError; +use core::fmt::{Debug, Formatter}; + +/// A container for zero or more argument counts for a function. +/// +/// For most functions, this will contain a single count, +/// however, overloaded functions may contain more. +/// +/// # Maximum Argument Count +/// +/// The maximum number of arguments that can be represented by this struct is 63, +/// as given by [`ArgCount::MAX_COUNT`]. +/// The reason for this is that all counts are stored internally as a single `u64` +/// with each bit representing a specific count based on its bit index. +/// +/// This allows for a smaller memory footprint and faster lookups compared to a +/// `HashSet` or `Vec` of possible counts. +/// It's also more appropriate for representing the argument counts of a function +/// given that most functions will not have more than a few arguments. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct ArgCount { + /// The bits representing the argument counts. + /// + /// Each bit represents a specific count based on its bit index. + bits: u64, + /// The total number of argument counts. + len: u8, +} + +impl ArgCount { + /// The maximum number of arguments that can be represented by this struct. + pub const MAX_COUNT: usize = u64::BITS as usize - 1; + + /// Create a new [`ArgCount`] with the given count. + /// + /// # Errors + /// + /// Returns an error if the count is greater than [`Self::MAX_COUNT`]. + pub fn new(count: usize) -> Result { + Ok(Self { + bits: 1 << Self::try_to_u8(count)?, + len: 1, + }) + } + + /// Adds the given count to this [`ArgCount`]. + /// + /// # Panics + /// + /// Panics if the count is greater than [`Self::MAX_COUNT`]. + pub fn add(&mut self, count: usize) { + self.try_add(count).unwrap(); + } + + /// Attempts to add the given count to this [`ArgCount`]. + /// + /// # Errors + /// + /// Returns an error if the count is greater than [`Self::MAX_COUNT`]. + pub fn try_add(&mut self, count: usize) -> Result<(), ArgCountOutOfBoundsError> { + let count = Self::try_to_u8(count)?; + + if !self.contains_unchecked(count) { + self.len += 1; + self.bits |= 1 << count; + } + + Ok(()) + } + + /// Removes the given count from this [`ArgCount`]. + pub fn remove(&mut self, count: usize) { + self.try_remove(count).unwrap(); + } + + /// Attempts to remove the given count from this [`ArgCount`]. + /// + /// # Errors + /// + /// Returns an error if the count is greater than [`Self::MAX_COUNT`]. + pub fn try_remove(&mut self, count: usize) -> Result<(), ArgCountOutOfBoundsError> { + let count = Self::try_to_u8(count)?; + + if self.contains_unchecked(count) { + self.len -= 1; + self.bits &= !(1 << count); + } + + Ok(()) + } + + /// Checks if this [`ArgCount`] contains the given count. + pub fn contains(&self, count: usize) -> bool { + count < usize::BITS as usize && (self.bits >> count) & 1 == 1 + } + + /// Returns the total number of argument counts that this [`ArgCount`] contains. + pub fn len(&self) -> usize { + self.len as usize + } + + /// Returns true if this [`ArgCount`] contains no argument counts. + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns an iterator over the argument counts in this [`ArgCount`]. + pub fn iter(&self) -> ArgCountIter { + ArgCountIter { + count: *self, + index: 0, + found: 0, + } + } + + /// Checks if this [`ArgCount`] contains the given count without any bounds checking. + /// + /// # Panics + /// + /// Panics if the count is greater than [`Self::MAX_COUNT`]. + fn contains_unchecked(&self, count: u8) -> bool { + (self.bits >> count) & 1 == 1 + } + + /// Attempts to convert the given count to a `u8` within the bounds of the [maximum count]. + /// + /// [maximum count]: Self::MAX_COUNT + fn try_to_u8(count: usize) -> Result { + if count > Self::MAX_COUNT { + Err(ArgCountOutOfBoundsError(count)) + } else { + Ok(count as u8) + } + } +} + +/// Defaults this [`ArgCount`] to empty. +/// +/// This means that it contains no argument counts, including zero. +impl Default for ArgCount { + fn default() -> Self { + Self { bits: 0, len: 0 } + } +} + +impl Debug for ArgCount { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.debug_set().entries(self.iter()).finish() + } +} + +/// An iterator for the argument counts in an [`ArgCount`]. +pub struct ArgCountIter { + count: ArgCount, + index: u8, + found: u8, +} + +impl Iterator for ArgCountIter { + type Item = usize; + + fn next(&mut self) -> Option { + loop { + if self.index as usize > ArgCount::MAX_COUNT { + return None; + } + + if self.found == self.count.len { + // All counts have been found + return None; + } + + if self.count.contains_unchecked(self.index) { + self.index += 1; + self.found += 1; + return Some(self.index as usize - 1); + } + + self.index += 1; + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.count.len(), Some(self.count.len())) + } +} + +impl ExactSizeIterator for ArgCountIter {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_default_to_empty() { + let count = ArgCount::default(); + + assert_eq!(count.len(), 0); + assert!(count.is_empty()); + + assert!(!count.contains(0)); + } + + #[test] + fn should_construct_with_count() { + let count = ArgCount::new(3).unwrap(); + + assert_eq!(count.len(), 1); + assert!(!count.is_empty()); + + assert!(count.contains(3)); + } + + #[test] + fn should_add_count() { + let mut count = ArgCount::default(); + count.add(3); + + assert_eq!(count.len(), 1); + + assert!(count.contains(3)); + } + + #[test] + fn should_add_multiple_counts() { + let mut count = ArgCount::default(); + count.add(3); + count.add(5); + count.add(7); + + assert_eq!(count.len(), 3); + + assert!(!count.contains(0)); + assert!(!count.contains(1)); + assert!(!count.contains(2)); + + assert!(count.contains(3)); + assert!(count.contains(5)); + assert!(count.contains(7)); + } + + #[test] + fn should_add_idempotently() { + let mut count = ArgCount::default(); + count.add(3); + count.add(3); + + assert_eq!(count.len(), 1); + assert!(count.contains(3)); + } + + #[test] + fn should_remove_count() { + let mut count = ArgCount::default(); + count.add(3); + + assert_eq!(count.len(), 1); + assert!(count.contains(3)); + + count.remove(3); + + assert_eq!(count.len(), 0); + assert!(!count.contains(3)); + } + + #[test] + fn should_allow_removeting_nonexistent_count() { + let mut count = ArgCount::default(); + + assert_eq!(count.len(), 0); + assert!(!count.contains(3)); + + count.remove(3); + + assert_eq!(count.len(), 0); + assert!(!count.contains(3)); + } + + #[test] + fn should_iterate_over_counts() { + let mut count = ArgCount::default(); + count.add(3); + count.add(5); + count.add(7); + + let mut iter = count.iter(); + + assert_eq!(iter.len(), 3); + + assert_eq!(iter.next(), Some(3)); + assert_eq!(iter.next(), Some(5)); + assert_eq!(iter.next(), Some(7)); + assert_eq!(iter.next(), None); + } + + #[test] + fn should_return_error_for_out_of_bounds_count() { + let count = ArgCount::new(64); + assert_eq!(count, Err(ArgCountOutOfBoundsError(64))); + + let mut count = ArgCount::default(); + assert_eq!(count.try_add(64), Err(ArgCountOutOfBoundsError(64))); + assert_eq!(count.try_remove(64), Err(ArgCountOutOfBoundsError(64))); + } + + #[test] + fn should_return_false_for_out_of_bounds_contains() { + let count = ArgCount::default(); + assert!(!count.contains(64)); + } +} diff --git a/crates/bevy_reflect/src/func/args/error.rs b/crates/bevy_reflect/src/func/args/error.rs index 9d66c7039354b..65c4caa6e8449 100644 --- a/crates/bevy_reflect/src/func/args/error.rs +++ b/crates/bevy_reflect/src/func/args/error.rs @@ -32,3 +32,8 @@ pub enum ArgError { #[error("expected an argument but received none")] EmptyArgList, } + +/// The given argument count is out of bounds. +#[derive(Debug, Error, PartialEq)] +#[error("argument count out of bounds: {0}")] +pub struct ArgCountOutOfBoundsError(pub usize); diff --git a/crates/bevy_reflect/src/func/args/list.rs b/crates/bevy_reflect/src/func/args/list.rs index 6ed7eace98c2e..145414424f4b1 100644 --- a/crates/bevy_reflect/src/func/args/list.rs +++ b/crates/bevy_reflect/src/func/args/list.rs @@ -5,7 +5,10 @@ use crate::{ }, PartialReflect, Reflect, TypePath, }; -use alloc::{boxed::Box, collections::VecDeque}; +use alloc::{ + boxed::Box, + collections::vec_deque::{Iter, VecDeque}, +}; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; @@ -286,6 +289,11 @@ impl<'a> ArgList<'a> { self.pop_arg()?.take_mut() } + /// Returns an iterator over the arguments in the list. + pub fn iter(&self) -> Iter<'_, Arg<'a>> { + self.list.iter() + } + /// Returns the number of arguments in the list. pub fn len(&self) -> usize { self.list.len() diff --git a/crates/bevy_reflect/src/func/args/mod.rs b/crates/bevy_reflect/src/func/args/mod.rs index da0ea00bb1abd..3b167bd2f0b32 100644 --- a/crates/bevy_reflect/src/func/args/mod.rs +++ b/crates/bevy_reflect/src/func/args/mod.rs @@ -4,6 +4,7 @@ //! [`DynamicFunctionMut`]: crate::func::DynamicFunctionMut pub use arg::*; +pub use count::*; pub use error::*; pub use from_arg::*; pub use info::*; @@ -11,6 +12,7 @@ pub use list::*; pub use ownership::*; mod arg; +mod count; mod error; mod from_arg; mod info; diff --git a/crates/bevy_reflect/src/func/dynamic_function.rs b/crates/bevy_reflect/src/func/dynamic_function.rs index 36a2f22a8030c..863bffce77bfc 100644 --- a/crates/bevy_reflect/src/func/dynamic_function.rs +++ b/crates/bevy_reflect/src/func/dynamic_function.rs @@ -2,8 +2,11 @@ use crate::{ self as bevy_reflect, __macro_exports::RegisterForReflection, func::{ - args::ArgList, info::FunctionInfo, DynamicFunctionMut, Function, FunctionError, - FunctionResult, IntoFunction, IntoFunctionMut, + args::{ArgCount, ArgList}, + dynamic_function_internal::DynamicFunctionInternal, + info::FunctionInfo, + DynamicFunctionMut, Function, FunctionOverloadError, FunctionResult, IntoFunction, + IntoFunctionMut, }, serde::Serializable, ApplyError, MaybeTyped, PartialReflect, Reflect, ReflectKind, ReflectMut, ReflectOwned, @@ -16,6 +19,16 @@ use core::fmt::{Debug, Formatter}; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; +/// An [`Arc`] containing a callback to a reflected function. +/// +/// The `Arc` is used to both ensure that it is `Send + Sync` +/// and to allow for the callback to be cloned. +/// +/// Note that cloning is okay since we only ever need an immutable reference +/// to call a `dyn Fn` function. +/// If we were to contain a `dyn FnMut` instead, cloning would be a lot more complicated. +type ArcFn<'env> = Arc Fn(ArgList<'a>) -> FunctionResult<'a> + Send + Sync + 'env>; + /// A dynamic representation of a function. /// /// This type can be used to represent any callable that satisfies [`Fn`] @@ -35,7 +48,7 @@ use alloc::{boxed::Box, format, vec}; /// Most of the time, a [`DynamicFunction`] can be created using the [`IntoFunction`] trait: /// /// ``` -/// # use bevy_reflect::func::{ArgList, DynamicFunction, FunctionInfo, IntoFunction}; +/// # use bevy_reflect::func::{ArgList, DynamicFunction, IntoFunction}; /// # /// fn add(a: i32, b: i32) -> i32 { /// a + b @@ -54,9 +67,9 @@ use alloc::{boxed::Box, format, vec}; /// /// [`ReflectFn`]: crate::func::ReflectFn /// [module-level documentation]: crate::func +#[derive(Clone)] pub struct DynamicFunction<'env> { - pub(super) info: FunctionInfo, - pub(super) func: Arc Fn(ArgList<'a>) -> FunctionResult<'a> + Send + Sync + 'env>, + pub(super) internal: DynamicFunctionInternal>, } impl<'env> DynamicFunction<'env> { @@ -65,17 +78,26 @@ impl<'env> DynamicFunction<'env> { /// The given function can be used to call out to any other callable, /// including functions, closures, or methods. /// - /// It's important that the function signature matches the provided [`FunctionInfo`] + /// It's important that the function signature matches the provided [`FunctionInfo`]. /// as this will be used to validate arguments when [calling] the function. + /// This is also required in order for [function overloading] to work correctly. + /// + /// # Panics + /// + /// This function may panic for any of the following reasons: + /// - No [`SignatureInfo`] is provided. + /// - A provided [`SignatureInfo`] has more arguments than [`ArgCount::MAX_COUNT`]. + /// - The conversion to [`FunctionInfo`] fails. /// - /// [calling]: DynamicFunction::call + /// [calling]: crate::func::dynamic_function::DynamicFunction::call + /// [`SignatureInfo`]: crate::func::SignatureInfo + /// [function overloading]: Self::with_overload pub fn new Fn(ArgList<'a>) -> FunctionResult<'a> + Send + Sync + 'env>( func: F, - info: FunctionInfo, + info: impl TryInto, ) -> Self { Self { - info, - func: Arc::new(func), + internal: DynamicFunctionInternal::new(Arc::new(func), info.try_into().unwrap()), } } @@ -88,10 +110,140 @@ impl<'env> DynamicFunction<'env> { /// /// [`DynamicFunctions`]: DynamicFunction pub fn with_name(mut self, name: impl Into>) -> Self { - self.info = self.info.with_name(name); + self.internal = self.internal.with_name(name); self } + /// Add an overload to this function. + /// + /// Overloads allow a single [`DynamicFunction`] to represent multiple functions of different signatures. + /// + /// This can be used to handle multiple monomorphizations of a generic function + /// or to allow functions with a variable number of arguments. + /// + /// Any functions with the same [argument signature] will be overwritten by the one from the new function, `F`. + /// For example, if the existing function had the signature `(i32, i32) -> i32`, + /// and the new function, `F`, also had the signature `(i32, i32) -> i32`, + /// the one from `F` would replace the one from the existing function. + /// + /// Overloaded functions retain the [name] of the original function. + /// + /// # Panics + /// + /// Panics if the function, `F`, contains a signature already found in this function. + /// + /// For a non-panicking version, see [`try_with_overload`]. + /// + /// # Examples + /// + /// ``` + /// # use std::ops::Add; + /// # use bevy_reflect::func::{ArgList, IntoFunction}; + /// # + /// fn add>(a: T, b: T) -> T { + /// a + b + /// } + /// + /// // Currently, the only generic type `func` supports is `i32`: + /// let mut func = add::.into_function(); + /// + /// // However, we can add an overload to handle `f32` as well: + /// func = func.with_overload(add::); + /// + /// // Test `i32`: + /// let args = ArgList::default().push_owned(25_i32).push_owned(75_i32); + /// let result = func.call(args).unwrap().unwrap_owned(); + /// assert_eq!(result.try_take::().unwrap(), 100); + /// + /// // Test `f32`: + /// let args = ArgList::default().push_owned(25.0_f32).push_owned(75.0_f32); + /// let result = func.call(args).unwrap().unwrap_owned(); + /// assert_eq!(result.try_take::().unwrap(), 100.0); + ///``` + /// + /// ``` + /// # use bevy_reflect::func::{ArgList, IntoFunction}; + /// # + /// fn add_2(a: i32, b: i32) -> i32 { + /// a + b + /// } + /// + /// fn add_3(a: i32, b: i32, c: i32) -> i32 { + /// a + b + c + /// } + /// + /// // Currently, `func` only supports two arguments. + /// let mut func = add_2.into_function(); + /// + /// // However, we can add an overload to handle three arguments as well. + /// func = func.with_overload(add_3); + /// + /// // Test two arguments: + /// let args = ArgList::default().push_owned(25_i32).push_owned(75_i32); + /// let result = func.call(args).unwrap().unwrap_owned(); + /// assert_eq!(result.try_take::().unwrap(), 100); + /// + /// // Test three arguments: + /// let args = ArgList::default() + /// .push_owned(25_i32) + /// .push_owned(75_i32) + /// .push_owned(100_i32); + /// let result = func.call(args).unwrap().unwrap_owned(); + /// assert_eq!(result.try_take::().unwrap(), 200); + /// ``` + /// + ///```should_panic + /// # use bevy_reflect::func::IntoFunction; + /// + /// fn add(a: i32, b: i32) -> i32 { + /// a + b + /// } + /// + /// fn sub(a: i32, b: i32) -> i32 { + /// a - b + /// } + /// + /// let mut func = add.into_function(); + /// + /// // This will panic because the function already has an argument signature for `(i32, i32)`: + /// func = func.with_overload(sub); + /// ``` + /// + /// [argument signature]: crate::func::signature::ArgumentSignature + /// [name]: Self::name + /// [`try_with_overload`]: Self::try_with_overload + pub fn with_overload<'a, F: IntoFunction<'a, Marker>, Marker>( + self, + function: F, + ) -> DynamicFunction<'a> + where + 'env: 'a, + { + self.try_with_overload(function).unwrap_or_else(|(_, err)| { + panic!("{}", err); + }) + } + + /// Attempt to add an overload to this function. + /// + /// If the function, `F`, contains a signature already found in this function, + /// an error will be returned along with the original function. + /// + /// For a panicking version, see [`with_overload`]. + /// + /// [`with_overload`]: Self::with_overload + pub fn try_with_overload, Marker>( + mut self, + function: F, + ) -> Result, FunctionOverloadError)> { + let function = function.into_function(); + + match self.internal.merge(function.internal) { + Ok(_) => Ok(self), + Err(err) => Err((Box::new(self), err)), + } + } + /// Call the function with the given arguments. /// /// # Example @@ -116,25 +268,17 @@ impl<'env> DynamicFunction<'env> { /// /// The function itself may also return any errors it needs to. pub fn call<'a>(&self, args: ArgList<'a>) -> FunctionResult<'a> { - let expected_arg_count = self.info.arg_count(); - let received_arg_count = args.len(); - - if expected_arg_count != received_arg_count { - Err(FunctionError::ArgCountMismatch { - expected: expected_arg_count, - received: received_arg_count, - }) - } else { - (self.func)(args) - } + self.internal.validate_args(&args)?; + let func = self.internal.get(&args)?; + func(args) } /// Returns the function info. pub fn info(&self) -> &FunctionInfo { - &self.info + self.internal.info() } - /// The [name] of the function. + /// The name of the function. /// /// For [`DynamicFunctions`] created using [`IntoFunction`], /// the default name will always be the full path to the function as returned by [`core::any::type_name`], @@ -143,17 +287,62 @@ impl<'env> DynamicFunction<'env> { /// /// This can be overridden using [`with_name`]. /// - /// [name]: FunctionInfo::name + /// If the function was [overloaded], it will retain its original name if it had one. + /// /// [`DynamicFunctions`]: DynamicFunction /// [`with_name`]: Self::with_name + /// [overloaded]: Self::with_overload pub fn name(&self) -> Option<&Cow<'static, str>> { - self.info.name() + self.internal.name() + } + + /// Returns `true` if the function is [overloaded]. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::IntoFunction; + /// let add = (|a: i32, b: i32| a + b).into_function(); + /// assert!(!add.is_overloaded()); + /// + /// let add = add.with_overload(|a: f32, b: f32| a + b); + /// assert!(add.is_overloaded()); + /// ``` + /// + /// [overloaded]: Self::with_overload + pub fn is_overloaded(&self) -> bool { + self.internal.is_overloaded() + } + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will contain the full set of counts for all signatures. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::IntoFunction; + /// let add = (|a: i32, b: i32| a + b).into_function(); + /// assert!(add.arg_count().contains(2)); + /// + /// let add = add.with_overload(|a: f32, b: f32, c: f32| a + b + c); + /// assert!(add.arg_count().contains(2)); + /// assert!(add.arg_count().contains(3)); + /// ``` + /// + /// [overloaded]: Self::with_overload + pub fn arg_count(&self) -> ArgCount { + self.internal.arg_count() } } impl Function for DynamicFunction<'static> { + fn name(&self) -> Option<&Cow<'static, str>> { + self.internal.name() + } + fn info(&self) -> &FunctionInfo { - self.info() + self.internal.info() } fn reflect_call<'a>(&self, args: ArgList<'a>) -> FunctionResult<'a> { @@ -258,32 +447,14 @@ impl_type_path!((in bevy_reflect) DynamicFunction<'env>); /// This takes the format: `DynamicFunction(fn {name}({arg1}: {type1}, {arg2}: {type2}, ...) -> {return_type})`. /// /// Names for arguments and the function itself are optional and will default to `_` if not provided. +/// +/// If the function is [overloaded], the output will include the signatures of all overloads as a set. +/// For example, `DynamicFunction(fn add{(_: i32, _: i32) -> i32, (_: f32, _: f32) -> f32})`. +/// +/// [overloaded]: DynamicFunction::with_overload impl<'env> Debug for DynamicFunction<'env> { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - let name = self.info.name().unwrap_or(&Cow::Borrowed("_")); - write!(f, "DynamicFunction(fn {name}(")?; - - for (index, arg) in self.info.args().iter().enumerate() { - let name = arg.name().unwrap_or("_"); - let ty = arg.type_path(); - write!(f, "{name}: {ty}")?; - - if index + 1 < self.info.args().len() { - write!(f, ", ")?; - } - } - - let ret = self.info.return_info().type_path(); - write!(f, ") -> {ret})") - } -} - -impl<'env> Clone for DynamicFunction<'env> { - fn clone(&self) -> Self { - Self { - info: self.info.clone(), - func: Arc::clone(&self.func), - } + write!(f, "DynamicFunction({:?})", &self.internal) } } @@ -304,15 +475,20 @@ impl<'env> IntoFunctionMut<'env, ()> for DynamicFunction<'env> { #[cfg(test)] mod tests { use super::*; - use crate::func::IntoReturn; + use crate::func::signature::ArgumentSignature; + use crate::func::{FunctionError, IntoReturn, SignatureInfo}; + use crate::Type; + use bevy_utils::HashSet; + use core::ops::Add; #[test] fn should_overwrite_function_name() { let c = 23; - let func = (|a: i32, b: i32| a + b + c) - .into_function() - .with_name("my_function"); - assert_eq!(func.info().name().unwrap(), "my_function"); + let func = (|a: i32, b: i32| a + b + c).into_function(); + assert!(func.name().is_none()); + + let func = func.with_name("my_function"); + assert_eq!(func.name().unwrap(), "my_function"); } #[test] @@ -332,13 +508,40 @@ mod tests { let args = ArgList::default().push_owned(25_i32); let error = func.call(args).unwrap_err(); - assert!(matches!( + + assert_eq!( error, FunctionError::ArgCountMismatch { - expected: 2, + expected: ArgCount::new(2).unwrap(), received: 1 } - )); + ); + } + + #[test] + fn should_return_error_on_arg_count_mismatch_overloaded() { + let func = (|a: i32, b: i32| a + b) + .into_function() + .with_overload(|a: i32, b: i32, c: i32| a + b + c); + + let args = ArgList::default() + .push_owned(1_i32) + .push_owned(2_i32) + .push_owned(3_i32) + .push_owned(4_i32); + + let error = func.call(args).unwrap_err(); + + let mut expected_count = ArgCount::new(2).unwrap(); + expected_count.add(3); + + assert_eq!( + error, + FunctionError::ArgCountMismatch { + expected: expected_count, + received: 4 + } + ); } #[test] @@ -400,7 +603,7 @@ mod tests { }, // The `FunctionInfo` doesn't really matter for this test // so we can just give it dummy information. - FunctionInfo::anonymous() + SignatureInfo::anonymous() .with_arg::("curr") .with_arg::<()>("this"), ); @@ -409,4 +612,190 @@ mod tests { let value = factorial.call(args).unwrap().unwrap_owned(); assert_eq!(value.try_take::().unwrap(), 120); } + + #[test] + fn should_allow_creating_manual_generic_dynamic_function() { + let func = DynamicFunction::new( + |mut args| { + let a = args.take_arg()?; + let b = args.take_arg()?; + + if a.is::() { + let a = a.take::()?; + let b = b.take::()?; + Ok((a + b).into_return()) + } else { + let a = a.take::()?; + let b = b.take::()?; + Ok((a + b).into_return()) + } + }, + vec![ + SignatureInfo::named("add::") + .with_arg::("a") + .with_arg::("b") + .with_return::(), + SignatureInfo::named("add::") + .with_arg::("a") + .with_arg::("b") + .with_return::(), + ], + ); + + assert_eq!(func.name().unwrap(), "add::"); + let func = func.with_name("add"); + assert_eq!(func.name().unwrap(), "add"); + + let args = ArgList::default().push_owned(25_i32).push_owned(75_i32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100); + + let args = ArgList::default().push_owned(25.0_f32).push_owned(75.0_f32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100.0); + } + + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: MissingSignature")] + fn should_panic_on_missing_function_info() { + let _ = DynamicFunction::new(|_| Ok(().into_return()), Vec::new()); + } + + #[test] + fn should_allow_function_overloading() { + fn add>(a: T, b: T) -> T { + a + b + } + + let func = add::.into_function().with_overload(add::); + + let args = ArgList::default().push_owned(25_i32).push_owned(75_i32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100); + + let args = ArgList::default().push_owned(25.0_f32).push_owned(75.0_f32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100.0); + } + + #[test] + fn should_allow_variable_arguments_via_overloading() { + fn add_2(a: i32, b: i32) -> i32 { + a + b + } + + fn add_3(a: i32, b: i32, c: i32) -> i32 { + a + b + c + } + + let func = add_2.into_function().with_overload(add_3); + + let args = ArgList::default().push_owned(25_i32).push_owned(75_i32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100); + + let args = ArgList::default() + .push_owned(25_i32) + .push_owned(75_i32) + .push_owned(100_i32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 200); + } + + #[test] + fn should_allow_function_overloading_with_manual_overload() { + let manual = DynamicFunction::new( + |mut args| { + let a = args.take_arg()?; + let b = args.take_arg()?; + + if a.is::() { + let a = a.take::()?; + let b = b.take::()?; + Ok((a + b).into_return()) + } else { + let a = a.take::()?; + let b = b.take::()?; + Ok((a + b).into_return()) + } + }, + vec![ + SignatureInfo::named("add::") + .with_arg::("a") + .with_arg::("b") + .with_return::(), + SignatureInfo::named("add::") + .with_arg::("a") + .with_arg::("b") + .with_return::(), + ], + ); + + let func = manual.with_overload(|a: u32, b: u32| a + b); + + let args = ArgList::default().push_owned(25_i32).push_owned(75_i32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100); + + let args = ArgList::default().push_owned(25_u32).push_owned(75_u32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100); + } + + #[test] + fn should_return_error_on_unknown_overload() { + fn add>(a: T, b: T) -> T { + a + b + } + + let func = add::.into_function().with_overload(add::); + + let args = ArgList::default().push_owned(25_u32).push_owned(75_u32); + let result = func.call(args); + assert_eq!( + result.unwrap_err(), + FunctionError::NoOverload { + expected: HashSet::from([ + ArgumentSignature::from_iter(vec![Type::of::(), Type::of::()]), + ArgumentSignature::from_iter(vec![Type::of::(), Type::of::()]) + ]), + received: ArgumentSignature::from_iter(vec![Type::of::(), Type::of::()]), + } + ); + } + + #[test] + fn should_debug_dynamic_function() { + fn greet(name: &String) -> String { + format!("Hello, {}!", name) + } + + let function = greet.into_function(); + let debug = format!("{:?}", function); + assert_eq!(debug, "DynamicFunction(fn bevy_reflect::func::dynamic_function::tests::should_debug_dynamic_function::greet(_: &alloc::string::String) -> alloc::string::String)"); + } + + #[test] + fn should_debug_anonymous_dynamic_function() { + let function = (|a: i32, b: i32| a + b).into_function(); + let debug = format!("{:?}", function); + assert_eq!(debug, "DynamicFunction(fn _(_: i32, _: i32) -> i32)"); + } + + #[test] + fn should_debug_overloaded_dynamic_function() { + fn add>(a: T, b: T) -> T { + a + b + } + + let func = add:: + .into_function() + .with_overload(add::) + .with_name("add"); + let debug = format!("{:?}", func); + assert_eq!( + debug, + "DynamicFunction(fn add{(_: i32, _: i32) -> i32, (_: f32, _: f32) -> f32})" + ); + } } diff --git a/crates/bevy_reflect/src/func/dynamic_function_internal.rs b/crates/bevy_reflect/src/func/dynamic_function_internal.rs new file mode 100644 index 0000000000000..33907ad799ccd --- /dev/null +++ b/crates/bevy_reflect/src/func/dynamic_function_internal.rs @@ -0,0 +1,378 @@ +use crate::func::args::ArgCount; +use crate::func::signature::{ArgListSignature, ArgumentSignature}; +use crate::func::{ArgList, FunctionError, FunctionInfo, FunctionOverloadError}; +use alloc::borrow::Cow; +use bevy_utils::hashbrown::HashMap; +use core::fmt::{Debug, Formatter}; + +/// An internal structure for storing a function and its corresponding [function information]. +/// +/// This is used to facilitate the sharing of functionality between [`DynamicFunction`] +/// and [`DynamicFunctionMut`]. +/// +/// [function information]: FunctionInfo +/// [`DynamicFunction`]: crate::func::DynamicFunction +/// [`DynamicFunctionMut`]: crate::func::DynamicFunctionMut +#[derive(Clone)] +pub(super) struct DynamicFunctionInternal { + functions: Vec, + info: FunctionInfo, + arg_map: HashMap, +} + +impl DynamicFunctionInternal { + /// Create a new instance of [`DynamicFunctionInternal`] with the given function + /// and its corresponding information. + pub fn new(func: F, info: FunctionInfo) -> Self { + let arg_map = info + .signatures() + .iter() + .map(|sig| (ArgumentSignature::from(sig), 0)) + .collect(); + + Self { + functions: vec![func], + info, + arg_map, + } + } + pub fn with_name(mut self, name: impl Into>) -> Self { + self.info = self.info.with_name(Some(name.into())); + self + } + + /// The name of the function. + pub fn name(&self) -> Option<&Cow<'static, str>> { + self.info.name() + } + + /// Returns `true` if the function is overloaded. + pub fn is_overloaded(&self) -> bool { + self.info.is_overloaded() + } + + /// Get an immutable reference to the function. + /// + /// If the function is not overloaded, it will always be returned regardless of the arguments. + /// Otherwise, the function will be selected based on the arguments provided. + /// + /// If no overload matches the provided arguments, returns [`FunctionError::NoOverload`]. + pub fn get(&self, args: &ArgList) -> Result<&F, FunctionError> { + if !self.info.is_overloaded() { + return Ok(&self.functions[0]); + } + + let signature = ArgListSignature::from(args); + self.arg_map + .get(&signature) + .map(|index| &self.functions[*index]) + .ok_or_else(|| FunctionError::NoOverload { + expected: self.arg_map.keys().cloned().collect(), + received: ArgumentSignature::from(args), + }) + } + + /// Get a mutable reference to the function. + /// + /// If the function is not overloaded, it will always be returned regardless of the arguments. + /// Otherwise, the function will be selected based on the arguments provided. + /// + /// If no overload matches the provided arguments, returns [`FunctionError::NoOverload`]. + pub fn get_mut(&mut self, args: &ArgList) -> Result<&mut F, FunctionError> { + if !self.info.is_overloaded() { + return Ok(&mut self.functions[0]); + } + + let signature = ArgListSignature::from(args); + self.arg_map + .get(&signature) + .map(|index| &mut self.functions[*index]) + .ok_or_else(|| FunctionError::NoOverload { + expected: self.arg_map.keys().cloned().collect(), + received: ArgumentSignature::from(args), + }) + } + + /// Returns the function information contained in the map. + #[inline] + pub fn info(&self) -> &FunctionInfo { + &self.info + } + + /// Returns the number of arguments the function expects. + /// + /// For overloaded functions that can have a variable number of arguments, + /// this will contain the full set of counts for all signatures. + pub fn arg_count(&self) -> ArgCount { + self.info.arg_count() + } + + /// Helper method for validating that a given set of arguments are _potentially_ valid for this function. + /// + /// Currently, this validates: + /// - The number of arguments is within the expected range + pub fn validate_args(&self, args: &ArgList) -> Result<(), FunctionError> { + let expected_arg_count = self.arg_count(); + let received_arg_count = args.len(); + + if !expected_arg_count.contains(received_arg_count) { + Err(FunctionError::ArgCountMismatch { + expected: expected_arg_count, + received: received_arg_count, + }) + } else { + Ok(()) + } + } + + /// Merge another [`DynamicFunctionInternal`] into this one. + /// + /// If `other` contains any functions with the same signature as this one, + /// an error will be returned along with the original, unchanged instance. + /// + /// Therefore, this method should always return an overloaded function if the merge is successful. + /// + /// Additionally, if the merge succeeds, it should be guaranteed that the order + /// of the functions in the map will be preserved. + /// For example, merging `[func_a, func_b]` (self) with `[func_c, func_d]` (other) should result in + /// `[func_a, func_b, func_c, func_d]`. + /// And merging `[func_c, func_d]` (self) with `[func_a, func_b]` (other) should result in + /// `[func_c, func_d, func_a, func_b]`. + pub fn merge(&mut self, mut other: Self) -> Result<(), FunctionOverloadError> { + // Keep a separate map of the new indices to avoid mutating the existing one + // until we can be sure the merge will be successful. + let mut new_signatures = HashMap::new(); + + for (sig, index) in other.arg_map { + if self.arg_map.contains_key(&sig) { + return Err(FunctionOverloadError::DuplicateSignature(sig)); + } + + new_signatures.insert_unique_unchecked(sig, self.functions.len() + index); + } + + self.arg_map.reserve(new_signatures.len()); + for (sig, index) in new_signatures { + self.arg_map.insert_unique_unchecked(sig, index); + } + + self.functions.append(&mut other.functions); + self.info.extend_unchecked(other.info); + + Ok(()) + } + + /// Maps the internally stored function(s) from type `F` to type `G`. + pub fn map_functions(self, f: fn(F) -> G) -> DynamicFunctionInternal { + DynamicFunctionInternal { + functions: self.functions.into_iter().map(f).collect(), + info: self.info, + arg_map: self.arg_map, + } + } +} + +impl Debug for DynamicFunctionInternal { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + self.info + .pretty_printer() + .include_fn_token() + .include_name() + .fmt(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::func::{FunctionInfo, SignatureInfo}; + use crate::Type; + + #[test] + fn should_merge_single_into_single() { + let mut func_a = DynamicFunctionInternal::new( + 'a', + FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")), + ); + + let func_b = DynamicFunctionInternal::new( + 'b', + FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")), + ); + + func_a.merge(func_b).unwrap(); + + assert_eq!(func_a.functions, vec!['a', 'b']); + assert_eq!(func_a.info.signatures().len(), 2); + assert_eq!( + func_a.arg_map, + HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + ]) + ); + } + + #[test] + fn should_merge_single_into_overloaded() { + let mut func_a = DynamicFunctionInternal::new( + 'a', + FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")), + ); + + let func_b = DynamicFunctionInternal { + functions: vec!['b', 'c'], + info: FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")) + .with_overload(SignatureInfo::anonymous().with_arg::("arg0")) + .unwrap(), + arg_map: HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + ]), + }; + + func_a.merge(func_b).unwrap(); + + assert_eq!(func_a.functions, vec!['a', 'b', 'c']); + assert_eq!(func_a.info.signatures().len(), 3); + assert_eq!( + func_a.arg_map, + HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + ]) + ); + } + + #[test] + fn should_merge_overloaed_into_single() { + let mut func_a = DynamicFunctionInternal { + functions: vec!['a', 'b'], + info: FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")) + .with_overload(SignatureInfo::anonymous().with_arg::("arg0")) + .unwrap(), + arg_map: HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + ]), + }; + + let func_b = DynamicFunctionInternal::new( + 'c', + FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")), + ); + + func_a.merge(func_b).unwrap(); + + assert_eq!(func_a.functions, vec!['a', 'b', 'c']); + assert_eq!(func_a.info.signatures().len(), 3); + assert_eq!( + func_a.arg_map, + HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + ]) + ); + } + + #[test] + fn should_merge_overloaded_into_overloaded() { + let mut func_a = DynamicFunctionInternal { + functions: vec!['a', 'b'], + info: FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")) + .with_overload(SignatureInfo::anonymous().with_arg::("arg0")) + .unwrap(), + arg_map: HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + ]), + }; + + let func_b = DynamicFunctionInternal { + functions: vec!['c', 'd'], + info: FunctionInfo::new(SignatureInfo::anonymous().with_arg::("arg0")) + .with_overload(SignatureInfo::anonymous().with_arg::("arg0")) + .unwrap(), + arg_map: HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + ]), + }; + + func_a.merge(func_b).unwrap(); + + assert_eq!(func_a.functions, vec!['a', 'b', 'c', 'd']); + assert_eq!(func_a.info.signatures().len(), 4); + assert_eq!( + func_a.arg_map, + HashMap::from_iter([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + (ArgumentSignature::from_iter([Type::of::()]), 3), + ]) + ); + } + + #[test] + fn should_return_error_on_duplicate_signature() { + let mut func_a = DynamicFunctionInternal::new( + 'a', + FunctionInfo::new( + SignatureInfo::anonymous() + .with_arg::("arg0") + .with_arg::("arg1"), + ), + ); + + let func_b = DynamicFunctionInternal { + functions: vec!['b', 'c'], + info: FunctionInfo::new( + SignatureInfo::anonymous() + .with_arg::("arg0") + .with_arg::("arg1"), + ) + .with_overload( + SignatureInfo::anonymous() + .with_arg::("arg0") + .with_arg::("arg1"), + ) + .unwrap(), + arg_map: HashMap::from_iter([ + ( + ArgumentSignature::from_iter([Type::of::(), Type::of::()]), + 0, + ), + ( + ArgumentSignature::from_iter([Type::of::(), Type::of::()]), + 1, + ), + ]), + }; + + let FunctionOverloadError::DuplicateSignature(duplicate) = + func_a.merge(func_b).unwrap_err() + else { + panic!("Expected `FunctionOverloadError::DuplicateSignature`"); + }; + + assert_eq!( + duplicate, + ArgumentSignature::from_iter([Type::of::(), Type::of::()]) + ); + + // Assert the original remains unchanged: + assert!(!func_a.is_overloaded()); + assert_eq!(func_a.functions, vec!['a']); + assert_eq!(func_a.info.signatures().len(), 1); + assert_eq!( + func_a.arg_map, + HashMap::from_iter([( + ArgumentSignature::from_iter([Type::of::(), Type::of::()]), + 0 + ),]) + ); + } +} diff --git a/crates/bevy_reflect/src/func/dynamic_function_mut.rs b/crates/bevy_reflect/src/func/dynamic_function_mut.rs index 46341ee553124..bcbf6a96337ab 100644 --- a/crates/bevy_reflect/src/func/dynamic_function_mut.rs +++ b/crates/bevy_reflect/src/func/dynamic_function_mut.rs @@ -1,14 +1,18 @@ -use alloc::{borrow::Cow, boxed::Box}; +use alloc::{borrow::Cow, boxed::Box, sync::Arc}; use core::fmt::{Debug, Formatter}; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; use crate::func::{ - args::ArgList, info::FunctionInfo, DynamicFunction, FunctionError, FunctionResult, - IntoFunctionMut, + args::{ArgCount, ArgList}, + dynamic_function_internal::DynamicFunctionInternal, + DynamicFunction, FunctionInfo, FunctionOverloadError, FunctionResult, IntoFunctionMut, }; +/// A [`Box`] containing a callback to a reflected function. +type BoxFnMut<'env> = Box FnMut(ArgList<'a>) -> FunctionResult<'a> + 'env>; + /// A dynamic representation of a function. /// /// This type can be used to represent any callable that satisfies [`FnMut`] @@ -66,8 +70,7 @@ use crate::func::{ /// [`ReflectFnMut`]: crate::func::ReflectFnMut /// [module-level documentation]: crate::func pub struct DynamicFunctionMut<'env> { - info: FunctionInfo, - func: Box FnMut(ArgList<'a>) -> FunctionResult<'a> + 'env>, + internal: DynamicFunctionInternal>, } impl<'env> DynamicFunctionMut<'env> { @@ -76,17 +79,26 @@ impl<'env> DynamicFunctionMut<'env> { /// The given function can be used to call out to any other callable, /// including functions, closures, or methods. /// - /// It's important that the function signature matches the provided [`FunctionInfo`] + /// It's important that the function signature matches the provided [`FunctionInfo`]. /// as this will be used to validate arguments when [calling] the function. + /// This is also required in order for [function overloading] to work correctly. + /// + /// # Panics + /// + /// This function may panic for any of the following reasons: + /// - No [`SignatureInfo`] is provided. + /// - A provided [`SignatureInfo`] has more arguments than [`ArgCount::MAX_COUNT`]. + /// - The conversion to [`FunctionInfo`] fails. /// - /// [calling]: DynamicFunctionMut::call + /// [calling]: crate::func::dynamic_function_mut::DynamicFunctionMut::call + /// [`SignatureInfo`]: crate::func::SignatureInfo + /// [function overloading]: Self::with_overload pub fn new FnMut(ArgList<'a>) -> FunctionResult<'a> + 'env>( func: F, - info: FunctionInfo, + info: impl TryInto, ) -> Self { Self { - info, - func: Box::new(func), + internal: DynamicFunctionInternal::new(Box::new(func), info.try_into().unwrap()), } } @@ -99,10 +111,99 @@ impl<'env> DynamicFunctionMut<'env> { /// /// [`DynamicFunctionMuts`]: DynamicFunctionMut pub fn with_name(mut self, name: impl Into>) -> Self { - self.info = self.info.with_name(name); + self.internal = self.internal.with_name(name); self } + /// Add an overload to this function. + /// + /// Overloads allow a single [`DynamicFunctionMut`] to represent multiple functions of different signatures. + /// + /// This can be used to handle multiple monomorphizations of a generic function + /// or to allow functions with a variable number of arguments. + /// + /// Any functions with the same [argument signature] will be overwritten by the one from the new function, `F`. + /// For example, if the existing function had the signature `(i32, i32) -> i32`, + /// and the new function, `F`, also had the signature `(i32, i32) -> i32`, + /// the one from `F` would replace the one from the existing function. + /// + /// Overloaded functions retain the [name] of the original function. + /// + /// Note that it may be impossible to overload closures that mutably borrow from their environment + /// due to Rust's borrowing rules. + /// However, it's still possible to overload functions that do not capture their environment mutably, + /// or those that maintain mutually exclusive mutable references to their environment. + /// + /// # Panics + /// + /// Panics if the function, `F`, contains a signature already found in this function. + /// + /// For a non-panicking version, see [`try_with_overload`]. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::IntoFunctionMut; + /// let mut total_i32 = 0; + /// let mut add_i32 = |a: i32| total_i32 += a; + /// + /// let mut total_f32 = 0.0; + /// let mut add_f32 = |a: f32| total_f32 += a; + /// + /// // Currently, the only generic type `func` supports is `i32`. + /// let mut func = add_i32.into_function_mut(); + /// + /// // However, we can add an overload to handle `f32` as well: + /// func = func.with_overload(add_f32); + /// + /// // Test `i32`: + /// let args = bevy_reflect::func::ArgList::new().push_owned(123_i32); + /// func.call(args).unwrap(); + /// + /// // Test `f32`: + /// let args = bevy_reflect::func::ArgList::new().push_owned(1.23_f32); + /// func.call(args).unwrap(); + /// + /// drop(func); + /// assert_eq!(total_i32, 123); + /// assert_eq!(total_f32, 1.23); + /// ``` + /// + /// [argument signature]: crate::func::signature::ArgumentSignature + /// [name]: Self::name + /// [`try_with_overload`]: Self::try_with_overload + pub fn with_overload<'a, F: IntoFunctionMut<'a, Marker>, Marker>( + self, + function: F, + ) -> DynamicFunctionMut<'a> + where + 'env: 'a, + { + self.try_with_overload(function).unwrap_or_else(|(_, err)| { + panic!("{}", err); + }) + } + + /// Attempt to add an overload to this function. + /// + /// If the function, `F`, contains a signature already found in this function, + /// an error will be returned along with the original function. + /// + /// For a panicking version, see [`with_overload`]. + /// + /// [`with_overload`]: Self::with_overload + pub fn try_with_overload, Marker>( + mut self, + function: F, + ) -> Result, FunctionOverloadError)> { + let function = function.into_function_mut(); + + match self.internal.merge(function.internal) { + Ok(_) => Ok(self), + Err(err) => Err((Box::new(self), err)), + } + } + /// Call the function with the given arguments. /// /// Variables that are captured mutably by this function @@ -135,17 +236,9 @@ impl<'env> DynamicFunctionMut<'env> { /// /// [`call_once`]: DynamicFunctionMut::call_once pub fn call<'a>(&mut self, args: ArgList<'a>) -> FunctionResult<'a> { - let expected_arg_count = self.info.arg_count(); - let received_arg_count = args.len(); - - if expected_arg_count != received_arg_count { - Err(FunctionError::ArgCountMismatch { - expected: expected_arg_count, - received: received_arg_count, - }) - } else { - (self.func)(args) - } + self.internal.validate_args(&args)?; + let func = self.internal.get_mut(&args)?; + func(args) } /// Call the function with the given arguments and consume it. @@ -177,25 +270,15 @@ impl<'env> DynamicFunctionMut<'env> { /// /// The function itself may also return any errors it needs to. pub fn call_once(mut self, args: ArgList) -> FunctionResult { - let expected_arg_count = self.info.arg_count(); - let received_arg_count = args.len(); - - if expected_arg_count != received_arg_count { - Err(FunctionError::ArgCountMismatch { - expected: expected_arg_count, - received: received_arg_count, - }) - } else { - (self.func)(args) - } + self.call(args) } /// Returns the function info. pub fn info(&self) -> &FunctionInfo { - &self.info + self.internal.info() } - /// The [name] of the function. + /// The name of the function. /// /// For [`DynamicFunctionMuts`] created using [`IntoFunctionMut`], /// the default name will always be the full path to the function as returned by [`core::any::type_name`], @@ -204,11 +287,52 @@ impl<'env> DynamicFunctionMut<'env> { /// /// This can be overridden using [`with_name`]. /// - /// [name]: FunctionInfo::name /// [`DynamicFunctionMuts`]: DynamicFunctionMut /// [`with_name`]: Self::with_name pub fn name(&self) -> Option<&Cow<'static, str>> { - self.info.name() + self.internal.name() + } + + /// Returns `true` if the function is [overloaded]. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::IntoFunctionMut; + /// let mut total_i32 = 0; + /// let increment = (|value: i32| total_i32 += value).into_function_mut(); + /// assert!(!increment.is_overloaded()); + /// + /// let mut total_f32 = 0.0; + /// let increment = increment.with_overload(|value: f32| total_f32 += value); + /// assert!(increment.is_overloaded()); + /// ``` + /// + /// [overloaded]: Self::with_overload + pub fn is_overloaded(&self) -> bool { + self.internal.is_overloaded() + } + + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will contain the full set of counts for all signatures. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::IntoFunctionMut; + /// let add = (|a: i32, b: i32| a + b).into_function_mut(); + /// assert!(add.arg_count().contains(2)); + /// + /// let add = add.with_overload(|a: f32, b: f32, c: f32| a + b + c); + /// assert!(add.arg_count().contains(2)); + /// assert!(add.arg_count().contains(3)); + /// ``` + /// + /// [overloaded]: Self::with_overload + pub fn arg_count(&self) -> ArgCount { + self.internal.arg_count() } } @@ -217,23 +341,14 @@ impl<'env> DynamicFunctionMut<'env> { /// This takes the format: `DynamicFunctionMut(fn {name}({arg1}: {type1}, {arg2}: {type2}, ...) -> {return_type})`. /// /// Names for arguments and the function itself are optional and will default to `_` if not provided. +/// +/// If the function is [overloaded], the output will include the signatures of all overloads as a set. +/// For example, `DynamicFunctionMut(fn add{(_: i32, _: i32) -> i32, (_: f32, _: f32) -> f32})`. +/// +/// [overloaded]: DynamicFunctionMut::with_overload impl<'env> Debug for DynamicFunctionMut<'env> { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - let name = self.info.name().unwrap_or(&Cow::Borrowed("_")); - write!(f, "DynamicFunctionMut(fn {name}(")?; - - for (index, arg) in self.info.args().iter().enumerate() { - let name = arg.name().unwrap_or("_"); - let ty = arg.type_path(); - write!(f, "{name}: {ty}")?; - - if index + 1 < self.info.args().len() { - write!(f, ", ")?; - } - } - - let ret = self.info.return_info().type_path(); - write!(f, ") -> {ret})") + write!(f, "DynamicFunctionMut({:?})", &self.internal) } } @@ -241,12 +356,20 @@ impl<'env> From> for DynamicFunctionMut<'env> { #[inline] fn from(function: DynamicFunction<'env>) -> Self { Self { - info: function.info, - func: Box::new(move |args| (function.func)(args)), + internal: function.internal.map_functions(arc_to_box), } } } +/// Helper function from converting an [`Arc`] function to a [`Box`] function. +/// +/// This is needed to help the compiler infer the correct types. +fn arc_to_box<'env>( + f: Arc Fn(ArgList<'a>) -> FunctionResult<'a> + Send + Sync + 'env>, +) -> BoxFnMut<'env> { + Box::new(move |args| f(args)) +} + impl<'env> IntoFunctionMut<'env, ()> for DynamicFunctionMut<'env> { #[inline] fn into_function_mut(self) -> DynamicFunctionMut<'env> { @@ -257,14 +380,17 @@ impl<'env> IntoFunctionMut<'env, ()> for DynamicFunctionMut<'env> { #[cfg(test)] mod tests { use super::*; + use crate::func::{FunctionError, IntoReturn, SignatureInfo}; + use core::ops::Add; #[test] fn should_overwrite_function_name() { let mut total = 0; - let func = (|a: i32, b: i32| total = a + b) - .into_function_mut() - .with_name("my_function"); - assert_eq!(func.info().name().unwrap(), "my_function"); + let func = (|a: i32, b: i32| total = a + b).into_function_mut(); + assert!(func.name().is_none()); + + let func = func.with_name("my_function"); + assert_eq!(func.name().unwrap(), "my_function"); } #[test] @@ -285,22 +411,79 @@ mod tests { let args = ArgList::default().push_owned(25_i32); let error = func.call(args).unwrap_err(); - assert!(matches!( + assert_eq!( error, FunctionError::ArgCountMismatch { - expected: 2, + expected: ArgCount::new(2).unwrap(), received: 1 } - )); + ); let args = ArgList::default().push_owned(25_i32); let error = func.call_once(args).unwrap_err(); - assert!(matches!( + assert_eq!( error, FunctionError::ArgCountMismatch { - expected: 2, + expected: ArgCount::new(2).unwrap(), received: 1 } - )); + ); + } + + #[test] + fn should_allow_creating_manual_generic_dynamic_function_mut() { + let mut total = 0_i32; + let func = DynamicFunctionMut::new( + |mut args| { + let value = args.take_arg()?; + + if value.is::() { + let value = value.take::()?; + total += value; + } else { + let value = value.take::()?; + total += value as i32; + } + + Ok(().into_return()) + }, + vec![ + SignatureInfo::named("add::").with_arg::("value"), + SignatureInfo::named("add::").with_arg::("value"), + ], + ); + + assert_eq!(func.name().unwrap(), "add::"); + let mut func = func.with_name("add"); + assert_eq!(func.name().unwrap(), "add"); + + let args = ArgList::default().push_owned(25_i32); + func.call(args).unwrap(); + let args = ArgList::default().push_owned(75_i16); + func.call(args).unwrap(); + + drop(func); + assert_eq!(total, 100); + } + + // Closures that mutably borrow from their environment cannot realistically + // be overloaded since that would break Rust's borrowing rules. + // However, we still need to verify overloaded functions work since a + // `DynamicFunctionMut` can also be made from a non-mutably borrowing closure/function. + #[test] + fn should_allow_function_overloading() { + fn add>(a: T, b: T) -> T { + a + b + } + + let mut func = add::.into_function_mut().with_overload(add::); + + let args = ArgList::default().push_owned(25_i32).push_owned(75_i32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100); + + let args = ArgList::default().push_owned(25.0_f32).push_owned(75.0_f32); + let result = func.call(args).unwrap().unwrap_owned(); + assert_eq!(result.try_take::().unwrap(), 100.0); } } diff --git a/crates/bevy_reflect/src/func/error.rs b/crates/bevy_reflect/src/func/error.rs index 234cbb9e6a779..ad6796be19f1e 100644 --- a/crates/bevy_reflect/src/func/error.rs +++ b/crates/bevy_reflect/src/func/error.rs @@ -1,5 +1,10 @@ -use crate::func::{args::ArgError, Return}; +use crate::func::signature::ArgumentSignature; +use crate::func::{ + args::{ArgCount, ArgError}, + Return, +}; use alloc::borrow::Cow; +use bevy_utils::HashSet; use thiserror::Error; #[cfg(not(feature = "std"))] @@ -15,8 +20,14 @@ pub enum FunctionError { #[error(transparent)] ArgError(#[from] ArgError), /// The number of arguments provided does not match the expected number. - #[error("expected {expected} arguments but received {received}")] - ArgCountMismatch { expected: usize, received: usize }, + #[error("received {received} arguments but expected one of {expected:?}")] + ArgCountMismatch { expected: ArgCount, received: usize }, + /// No overload was found for the given set of arguments. + #[error("no overload found for arguments with signature `{received:?}`, expected one of `{expected:?}`")] + NoOverload { + expected: HashSet, + received: ArgumentSignature, + }, } /// The result of calling a [`DynamicFunction`] or [`DynamicFunctionMut`]. @@ -28,6 +39,25 @@ pub enum FunctionError { /// [`DynamicFunctionMut`]: crate::func::DynamicFunctionMut pub type FunctionResult<'a> = Result, FunctionError>; +/// An error that occurs when attempting to add a function overload. +#[derive(Debug, Error, PartialEq)] +pub enum FunctionOverloadError { + /// A [`SignatureInfo`] was expected, but none was found. + /// + /// [`SignatureInfo`]: crate::func::info::SignatureInfo + #[error("expected at least one `SignatureInfo` but found none")] + MissingSignature, + /// An error that occurs when attempting to add a function overload with a duplicate signature. + #[error("could not add function overload: duplicate found for signature `{0:?}`")] + DuplicateSignature(ArgumentSignature), + #[error( + "argument signature `{:?}` has too many arguments (max {})", + 0, + ArgCount::MAX_COUNT + )] + TooManyArguments(ArgumentSignature), +} + /// An error that occurs when registering a function into a [`FunctionRegistry`]. /// /// [`FunctionRegistry`]: crate::func::FunctionRegistry diff --git a/crates/bevy_reflect/src/func/function.rs b/crates/bevy_reflect/src/func/function.rs index face9f6466601..0d8e94ca95aff 100644 --- a/crates/bevy_reflect/src/func/function.rs +++ b/crates/bevy_reflect/src/func/function.rs @@ -1,5 +1,8 @@ use crate::{ - func::{ArgList, DynamicFunction, FunctionInfo, FunctionResult}, + func::{ + args::{ArgCount, ArgList}, + DynamicFunction, FunctionInfo, FunctionResult, + }, PartialReflect, }; use alloc::borrow::Cow; @@ -45,12 +48,15 @@ pub trait Function: PartialReflect + Debug { /// /// [`DynamicFunctions`]: crate::func::DynamicFunction /// [`IntoFunction`]: crate::func::IntoFunction - fn name(&self) -> Option<&Cow<'static, str>> { - self.info().name() - } + fn name(&self) -> Option<&Cow<'static, str>>; - /// The number of arguments this function accepts. - fn arg_count(&self) -> usize { + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will contain the full set of counts for all signatures. + /// + /// [overloaded]: crate::func#overloading-functions + fn arg_count(&self) -> ArgCount { self.info().arg_count() } diff --git a/crates/bevy_reflect/src/func/info.rs b/crates/bevy_reflect/src/func/info.rs index d8325c582d8a4..797b97e7e1bac 100644 --- a/crates/bevy_reflect/src/func/info.rs +++ b/crates/bevy_reflect/src/func/info.rs @@ -1,62 +1,274 @@ use alloc::{borrow::Cow, vec}; +use core::fmt::{Debug, Formatter}; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; -use variadics_please::all_tuples; - use crate::{ - func::args::{ArgInfo, GetOwnership, Ownership}, + func::args::{ArgCount, ArgCountOutOfBoundsError, ArgInfo, GetOwnership, Ownership}, + func::signature::ArgumentSignature, + func::FunctionOverloadError, type_info::impl_type_methods, Type, TypePath, }; +use variadics_please::all_tuples; + /// Type information for a [`DynamicFunction`] or [`DynamicFunctionMut`]. /// /// This information can be retrieved directly from certain functions and closures /// using the [`TypedFunction`] trait, and manually constructed otherwise. /// +/// It is compromised of one or more [`SignatureInfo`] structs, +/// allowing it to represent functions with multiple sets of arguments (i.e. "overloaded functions"). +/// /// [`DynamicFunction`]: crate::func::DynamicFunction /// [`DynamicFunctionMut`]: crate::func::DynamicFunctionMut #[derive(Debug, Clone)] pub struct FunctionInfo { name: Option>, - args: Vec, - return_info: ReturnInfo, + arg_count: ArgCount, + signatures: Box<[SignatureInfo]>, } impl FunctionInfo { - /// Create a new [`FunctionInfo`] for a function with the given name. + /// Create a new [`FunctionInfo`] for a function with the given signature. + /// + /// # Panics + /// + /// Panics if the given signature has more than the maximum number of arguments + /// as specified by [`ArgCount::MAX_COUNT`]. + pub fn new(signature: SignatureInfo) -> Self { + Self { + name: signature.name.clone(), + arg_count: ArgCount::new(signature.arg_count()).unwrap(), + signatures: vec![signature].into(), + } + } + + /// Create a new [`FunctionInfo`] from a set of signatures. + /// + /// Returns an error if the given iterator is empty or contains duplicate signatures. + pub fn try_from_iter( + signatures: impl IntoIterator, + ) -> Result { + let mut iter = signatures.into_iter(); + + let base = iter.next().ok_or(FunctionOverloadError::MissingSignature)?; + + if base.arg_count() > ArgCount::MAX_COUNT { + return Err(FunctionOverloadError::TooManyArguments( + ArgumentSignature::from(&base), + )); + } + + let mut info = Self::new(base); + + for signature in iter { + if signature.arg_count() > ArgCount::MAX_COUNT { + return Err(FunctionOverloadError::TooManyArguments( + ArgumentSignature::from(&signature), + )); + } + + info = info.with_overload(signature).map_err(|sig| { + FunctionOverloadError::DuplicateSignature(ArgumentSignature::from(&sig)) + })?; + } + + Ok(info) + } + + /// The base signature for this function. + /// + /// All functions—including overloaded functions—are guaranteed to have at least one signature. + /// The first signature used to define the [`FunctionInfo`] is considered the base signature. + pub fn base(&self) -> &SignatureInfo { + &self.signatures[0] + } + + /// Whether this function is overloaded. + /// + /// This is determined by the existence of multiple signatures. + pub fn is_overloaded(&self) -> bool { + self.signatures.len() > 1 + } + + /// Set the name of the function. + pub fn with_name(mut self, name: Option>>) -> Self { + self.name = name.map(Into::into); + self + } + + /// The name of the function. + /// + /// For [`DynamicFunctions`] created using [`IntoFunction`] or [`DynamicFunctionMuts`] created using [`IntoFunctionMut`], + /// the default name will always be the full path to the function as returned by [`std::any::type_name`], + /// unless the function is a closure, anonymous function, or function pointer, + /// in which case the name will be `None`. + /// + /// For overloaded functions, this will be the name of the base signature, + /// unless manually overwritten using [`Self::with_name`]. + /// + /// [`DynamicFunctions`]: crate::func::DynamicFunction + /// [`IntoFunction`]: crate::func::IntoFunction + /// [`DynamicFunctionMuts`]: crate::func::DynamicFunctionMut + /// [`IntoFunctionMut`]: crate::func::IntoFunctionMut + pub fn name(&self) -> Option<&Cow<'static, str>> { + self.name.as_ref() + } + + /// Add a signature to this function. + /// + /// If a signature with the same [`ArgumentSignature`] already exists, + /// an error is returned with the given signature. + /// + /// # Panics + /// + /// Panics if the given signature has more than the maximum number of arguments + /// as specified by [`ArgCount::MAX_COUNT`]. + pub fn with_overload(mut self, signature: SignatureInfo) -> Result { + let is_duplicate = self.signatures.iter().any(|s| { + s.arg_count() == signature.arg_count() + && ArgumentSignature::from(s) == ArgumentSignature::from(&signature) + }); + + if is_duplicate { + return Err(signature); + } + + self.arg_count.add(signature.arg_count()); + self.signatures = IntoIterator::into_iter(self.signatures) + .chain(Some(signature)) + .collect(); + Ok(self) + } + + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will contain the full set of counts for all signatures. + /// + /// [overloaded]: crate::func#overloading-functions + pub fn arg_count(&self) -> ArgCount { + self.arg_count + } + + /// The signatures of the function. + /// + /// This is guaranteed to always contain at least one signature. + /// Overloaded functions will contain two or more. + pub fn signatures(&self) -> &[SignatureInfo] { + &self.signatures + } + + /// Returns a wrapper around this info that implements [`Debug`] for pretty-printing the function. + /// + /// This can be useful for more readable debugging and logging. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::{FunctionInfo, TypedFunction}; + /// # + /// fn add(a: i32, b: i32) -> i32 { + /// a + b + /// } + /// + /// let info = add.get_function_info(); + /// + /// let pretty = info.pretty_printer(); + /// assert_eq!(format!("{:?}", pretty), "(_: i32, _: i32) -> i32"); + /// ``` + pub fn pretty_printer(&self) -> PrettyPrintFunctionInfo { + PrettyPrintFunctionInfo::new(self) + } + + /// Extend this [`FunctionInfo`] with another without checking for duplicates. + /// + /// # Panics + /// + /// Panics if the given signature has more than the maximum number of arguments + /// as specified by [`ArgCount::MAX_COUNT`]. + pub(super) fn extend_unchecked(&mut self, other: FunctionInfo) { + if self.name.is_none() { + self.name = other.name; + } + + let signatures = core::mem::take(&mut self.signatures); + self.signatures = IntoIterator::into_iter(signatures) + .chain(IntoIterator::into_iter(other.signatures)) + .collect(); + self.arg_count = self + .signatures + .iter() + .fold(ArgCount::default(), |mut count, sig| { + count.add(sig.arg_count()); + count + }); + } +} + +impl TryFrom for FunctionInfo { + type Error = ArgCountOutOfBoundsError; + + fn try_from(signature: SignatureInfo) -> Result { + let count = signature.arg_count(); + if count > ArgCount::MAX_COUNT { + return Err(ArgCountOutOfBoundsError(count)); + } + + Ok(Self::new(signature)) + } +} + +impl TryFrom> for FunctionInfo { + type Error = FunctionOverloadError; + + fn try_from(signatures: Vec) -> Result { + Self::try_from_iter(signatures) + } +} + +impl TryFrom<[SignatureInfo; N]> for FunctionInfo { + type Error = FunctionOverloadError; + + fn try_from(signatures: [SignatureInfo; N]) -> Result { + Self::try_from_iter(signatures) + } +} + +#[derive(Debug, Clone)] +pub struct SignatureInfo { + name: Option>, + args: Box<[ArgInfo]>, + return_info: ReturnInfo, +} + +impl SignatureInfo { + /// Create a new [`SignatureInfo`] for a function with the given name. pub fn named(name: impl Into>) -> Self { Self { name: Some(name.into()), - args: Vec::new(), + args: Box::new([]), return_info: ReturnInfo::new::<()>(), } } - /// Create a new [`FunctionInfo`] with no name. + /// Create a new [`SignatureInfo`] with no name. /// /// For the purposes of debugging and [registration], - /// it's recommended to use [`FunctionInfo::named`] instead. + /// it's recommended to use [`Self::named`] instead. /// /// [registration]: crate::func::FunctionRegistry pub fn anonymous() -> Self { Self { name: None, - args: Vec::new(), + args: Box::new([]), return_info: ReturnInfo::new::<()>(), } } - /// Create a new [`FunctionInfo`] from the given function. - pub fn from(function: &F) -> Self - where - F: TypedFunction, - { - function.get_function_info() - } - /// Set the name of the function. pub fn with_name(mut self, name: impl Into>) -> Self { self.name = Some(name.into()); @@ -72,7 +284,9 @@ impl FunctionInfo { name: impl Into>, ) -> Self { let index = self.args.len(); - self.args.push(ArgInfo::new::(index).with_name(name)); + self.args = IntoIterator::into_iter(self.args) + .chain(Some(ArgInfo::new::(index).with_name(name))) + .collect(); self } @@ -83,7 +297,7 @@ impl FunctionInfo { /// It's preferable to use [`Self::with_arg`] to add arguments to the function /// as it will automatically set the index of the argument. pub fn with_args(mut self, args: Vec) -> Self { - self.args = args; + self.args = IntoIterator::into_iter(self.args).chain(args).collect(); self } @@ -167,6 +381,161 @@ impl ReturnInfo { } } +/// A wrapper around [`FunctionInfo`] that implements [`Debug`] for pretty-printing function information. +/// +/// # Example +/// +/// ``` +/// # use bevy_reflect::func::{FunctionInfo, PrettyPrintFunctionInfo, TypedFunction}; +/// # +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// +/// let info = add.get_function_info(); +/// +/// let pretty = PrettyPrintFunctionInfo::new(&info); +/// assert_eq!(format!("{:?}", pretty), "(_: i32, _: i32) -> i32"); +/// ``` +pub struct PrettyPrintFunctionInfo<'a> { + info: &'a FunctionInfo, + include_fn_token: bool, + include_name: bool, +} + +impl<'a> PrettyPrintFunctionInfo<'a> { + /// Create a new pretty-printer for the given [`FunctionInfo`]. + pub fn new(info: &'a FunctionInfo) -> Self { + Self { + info, + include_fn_token: false, + include_name: false, + } + } + + /// Include the function name in the pretty-printed output. + pub fn include_name(mut self) -> Self { + self.include_name = true; + self + } + + /// Include the `fn` token in the pretty-printed output. + pub fn include_fn_token(mut self) -> Self { + self.include_fn_token = true; + self + } +} + +impl<'a> Debug for PrettyPrintFunctionInfo<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + if self.include_fn_token { + write!(f, "fn")?; + + if self.include_name { + write!(f, " ")?; + } + } + + match (self.include_name, self.info.name()) { + (true, Some(name)) => write!(f, "{}", name)?, + (true, None) => write!(f, "_")?, + _ => {} + } + + if self.info.is_overloaded() { + // `{(arg0: i32, arg1: i32) -> (), (arg0: f32, arg1: f32) -> ()}` + let mut set = f.debug_set(); + for signature in self.info.signatures() { + set.entry(&PrettyPrintSignatureInfo::new(signature)); + } + set.finish() + } else { + // `(arg0: i32, arg1: i32) -> ()` + PrettyPrintSignatureInfo::new(self.info.base()).fmt(f) + } + } +} + +/// A wrapper around [`SignatureInfo`] that implements [`Debug`] for pretty-printing function signature information. +/// +/// # Example +/// +/// ``` +/// # use bevy_reflect::func::{FunctionInfo, PrettyPrintSignatureInfo, TypedFunction}; +/// # +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// +/// let info = add.get_function_info(); +/// +/// let pretty = PrettyPrintSignatureInfo::new(info.base()); +/// assert_eq!(format!("{:?}", pretty), "(_: i32, _: i32) -> i32"); +/// ``` +pub struct PrettyPrintSignatureInfo<'a> { + info: &'a SignatureInfo, + include_fn_token: bool, + include_name: bool, +} + +impl<'a> PrettyPrintSignatureInfo<'a> { + /// Create a new pretty-printer for the given [`SignatureInfo`]. + pub fn new(info: &'a SignatureInfo) -> Self { + Self { + info, + include_fn_token: false, + include_name: false, + } + } + + /// Include the function name in the pretty-printed output. + pub fn include_name(mut self) -> Self { + self.include_name = true; + self + } + + /// Include the `fn` token in the pretty-printed output. + pub fn include_fn_token(mut self) -> Self { + self.include_fn_token = true; + self + } +} + +impl<'a> Debug for PrettyPrintSignatureInfo<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + if self.include_fn_token { + write!(f, "fn")?; + + if self.include_name { + write!(f, " ")?; + } + } + + match (self.include_name, self.info.name()) { + (true, Some(name)) => write!(f, "{}", name)?, + (true, None) => write!(f, "_")?, + _ => {} + } + + write!(f, "(")?; + + // We manually write the args instead of using `DebugTuple` to avoid trailing commas + // and (when used with `{:#?}`) unnecessary newlines + for (index, arg) in self.info.args().iter().enumerate() { + if index > 0 { + write!(f, ", ")?; + } + + let name = arg.name().unwrap_or("_"); + let ty = arg.type_path(); + write!(f, "{name}: {ty}")?; + } + + let ret = self.info.return_info().type_path(); + write!(f, ") -> {ret}") + } +} + /// A static accessor to compile-time type information for functions. /// /// This is the equivalent of [`Typed`], but for function. @@ -194,7 +563,7 @@ impl ReturnInfo { /// # Example /// /// ``` -/// # use bevy_reflect::func::{ArgList, FunctionInfo, ReflectFnMut, TypedFunction}; +/// # use bevy_reflect::func::{ArgList, ReflectFnMut, TypedFunction}; /// # /// fn print(value: String) { /// println!("{}", value); @@ -202,9 +571,9 @@ impl ReturnInfo { /// /// let info = print.get_function_info(); /// assert!(info.name().unwrap().ends_with("print")); -/// assert_eq!(info.arg_count(), 1); -/// assert_eq!(info.args()[0].type_path(), "alloc::string::String"); -/// assert_eq!(info.return_info().type_path(), "()"); +/// assert!(info.arg_count().contains(1)); +/// assert_eq!(info.base().args()[0].type_path(), "alloc::string::String"); +/// assert_eq!(info.base().return_info().type_path(), "()"); /// ``` /// /// # Trait Parameters @@ -243,18 +612,20 @@ macro_rules! impl_typed_function { Function: FnMut($($Arg),*) -> ReturnType, { fn function_info() -> FunctionInfo { - create_info::() - .with_args({ - #[allow(unused_mut)] - let mut _index = 0; - vec![ - $(ArgInfo::new::<$Arg>({ - _index += 1; - _index - 1 - }),)* - ] - }) - .with_return_info(ReturnInfo::new::()) + FunctionInfo::new( + create_info::() + .with_args({ + #[allow(unused_mut)] + let mut _index = 0; + vec![ + $(ArgInfo::new::<$Arg>({ + _index += 1; + _index - 1 + }),)* + ] + }) + .with_return_info(ReturnInfo::new::()) + ) } } @@ -266,20 +637,22 @@ macro_rules! impl_typed_function { for<'a> &'a ReturnType: TypePath + GetOwnership, Function: for<'a> FnMut(&'a Receiver, $($Arg),*) -> &'a ReturnType, { - fn function_info() -> $crate::func::FunctionInfo { - create_info::() - .with_args({ - #[allow(unused_mut)] - let mut _index = 1; - vec![ - ArgInfo::new::<&Receiver>(0), - $($crate::func::args::ArgInfo::new::<$Arg>({ - _index += 1; - _index - 1 - }),)* - ] - }) - .with_return_info(ReturnInfo::new::<&ReturnType>()) + fn function_info() -> FunctionInfo { + FunctionInfo::new( + create_info::() + .with_args({ + #[allow(unused_mut)] + let mut _index = 1; + vec![ + ArgInfo::new::<&Receiver>(0), + $($crate::func::args::ArgInfo::new::<$Arg>({ + _index += 1; + _index - 1 + }),)* + ] + }) + .with_return_info(ReturnInfo::new::<&ReturnType>()) + ) } } @@ -292,19 +665,21 @@ macro_rules! impl_typed_function { Function: for<'a> FnMut(&'a mut Receiver, $($Arg),*) -> &'a mut ReturnType, { fn function_info() -> FunctionInfo { - create_info::() - .with_args({ - #[allow(unused_mut)] - let mut _index = 1; - vec![ - ArgInfo::new::<&mut Receiver>(0), - $(ArgInfo::new::<$Arg>({ - _index += 1; - _index - 1 - }),)* - ] - }) - .with_return_info(ReturnInfo::new::<&mut ReturnType>()) + FunctionInfo::new( + create_info::() + .with_args({ + #[allow(unused_mut)] + let mut _index = 1; + vec![ + ArgInfo::new::<&mut Receiver>(0), + $(ArgInfo::new::<$Arg>({ + _index += 1; + _index - 1 + }),)* + ] + }) + .with_return_info(ReturnInfo::new::<&mut ReturnType>()) + ) } } @@ -317,19 +692,21 @@ macro_rules! impl_typed_function { Function: for<'a> FnMut(&'a mut Receiver, $($Arg),*) -> &'a ReturnType, { fn function_info() -> FunctionInfo { - create_info::() - .with_args({ - #[allow(unused_mut)] - let mut _index = 1; - vec![ - ArgInfo::new::<&mut Receiver>(0), - $(ArgInfo::new::<$Arg>({ - _index += 1; - _index - 1 - }),)* - ] - }) - .with_return_info(ReturnInfo::new::<&ReturnType>()) + FunctionInfo::new( + create_info::() + .with_args({ + #[allow(unused_mut)] + let mut _index = 1; + vec![ + ArgInfo::new::<&mut Receiver>(0), + $(ArgInfo::new::<$Arg>({ + _index += 1; + _index - 1 + }),)* + ] + }) + .with_return_info(ReturnInfo::new::<&ReturnType>()) + ) } } }; @@ -355,13 +732,13 @@ all_tuples!(impl_typed_function, 0, 15, Arg, arg); /// | Function pointer | `fn() -> String` | `None` | /// /// [`type_name`]: core::any::type_name -fn create_info() -> FunctionInfo { +fn create_info() -> SignatureInfo { let name = core::any::type_name::(); if name.ends_with("{{closure}}") || name.starts_with("fn(") { - FunctionInfo::anonymous() + SignatureInfo::anonymous() } else { - FunctionInfo::named(name) + SignatureInfo::named(name) } } @@ -386,10 +763,10 @@ mod tests { info.name().unwrap(), "bevy_reflect::func::info::tests::should_create_function_info::add" ); - assert_eq!(info.arg_count(), 2); - assert_eq!(info.args()[0].type_path(), "i32"); - assert_eq!(info.args()[1].type_path(), "i32"); - assert_eq!(info.return_info().type_path(), "i32"); + assert_eq!(info.base().arg_count(), 2); + assert_eq!(info.base().args()[0].type_path(), "i32"); + assert_eq!(info.base().args()[1].type_path(), "i32"); + assert_eq!(info.base().return_info().type_path(), "i32"); } #[test] @@ -405,10 +782,10 @@ mod tests { let info = add.get_function_info(); assert!(info.name().is_none()); - assert_eq!(info.arg_count(), 2); - assert_eq!(info.args()[0].type_path(), "i32"); - assert_eq!(info.args()[1].type_path(), "i32"); - assert_eq!(info.return_info().type_path(), "i32"); + assert_eq!(info.base().arg_count(), 2); + assert_eq!(info.base().args()[0].type_path(), "i32"); + assert_eq!(info.base().args()[1].type_path(), "i32"); + assert_eq!(info.base().return_info().type_path(), "i32"); } #[test] @@ -423,10 +800,10 @@ mod tests { let info = add.get_function_info(); assert!(info.name().is_none()); - assert_eq!(info.arg_count(), 2); - assert_eq!(info.args()[0].type_path(), "i32"); - assert_eq!(info.args()[1].type_path(), "i32"); - assert_eq!(info.return_info().type_path(), "i32"); + assert_eq!(info.base().arg_count(), 2); + assert_eq!(info.base().args()[0].type_path(), "i32"); + assert_eq!(info.base().args()[1].type_path(), "i32"); + assert_eq!(info.base().return_info().type_path(), "i32"); } #[test] @@ -442,9 +819,30 @@ mod tests { let info = add.get_function_info(); assert!(info.name().is_none()); - assert_eq!(info.arg_count(), 2); - assert_eq!(info.args()[0].type_path(), "i32"); - assert_eq!(info.args()[1].type_path(), "i32"); - assert_eq!(info.return_info().type_path(), "()"); + assert_eq!(info.base().arg_count(), 2); + assert_eq!(info.base().args()[0].type_path(), "i32"); + assert_eq!(info.base().args()[1].type_path(), "i32"); + assert_eq!(info.base().return_info().type_path(), "()"); + } + + #[test] + fn should_pretty_print_info() { + // fn add(a: i32, b: i32) -> i32 { + // a + b + // } + // + // let info = add.get_function_info().with_name("add"); + // + // let pretty = info.pretty_printer(); + // assert_eq!(format!("{:?}", pretty), "(_: i32, _: i32) -> i32"); + // + // let pretty = info.pretty_printer().include_fn_token(); + // assert_eq!(format!("{:?}", pretty), "fn(_: i32, _: i32) -> i32"); + // + // let pretty = info.pretty_printer().include_name(); + // assert_eq!(format!("{:?}", pretty), "add(_: i32, _: i32) -> i32"); + // + // let pretty = info.pretty_printer().include_fn_token().include_name(); + // assert_eq!(format!("{:?}", pretty), "fn add(_: i32, _: i32) -> i32"); } } diff --git a/crates/bevy_reflect/src/func/into_function.rs b/crates/bevy_reflect/src/func/into_function.rs index 78a92b05f5953..e913045f8cc2a 100644 --- a/crates/bevy_reflect/src/func/into_function.rs +++ b/crates/bevy_reflect/src/func/into_function.rs @@ -66,6 +66,6 @@ mod tests { fn should_default_closure_name_to_none() { let c = 23; let func = (|a: i32, b: i32| a + b + c).into_function(); - assert_eq!(func.info().name(), None); + assert!(func.name().is_none()); } } diff --git a/crates/bevy_reflect/src/func/into_function_mut.rs b/crates/bevy_reflect/src/func/into_function_mut.rs index a33b840bfc971..8f7f1b0a6dd1a 100644 --- a/crates/bevy_reflect/src/func/into_function_mut.rs +++ b/crates/bevy_reflect/src/func/into_function_mut.rs @@ -81,6 +81,6 @@ mod tests { fn should_default_closure_name_to_none() { let mut total = 0; let func = (|a: i32, b: i32| total = a + b).into_function_mut(); - assert_eq!(func.info().name(), None); + assert!(func.name().is_none()); } } diff --git a/crates/bevy_reflect/src/func/mod.rs b/crates/bevy_reflect/src/func/mod.rs index 2811f5c22663a..f990e135f017d 100644 --- a/crates/bevy_reflect/src/func/mod.rs +++ b/crates/bevy_reflect/src/func/mod.rs @@ -94,6 +94,32 @@ //! For other functions that don't conform to one of the above signatures, //! [`DynamicFunction`] and [`DynamicFunctionMut`] can instead be created manually. //! +//! # Generic Functions +//! +//! In Rust, generic functions are [monomophized] by the compiler, +//! which means that a separate copy of the function is generated for each concrete set of type parameters. +//! +//! When converting a generic function to a [`DynamicFunction`] or [`DynamicFunctionMut`], +//! the function must be manually monomorphized with concrete types. +//! In other words, you cannot write `add.into_function()`. +//! Instead, you will need to write `add::.into_function()`. +//! +//! This means that reflected functions cannot be generic themselves. +//! To get around this limitation, you can consider [overloading] your function with multiple concrete types. +//! +//! # Overloading Functions +//! +//! Both [`DynamicFunction`] and [`DynamicFunctionMut`] support [function overloading]. +//! +//! Function overloading allows one function to handle multiple types of arguments. +//! This is useful for simulating generic functions by having an overload for each known concrete type. +//! Additionally, it can also simulate [variadic functions]: functions that can be called with a variable number of arguments. +//! +//! Internally, this works by storing multiple functions in a map, +//! where each function is associated with a specific argument signature. +//! +//! To learn more, see the docs on [`DynamicFunction::with_overload`]. +//! //! # Function Registration //! //! This module also provides a [`FunctionRegistry`] that can be used to register functions and closures @@ -127,6 +153,10 @@ //! [`Reflect`]: crate::Reflect //! [lack of variadic generics]: https://poignardazur.github.io/2024/05/25/report-on-rustnl-variadics/ //! [coherence issues]: https://doc.rust-lang.org/rustc/lints/listing/warn-by-default.html#coherence-leak-check +//! [monomophized]: https://en.wikipedia.org/wiki/Monomorphization +//! [overloading]: #overloading-functions +//! [function overloading]: https://en.wikipedia.org/wiki/Function_overloading +//! [variadic functions]: https://en.wikipedia.org/wiki/Variadic_function pub use args::{ArgError, ArgList, ArgValue}; pub use dynamic_function::*; @@ -143,6 +173,7 @@ pub use return_type::*; pub mod args; mod dynamic_function; +mod dynamic_function_internal; mod dynamic_function_mut; mod error; mod function; @@ -154,18 +185,19 @@ mod reflect_fn; mod reflect_fn_mut; mod registry; mod return_type; +pub mod signature; #[cfg(test)] mod tests { use alloc::borrow::Cow; + use super::*; + use crate::func::args::ArgCount; use crate::{ func::args::{ArgError, ArgList, Ownership}, TypePath, }; - use super::*; - #[test] fn should_error_on_missing_args() { fn foo(_: i32) {} @@ -176,7 +208,7 @@ mod tests { assert_eq!( result.unwrap_err(), FunctionError::ArgCountMismatch { - expected: 1, + expected: ArgCount::new(1).unwrap(), received: 0 } ); @@ -192,7 +224,7 @@ mod tests { assert_eq!( result.unwrap_err(), FunctionError::ArgCountMismatch { - expected: 0, + expected: ArgCount::new(0).unwrap(), received: 1 } ); diff --git a/crates/bevy_reflect/src/func/reflect_fn.rs b/crates/bevy_reflect/src/func/reflect_fn.rs index 486fa452aa0ef..38a18141fcf43 100644 --- a/crates/bevy_reflect/src/func/reflect_fn.rs +++ b/crates/bevy_reflect/src/func/reflect_fn.rs @@ -5,8 +5,9 @@ use alloc::{boxed::Box, format, vec}; use crate::{ func::{ - args::FromArg, macros::count_tokens, ArgList, FunctionError, FunctionResult, IntoReturn, - ReflectFnMut, + args::{ArgCount, FromArg}, + macros::count_tokens, + ArgList, FunctionError, FunctionResult, IntoReturn, ReflectFnMut, }, Reflect, TypePath, }; @@ -96,7 +97,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } @@ -125,7 +126,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } @@ -155,7 +156,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } @@ -185,7 +186,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } diff --git a/crates/bevy_reflect/src/func/reflect_fn_mut.rs b/crates/bevy_reflect/src/func/reflect_fn_mut.rs index 6a8b9a6d73264..760e657037c5b 100644 --- a/crates/bevy_reflect/src/func/reflect_fn_mut.rs +++ b/crates/bevy_reflect/src/func/reflect_fn_mut.rs @@ -5,7 +5,9 @@ use alloc::{boxed::Box, format, vec}; use crate::{ func::{ - args::FromArg, macros::count_tokens, ArgList, FunctionError, FunctionResult, IntoReturn, + args::{ArgCount, FromArg}, + macros::count_tokens, + ArgList, FunctionError, FunctionResult, IntoReturn, }, Reflect, TypePath, }; @@ -102,7 +104,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } @@ -131,7 +133,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } @@ -161,7 +163,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } @@ -191,7 +193,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: ArgCount::new(COUNT).unwrap(), received: args.len(), }); } diff --git a/crates/bevy_reflect/src/func/signature.rs b/crates/bevy_reflect/src/func/signature.rs new file mode 100644 index 0000000000000..965ff401e00b8 --- /dev/null +++ b/crates/bevy_reflect/src/func/signature.rs @@ -0,0 +1,234 @@ +//! Function signature types. +//! +//! Function signatures differ from [`FunctionInfo`] and [`SignatureInfo`] in that they +//! are only concerned about the types and order of the arguments and return type of a function. +//! +//! The names of arguments do not matter, +//! nor does any other information about the function such as its name or other attributes. +//! +//! This makes signatures useful for comparing or hashing functions strictly based on their +//! arguments and return type. +//! +//! [`FunctionInfo`]: crate::func::info::FunctionInfo + +use crate::func::args::ArgInfo; +use crate::func::{ArgList, SignatureInfo}; +use crate::Type; +use bevy_utils::hashbrown::Equivalent; +use core::borrow::Borrow; +use core::fmt::{Debug, Formatter}; +use core::hash::{Hash, Hasher}; +use core::ops::{Deref, DerefMut}; + +/// The signature of a function. +/// +/// This can be used as a way to compare or hash functions based on their arguments and return type. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Signature { + args: ArgumentSignature, + ret: Type, +} + +impl Signature { + /// Create a new function signature with the given argument signature and return type. + pub fn new(args: ArgumentSignature, ret: Type) -> Self { + Self { args, ret } + } + + /// Get the argument signature of the function. + pub fn args(&self) -> &ArgumentSignature { + &self.args + } + + /// Get the return type of the function. + pub fn return_type(&self) -> &Type { + &self.ret + } +} + +impl Debug for Signature { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "{:?} -> {:?}", self.args, self.ret) + } +} + +impl> From for Signature { + fn from(info: T) -> Self { + let info = info.borrow(); + Self::new(ArgumentSignature::from(info), *info.return_info().ty()) + } +} + +/// A wrapper around a borrowed [`ArgList`] that can be used as an +/// [equivalent] of an [`ArgumentSignature`]. +/// +/// [equivalent]: Equivalent +pub(super) struct ArgListSignature<'a, 'b>(&'a ArgList<'b>); + +impl Equivalent for ArgListSignature<'_, '_> { + fn equivalent(&self, key: &ArgumentSignature) -> bool { + self.len() == key.len() && self.iter().eq(key.iter()) + } +} + +impl<'a, 'b> ArgListSignature<'a, 'b> { + pub fn iter(&self) -> impl ExactSizeIterator { + self.0.iter().map(|arg| { + arg.value() + .get_represented_type_info() + .unwrap_or_else(|| { + panic!("no `TypeInfo` found for argument: {:?}", arg); + }) + .ty() + }) + } + + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl Eq for ArgListSignature<'_, '_> {} +impl PartialEq for ArgListSignature<'_, '_> { + fn eq(&self, other: &Self) -> bool { + self.len() == other.len() && self.iter().eq(other.iter()) + } +} + +impl Hash for ArgListSignature<'_, '_> { + fn hash(&self, state: &mut H) { + self.0.iter().for_each(|arg| { + arg.value() + .get_represented_type_info() + .unwrap_or_else(|| { + panic!("no `TypeInfo` found for argument: {:?}", arg); + }) + .ty() + .hash(state); + }); + } +} + +impl<'a, 'b> From<&'a ArgList<'b>> for ArgListSignature<'a, 'b> { + fn from(args: &'a ArgList<'b>) -> Self { + Self(args) + } +} + +/// The argument-portion of a function signature. +/// +/// For example, given a function signature `(a: i32, b: f32) -> u32`, +/// the argument signature would be `(i32, f32)`. +/// +/// This can be used as a way to compare or hash functions based on their arguments. +#[derive(Clone)] +pub struct ArgumentSignature(Box<[Type]>); + +impl Debug for ArgumentSignature { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + let mut tuple = f.debug_tuple(""); + for ty in self.0.iter() { + tuple.field(ty); + } + tuple.finish() + } +} + +impl Deref for ArgumentSignature { + type Target = [Type]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ArgumentSignature { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Eq for ArgumentSignature {} + +impl PartialEq for ArgumentSignature { + fn eq(&self, other: &Self) -> bool { + self.0.len() == other.0.len() && self.0.iter().eq(other.0.iter()) + } +} + +impl Hash for ArgumentSignature { + fn hash(&self, state: &mut H) { + self.0.iter().for_each(|ty| ty.hash(state)); + } +} + +impl FromIterator for ArgumentSignature { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl> From for ArgumentSignature { + fn from(info: T) -> Self { + Self( + info.borrow() + .args() + .iter() + .map(ArgInfo::ty) + .copied() + .collect(), + ) + } +} + +impl From<&ArgList<'_>> for ArgumentSignature { + fn from(args: &ArgList) -> Self { + Self( + args.iter() + .map(|arg| { + arg.value() + .get_represented_type_info() + .unwrap_or_else(|| { + panic!("no `TypeInfo` found for argument: {:?}", arg); + }) + .ty() + }) + .copied() + .collect(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::func::TypedFunction; + + #[test] + fn should_generate_signature_from_function_info() { + fn add(a: i32, b: f32) -> u32 { + (a as f32 + b).round() as u32 + } + + let info = add.get_function_info(); + let signature = Signature::from(info.base()); + + assert_eq!(signature.args().0.len(), 2); + assert_eq!(signature.args().0[0], Type::of::()); + assert_eq!(signature.args().0[1], Type::of::()); + assert_eq!(*signature.return_type(), Type::of::()); + } + + #[test] + fn should_debug_signature() { + let signature = Signature::new( + ArgumentSignature::from_iter(vec![Type::of::<&mut String>(), Type::of::()]), + Type::of::<()>(), + ); + + assert_eq!( + format!("{:?}", signature), + "(&mut alloc::string::String, i32) -> ()" + ); + } +} diff --git a/examples/reflection/function_reflection.rs b/examples/reflection/function_reflection.rs index c37db7941a0d5..1a09cfa2b0d00 100644 --- a/examples/reflection/function_reflection.rs +++ b/examples/reflection/function_reflection.rs @@ -8,8 +8,8 @@ use bevy::reflect::{ func::{ - ArgList, DynamicFunction, DynamicFunctionMut, FunctionInfo, FunctionResult, IntoFunction, - IntoFunctionMut, Return, + ArgList, DynamicFunction, DynamicFunctionMut, FunctionResult, IntoFunction, + IntoFunctionMut, Return, SignatureInfo, }, PartialReflect, Reflect, }; @@ -83,7 +83,50 @@ fn main() { dbg!(closure.call_once(args).unwrap()); assert_eq!(count, 5); - // As stated before, this works for many kinds of simple functions. + // Generic functions can also be converted into a `DynamicFunction`, + // however, they will need to be manually monomorphized first. + fn stringify(value: T) -> String { + value.to_string() + } + + // We have to manually specify the concrete generic type we want to use. + let function = stringify::.into_function(); + + let args = ArgList::new().push_owned(123_i32); + let return_value = function.call(args).unwrap(); + let value: Box = return_value.unwrap_owned(); + assert_eq!(value.try_take::().unwrap(), "123"); + + // To make things a little easier, we can also "overload" functions. + // This makes it so that a single `DynamicFunction` can represent multiple functions, + // and the correct one is chosen based on the types of the arguments. + // Each function overload must have a unique argument signature. + let function = stringify:: + .into_function() + .with_overload(stringify::); + + // Now our `function` accepts both `i32` and `f32` arguments. + let args = ArgList::new().push_owned(1.23_f32); + let return_value = function.call(args).unwrap(); + let value: Box = return_value.unwrap_owned(); + assert_eq!(value.try_take::().unwrap(), "1.23"); + + // Function overloading even allows us to have a variable number of arguments. + let function = (|| 0) + .into_function() + .with_overload(|a: i32| a) + .with_overload(|a: i32, b: i32| a + b) + .with_overload(|a: i32, b: i32, c: i32| a + b + c); + + let args = ArgList::new() + .push_owned(1_i32) + .push_owned(2_i32) + .push_owned(3_i32); + let return_value = function.call(args).unwrap(); + let value: Box = return_value.unwrap_owned(); + assert_eq!(value.try_take::().unwrap(), 6); + + // As stated earlier, `IntoFunction` works for many kinds of simple functions. // Functions with non-reflectable arguments or return values may not be able to be converted. // Generic functions are also not supported (unless manually monomorphized like `foo::.into_function()`). // Additionally, the lifetime of the return value is tied to the lifetime of the first argument. @@ -118,7 +161,7 @@ fn main() { let value: &dyn PartialReflect = return_value.unwrap_ref(); assert_eq!(value.try_downcast_ref::().unwrap(), "Hello, world!"); - // Lastly, for more complex use cases, you can always create a custom `DynamicFunction` manually. + // For more complex use cases, you can always create a custom `DynamicFunction` manually. // This is useful for functions that can't be converted via the `IntoFunction` trait. // For example, this function doesn't implement `IntoFunction` due to the fact that // the lifetime of the return value is not tied to the lifetime of the first argument. @@ -150,7 +193,7 @@ fn main() { // This makes it easier to debug and is also required for function registration. // We can either give it a custom name or use the function's type name as // derived from `std::any::type_name_of_val`. - FunctionInfo::named(std::any::type_name_of_val(&get_or_insert)) + SignatureInfo::named(std::any::type_name_of_val(&get_or_insert)) // We can always change the name if needed. // It's a good idea to also ensure that the name is unique, // such as by using its type name or by prefixing it with your crate name.