diff --git a/fiddy/success.py b/fiddy/success.py index 1c6fdda..3746d4f 100644 --- a/fiddy/success.py +++ b/fiddy/success.py @@ -1,6 +1,6 @@ import abc from collections.abc import Callable -from typing import Any, Union +from typing import Any import numpy as np @@ -97,25 +97,26 @@ def method( equal_nan=self.equal_nan, ).all() - consistent_results = [ - np.nanmean(list(results_by_size[size].values()), axis=0) - for size, success in success_by_size.items() - if success - ] - - success = False - value = np.nanmean(np.array(consistent_results), axis=0) - if consistent_results: - success = ( - np.isclose( - consistent_results, - value, - rtol=self.rtol, - atol=self.atol, - equal_nan=self.equal_nan, - ).all() - and not np.isnan(consistent_results).all() - ) - value = np.average(np.array(consistent_results), axis=0) + consistent_results = np.array( + [ + np.nanmean(list(results_by_size[size].values()), axis=0) + for size, success in success_by_size.items() + if success + ] + ) + + if len(consistent_results) == 0: + return False, np.nan + value = np.nanmean(consistent_results, axis=0) + success = ( + np.isclose( + consistent_results, + value, + rtol=self.rtol, + atol=self.atol, + equal_nan=self.equal_nan, + ).all() + and not np.isnan(consistent_results).all() + ) return success, value