Skip to content

Commit

Permalink
Refined strategy for selecting configs for pretrained multihead
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Aug 26, 2024
1 parent 5e70bfe commit 6d4daa7
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 108 deletions.
229 changes: 123 additions & 106 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
from typing import List

from tqdm import tqdm

import ase.data
import ase.io
import numpy as np
Expand Down Expand Up @@ -82,8 +84,8 @@ def parse_args() -> argparse.Namespace:
"--filtering_type",
help="filtering type",
type=str,
choices=[None, "combinations", "exclusive", "inclusive"],
default=None,
choices=["none", "subset", "exact", "superset", "any_overlap"],
default="subset",
)
parser.add_argument(
"--weight_ft",
Expand Down Expand Up @@ -114,40 +116,36 @@ def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None:


def filter_atoms(
atoms: ase.Atoms, element_subset: List[str], filtering_type: str
atoms: ase.Atoms, selected_elements: List[str], filtering_type: str
) -> bool:
"""
Filters atoms based on the provided filtering type and element subset.
Parameters:
atoms (ase.Atoms): The atoms object to filter.
element_subset (list): The list of elements to consider during filtering.
selected_elements (list): The list of elements to consider during filtering.
filtering_type (str): The type of filtering to apply. Can be 'none', 'exclusive', or 'inclusive'.
'none' - No filtering is applied.
'combinations' - Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present.
'exclusive' - Return true if `atoms` contains *only* elements in the subset, false otherwise.
'inclusive' - Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements.
'exact' - Return true if `atoms` is composed of exactly the same elements as the `seleted_elements`, false otherwise
'subset' - Return true if `atoms` is composed of a subset of elements in `selected_elements`, false otherwise
'superset' - Return true if `atoms` is composed of a superset of elements in `selected_elements`, false otherwise
`any_overlap` - Return true if `atoms` contains any of the elements in `selected_elements`
Returns:
bool: True if the atoms pass the filter, False otherwise.
"""
if filtering_type == "none":
return True
if filtering_type == "combinations":
atom_symbols = np.unique(atoms.symbols)
return all(
x in element_subset for x in atom_symbols
) # atoms must *only* contain elements in the subset
if filtering_type == "exclusive":
atom_symbols = set(list(atoms.symbols))
return atom_symbols == set(element_subset)
if filtering_type == "inclusive":
atom_symbols = np.unique(atoms.symbols)
return all(
x in atom_symbols for x in element_subset
) # atoms must *at least* contain elements in the subset
if filtering_type == "exact":
return set(atoms.symbols) == set(selected_elements)
if filtering_type == "subset":
return set(atoms.symbols).issubset(selected_elements)
if filtering_type == "superset":
return set(selected_elements).issubset(atoms.symbols)
if filtering_type == "any_overlap":
return len(set(selected_elements) & set(atoms.symbols)) >= 1
raise ValueError(
f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'."
f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'subset', 'exact', 'superset', or 'any_overlap'"
)


Expand Down Expand Up @@ -204,108 +202,127 @@ def assemble_descriptors(self) -> np.ndarray:
def select_samples(
args: argparse.Namespace,
) -> None:
# setup
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.model in ["small", "medium", "large"]:
calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype)
else:
calc = MACECalculator(
model_paths=args.model, device=args.device, default_dtype=args.default_dtype
)

# read finetuning set
if isinstance(args.configs_ft, str):
atoms_list_ft = ase.io.read(args.configs_ft, index=":")
atoms_list_ft = list(tqdm(ase.io.iread(args.configs_ft, index=":"), desc=f"reading configs_ft {args.configs_ft}"))
else:
atoms_list_ft = []
for path in args.configs_ft:
atoms_list_ft += ase.io.read(path, index=":")
atoms_list_ft += list(tqdm(ase.io.iread(path, index=":"), desc=f"reading configs_ft item {path}"))

# read pretrained set
atoms_list_pt = list(tqdm(ase.io.iread(args.configs_pt, index=":"), desc="reading configs_pt"))

if args.filtering_type is not None:
all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms])
indices_pt_filtered = []
atoms_list_pt_filtered = []

