Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved selection of configs for multihead pretrained head #570

Open
wants to merge 2 commits into
base: multihead-merge
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 123 additions & 106 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
from typing import List

from tqdm import tqdm

import ase.data
import ase.io
import numpy as np
Expand Down Expand Up @@ -82,8 +84,8 @@ def parse_args() -> argparse.Namespace:
"--filtering_type",
help="filtering type",
type=str,
choices=[None, "combinations", "exclusive", "inclusive"],
default=None,
choices=["none", "subset", "exact", "superset", "any_overlap"],
default="subset",
)
parser.add_argument(
"--weight_ft",
Expand Down Expand Up @@ -114,40 +116,36 @@ def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None:


def filter_atoms(
atoms: ase.Atoms, element_subset: List[str], filtering_type: str
atoms: ase.Atoms, selected_elements: List[str], filtering_type: str
) -> bool:
"""
Filters atoms based on the provided filtering type and element subset.

Parameters:
atoms (ase.Atoms): The atoms object to filter.
element_subset (list): The list of elements to consider during filtering.
selected_elements (list): The list of elements to consider during filtering.
filtering_type (str): The type of filtering to apply. Can be 'none', 'exclusive', or 'inclusive'.
'none' - No filtering is applied.
'combinations' - Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present.
'exclusive' - Return true if `atoms` contains *only* elements in the subset, false otherwise.
'inclusive' - Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements.
'exact' - Return true if `atoms` is composed of exactly the same elements as the `seleted_elements`, false otherwise
'subset' - Return true if `atoms` is composed of a subset of elements in `selected_elements`, false otherwise
'superset' - Return true if `atoms` is composed of a superset of elements in `selected_elements`, false otherwise
`any_overlap` - Return true if `atoms` contains any of the elements in `selected_elements`

Returns:
bool: True if the atoms pass the filter, False otherwise.
"""
if filtering_type == "none":
return True
if filtering_type == "combinations":
atom_symbols = np.unique(atoms.symbols)
return all(
x in element_subset for x in atom_symbols
) # atoms must *only* contain elements in the subset
if filtering_type == "exclusive":
atom_symbols = set(list(atoms.symbols))
return atom_symbols == set(element_subset)
if filtering_type == "inclusive":
atom_symbols = np.unique(atoms.symbols)
return all(
x in atom_symbols for x in element_subset
) # atoms must *at least* contain elements in the subset
if filtering_type == "exact":
return set(atoms.symbols) == set(selected_elements)
if filtering_type == "subset":
return set(atoms.symbols).issubset(selected_elements)
if filtering_type == "superset":
return set(selected_elements).issubset(atoms.symbols)
if filtering_type == "any_overlap":
return len(set(selected_elements) & set(atoms.symbols)) >= 1
raise ValueError(
f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'."
f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'subset', 'exact', 'superset', or 'any_overlap'"
)


Expand Down Expand Up @@ -204,108 +202,127 @@ def assemble_descriptors(self) -> np.ndarray:
def select_samples(
args: argparse.Namespace,
) -> None:
# setup
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.model in ["small", "medium", "large"]:
calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype)
else:
calc = MACECalculator(
model_paths=args.model, device=args.device, default_dtype=args.default_dtype
)

# read finetuning set
if isinstance(args.configs_ft, str):
atoms_list_ft = ase.io.read(args.configs_ft, index=":")
atoms_list_ft = list(tqdm(ase.io.iread(args.configs_ft, index=":"), desc=f"reading configs_ft {args.configs_ft}"))
else:
atoms_list_ft = []
for path in args.configs_ft:
atoms_list_ft += ase.io.read(path, index=":")
atoms_list_ft += list(tqdm(ase.io.iread(path, index=":"), desc=f"reading configs_ft item {path}"))

# read pretrained set
atoms_list_pt = list(tqdm(ase.io.iread(args.configs_pt, index=":"), desc="reading configs_pt"))

if args.filtering_type is not None:
all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms])
indices_pt_filtered = []
atoms_list_pt_filtered = []

