Skip to content

Commit

Permalink
feat: reintroduce the invalid_value default value in bisect&compute_c…
Browse files Browse the repository at this point in the history
…onductor_ampacity
  • Loading branch information
amundfr committed Jun 10, 2024
1 parent ae1eee8 commit be5d802
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions linerate/solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Callable
from typing import Callable, Optional

import numpy as np

Expand All @@ -13,6 +13,7 @@ def bisect(
xmin: FloatOrFloatArray,
xmax: FloatOrFloatArray,
tolerance: float,
invalid_value: Optional[float] = None,
) -> FloatOrFloatArray:
r"""Compute the roots of a function using a vectorized bisection method.
Expand All @@ -31,6 +32,9 @@ def bisect(
bounded within an interval of size :math:`\Delta x` or less. The bisection method will
run for :math:`\left\lceil\frac{x_\max - x_\min}{\Delta x}\right\rceil`
iterations.
invalid_value:
If provided, then the this value is used whenever
:math:`\text{sign}(f(\mathbf{x}_\min)) = \text{sign}(f(\mathbf{x}_\max))`.
Returns
-------
Expand All @@ -47,10 +51,12 @@ def bisect(
f_right = f(xmax)

invalid_mask = np.sign(f_left) == np.sign(f_right)
if np.any(invalid_mask):
if np.any(invalid_mask) and invalid_value is None:
raise ValueError(
"f(xmin) and f(xmax) have the same sign. Consider increasing the search interval."
)
elif isinstance(invalid_mask, bool) and invalid_mask:
return invalid_value # type: ignore

while interval > tolerance:
xmid = 0.5 * (xmax + xmin)
Expand All @@ -63,7 +69,7 @@ def bisect(
f_left = np.where(mask, f_mid, f_left)
f_right = np.where(mask, f_right, f_mid)

out = 0.5 * (xmax + xmin)
out = np.where(invalid_mask, invalid_value, 0.5 * (xmax + xmin)) # type: ignore
return out


Expand Down Expand Up @@ -142,4 +148,4 @@ def compute_conductor_ampacity(
"""
f = partial(heat_balance, max_conductor_temperature)

return bisect(f, min_ampacity, max_ampacity, tolerance)
return bisect(f, min_ampacity, max_ampacity, tolerance, invalid_value=0)

0 comments on commit be5d802

Please sign in to comment.