Skip to content

Commit

Permalink
start new parsing of head
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Aug 21, 2024
1 parent d17ad39 commit f2bf7a1
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 324 deletions.
6 changes: 5 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,12 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):

if self.model_type in ["MACE", "EnergyDipoleMACE"]:
batch = self._clone_batch(batch_base)
print("head", batch["head"])
node_heads = batch["head"][batch["batch"]]
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"], node_heads)
num_atoms_arange = torch.arange(batch["positions"].shape[0])
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
num_atoms_arange, node_heads
]
compute_stress = not self.use_compile
else:
compute_stress = False
Expand Down
63 changes: 42 additions & 21 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def parse_args() -> argparse.Namespace:
required=False,
default=None,
)
parser.add_argument(
"--subselect",
help="method to subselect the configurations of the pretraining set",
type=str,
choices=["fps", "random"],
default="fps",
)
parser.add_argument(
"--model", help="path to model", default="small", required=False
)
Expand Down Expand Up @@ -187,9 +194,11 @@ def assemble_descriptors(self) -> np.ndarray:
).astype(np.float32)

for i, atoms in enumerate(self.atoms_list):
descriptors = np.array(atoms.info["mace_descriptors"]).astype(np.float32)
descriptors = atoms.info["mace_descriptors"]
for z in descriptors:
self.descriptors_dataset[i, self.species_dict[z]] = descriptors[z]
self.descriptors_dataset[i, self.species_dict[z]] = np.array(
descriptors[z]
).astype(np.float32)


def select_samples(
Expand Down Expand Up @@ -260,27 +269,39 @@ def select_samples(
atoms.info["mace_descriptors"] = descriptors[i]

if args.num_samples is not None and args.num_samples < len(atoms_list_pt):
if args.descriptors is None:
logging.info("Calculating descriptors for the pretraining set")
calculate_descriptors(atoms_list_pt, calc)
descriptors_list = [
atoms.info["mace_descriptors"] for atoms in atoms_list_pt
]
logging.info(
f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}"
)
np.save(args.output.replace(".xyz", "_descriptors.npy"), descriptors_list)
logging.info("Selecting configurations using Farthest Point Sampling")
try:
fps_pt = FPS(atoms_list_pt, args.num_samples)
idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations")
except Exception as e: # pylint: disable=W0703
logging.error(f"FPS failed, selecting random configurations instead: {e}")
if args.subselect == "fps":
if args.descriptors is None:
logging.info("Calculating descriptors for the pretraining set")
calculate_descriptors(atoms_list_pt, calc)
descriptors_list = [
atoms.info["mace_descriptors"] for atoms in atoms_list_pt
]
logging.info(
f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}"
)
np.save(
args.output.replace(".xyz", "_descriptors.npy"), descriptors_list
)
logging.info("Selecting configurations using Farthest Point Sampling")
try:
fps_pt = FPS(atoms_list_pt, args.num_samples)
idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations")
except Exception as e: # pylint: disable=W0703
logging.error(
f"FPS failed, selecting random configurations instead: {e}"
)
idx_pt = np.random.choice(
list(range(len(atoms_list_pt))), args.num_samples, replace=False
)
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
else:
logging.info("Selecting random configurations")
idx_pt = np.random.choice(
list(range(len(atoms_list_pt)), args.num_samples, replace=False)
list(range(len(atoms_list_pt))), args.num_samples, replace=False
)
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
print("idx_pt", idx_pt)
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
for atoms in atoms_list_pt:
# del atoms.info["mace_descriptors"]
atoms.info["pretrained"] = True
Expand Down
Loading

0 comments on commit f2bf7a1

Please sign in to comment.