Skip to content

Commit

Permalink
Added torch DNN trainer and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ipcamit committed Jul 5, 2024
1 parent f919d89 commit 0025474
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 71 deletions.
8 changes: 3 additions & 5 deletions kliff/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def __init__(self, training_manifest: dict, model=None):
"learning_rate": None,
"kwargs": None,
"epochs": 10000,
"stop_condition": None,
"num_workers": None,
"batch_size": 1,
}
Expand Down Expand Up @@ -281,9 +280,6 @@ def parse_manifest(self, manifest: dict):

self.optimizer_manifest |= self.training_manifest.get("optimizer")
self.optimizer_manifest["epochs"] = self.training_manifest.get("epochs", 10000)
self.optimizer_manifest["stop_condition"] = self.training_manifest.get(
"stop_condition", None
)
self.optimizer_manifest["num_workers"] = self.training_manifest.get(
"num_workers", None
)
Expand Down Expand Up @@ -728,6 +724,7 @@ def write_training_env_edn(self, path: str):
pip freeze.
"""
env_file = f"{path}/training_env.edn"
hash = self.get_trainer_hash()
with open(env_file, "w") as f:
try:
from pip._internal.operations.freeze import freeze
Expand All @@ -746,7 +743,8 @@ def write_training_env_edn(self, path: str):

f.write("{\n")
f.write(f'"kliff-version" "{__version__}"\n')
f.write(f'"manifest-hash" "{self.current["run_hash"]}"\n')
f.write(f'"trainer-used" "{type(self).__name__}"\n')
f.write(f'"manifest-hash" "{hash}"\n')
f.write(f'"python-dependencies" [\n')
for module in python_env:
f.write(f' "{module}"\n')
Expand Down
3 changes: 3 additions & 0 deletions kliff/trainer/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(self, manifest, model=None):
self.callbacks = self._get_callbacks()

# setup lightning trainer
self.setup_model() # call setup_model explicitly as it converty torch -> lightning
self.pl_trainer = self._get_pl_trainer()

def setup_model(self):
Expand All @@ -293,6 +294,7 @@ def setup_model(self):
ema = True if self.optimizer_manifest.get("ema", False) else False
if ema:
ema_decay = self.optimizer_manifest.get("ema_decay", 0.99)
logger.info(f"Using Exponential Moving Average with decay rate {ema_decay}")
else:
ema_decay = None

Expand All @@ -312,6 +314,7 @@ def setup_model(self):
lr_scheduler=scheduler.get("name", None),
lr_scheduler_args=scheduler.get("args", None),
)
logger.info("Lightning Model setup complete.")

def train(self):
"""
Expand Down
Loading

0 comments on commit 0025474

Please sign in to comment.