# do filtering by elements
if args.filtering_type != "none":
all_species_ft = {atom.symbol for atoms in atoms_list_ft for atom in atoms}
logging.info(
"Filtering configurations based on the finetuning set, "
f"filtering type: combinations, elements: {all_species_ft}"
f"filtering type: {args.filtering_type}, elements: {all_species_ft}"
)
if args.descriptors is not None:
logging.info("Loading descriptors")
descriptors = np.load(args.descriptors, allow_pickle=True)
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
if len(atoms_list_pt_filtered) <= args.num_samples:
logging.info(
f"Number of configurations after filtering {len(atoms_list_pt_filtered)} "
f"is less than the number of samples {args.num_samples}, "
"selecting random configurations for the rest."
)
atoms_list_pt_minus_filtered = [
x for x in atoms_list_pt if x not in atoms_list_pt_filtered
]
atoms_list_pt_random_inds = np.random.choice(
list(range(len(atoms_list_pt_minus_filtered))),
args.num_samples - len(atoms_list_pt_filtered),
replace=False,
)
atoms_list_pt = atoms_list_pt_filtered + [
atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds
]

# select by requested strategy
pt_filter = [filter_atoms(atoms, all_species_ft, args.filtering_type) for atoms in atoms_list_pt]
if sum(pt_filter) <= args.num_samples:
# few enough to include all, will be supplemented by FPS/random later
logging.info(f"Found few enough to include all {sum(pt_filter)} filtered by elements")
indices_pt_filtered = np.where(pt_filter)[0]
else:
atoms_list_pt = atoms_list_pt_filtered
# too many, select by increasingly generous strategy and within each one, match in composition
# [NB should we allow setting of exponential base relating overlap and probability, currently 10.0 ?]]
logging.info(f"Found too many filtered by elements {sum(pt_filter)}, choosing based on composition match")
# try increasingly generous matching strategies
indices_pt_filtered_orig = set(np.where(pt_filter)[0])
indices_pt_filtered = set()
for strategy in ("exact", "subset", "any_overlap"):
strategy_filter = [filter_atoms(atoms, all_species_ft, strategy) for atoms in atoms_list_pt]
if sum(strategy_filter) == 0:
logging.info(f"Nothing selected by {strategy}")
continue
indices_pt_strategy = set(np.where(strategy_filter)[0]) & indices_pt_filtered_orig
indices_pt_strategy -= indices_pt_filtered
if len(indices_pt_filtered) + len(indices_pt_strategy) <= args.num_samples:
# can include all of these
indices_pt_filtered |= indices_pt_strategy
logging.info(f"Adding all {len(indices_pt_strategy)} selected by {strategy}")
else:
# pick a subset with weights, penalizing missing and extra elements
# first term is number of elements that are missing from each config
# second term is number of elements that are extra in each config
#
# for exact distances should all be 0
# for subset should only have missing elements, no extra (first term only)
# for any_overlap could have either/both, add them up (both terms)
indices_pt_strategy = list(indices_pt_strategy)
d = np.asarray([len(all_species_ft - set(atoms_list_pt[ind].symbols)) +
len(set(atoms_list_pt[ind].symbols) - all_species_ft) for ind in indices_pt_strategy])
p = 10.0 ** (-d)
p /= np.sum(p)
inds = np.random.choice(len(indices_pt_strategy), args.num_samples - len(indices_pt_filtered), replace=False, p=p)
logging.info(f"Adding subset len {len(inds)} randomly chosen from those selected by {strategy}")
indices_pt_filtered |= {indices_pt_strategy[ind] for ind in inds}
# we already had too many, don't check more generous strategies
break

else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
if args.descriptors is not None:
logging.info(
f"Loading descriptors for the pretraining set from {args.descriptors}"
)
descriptors = np.load(args.descriptors, allow_pickle=True)
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
# actually do filtering by composition done so far
atoms_list_pt_filtered = [atoms_list_pt[ind] for ind in indices_pt_filtered]

