Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix case with multihead foundation model #687

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
29 changes: 23 additions & 6 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
LRScheduler,
check_path_ase_read,
convert_to_json_format,
create_error_table,
dict_to_array,
extract_config_mace_model,
get_atomic_energies,
Expand All @@ -49,9 +48,11 @@
get_params_options,
get_swa,
print_git_commit,
remove_pt_head,
setup_wandb,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable


Expand Down Expand Up @@ -115,10 +116,6 @@ def run(args: argparse.Namespace) -> None:
commit = print_git_commit()
model_foundation: Optional[torch.nn.Module] = None
if args.foundation_model is not None:
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
if args.foundation_model in ["small", "medium", "large"]:
logging.info(
f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint."
Expand Down Expand Up @@ -148,6 +145,27 @@ def run(args: argparse.Namespace) -> None:
f"Using foundation model {args.foundation_model} as initial checkpoint."
)
args.r_max = model_foundation.r_max.item()
if (
args.foundation_model not in ["small", "medium", "large"]
and args.pt_train_file is None
):
logging.warning(
"Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file."
)
args.multiheads_finetuning = False
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head
if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1:
logging.warning(
"Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head."
)
model_foundation = remove_pt_head(
model_foundation, args.foundation_head
)
else:
args.multiheads_finetuning = False

Expand Down Expand Up @@ -587,7 +605,6 @@ def run(args: argparse.Namespace) -> None:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None

tools.train(
model=model,
loss_fn=loss_fn,
Expand Down
33 changes: 33 additions & 0 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from argparse import ArgumentParser

import torch

from mace.tools.scripts_utils import remove_pt_head


def main():
parser = ArgumentParser()
parser.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model_single = remove_pt_head(model, args.head_name)
torch.save(model_single, args.output_file)


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str2bool,
default=True,
)
parser.add_argument(
"--foundation_model_head",
help="Name of the head to use for fine-tuning",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--weight_pt_head",
help="Weight of the pretrained head in the loss function",
Expand Down
20 changes: 9 additions & 11 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def load_foundations_elements(
model.interactions[i].linear.weight = torch.nn.Parameter(
model_foundations.interactions[i].linear.weight.clone()
)
if (
model.interactions[i].__class__.__name__
in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
if model.interactions[i].__class__.__name__ in [
"RealAgnosticResidualInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
.skip_tp.weight.reshape(
Expand All @@ -101,19 +101,17 @@ def load_foundations_elements(
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
if (
model.interactions[i].__class__.__name__
in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
if model.interactions[i].__class__.__name__ in [
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
)
.weight
.clone()
).weight.clone()
)
)
# Transferring products
Expand Down
1 change: 0 additions & 1 deletion mace/tools/model_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def configure_model(
model_config_foundation["atomic_inter_shift"] = (
_determine_atomic_inter_shift(args.mean, heads)
)

model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads)
args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"]
args.model = "FoundationMACE"
Expand Down
Loading
Loading