Skip to content

Commit

Permalink
Adding parameters to optimizer automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsSchaaf committed Jul 19, 2024
1 parent 1700cae commit f328bbd
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ def run(args: argparse.Namespace) -> None:
betas=(args.beta, 0.999),
)

param_options = add_all_parameters_to_optimizer(param_options, model.named_parameters())

optimizer: torch.optim.Optimizer
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(**param_options)
Expand Down Expand Up @@ -859,5 +861,40 @@ def run(args: argparse.Namespace) -> None:
torch.distributed.destroy_process_group()


def add_all_parameters_to_optimizer(optimizer_params, all_named_params) -> None:
"""Adds parameters in all_named_parameters to optimizer_parameters with default settings
Args:
optimizer_params: dict, optimizer parameters
all_named_params: dict, all named parameters [eg. model.named_parameters()]
"""
if not isinstance(all_named_params, dict):
all_named_params = dict(all_named_params)

all_params = set(all_named_params.values())
explicit_params = set()
for group in optimizer_params["params"]:
explicit_params.update(group["params"])
implicit_params = all_params - explicit_params
if implicit_params:
implicit_param_names = [
name for name, param in all_named_params.items() if param in implicit_params
]
logging.warning(
f"Adding {len(implicit_params)} Parameters to Optimizer - Previously not added:"
)
for name in implicit_param_names:
logging.warning(f" Adding: {name}")
optimizer_params["params"].append(
{
"name": "default",
"params": list(implicit_params),
"weight_decay": 0.0, # default weight decay
}
)

return optimizer_params


if __name__ == "__main__":
main()

0 comments on commit f328bbd

Please sign in to comment.