Skip to content

Commit

Permalink
speed up slow pathway
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisMRuss committed Nov 23, 2024
1 parent 8418272 commit e611a40
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/oxonfair/learners/fair.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.preprocessing import OneHotEncoder
from ..utils import group_metrics
from .. utils.scipy_metrics_cont_wrapper import ScorerRequiresContPred
from ..utils.group_metric_classes import BaseGroupMetric
from ..utils.group_metric_classes import BaseGroupMetric, Overall

from ..utils import performance as perf
from . import efficient_compute, fair_frontier
Expand Down Expand Up @@ -1093,6 +1093,9 @@ def fix_groups(metric, groups):

groups = groups_to_masks(groups)

if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1.
groups = np.ones(groups.shape[0])

def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
return metric(y_true, y_pred, groups)
return new_metric
Expand Down Expand Up @@ -1146,6 +1149,10 @@ def fix_groups_and_conditioning(metric, groups, conditioning_factor, y_true):
weights = metric.cond_weights(conditioning_factor, groups, y_true)
groups = groups_to_masks(groups)

if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1.
groups = np.ones(groups.shape[0])


def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
return metric(y_true, y_pred, groups, weights)
return new_metric
Expand Down

0 comments on commit e611a40

Please sign in to comment.