Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Aug 13, 2024
1 parent 7c52013 commit 4a1beb9
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,34 +160,36 @@ def run(
"""
Run the farthest point sampling algorithm.
"""
logging.info(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape)
logging.info("n_samples", self.n_samples)
descriptor_dataset_reshaped = (
self.descriptors_dataset.reshape( # pylint: disable=E1121
(len(self.atoms_list), -1)
)
)
logging.info(f"{descriptor_dataset_reshaped.shape}")
logging.info(f"n_samples: {self.n_samples}")
self.list_index = fpsample.fps_npdu_kdtree_sampling(
self.descriptors_dataset.reshape(len(self.atoms_list), -1), self.n_samples
descriptor_dataset_reshaped,
self.n_samples,
)
return self.list_index

def assemble_descriptors(self) -> np.ndarray:
"""
Assemble the descriptors for all the configurations.
"""
self.descriptors_dataset = np.float32(
10e10
* np.ones(
(
len(self.atoms_list),
len(self.species),
len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]),
),
dtype=np.float32,
)
)
self.descriptors_dataset: np.ndarray = 10e10 * np.ones(
(
len(self.atoms_list),
len(self.species),
len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]),
),
dtype=np.float32,
).astype(np.float32)

for i, atoms in enumerate(self.atoms_list):
descriptors = atoms.info["mace_descriptors"]
descriptors = np.array(atoms.info["mace_descriptors"]).astype(np.float32)
for z in descriptors:
self.descriptors_dataset[i, self.species_dict[z]] = np.float32(
descriptors[z]
)
self.descriptors_dataset[i, self.species_dict[z]] = descriptors[z]


def select_samples(
Expand Down Expand Up @@ -251,9 +253,7 @@ def select_samples(
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
if args.descriptors is not None:
logging.info(
"Loading descriptors for the pretraining set from {}".format(
args.descriptors
)
f"Loading descriptors for the pretraining set from {args.descriptors}"
)
descriptors = np.load(args.descriptors, allow_pickle=True)
for i, atoms in enumerate(atoms_list_pt):
Expand All @@ -267,17 +267,15 @@ def select_samples(
atoms.info["mace_descriptors"] for atoms in atoms_list_pt
]
logging.info(
"Saving descriptors at {}".format(
args.output.replace(".xyz", "descriptors.npy")
)
f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}"
)
np.save(args.output.replace(".xyz", "descriptors.npy"), descriptors_list)
np.save(args.output.replace(".xyz", "_descriptors.npy"), descriptors_list)
logging.info("Selecting configurations using Farthest Point Sampling")
try:
fps_pt = FPS(atoms_list_pt, args.num_samples)
idx_pt = fps_pt.run()
logging.info(f"Selected {len(idx_pt)} configurations")
except Exception as e:
except Exception as e: # pylint: disable=W0703
logging.error(f"FPS failed, selecting random configurations instead: {e}")
idx_pt = np.random.choice(
list(range(len(atoms_list_pt)), args.num_samples, replace=False)
Expand Down

0 comments on commit 4a1beb9

Please sign in to comment.