diff --git a/CHANGES.rst b/CHANGES.rst index edccbd1..76a9863 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,12 @@ +1.1.0 (unreleased) +================== + +new features +------------ + +- `architect.builder.Builder.save_model` uses preferred keras archive format by default [#50] + + 1.0.1 (2024-04-03) ================== diff --git a/spacekit/builder/architect.py b/spacekit/builder/architect.py index 7f0d03d..f8a27c8 100644 --- a/spacekit/builder/architect.py +++ b/spacekit/builder/architect.py @@ -145,7 +145,7 @@ def load_pretrained_network(self, arch=None): self.log.error(err) sys.exit(1) model_src = "spacekit.builder.trained_networks" - archive_file = f"{arch}.zip" # hst_cal.zip | jwt_cal.zip | svm_align.zip + archive_file = f"{arch}.zip" # hst_cal.zip | jwst_cal.zip | svm_align.zip with importlib.resources.path(model_src, archive_file) as mod: self.model_path = mod if self.blueprint is None: @@ -338,7 +338,16 @@ def set_callbacks(self, patience=15): self.callbacks = [checkpoint_cb, early_stopping_cb] return self.callbacks - def save_model(self, weights=True, output_path="."): + def save_keras_model(self, model_path): + dpath = os.path.dirname(model_path) + name = os.path.basename(model_path) + if not name.endswith("keras"): + name += ".keras" + keras_model_path = os.path.join(dpath, name) + self.model.save(keras_model_path) + self.model_path = keras_model_path + + def save_model(self, weights=True, output_path=".", keras_archive=True): """The model architecture, and training configuration (including the optimizer, losses, and metrics) are stored in saved_model.pb. The weights are saved in the variables/ directory. @@ -348,6 +357,8 @@ def save_model(self, weights=True, output_path="."): save weights learned by the model separately also, by default True output_path : str, optional where to save the model files, by default "." + keras_archive : bool, optional + save model using new (preferred) keras archive format, by default True """ if self.name is None: self.name = str(self.model.name_scope().rstrip("/")) @@ -357,6 +368,22 @@ def save_model(self, weights=True, output_path="."): model_name = self.name model_path = os.path.join(output_path, "models", model_name) + + if keras_archive is True: + self.save_keras_model(model_path) + else: + self.model.save(model_path) + if weights is True: + weights_path = f"{model_path}/weights/ckpt" + self.model.save_weights(weights_path) + for root, _, files in os.walk(model_path): + indent = " " * root.count(os.sep) + print("{}{}/".format(indent, os.path.basename(root))) + for filename in files: + print("{}{}".format(indent + " ", filename)) + self.model_path = model_path + + weights_path = f"{model_path}/weights/ckpt" self.model.save(model_path) if weights is True: