diff --git a/src/darsia/measure/wasserstein.py b/src/darsia/measure/wasserstein.py index 87635795..79be2c82 100644 --- a/src/darsia/measure/wasserstein.py +++ b/src/darsia/measure/wasserstein.py @@ -1387,6 +1387,7 @@ def __call__( # Return solution return_info = self.options.get("return_info", False) + return_status = self.options.get("return_status", False) if return_info: info.update( { @@ -1403,6 +1404,8 @@ def __call__( } ) return distance, info + elif return_status: + return distance, info["converged"] else: return distance @@ -1693,7 +1696,7 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]: # Define performance metric info = { - "converged": iter < num_iter, + "converged": iter < num_iter - 1, "number_iterations": iter, "convergence_history": convergence_history, "timings": total_timings, @@ -1952,6 +1955,15 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]: # Update distance new_distance = self.l1_dissipation(flux) + # Catch nan values + if np.isnan(new_distance): + info = { + "converged": False, + "number_iterations": iter, + "convergence_history": convergence_history, + } + return new_distance, solution_i, info + # Determine the error in the mass conservation equation mass_conservation_residual = ( self.div.dot(flux) - rhs[self.pressure_slice] @@ -1993,22 +2005,26 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]: # Print status if self.verbose: - distance_increment = ( - convergence_history["distance_increment"][-1] / new_distance - ) - aux_force_increment = ( - convergence_history["aux_force_increment"][-1] - / convergence_history["aux_force_increment"][0] - ) - mass_conservation_residual = convergence_history[ - "mass_conservation_residual" - ][-1] - print( - f"Iter. {iter} \t| {new_distance:.6e} \t| " - "" - f"""{distance_increment:.6e} \t| {aux_force_increment:.6e} \t| """ - f"""{mass_conservation_residual:.6e}""" - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="overflow encountered" + ) + distance_increment = ( + convergence_history["distance_increment"][-1] / new_distance + ) + aux_force_increment = ( + convergence_history["aux_force_increment"][-1] + / convergence_history["aux_force_increment"][0] + ) + mass_conservation_residual = convergence_history[ + "mass_conservation_residual" + ][-1] + print( + f"Iter. {iter} \t| {new_distance:.6e} \t| " + "" + f"""{distance_increment:.6e} \t| {aux_force_increment:.6e} \t| """ + f"""{mass_conservation_residual:.6e}""" + ) # Base stopping citeria on the different interpretations of the split Bregman # method: @@ -2058,7 +2074,7 @@ def _solve(self, flat_mass_diff: np.ndarray) -> tuple[float, np.ndarray, dict]: # Define performance metric info = { - "converged": iter < num_iter, + "converged": iter < num_iter - 1, "number_iterations": iter, "convergence_history": convergence_history, "timings": total_timings,