From 38d211a862783c5709f6baec1b4d740ae715c113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Amund=20Faller=20R=C3=A5heim?= Date: Tue, 11 Jun 2024 11:42:08 +0200 Subject: [PATCH] feat: reintroduce invalid_value in bisect --- linerate/solver.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/linerate/solver.py b/linerate/solver.py index 1581c35..aaf1b3a 100644 --- a/linerate/solver.py +++ b/linerate/solver.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable +from typing import Callable, Optional import numpy as np @@ -13,6 +13,7 @@ def bisect( xmin: FloatOrFloatArray, xmax: FloatOrFloatArray, tolerance: float, + invalid_value: Optional[float] = None, ) -> FloatOrFloatArray: r"""Compute the roots of a function using a vectorized bisection method. @@ -31,6 +32,9 @@ def bisect( bounded within an interval of size :math:`\Delta x` or less. The bisection method will run for :math:`\left\lceil\frac{x_\max - x_\min}{\Delta x}\right\rceil` iterations. + invalid_value: + If provided, then the this value is used whenever + :math:`\text{sign}(f(\mathbf{x}_\min)) = \text{sign}(f(\mathbf{x}_\max))`. Returns ------- @@ -47,10 +51,12 @@ def bisect( f_right = f(xmax) invalid_mask = np.sign(f_left) == np.sign(f_right) - if np.any(invalid_mask): + if np.any(invalid_mask) and invalid_value is None: raise ValueError( "f(xmin) and f(xmax) have the same sign. Consider increasing the search interval." ) + elif isinstance(invalid_mask, bool) and invalid_mask: + return invalid_value # type: ignore while interval > tolerance: xmid = 0.5 * (xmax + xmin) @@ -63,7 +69,7 @@ def bisect( f_left = np.where(mask, f_mid, f_left) f_right = np.where(mask, f_right, f_mid) - out = 0.5 * (xmax + xmin) + out = np.where(invalid_mask, invalid_value, 0.5 * (xmax + xmin)) # type: ignore return out @@ -142,4 +148,4 @@ def compute_conductor_ampacity( """ f = partial(heat_balance, max_conductor_temperature) - return bisect(f, min_ampacity, max_ampacity, tolerance) + return bisect(f, min_ampacity, max_ampacity, tolerance, invalid_value=0)