Skip to content

Commit

Permalink
Warn the user if layers do not load correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
ZJaume committed Apr 16, 2024
1 parent 7a214da commit deec8a4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Compress noise generation intermediate files.
- Multilingual models training documentation.
- Show a warning when no GPU/TPU has been detected.
- Show a warning if some layers do not correctly.

### Changed:
- Huge improvements in accuracy multilingual full models.
Expand Down
27 changes: 19 additions & 8 deletions src/bicleaner_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,26 @@ def get_generator(self, batch_size, shuffle):
batch_size=batch_size,
maxlen=self.settings["maxlen"])

def load_model(self, model_file):
def load_model(self, model_file, train=False):
settings = self.settings

tf_model = TFXLMRBicleanerAI.from_pretrained(
model_file,
num_labels=settings["n_classes"],
head_hidden_size=settings["n_hidden"],
head_dropout=settings["dropout"],
head_activation=settings["activation"])
tf_model, loading_info = TFXLMRBicleanerAI.from_pretrained(
model_file,
num_labels=settings["n_classes"],
head_hidden_size=settings["n_hidden"],
head_dropout=settings["dropout"],
head_activation=settings["activation"],
output_loading_info=True)

logging.debug(loading_info)

# Warn if layers do not load correctly (might be a bug)
# only check missng keys when inference, for training is expected
if loading_info["unexpected_keys"] or loading_info["mismatched_keys"] \
or (loading_info["missing_keys"] and not train):
logging.warning("Some layers were not initialized when loading model file. "
"Please check that the model has been downloaded correctly "
"or report this error if persists after checking.")

return tf_model

Expand Down Expand Up @@ -558,7 +569,7 @@ def train(self, train_set, dev_set):
strategy = tf.distribute.MirroredStrategy()
num_devices = strategy.num_replicas_in_sync
with strategy.scope():
self.model = self.load_model(self.settings["base_model"])
self.model = self.load_model(self.settings["base_model"], train=True)
self.model.compile(optimizer=self.settings["optimizer"],
loss=SparseCategoricalCrossentropy(
from_logits=True),
Expand Down

0 comments on commit deec8a4

Please sign in to comment.