# do filtering by elements
if args.filtering_type != "none":
all_species_ft = {atom.symbol for atoms in atoms_list_ft for atom in atoms}
logging.info(
"Filtering configurations based on the finetuning set, "
f"filtering type: combinations, elements: {all_species_ft}"
f"filtering type: {args.filtering_type}, elements: {all_species_ft}"
)
if args.descriptors is not None:
logging.info("Loading descriptors")
descriptors = np.load(args.descriptors, allow_pickle=True)
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
atoms_list_pt_filtered = [
x
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]
if len(atoms_list_pt_filtered) <= args.num_samples:
logging.info(
f"Number of configurations after filtering {len(atoms_list_pt_filtered)} "
f"is less than the number of samples {args.num_samples}, "
"selecting random configurations for the rest."
)
atoms_list_pt_minus_filtered = [
x for x in atoms_list_pt if x not in atoms_list_pt_filtered
]
atoms_list_pt_random_inds = np.random.choice(
list(range(len(atoms_list_pt_minus_filtered))),
args.num_samples - len(atoms_list_pt_filtered),
replace=False,
)
atoms_list_pt = atoms_list_pt_filtered + [
atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds
]

# select by requested strategy
pt_filter = [filter_atoms(atoms, all_species_ft, args.filtering_type) for atoms in atoms_list_pt]
if sum(pt_filter) <= args.num_samples:
# few enough to include all, will be supplemented by FPS/random later
logging.info(f"Found few enough to include all {sum(pt_filter)} filtered by elements")
indices_pt_filtered = np.where(pt_filter)[0]
else:
atoms_list_pt = atoms_list_pt_filtered
# too many, select by increasingly generous strategy and within each one, match in composition
# [NB should we allow setting of exponential base relating overlap and probability, currently 10.0 ?]]
logging.info(f"Found too many filtered by elements {sum(pt_filter)}, choosing based on composition match")
# try increasingly generous matching strategies
indices_pt_filtered_orig = set(np.where(pt_filter)[0])
indices_pt_filtered = set()
for strategy in ("exact", "subset", "any_overlap"):
strategy_filter = [filter_atoms(atoms, all_species_ft, strategy) for atoms in atoms_list_pt]
if sum(strategy_filter) == 0:
logging.info(f"Nothing selected by {strategy}")
continue
indices_pt_strategy = set(np.where(strategy_filter)[0]) & indices_pt_filtered_orig
indices_pt_strategy -= indices_pt_filtered
if len(indices_pt_filtered) + len(indices_pt_strategy) <= args.num_samples:
# can include all of these
indices_pt_filtered |= indices_pt_strategy
logging.info(f"Adding all {len(indices_pt_strategy)} selected by {strategy}")
else:
# pick a subset with weights, penalizing missing and extra elements
# first term is number of elements that are missing from each config
# second term is number of elements that are extra in each config
#
# for exact distances should all be 0
# for subset should only have missing elements, no extra (first term only)
# for any_overlap could have either/both, add them up (both terms)
indices_pt_strategy = list(indices_pt_strategy)
d = np.asarray([len(all_species_ft - set(atoms_list_pt[ind].symbols)) +
len(set(atoms_list_pt[ind].symbols) - all_species_ft) for ind in indices_pt_strategy])
p = 10.0 ** (-d)
p /= np.sum(p)
inds = np.random.choice(len(indices_pt_strategy), args.num_samples - len(indices_pt_filtered), replace=False, p=p)
logging.info(f"Adding subset len {len(inds)} randomly chosen from those selected by {strategy}")
indices_pt_filtered |= {indices_pt_strategy[ind] for ind in inds}
# we already had too many, don't check more generous strategies
break

else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
if args.descriptors is not None:
logging.info(
f"Loading descriptors for the pretraining set from {args.descriptors}"
)
descriptors = np.load(args.descriptors, allow_pickle=True)
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
# actually do filtering by composition done so far
atoms_list_pt_filtered = [atoms_list_pt[ind] for ind in indices_pt_filtered]

# get additional configs from across DB
# [NB: should we be able to control this size separately from size set chosen by filtering?]
atoms_list_pt_extra = []
if len(atoms_list_pt_filtered) < args.num_samples:
logging.info(
f"Number of configurations after filtering {len(atoms_list_pt_filtered)} "
f"< {args.num_samples} number of samples, "
f"selecting the rest with {args.subselect}"
)

