Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Aug 29, 2024
1 parent e277d32 commit 85c5d9a
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 132 deletions.
9 changes: 5 additions & 4 deletions mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse

import torch
from e3nn.util import jit

from mace.calculators import LAMMPS_MACE


Expand Down Expand Up @@ -42,12 +44,11 @@ def select_head(model):

if selected.isdigit() and 1 <= int(selected) <= len(heads):
return heads[int(selected) - 1]
elif selected == "":
if selected == "":
print("No head selected. Proceeding without specifying a head.")
return None
else:
print(f"No valid selection made. Defaulting to the last head: {heads[-1]}")
return heads[-1]
print(f"No valid selection made. Defaulting to the last head: {heads[-1]}")
return heads[-1]


def main():
Expand Down
125 changes: 21 additions & 104 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@
dict_to_array,
extract_config_mace_model,
get_atomic_energies,
get_avg_num_neighbors,
get_config_type_weights,
get_dataset_from_xyz,
get_files_with_suffix,
get_loss_fn,
get_optimizer,
get_params_options,
get_swa,
print_git_commit,
setup_wandb,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.utils import AtomicNumberTable
Expand Down Expand Up @@ -194,6 +198,7 @@ def run(args: argparse.Namespace) -> None:
head_config.config_type_weights
)
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=head_config.train_file,
valid_path=head_config.valid_file,
valid_fraction=head_config.valid_fraction,
Expand All @@ -219,10 +224,10 @@ def run(args: argparse.Namespace) -> None:

if all(head_config.train_file.endswith(".xyz") for head_config in head_configs):
size_collections_train = sum(
[len(head_config.collections.train) for head_config in head_configs]
len(head_config.collections.train) for head_config in head_configs
)
size_collections_valid = sum(
[len(head_config.collections.valid) for head_config in head_configs]
len(head_config.collections.valid) for head_config in head_configs
)
if size_collections_train < args.batch_size:
logging.error(
Expand Down Expand Up @@ -257,6 +262,7 @@ def run(args: argparse.Namespace) -> None:
else:
heads = list(dict.fromkeys(["pt_head"] + heads))
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=args.pt_train_file,
valid_path=args.pt_valid_file,
valid_fraction=args.valid_fraction,
Expand Down Expand Up @@ -367,8 +373,9 @@ def run(args: argparse.Namespace) -> None:
].item()
for z in z_table.zs
}
atomic_energies_dict_pt = atomic_energies_dict["pt_head"]
logging.info(
f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}"
f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict_pt[z]}' for z in z_table_foundation.zs])}"
)

if args.model == "AtomicDipolesMACE":
Expand All @@ -394,9 +401,14 @@ def run(args: argparse.Namespace) -> None:
# [atomic_energies_dict[z] for z in z_table.zs]
# )
atomic_energies = dict_to_array(atomic_energies_dict, heads)
logging.info(
f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}"
result = "\n".join(
[
f"Atomic Energies used (z: eV) for head {head_config.head_name}: " +
"{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}"
for head_config in head_configs
]
)
logging.info(result)

valid_sets = {head: [] for head in heads}
train_sets = {head: [] for head in heads}
Expand Down Expand Up @@ -488,30 +500,7 @@ def run(args: argparse.Namespace) -> None:
)

loss_fn = get_loss_fn(args, dipole_only, compute_dipole)

if all(head_config.compute_avg_num_neighbors for head_config in head_configs):
logging.info("Computing average number of neighbors")
avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader)
if args.distributed:
num_graphs = torch.tensor(len(train_loader.dataset)).to(device)
num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device)
torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(
num_neighbors, op=torch.distributed.ReduceOp.SUM
)
args.avg_num_neighbors = (num_neighbors / num_graphs).item()
else:
args.avg_num_neighbors = avg_num_neighbors
else:
assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration"
args.avg_num_neighbors = max(head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None)

if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100:
logging.warning(
f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}"
)
else:
logging.info(f"Average number of neighbors: {args.avg_num_neighbors}")
args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device)

# Selecting outputs
compute_virials = False
Expand Down Expand Up @@ -745,61 +734,9 @@ def run(args: argparse.Namespace) -> None:
logging.info(loss_fn)

# Optimizer
decay_interactions = {}
no_decay_interactions = {}
for name, param in model.interactions.named_parameters():
if "linear.weight" in name or "skip_tp_full.weight" in name:
decay_interactions[name] = param
else:
no_decay_interactions[name] = param

param_options = dict(
params=[
{
"name": "embedding",
"params": model.node_embedding.parameters(),
"weight_decay": 0.0,
},
{
"name": "interactions_decay",
"params": list(decay_interactions.values()),
"weight_decay": args.weight_decay,
},
{
"name": "interactions_no_decay",
"params": list(no_decay_interactions.values()),
"weight_decay": 0.0,
},
{
"name": "products",
"params": model.products.parameters(),
"weight_decay": args.weight_decay,
},
{
"name": "readouts",
"params": model.readouts.parameters(),
"weight_decay": 0.0,
},
],
lr=args.lr,
amsgrad=args.amsgrad,
betas=(args.beta, 0.999),
)

