From a6986edcacf09d8db67feaceb843d1eda974df11 Mon Sep 17 00:00:00 2001 From: Brian Healy <42810347+bfhealy@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:06:44 -0500 Subject: [PATCH] Add separate plot_model arg to training (#582) --- scope/scope_class.py | 40 ++++++++-------------------------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/scope/scope_class.py b/scope/scope_class.py index b9ac4fa6..61de9c39 100755 --- a/scope/scope_class.py +++ b/scope/scope_class.py @@ -674,6 +674,11 @@ def parse_run_train(self): action="store_true", help="if set, generate/save diagnostic training plots", ) + parser.add_argument( + "--plot-model", + action="store_true", + help="if set, plot model architecture", + ) parser.add_argument( "--weights-only", action="store_true", @@ -688,37 +693,6 @@ def parse_run_train(self): args, _ = parser.parse_known_args() self.train(**vars(args)) - # args to add for ds.make (override config-specified values) - # threshold - # balance - # weight_per_class (test this to make sure it works as intended) - # scale_features - # test_size - # val_size - # random_state - # feature_stats - # batch_size - # shuffle_buffer_size - # epochs - # float_convert_types - - # Args to add with descriptions (or references to tf docs) - # lr - # beta_1 - # beta_2 - # epsilon - # decay - # amsgrad - # momentum - # monitor - # patience - # callbacks - # run_eagerly - # pre_trained_model - # save - # plot - # weights_only - def train( self, tag: str, @@ -756,6 +730,7 @@ def train( pre_trained_model: str = None, save: bool = False, plot: bool = False, + plot_model: bool = False, weights_only: bool = False, skip_cv: bool = False, **kwargs, @@ -797,6 +772,7 @@ def train( :param pre_trained_model: name of dnn pre-trained model to load, if any (str) :param save: if set, save trained model (bool) :param plot: if set, generate/save diagnostic training plots (bool) + :param plot_model: if set, plot model architecture (bool) :param weights_only: if set and pre-trained model specified, load only weights (bool) :param skip_cv: if set, skip XGB cross-validation (bool) @@ -1121,7 +1097,7 @@ def train( amsgrad=amsgrad, ) - if plot: + if plot_model: tf.keras.utils.plot_model( classifier.model, to_file=self.base_path / "DNN.pdf",