# get additional configs from across DB
# [NB: should we be able to control this size separately from size set chosen by filtering?]
atoms_list_pt_extra = []
if len(atoms_list_pt_filtered) < args.num_samples:
logging.info(
f"Number of configurations after filtering {len(atoms_list_pt_filtered)} "
f"< {args.num_samples} number of samples, "
f"selecting the rest with {args.subselect}"
)

indices_pt_avail = set(list(range(len(atoms_list_pt)))) - set(indices_pt_filtered)
atoms_list_pt_avail = [atoms_list_pt[ind] for ind in indices_pt_avail]

if args.num_samples is not None and args.num_samples < len(atoms_list_pt):
if args.subselect == "fps":
if args.descriptors is None:
if args.subselect == "random":
logging.info("Selecting configurations randomly")
idx_pt = np.random.choice(len(atoms_list_pt_avail), args.num_samples - len(atoms_list_pt_filtered), replace=False)
elif args.subselect == "fps":
if args.descriptors is not None:
logging.info(f"Loading descriptors from {args.descriptors}")
descriptors = np.load(args.descriptors, allow_pickle=True)
for descriptor, atoms in zip(descriptors, atoms_list_pt):
atoms.info["mace_descriptors"] = descriptor
else:
logging.info("Calculating descriptors for the pretraining set")
# [NB Not great that this parsing of args.model happens here as well as other places. Refactor?]
if args.model in ["small", "medium", "large"]:
calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype)
else:
calc = MACECalculator(
model_paths=args.model, device=args.device, default_dtype=args.default_dtype
)
calculate_descriptors(atoms_list_pt, calc)
descriptors_list = [
atoms.info["mace_descriptors"] for atoms in atoms_list_pt
]
logging.info(
f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}"
)
np.save(
args.output.replace(".xyz", "_descriptors.npy"), descriptors_list
)
descriptors_list = [atoms.info["mace_descriptors"] for atoms in atoms_list_pt]

descriptors_file = args.output.replace(".xyz", "descriptors.npy")
logging.info(f"Saving descriptors at {descriptors_file}")
np.save(descriptors_file, descriptors_list)

logging.info("Selecting configurations using Farthest Point Sampling")
try:
fps_pt = FPS(atoms_list_pt, args.num_samples)
idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations")
except Exception as e: # pylint: disable=W0703
logging.error(
f"FPS failed, selecting random configurations instead: {e}"
)
idx_pt = np.random.choice(
list(range(len(atoms_list_pt))), args.num_samples, replace=False
)
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
fps_pt = FPS(atoms_list_pt_avail, args.num_samples - len(atoms_list_pt_filtered))
idx_pt = fps_pt.run()
else:
logging.info("Selecting random configurations")
idx_pt = np.random.choice(
list(range(len(atoms_list_pt))), args.num_samples, replace=False
)
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
raise ValueError(f"subselect type {args.subselect} not 'random' or 'fps'")

logging.info(f"Selected {len(idx_pt)} configurations")
atoms_list_pt_extra = [atoms_list_pt_avail[i] for i in idx_pt]

atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_extra

for atoms in atoms_list_pt:
# del atoms.info["mace_descriptors"]
atoms.info["pretrained"] = True
Expand Down
6 changes: 6 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=int,
default=1000,
)
parser.add_argument(
"--filtering_type_pt",
help="strategy for filtering of configurations for pretrained head",
choices=["none", "subset", "exact", "superset", "any_overlap"],
default="subset"
)
parser.add_argument(
"--subselect_pt",
help="Method to subselect the configurations of the pretraining set",
Expand Down
4 changes: 2 additions & 2 deletions mace/tools/multihead_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def assemble_mp_data(
"head_ft": "Default",
"weight_pt": args.weight_pt_head,
"weight_ft": 1.0,
"filtering_type": "combination",
"output": f"mp_finetuning-{tag}.xyz",
"descriptors": descriptors_mp,
"filtering_type": args.filtering_type_pt,
"subselect": args.subselect_pt,
"device": args.device,
"default_dtype": args.default_dtype,
Expand All @@ -179,4 +179,4 @@ def assemble_mp_data(
)
return collections_mp
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc
raise RuntimeError("Failed to assemble pretrained data") from exc
Loading