From deec8a4a30976e5ee1fe822f12e2646ba69a550f Mon Sep 17 00:00:00 2001 From: ZJaume Date: Tue, 16 Apr 2024 13:56:32 +0000 Subject: [PATCH] Warn the user if layers do not load correctly --- CHANGELOG.md | 1 + src/bicleaner_ai/models.py | 27 +++++++++++++++++++-------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae403c5..c300e99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Compress noise generation intermediate files. - Multilingual models training documentation. - Show a warning when no GPU/TPU has been detected. +- Show a warning if some layers do not correctly. ### Changed: - Huge improvements in accuracy multilingual full models. diff --git a/src/bicleaner_ai/models.py b/src/bicleaner_ai/models.py index 99cdeb0..4280f3e 100644 --- a/src/bicleaner_ai/models.py +++ b/src/bicleaner_ai/models.py @@ -474,15 +474,26 @@ def get_generator(self, batch_size, shuffle): batch_size=batch_size, maxlen=self.settings["maxlen"]) - def load_model(self, model_file): + def load_model(self, model_file, train=False): settings = self.settings - tf_model = TFXLMRBicleanerAI.from_pretrained( - model_file, - num_labels=settings["n_classes"], - head_hidden_size=settings["n_hidden"], - head_dropout=settings["dropout"], - head_activation=settings["activation"]) + tf_model, loading_info = TFXLMRBicleanerAI.from_pretrained( + model_file, + num_labels=settings["n_classes"], + head_hidden_size=settings["n_hidden"], + head_dropout=settings["dropout"], + head_activation=settings["activation"], + output_loading_info=True) + + logging.debug(loading_info) + + # Warn if layers do not load correctly (might be a bug) + # only check missng keys when inference, for training is expected + if loading_info["unexpected_keys"] or loading_info["mismatched_keys"] \ + or (loading_info["missing_keys"] and not train): + logging.warning("Some layers were not initialized when loading model file. " + "Please check that the model has been downloaded correctly " + "or report this error if persists after checking.") return tf_model @@ -558,7 +569,7 @@ def train(self, train_set, dev_set): strategy = tf.distribute.MirroredStrategy() num_devices = strategy.num_replicas_in_sync with strategy.scope(): - self.model = self.load_model(self.settings["base_model"]) + self.model = self.load_model(self.settings["base_model"], train=True) self.model.compile(optimizer=self.settings["optimizer"], loss=SparseCategoricalCrossentropy( from_logits=True),