Skip to content

Commit

Permalink
Merge branch 'aserepo_develop' into dev_wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
Zekun Lou committed Jun 21, 2024
2 parents 912334d + 3e6eb77 commit 9c259c2
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 100 deletions.
99 changes: 54 additions & 45 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import multiprocessing as mp
import os
import random
from functools import partial
from glob import glob
from typing import List, Tuple

Expand Down Expand Up @@ -92,6 +93,27 @@ def get_prime_factors(n: int):
return factors


# Define Task for Multiprocessiing
def multi_train_hdf5(process, args, split_train, drop_last):
with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)


def multi_valid_hdf5(process, args, split_valid, drop_last):
with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)


def multi_test_hdf5(process, name, args, split_test, drop_last):
with h5py.File(
args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w"
) as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)


def main() -> None:
"""
This script loads an xyz dataset and prepares
Expand Down Expand Up @@ -172,47 +194,42 @@ def run(args: argparse.Namespace):
if len(collections.train) % 2 == 1:
drop_last = True

# Define Task for Multiprocessiing
def multi_train_hdf5(process):
with h5py.File(args.h5_prefix + "train/train_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_train[process], process, f)

multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_train_hdf5, args=[i])
p = mp.Process(target=multi_train_hdf5_, args=[i])
p.start()
processes.append(p)

for i in processes:
i.join()


logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")

# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str(z_table.zs),
"r_max": args.r_max,
}

with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)
if args.compute_statistics:
logging.info("Computing statistics")
if len(atomic_energies_dict) == 0:
atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table)
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
logging.info(f"Mean: {mean}")
logging.info(f"Standard deviation: {std}")

# save the statistics as a json
statistics = {
"atomic_energies": str(atomic_energies_dict),
"avg_num_neighbors": avg_num_neighbors,
"mean": mean,
"std": std,
"atomic_numbers": str(z_table.zs),
"r_max": args.r_max,
}

with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514
json.dump(statistics, f)

logging.info("Preparing validation set")
if args.shuffle:
Expand All @@ -222,36 +239,28 @@ def multi_train_hdf5(process):
if len(collections.valid) % 2 == 1:
drop_last = True

def multi_valid_hdf5(process):
with h5py.File(args.h5_prefix + "val/val_" + str(process)+".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_valid[process], process, f)

multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last)
processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_valid_hdf5, args=[i])
p = mp.Process(target=multi_valid_hdf5_, args=[i])
p.start()
processes.append(p)

for i in processes:
i.join()

if args.test_file is not None:
def multi_test_hdf5(process, name):
with h5py.File(args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w") as f:
f.attrs["drop_last"] = drop_last
save_configurations_as_HDF5(split_test[process], process, f)

logging.info("Preparing test sets")
for name, subset in collections.tests:
drop_last = False
if len(subset) % 2 == 1:
drop_last = True
split_test = np.array_split(subset, args.num_process)
multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last)

processes = []
for i in range(args.num_process):
p = mp.Process(target=multi_test_hdf5, args=[i, name])
p = mp.Process(target=multi_test_hdf5_, args=[i, name])
p.start()
processes.append(p)

Expand Down
20 changes: 16 additions & 4 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def run(args: argparse.Namespace) -> None:
logging.info(
f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license."
)
model_foundation = mace_off(
calc = mace_off(
model=model_type,
device=args.device,
default_dtype=args.default_dtype,
return_raw_model=True,
)
model_foundation = calc.models[0]
else:
model_foundation = torch.load(args.foundation_model, map_location=device)
logging.info(
Expand Down Expand Up @@ -354,7 +354,7 @@ def run(args: argparse.Namespace) -> None:

# Selecting outputs
compute_virials = False
if args.loss in ("stress", "virials", "huber"):
if args.loss in ("stress", "virials", "huber", "universal"):
compute_virials = True
args.compute_stress = True
args.error_table = "PerAtomRMSEstressvirials"
Expand Down Expand Up @@ -558,15 +558,27 @@ def run(args: argparse.Namespace) -> None:
],
lr=args.lr,
amsgrad=args.amsgrad,
betas=(args.beta, 0.999),
)

optimizer: torch.optim.Optimizer
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(**param_options)
elif args.optimizer == "schedulefree":
try:
from schedulefree import adamw_schedulefree
except ImportError as exc:
raise ImportError(
"`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`"
) from exc
_param_options = {k: v for k, v in param_options.items() if k != "amsgrad"}
optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options)
else:
optimizer = torch.optim.Adam(**param_options)

logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train")
logger = tools.MetricsLogger(
directory=args.results_dir, tag=tag + "_train"
) # pylint: disable=E1123

lr_scheduler = LRScheduler(optimizer, args)

Expand Down
3 changes: 1 addition & 2 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor:
# energy: [n_graphs, ]
configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ]
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,]
return torch.mean(
configs_weight
* configs_stress_weight
* torch.square((ref["stress"] - pred["stress"]) / num_atoms)
* torch.square(ref["stress"] - pred["stress"])
) # []


Expand Down
8 changes: 7 additions & 1 deletion mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
help="Optimizer for parameter optimization",
type=str,
default="adam",
choices=["adam", "adamw"],
choices=["adam", "adamw", "schedulefree"],
)
parser.add_argument(
"--beta",
help="Beta parameter for the optimizer",
type=float,
default=0.9,
)
parser.add_argument("--batch_size", help="batch size", type=int, default=10)
parser.add_argument(
Expand Down
21 changes: 17 additions & 4 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,18 @@ def radial_to_transform(radial):
"num_interactions": model.num_interactions.item(),
"num_elements": len(model.atomic_numbers),
"hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)),
"MLP_irreps": o3.Irreps(str(model.readouts[-1].hidden_irreps)),
"gate": model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f,
"MLP_irreps": (
o3.Irreps(str(model.readouts[-1].hidden_irreps))
if model.num_interactions.item() > 1
else 1
),
"gate": (
model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f
if model.num_interactions.item() > 1
else None
),
"atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(),
"avg_num_neighbors": model.interactions[0].avg_num_neighbors,
"atomic_numbers": model.atomic_numbers,
Expand Down Expand Up @@ -373,6 +381,9 @@ def custom_key(key):
class LRScheduler:
def __init__(self, optimizer, args) -> None:
self.scheduler = args.scheduler
self._optimizer_type = (
args.optimizer
) # Schedulefree does not need an optimizer but checkpoint handler does.
if args.scheduler == "ExponentialLR":
self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer=optimizer, gamma=args.lr_scheduler_gamma
Expand All @@ -387,6 +398,8 @@ def __init__(self, optimizer, args) -> None:
raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'")

def step(self, metrics=None, epoch=None): # pylint: disable=E1123
if self._optimizer_type == "schedulefree":
return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary
if self.scheduler == "ExponentialLR":
self.lr_scheduler.step(epoch=epoch)
elif self.scheduler == "ReduceLROnPlateau":
Expand Down
5 changes: 4 additions & 1 deletion mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def train(
# Train
if distributed:
train_sampler.set_epoch(epoch)

if "ScheduleFree" in type(optimizer).__name__:
optimizer.train()
train_one_epoch(
model=model,
loss_fn=loss_fn,
Expand All @@ -201,6 +202,8 @@ def train(
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
if "ScheduleFree" in type(optimizer).__name__:
optimizer.eval()
with param_context:
valid_loss, eval_metrics = evaluate(
model=model_to_evaluate,
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ python_requires = >=3.7
install_requires =
torch>=1.12
e3nn==0.4.4
numpy
numpy<2.0
opt_einsum
ase
torch-ema
Expand Down Expand Up @@ -50,3 +50,4 @@ dev =
pre-commit
pytest
pylint
schedulefree = schedulefree
Loading

0 comments on commit 9c259c2

Please sign in to comment.