diff --git a/mist/main.py b/mist/main.py index 1e32e03..47acc05 100755 --- a/mist/main.py +++ b/mist/main.py @@ -97,6 +97,9 @@ def main(args): if __name__ == "__main__": set_warning_levels() args = get_main_args() + if not args.overwrite: + assert not os.path.exists(os.path.join(args.results, "results.csv")), \ + "Results folder already contains a previous run. Enable --overwrite to overwrite the previous run" if args.loss in ["bl", "hdl", "gsl"]: args.use_dtm = True diff --git a/mist/runtime/args.py b/mist/runtime/args.py index e28b1fd..63321e2 100755 --- a/mist/runtime/args.py +++ b/mist/runtime/args.py @@ -60,6 +60,7 @@ def get_main_args(): p.arg("--master-port", type=str, default="12355", help="Master port for multi-gpu training") p.arg("--seed_val", type=non_negative_int, default=42, help="Random seed") p.boolean_flag("--tta", default=False, help="Enable test time augmentation") + p.boolean_flag("--overwrite", default=False, help="Overwrites previous run at specified results folder") # Output p.arg("--results", type=str, help="Path to output of MIST pipeline")