indices_pt_avail = set(list(range(len(atoms_list_pt)))) - set(indices_pt_filtered)
atoms_list_pt_avail = [atoms_list_pt[ind] for ind in indices_pt_avail]

if args.num_samples is not None and args.num_samples < len(atoms_list_pt):
if args.subselect == "fps":
if args.descriptors is None:
if args.subselect == "random":
logging.info("Selecting configurations randomly")
idx_pt = np.random.choice(len(atoms_list_pt_avail), args.num_samples - len(atoms_list_pt_filtered), replace=False)
elif args.subselect == "fps":
if args.descriptors is not None:
logging.info(f"Loading descriptors from {args.descriptors}")
descriptors = np.load(args.descriptors, allow_pickle=True)
for descriptor, atoms in zip(descriptors, atoms_list_pt):
atoms.info["mace_descriptors"] = descriptor
else:
logging.info("Calculating descriptors for the pretraining set")
# [NB Not great that this parsing of args.model happens here as well as other places. Refactor?]
if args.model in ["small", "medium", "large"]:
calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype)
else:
calc = MACECalculator(
model_paths=args.model, device=args.device, default_dtype=args.default_dtype
)
calculate_descriptors(atoms_list_pt, calc)
descriptors_list = [
atoms.info["mace_descriptors"] for atoms in atoms_list_pt
]
logging.info(
f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}"
)
np.save(
args.output.replace(".xyz", "_descriptors.npy"), descriptors_list
)
descriptors_list = [atoms.info["mace_descriptors"] for atoms in atoms_list_pt]

descriptors_file = args.output.replace(".xyz", "descriptors.npy")
logging.info(f"Saving descriptors at {descriptors_file}")
np.save(descriptors_file, descriptors_list)

logging.info("Selecting configurations using Farthest Point Sampling")
try:
fps_pt = FPS(atoms_list_pt, args.num_samples)
idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations")
except Exception as e: # pylint: disable=W0703
logging.error(
f"FPS failed, selecting random configurations instead: {e}"
)
idx_pt = np.random.choice(
list(range(len(atoms_list_pt))), args.num_samples, replace=False
)
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
fps_pt = FPS(atoms_list_pt_avail, args.num_samples - len(atoms_list_pt_filtered))
idx_pt = fps_pt.run()
else:
logging.info("Selecting random configurations")
idx_pt = np.random.choice(
list(range(len(atoms_list_pt))), args.num_samples, replace=False
)
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
raise ValueError(f"subselect type {args.subselect} not 'random' or 'fps'")

logging.info(f"Selected {len(idx_pt)} configurations")
atoms_list_pt_extra = [atoms_list_pt_avail[i] for i in idx_pt]

atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_extra

for atoms in atoms_list_pt:
# del atoms.info["mace_descriptors"]
atoms.info["pretrained"] = True
Expand Down
6 changes: 6 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=int,
default=1000,
)
parser.add_argument(
"--filtering_type_pt",
help="strategy for filtering of configurations for pretrained head",
choices=["none", "subset", "exact", "superset", "any_overlap"],
default="subset"
)
parser.add_argument(
"--subselect_pt",
help="Method to subselect the configurations of the pretraining set",
Expand Down
4 changes: 2 additions & 2 deletions mace/tools/multihead_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def assemble_mp_data(
"head_ft": "Default",
"weight_pt": args.weight_pt_head,
"weight_ft": 1.0,
"filtering_type": "combination",
"output": f"mp_finetuning-{tag}.xyz",
"descriptors": descriptors_mp,
"filtering_type": args.filtering_type_pt,
"subselect": args.subselect_pt,
"device": args.device,
"default_dtype": args.default_dtype,
Expand All @@ -179,4 +179,4 @@ def assemble_mp_data(
)
return collections_mp
except Exception as exc:
raise RuntimeError("Model download failed and no local model found") from exc
raise RuntimeError("Failed to assemble pretrained data") from exc

0 comments on commit 6d4daa7

Please sign in to comment.