Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
sophmrtn committed Jul 3, 2024
1 parent 07e1bed commit 6e5a0fc
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
parser.add_argument(
"--val", default=None, help="List of ids to use for validation."
)
parser.add_argument(
"mitigated_model_path",
type=str,
help="Path to the saved and mitigated model.",
)
parser.add_argument(
"--test", default=None, help="List of ids to use for validation."
)
parser.add_argument(
"--attribute",
default="gender",
Expand Down Expand Up @@ -111,8 +119,6 @@

model.fit(x_train, y_train, sensitive_features=metadata)

plot_threshold_optimizer(model)

# Save unbiased model to disk
log_dir = os.path.dirname(args.model_path)

Expand All @@ -121,7 +127,9 @@
) as f:
pickle.dump(model, f)

print("Evaluating on validation data...")
plot_threshold_optimizer(model)

print("Evaluating on validation/test data...")
validation_set = MIMIC4Dataset(
data_path,
"val",
Expand Down Expand Up @@ -160,5 +168,6 @@
print(cf)

elif model_type == "fusion":
# TODO: Implement technique for post-training debiasing of neural network
# Use adversarial learning to debias model
pass
raise NotImplementedError

0 comments on commit 6e5a0fc

Please sign in to comment.