Skip to content

Commit

Permalink
save model using keras archive format by default
Browse files Browse the repository at this point in the history
  • Loading branch information
alphasentaurii committed Apr 3, 2024
1 parent 0018bf9 commit 88b2bec
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
9 changes: 9 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
1.1.0 (unreleased)
==================

new features
------------

- `architect.builder.Builder.save_model` uses preferred keras archive format by default [#50]


1.0.1 (2024-04-03)
==================

Expand Down
31 changes: 29 additions & 2 deletions spacekit/builder/architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_pretrained_network(self, arch=None):
self.log.error(err)
sys.exit(1)
model_src = "spacekit.builder.trained_networks"
archive_file = f"{arch}.zip" # hst_cal.zip | jwt_cal.zip | svm_align.zip
archive_file = f"{arch}.zip" # hst_cal.zip | jwst_cal.zip | svm_align.zip
with importlib.resources.path(model_src, archive_file) as mod:
self.model_path = mod
if self.blueprint is None:
Expand Down Expand Up @@ -338,7 +338,16 @@ def set_callbacks(self, patience=15):
self.callbacks = [checkpoint_cb, early_stopping_cb]
return self.callbacks

def save_model(self, weights=True, output_path="."):
def save_keras_model(self, model_path):
dpath = os.path.dirname(model_path)
name = os.path.basename(model_path)
if not name.endswith("keras"):
name += ".keras"
keras_model_path = os.path.join(dpath, name)
self.model.save(keras_model_path)
self.model_path = keras_model_path

def save_model(self, weights=True, output_path=".", keras_archive=True):
"""The model architecture, and training configuration (including the optimizer, losses, and metrics)
are stored in saved_model.pb. The weights are saved in the variables/ directory.
Expand All @@ -348,6 +357,8 @@ def save_model(self, weights=True, output_path="."):
save weights learned by the model separately also, by default True
output_path : str, optional
where to save the model files, by default "."
keras_archive : bool, optional
save model using new (preferred) keras archive format, by default True
"""
if self.name is None:
self.name = str(self.model.name_scope().rstrip("/"))
Expand All @@ -357,6 +368,22 @@ def save_model(self, weights=True, output_path="."):
model_name = self.name

model_path = os.path.join(output_path, "models", model_name)

if keras_archive is True:
self.save_keras_model(model_path)
else:
self.model.save(model_path)
if weights is True:
weights_path = f"{model_path}/weights/ckpt"
self.model.save_weights(weights_path)
for root, _, files in os.walk(model_path):
indent = " " * root.count(os.sep)
print("{}{}/".format(indent, os.path.basename(root)))
for filename in files:
print("{}{}".format(indent + " ", filename))
self.model_path = model_path


weights_path = f"{model_path}/weights/ckpt"
self.model.save(model_path)
if weights is True:
Expand Down

0 comments on commit 88b2bec

Please sign in to comment.