diff --git a/src/postprocess.py b/src/postprocess.py index 05a1c0f..ce2712d 100644 --- a/src/postprocess.py +++ b/src/postprocess.py @@ -44,6 +44,14 @@ parser.add_argument( "--val", default=None, help="List of ids to use for validation." ) + parser.add_argument( + "mitigated_model_path", + type=str, + help="Path to the saved and mitigated model.", + ) + parser.add_argument( + "--test", default=None, help="List of ids to use for validation." + ) parser.add_argument( "--attribute", default="gender", @@ -111,8 +119,6 @@ model.fit(x_train, y_train, sensitive_features=metadata) - plot_threshold_optimizer(model) - # Save unbiased model to disk log_dir = os.path.dirname(args.model_path) @@ -121,7 +127,9 @@ ) as f: pickle.dump(model, f) - print("Evaluating on validation data...") + plot_threshold_optimizer(model) + + print("Evaluating on validation/test data...") validation_set = MIMIC4Dataset( data_path, "val", @@ -160,5 +168,6 @@ print(cf) elif model_type == "fusion": + # TODO: Implement technique for post-training debiasing of neural network # Use adversarial learning to debias model - pass + raise NotImplementedError