Skip to content

Commit

Permalink
fix model_config lit
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Aug 29, 2024
1 parent 85c5d9a commit 6288b02
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
12 changes: 3 additions & 9 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,7 @@ def run(args: argparse.Namespace) -> None:
z_table_foundation.z_to_index(z)
].item()
for z in z_table.zs
}
logging.info(
f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}"
)
}
else:
atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table)
else:
Expand All @@ -373,11 +370,7 @@ def run(args: argparse.Namespace) -> None:
].item()
for z in z_table.zs
}
atomic_energies_dict_pt = atomic_energies_dict["pt_head"]
logging.info(
f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict_pt[z]}' for z in z_table_foundation.zs])}"
)


if args.model == "AtomicDipolesMACE":
atomic_energies = None
dipole_only = True
Expand Down Expand Up @@ -560,6 +553,7 @@ def run(args: argparse.Namespace) -> None:
args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"]
args.model = "FoundationMACE"
model_config_foundation["heads"] = heads
model_config = model_config_foundation
logging.info("Model configuration extracted from foundation model")
logging.info("Using universal loss function for fine-tuning")
logging.info(
Expand Down
3 changes: 1 addition & 2 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from prettytable import PrettyTable
from torch.optim.swa_utils import SWALR, AveragedModel

from mace import data, modules
from mace import tools
from mace import data, modules, tools
from mace.tools import evaluate
from mace.tools.train import SWAContainer

Expand Down

0 comments on commit 6288b02

Please sign in to comment.