Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quality of life improvements #47

Merged
merged 8 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

FAIR = "oxonfair"

version = "0.2.1.8"
version = "0.2.1.9"

PYTHON_REQUIRES = ">=3.8"

Expand Down
16 changes: 14 additions & 2 deletions src/oxonfair/learners/fair.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.preprocessing import OneHotEncoder
from ..utils import group_metrics
from .. utils.scipy_metrics_cont_wrapper import ScorerRequiresContPred
from ..utils.group_metric_classes import BaseGroupMetric
from ..utils.group_metric_classes import BaseGroupMetric, Overall

from ..utils import performance as perf
from . import efficient_compute, fair_frontier
Expand Down Expand Up @@ -720,6 +720,9 @@ def evaluate_fairness(self, data=None, groups=None, factor=None, *,

collect = pd.concat([collect, new_pd], axis='columns')
collect.columns = ['original', 'updated']
else:
collect = pd.concat([collect,], axis='columns')
collect.columns = ['original']

return collect

Expand Down Expand Up @@ -822,7 +825,9 @@ def evaluate_groups(self, data=None, groups=None, metrics=None, fact=None, *,
verbose=verbose)

out = updated
if return_original:
if self.frontier is None:
out = pd.concat([updated, ], keys=['original', ])
elif return_original:
out = pd.concat([original, updated], keys=['original', 'updated'])
return out

Expand Down Expand Up @@ -1093,6 +1098,9 @@ def fix_groups(metric, groups):

groups = groups_to_masks(groups)

if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1.
groups = np.ones(groups.shape[0])

def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
return metric(y_true, y_pred, groups)
return new_metric
Expand Down Expand Up @@ -1146,6 +1154,10 @@ def fix_groups_and_conditioning(metric, groups, conditioning_factor, y_true):
weights = metric.cond_weights(conditioning_factor, groups, y_true)
groups = groups_to_masks(groups)

if isinstance(metric, Overall): # Performance hack. If metric is of type overall, groups don't matter -- assign all groups to 1.
groups = np.ones(groups.shape[0])


def new_metric(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
return metric(y_true, y_pred, groups, weights)
return new_metric
Expand Down
10 changes: 9 additions & 1 deletion tests/test_check_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,17 @@ def test_check_style_examples():

def test_md_links():
missing_links = lc.check_links('./', ext='.md', recurse=True, use_async=False)
missing_links_eg = lc.check_links('./examples/', ext='.md', recurse=True)

for link in missing_links:
warnings.warn(link)
assert missing_links == []

for link in missing_links_eg:
warnings.warn(link)

assert missing_links_eg == []
assert missing_links == [('README.md', 'https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4331652', 429),] or missing_links == []
# SSRN thinks we're crawling and blocks exactly one paper.


def test_run_notebooks_without_errors():
Expand Down
17 changes: 9 additions & 8 deletions tests/unittests/test_ag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,20 @@ def test_recall_diff(use_fast=True):

fpredictor = fair.FairPredictor(predictor, test_data, "sex", use_fast=use_fast)

fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025)
limit =0.01
fpredictor.fit(gm.accuracy, gm.recall.diff, limit)

# Evaluate the change in fairness (recall difference corresponds to EO)
measures = fpredictor.evaluate_fairness(verbose=False)

assert measures["updated"]["recall.diff"] < 0.025
assert measures["updated"]["recall.diff"] < limit
measures = fpredictor.evaluate()
acc = measures["updated"]["Accuracy"]
fpredictor.fit(gm.accuracy, gm.recall.diff, 0.025, greater_is_better_const=True)
fpredictor.fit(gm.accuracy, gm.recall.diff, limit, greater_is_better_const=True)
measures = fpredictor.evaluate_fairness(verbose=False)
assert measures["original"]["recall.diff"] > 0.025
assert measures["original"]["recall.diff"] > limit

fpredictor.fit(gm.accuracy, gm.recall.diff, 0.01, greater_is_better_obj=False)
fpredictor.fit(gm.accuracy, gm.recall.diff, limit/2, greater_is_better_obj=False)
assert acc >= fpredictor.evaluate()["updated"]["Accuracy"]


Expand All @@ -117,11 +118,11 @@ def test_subset(use_fast=True):

# Check that metrics computed over a subset of the data is consistent with metrics over all data
for group in (" White", " Black", " Amer-Indian-Eskimo"):
assert all(full_group_metrics.loc[group] == partial_group_metrics.loc[group])
assert all(full_group_metrics.loc[('original', group)] == partial_group_metrics.loc[('original', group)])

assert all(
full_group_metrics.loc["Maximum difference"]
>= partial_group_metrics.loc["Maximum difference"]
full_group_metrics.loc[('original', "Maximum difference")]
>= partial_group_metrics.loc[('original',"Maximum difference")]
)


Expand Down
6 changes: 4 additions & 2 deletions tests/unittests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ def test_conflict_groups():
def test_fit_creates_updated(use_fast=True):
"""eval should return 'updated' iff fit has been called"""
fpredictor = FairPredictor(predictor, val_dict, use_fast=use_fast)
assert isinstance(fpredictor.evaluate(), pd.Series)
assert not isinstance(fpredictor.evaluate(), pd.Series)
assert 'original' in fpredictor.evaluate().columns
fpredictor.fit(gm.accuracy, gm.recall, 0) # constraint is intentionally slack
assert not isinstance(fpredictor.evaluate(), pd.Series)
assert 'original' in fpredictor.evaluate().columns
assert 'updated' in fpredictor.evaluate().columns


Expand Down Expand Up @@ -460,4 +462,4 @@ def test_selection_rate_diff_levelling_up_slow():


def test_selection_rate_diff_levelling_up_hybrid():
test_selection_rate_diff_levelling_up(use_fast='hybrid')
test_selection_rate_diff_levelling_up(use_fast='hybrid')