Skip to content

Commit

Permalink
Merge pull request #488 from RocketPy-Team/enh/function-reverse-arith
Browse files Browse the repository at this point in the history
ENH: Function Reverse Arithmetic Priority
  • Loading branch information
phmbressan authored Nov 28, 2023
2 parents 214e4c2 + beb22ab commit 191ab9f
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ straightforward as possible.

### Changed

-
- ENH: Function Reverse Arithmetic Priority [#488](https://github.com/RocketPy-Team/RocketPy/pull/488)

### Fixed

Expand Down
27 changes: 21 additions & 6 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class Function:
extrapolation, plotting and algebra.
"""

# Arithmetic priority
__array_ufunc__ = None

def __init__(
self,
source,
Expand Down Expand Up @@ -1837,7 +1840,9 @@ def __add__(self, other):
return Function(lambda x: (self.get_value(x) + other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -1967,7 +1972,9 @@ def __mul__(self, other):
return Function(lambda x: (self.get_value(x) * other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -2056,7 +2063,9 @@ def __truediv__(self, other):
return Function(lambda x: (self.get_value_opt(x) / other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -2095,7 +2104,9 @@ def __rtruediv__(self, other):
A Function object which gives the result of other(x)/self(x).
"""
# Check if Function object source is array and other is float
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
if isinstance(self.source, np.ndarray):
# Operate on grid values
ys = other / self.y_array
Expand Down Expand Up @@ -2163,7 +2174,9 @@ def __pow__(self, other):
return Function(lambda x: (self.get_value_opt(x) ** other(x)))
# If other is Float except...
except AttributeError:
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
# Check if Function object source is array or callable
if isinstance(self.source, np.ndarray):
# Operate on grid values
Expand Down Expand Up @@ -2202,7 +2215,9 @@ def __rpow__(self, other):
A Function object which gives the result of other(x)**self(x).
"""
# Check if Function object source is array and other is float
if isinstance(other, (float, int, complex)):
if isinstance(
other, (float, int, complex, np.ndarray, np.integer, np.floating)
):
if isinstance(self.source, np.ndarray):
# Operate on grid values
ys = other**self.y_array
Expand Down
80 changes: 80 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,83 @@ def test_shepard_interpolation(x, y, z_expected):
func = Function(source=source, inputs=["x", "y"], outputs=["z"])
z = func(x, y)
assert np.isclose(z, z_expected, atol=1e-8).all()


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_sum_arithmetic_priority(other):
"""Test the arithmetic priority of the add operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda + func_array, Function)
assert isinstance(func_array + func_lambda, Function)
assert isinstance(func_lambda + other, Function)
assert isinstance(other + func_lambda, Function)
assert isinstance(func_array + other, Function)
assert isinstance(other + func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_sub_arithmetic_priority(other):
"""Test the arithmetic priority of the sub operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda - func_array, Function)
assert isinstance(func_array - func_lambda, Function)
assert isinstance(func_lambda - other, Function)
assert isinstance(other - func_lambda, Function)
assert isinstance(func_array - other, Function)
assert isinstance(other - func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_mul_arithmetic_priority(other):
"""Test the arithmetic priority of the mul operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda * func_array, Function)
assert isinstance(func_array * func_lambda, Function)
assert isinstance(func_lambda * other, Function)
assert isinstance(other * func_lambda, Function)
assert isinstance(func_array * other, Function)
assert isinstance(other * func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_truediv_arithmetic_priority(other):
"""Test the arithmetic priority of the truediv operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(1, 1), (2, 4)])

assert isinstance(func_lambda / func_array, Function)
assert isinstance(func_array / func_lambda, Function)
assert isinstance(func_lambda / other, Function)
assert isinstance(other / func_lambda, Function)
assert isinstance(func_array / other, Function)
assert isinstance(other / func_array, Function)


@pytest.mark.parametrize("other", [1, 0.1, np.int_(1), np.float_(0.1), np.array([1])])
def test_pow_arithmetic_priority(other):
"""Test the arithmetic priority of the pow operation of the Function class,
specially comparing to the numpy array operations.
"""
func_lambda = Function(lambda x: x**2)
func_array = Function([(0, 0), (1, 1), (2, 4)])

assert isinstance(func_lambda**func_array, Function)
assert isinstance(func_array**func_lambda, Function)
assert isinstance(func_lambda**other, Function)
assert isinstance(other**func_lambda, Function)
assert isinstance(func_array**other, Function)
assert isinstance(other**func_array, Function)

0 comments on commit 191ab9f

Please sign in to comment.