param_options = get_params_options(args, model)
optimizer: torch.optim.Optimizer
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(**param_options)
elif args.optimizer == "schedulefree":
try:
from schedulefree import adamw_schedulefree
except ImportError as exc:
raise ImportError(
"`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`"
) from exc
_param_options = {k: v for k, v in param_options.items() if k != "amsgrad"}
optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options)
else:
optimizer = torch.optim.Adam(**param_options)
optimizer = get_optimizer(args, param_options)
if args.device == "xpu":
logging.info("Optimzing model and optimzier for XPU")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
Expand Down Expand Up @@ -846,27 +783,7 @@ def run(args: argparse.Namespace) -> None:
group["lr"] = args.lr

if args.wandb:
logging.info("Using Weights and Biases for logging")
import wandb

wandb_config = {}
args_dict = vars(args)

for key, value in args_dict.items():
if isinstance(value, np.ndarray):
args_dict[key] = value.tolist()

args_dict_json = json.dumps(args_dict)
for key in args.wandb_log_hypers:
wandb_config[key] = args_dict[key]
tools.init_wandb(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.wandb_name,
config=wandb_config,
directory=args.wandb_dir,
)
wandb.run.summary["params"] = args_dict_json
setup_wandb(args)

if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
Expand Down
2 changes: 0 additions & 2 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,6 @@ def __init__(

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ]
return (
self.energy_weight
* self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms)
Expand Down
2 changes: 0 additions & 2 deletions mace/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
compute_rel_rmse,
compute_rmse,
get_atomic_number_table_from_zs,
get_optimizer,
get_tag,
setup_logger,
)
Expand All @@ -46,7 +45,6 @@
"setup_logger",
"get_tag",
"count_parameters",
"get_optimizer",
"MetricsLogger",
"get_atomic_number_table_from_zs",
"train",
Expand Down
1 change: 1 addition & 0 deletions mace/tools/multihead_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def assemble_mp_data(
}
select_samples(dict_to_namespace(args_samples))
collections_mp, _ = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=f"mp_finetuning-{tag}.xyz",
valid_path=None,
valid_fraction=args.valid_fraction,
Expand Down
122 changes: 122 additions & 0 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.optim.swa_utils import SWALR, AveragedModel

from mace import data, modules
from mace import tools
from mace.tools import evaluate
from mace.tools.train import SWAContainer

Expand Down Expand Up @@ -349,6 +350,38 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict:
return atomic_energies_dict


def get_avg_num_neighbors(head_configs, args, train_loader, device):
if all(head_config.compute_avg_num_neighbors for head_config in head_configs):
logging.info("Computing average number of neighbors")
avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader)
if args.distributed:
num_graphs = torch.tensor(len(train_loader.dataset)).to(device)
num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device)
torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(
num_neighbors, op=torch.distributed.ReduceOp.SUM
)
avg_num_neighbors_out = (num_neighbors / num_graphs).item()
else:
avg_num_neighbors_out = avg_num_neighbors
else:
assert any(
head_config.avg_num_neighbors is not None for head_config in head_configs
), "Average number of neighbors must be provided in the configuration"
avg_num_neighbors_out = max(
head_config.avg_num_neighbors
for head_config in head_configs
if head_config.avg_num_neighbors is not None
)
if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100:
logging.warning(
f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}"
)
else:
logging.info(f"Average number of neighbors: {avg_num_neighbors_out}")
return avg_num_neighbors_out


def get_loss_fn(
args: argparse.Namespace,
dipole_only: bool,
Expand Down Expand Up @@ -482,6 +515,95 @@ def get_swa(
return swa, swas


def get_params_options(
args: argparse.Namespace, model: torch.nn.Module
) -> Dict[str, Any]:
decay_interactions = {}
no_decay_interactions = {}
for name, param in model.interactions.named_parameters():
if "linear.weight" in name or "skip_tp_full.weight" in name:
decay_interactions[name] = param
else:
no_decay_interactions[name] = param

param_options = dict(
params=[
{
"name": "embedding",
"params": model.node_embedding.parameters(),
"weight_decay": 0.0,
},
{
"name": "interactions_decay",
"params": list(decay_interactions.values()),
"weight_decay": args.weight_decay,
},
{
"name": "interactions_no_decay",
"params": list(no_decay_interactions.values()),
"weight_decay": 0.0,
},
{
"name": "products",
"params": model.products.parameters(),
"weight_decay": args.weight_decay,
},
{
"name": "readouts",
"params": model.readouts.parameters(),
"weight_decay": 0.0,
},
],
lr=args.lr,
amsgrad=args.amsgrad,
betas=(args.beta, 0.999),
)
return param_options


def get_optimizer(
args: argparse.Namespace, param_options: Dict[str, Any]
) -> torch.optim.Optimizer:
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(**param_options)
elif args.optimizer == "schedulefree":
try:
from schedulefree import adamw_schedulefree
except ImportError as exc:
raise ImportError(
"`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`"
) from exc
_param_options = {k: v for k, v in param_options.items() if k != "amsgrad"}
optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options)
else:
optimizer = torch.optim.Adam(**param_options)
return optimizer


def setup_wandb(args: argparse.Namespace):
logging.info("Using Weights and Biases for logging")
import wandb

wandb_config = {}
args_dict = vars(args)

for key, value in args_dict.items():
if isinstance(value, np.ndarray):
args_dict[key] = value.tolist()

args_dict_json = json.dumps(args_dict)
for key in args.wandb_log_hypers:
wandb_config[key] = args_dict[key]
tools.init_wandb(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.wandb_name,
config=wandb_config,
directory=args.wandb_dir,
)
wandb.run.summary["params"] = args_dict_json


def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]:
return [
os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix)
Expand Down
Loading

0 comments on commit 85c5d9a

Please sign in to comment.