From 1df6b0322e910a7bddffb1425598bbacc08d62e2 Mon Sep 17 00:00:00 2001 From: Evan Lim Date: Tue, 16 Jul 2024 14:26:37 -0500 Subject: [PATCH 1/3] Added validation graph option and check for overwriting previous runs --- mist/main.py | 3 +++ mist/runtime/args.py | 1 + mist/runtime/run.py | 27 +++++++++++++++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/mist/main.py b/mist/main.py index 6ab9fbe..e581130 100755 --- a/mist/main.py +++ b/mist/main.py @@ -91,6 +91,9 @@ def main(args): if __name__ == "__main__": set_warning_levels() args = get_main_args() + + assert not os.path.exists(os.path.join(args.results, "results.csv")), \ + "Results folder already contains a previous run" if args.loss in ["bl", "hdl", "gsl"]: args.use_dtm = True diff --git a/mist/runtime/args.py b/mist/runtime/args.py index 1be9379..9d7692d 100755 --- a/mist/runtime/args.py +++ b/mist/runtime/args.py @@ -166,6 +166,7 @@ def get_main_args(): p.arg("--steps-per-epoch", type=positive_int, help="Steps per epoch. By default ceil(training_dataset_size / (batch_size * gpus)") + p.arg("--val-graph", type=bool, default=False, help="Output convergence graph for validation loss") # Evaluation p.arg("--metrics", diff --git a/mist/runtime/run.py b/mist/runtime/run.py index d1f85c7..88501a4 100755 --- a/mist/runtime/run.py +++ b/mist/runtime/run.py @@ -4,6 +4,7 @@ import pandas as pd import numpy as np from sklearn.model_selection import train_test_split +import matplotlib.pyplot as plt # Rich progress bar from rich.console import Console @@ -42,6 +43,7 @@ create_pretrained_config_file, get_progress_bar, AlphaSchedule, + create_empty_dir ) console = Console() @@ -361,6 +363,9 @@ def val_step(image, label): return self.val_loss(label, pred) + # Tabulate loss data for convergence plot + all_loss_data = [] + for epoch in range(self.args.epochs): # Make sure gradient tracking is on, and do a pass over the data model.train(True) @@ -438,6 +443,10 @@ def val_step(image, label): else: text = Text(f"Validation loss did NOT improve from {best_loss:.4}\n") console.print(text) + + # Collect validation loss data + if(self.args.val_graph): + all_loss_data.append(best_loss) else: for i in range(val_steps): # Get data from validation loader @@ -461,6 +470,24 @@ def val_step(image, label): # Reset running losses for new epoch running_loss_train.reset_states() running_loss_validation.reset_states() + + # Writes loss data to file and graph + if(self.args.val_graph): + + val_data_dir = os.path.join(self.args.results, "convergence_data") + create_empty_dir(val_data_dir) + val_data_loc = os.path.join(val_data_dir, "fold_{}.txt".format(fold)) + conv_map = open(val_data_loc, "w") + conv_map.write(str(all_loss_data)) + conv_map.close() + + # Output graph from validation loss data + plt.plot(list(range(len(all_loss_data))), all_loss_data, label="fold_{}".format(fold)) + plt.title("Validation Loss by Fold".format(fold)) + plt.xlabel("Epoch") + plt.ylabel("Loss Validation") + plt.legend() + plt.savefig(os.path.join(val_data_dir, "loss_val_graph.png".format(fold))) dist.barrier() if rank == 0: From acb00357b9d368a1390c6049d8cf31a5733c3939 Mon Sep 17 00:00:00 2001 From: Evan Lim Date: Wed, 17 Jul 2024 14:33:41 -0500 Subject: [PATCH 2/3] Added --overwrite flag for overwriting previous runs --- mist/main.py | 6 +++--- mist/runtime/args.py | 2 +- mist/runtime/run.py | 27 --------------------------- 3 files changed, 4 insertions(+), 31 deletions(-) diff --git a/mist/main.py b/mist/main.py index e581130..7a08df1 100755 --- a/mist/main.py +++ b/mist/main.py @@ -91,9 +91,9 @@ def main(args): if __name__ == "__main__": set_warning_levels() args = get_main_args() - - assert not os.path.exists(os.path.join(args.results, "results.csv")), \ - "Results folder already contains a previous run" + if not args.overwrite: + assert not os.path.exists(os.path.join(args.results, "results.csv")), \ + "Results folder already contains a previous run" if args.loss in ["bl", "hdl", "gsl"]: args.use_dtm = True diff --git a/mist/runtime/args.py b/mist/runtime/args.py index 9d7692d..8c140b9 100755 --- a/mist/runtime/args.py +++ b/mist/runtime/args.py @@ -60,6 +60,7 @@ def get_main_args(): p.arg("--master-port", type=str, default="12355", help="Master port for multi-gpu training") p.arg("--seed_val", type=non_negative_int, default=42, help="Random seed") p.boolean_flag("--tta", default=False, help="Enable test time augmentation") + p.boolean_flag("--overwrite", default=False, help="Overwrites previous run at specified results folder") # Output p.arg("--results", type=str, help="Path to output of MIST pipeline") @@ -166,7 +167,6 @@ def get_main_args(): p.arg("--steps-per-epoch", type=positive_int, help="Steps per epoch. By default ceil(training_dataset_size / (batch_size * gpus)") - p.arg("--val-graph", type=bool, default=False, help="Output convergence graph for validation loss") # Evaluation p.arg("--metrics", diff --git a/mist/runtime/run.py b/mist/runtime/run.py index 88501a4..d1f85c7 100755 --- a/mist/runtime/run.py +++ b/mist/runtime/run.py @@ -4,7 +4,6 @@ import pandas as pd import numpy as np from sklearn.model_selection import train_test_split -import matplotlib.pyplot as plt # Rich progress bar from rich.console import Console @@ -43,7 +42,6 @@ create_pretrained_config_file, get_progress_bar, AlphaSchedule, - create_empty_dir ) console = Console() @@ -363,9 +361,6 @@ def val_step(image, label): return self.val_loss(label, pred) - # Tabulate loss data for convergence plot - all_loss_data = [] - for epoch in range(self.args.epochs): # Make sure gradient tracking is on, and do a pass over the data model.train(True) @@ -443,10 +438,6 @@ def val_step(image, label): else: text = Text(f"Validation loss did NOT improve from {best_loss:.4}\n") console.print(text) - - # Collect validation loss data - if(self.args.val_graph): - all_loss_data.append(best_loss) else: for i in range(val_steps): # Get data from validation loader @@ -470,24 +461,6 @@ def val_step(image, label): # Reset running losses for new epoch running_loss_train.reset_states() running_loss_validation.reset_states() - - # Writes loss data to file and graph - if(self.args.val_graph): - - val_data_dir = os.path.join(self.args.results, "convergence_data") - create_empty_dir(val_data_dir) - val_data_loc = os.path.join(val_data_dir, "fold_{}.txt".format(fold)) - conv_map = open(val_data_loc, "w") - conv_map.write(str(all_loss_data)) - conv_map.close() - - # Output graph from validation loss data - plt.plot(list(range(len(all_loss_data))), all_loss_data, label="fold_{}".format(fold)) - plt.title("Validation Loss by Fold".format(fold)) - plt.xlabel("Epoch") - plt.ylabel("Loss Validation") - plt.legend() - plt.savefig(os.path.join(val_data_dir, "loss_val_graph.png".format(fold))) dist.barrier() if rank == 0: From a317ebc3227e6f3c2d796b6651802ded8e645cca Mon Sep 17 00:00:00 2001 From: Evan Lim Date: Wed, 17 Jul 2024 14:38:33 -0500 Subject: [PATCH 3/3] Small change to no overwrite message --- mist/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mist/main.py b/mist/main.py index 7a08df1..d3966dc 100755 --- a/mist/main.py +++ b/mist/main.py @@ -93,7 +93,7 @@ def main(args): args = get_main_args() if not args.overwrite: assert not os.path.exists(os.path.join(args.results, "results.csv")), \ - "Results folder already contains a previous run" + "Results folder already contains a previous run. Enable --overwrite to overwrite the previous run" if args.loss in ["bl", "hdl", "gsl"]: args.use_dtm = True