diff --git a/src/oxonfair/learners/fair.py b/src/oxonfair/learners/fair.py index 16ad4a9..342435e 100644 --- a/src/oxonfair/learners/fair.py +++ b/src/oxonfair/learners/fair.py @@ -8,7 +8,7 @@ from sklearn.preprocessing import OneHotEncoder from ..utils import group_metrics from .. utils.scipy_metrics_cont_wrapper import ScorerRequiresContPred -from ..utils.group_metric_classes import BaseGroupMetric +from ..utils.group_metric_classes import BaseGroupMetric, Overall from ..utils import performance as perf from . import efficient_compute, fair_frontier @@ -1093,6 +1093,9 @@ def fix_groups(metric, groups): groups = groups_to_masks(groups) + if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1. + groups = np.ones(groups.shape[0]) + def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: return metric(y_true, y_pred, groups) return new_metric @@ -1146,6 +1149,10 @@ def fix_groups_and_conditioning(metric, groups, conditioning_factor, y_true): weights = metric.cond_weights(conditioning_factor, groups, y_true) groups = groups_to_masks(groups) + if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1. + groups = np.ones(groups.shape[0]) + + def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: return metric(y_true, y_pred, groups, weights) return new_metric