Skip to content

Commit

Permalink
Further vectorise site matching in stenciling as planned
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Oct 31, 2024
1 parent d174811 commit b394087
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion doped/utils/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,6 @@ def find_nearest_coords(

distance_matrix = lattice.get_all_distances(candidate_frac_coords, target_frac_coords).ravel()
match = distance_matrix.argmin()
# TODO: Want to use this in stenciling?

return candidate_frac_coords[match], match if return_idx else candidate_frac_coords[match]

Expand Down Expand Up @@ -625,6 +624,7 @@ def find_missing_idx(
site_matches = distance_matrix.argmin(axis=0 if len(frac_coords1) > len(frac_coords2) else -1)

# TODO: Trial linear assignment, or matching the other way, if we fail
# TODO: Use linear assignment in stenciling? (for choosing candidate sites) Once tests setup?
# TODO: Depending on speed, could just go to linear assignment from the start. Either way can
# try successive stol increases
if len(np.unique(site_matches)) != len(site_matches):
Expand Down
43 changes: 22 additions & 21 deletions doped/utils/stenciling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
get_defect_type_and_composition_diff,
get_wigner_seitz_radius,
)
from doped.utils.supercells import min_dist
from doped.utils.supercells import _largest_cube_length_from_matrix, min_dist
from doped.utils.symmetry import (
SymmOp,
apply_symm_op_to_site,
Expand Down Expand Up @@ -969,34 +969,35 @@ def _get_matching_sites_from_s1_then_s2(
single_defect_subcell_sites, element.symbol, frac_coords=True # frac_coords
)[0]

# this could be made faster by vectorising with ``find_idx_of_nearest_coords`` from
# ``doped.utils.parsing`` but it's far from being the bottleneck in this workflow:
struct2_pool_idx_min_dist_dict = {}
struct2_pool_dists_to_template_centre = struct2_pool.lattice.get_all_distances(
struct2_pool.frac_coords,
struct2_pool.lattice.get_fractional_coords(
template_struct.lattice.get_cartesian_coords([0.5, 0.5, 0.5])
),
).ravel() # template centre is defect site in stenciling workflow
template_ws_radius = get_wigner_seitz_radius(template_struct)
for i, super_site in enumerate(struct2_pool):
if struct2_pool_dists_to_template_centre[i] > template_ws_radius * 0.75:
# check that it's outside WS radius, so not defect site itself
struct2_pool_idx_min_dist_dict[i] = np.min( # vectorised for fast computation
struct2_pool.lattice.get_all_distances(
species_coord_dict[super_site.specie.symbol], super_site.frac_coords
)
)
largest_encompassed_cube_length = _largest_cube_length_from_matrix(template_struct.lattice.matrix)
candidate_struct2_pool_species_sites: dict[str, list[PeriodicSite]] = {
super_site.specie.symbol: [] for super_site in struct2_pool
}
for dist_to_template_centre, super_site in zip(struct2_pool_dists_to_template_centre, struct2_pool):
# screen to sites outside defect WS radius, for efficiency:
if dist_to_template_centre > largest_encompassed_cube_length * 0.49: # 2% buffer (cube length / 2)
candidate_struct2_pool_species_sites[super_site.specie.symbol].append(super_site)

struct2_pool_site_min_dist_dict = {}
for species_symbol, species_sites in candidate_struct2_pool_species_sites.items():
dist_matrix = struct2_pool.lattice.get_all_distances( # vectorised for fast computation
species_coord_dict[species_symbol], [site.frac_coords for site in species_sites]
) # M x N
min_dists = np.min(dist_matrix, axis=0) # down columns
struct2_pool_site_min_dist_dict.update(dict(zip(species_sites, min_dists)))

# sort possible_bulk_outer_cell_sites by (largest) min dist to single_defect_subcell_sites:
possible_bulk_outer_cell_sites = [
struct2_pool[i]
for i in sorted(
struct2_pool_idx_min_dist_dict.keys(),
key=lambda x: struct2_pool_idx_min_dist_dict[x],
reverse=True,
)
]
possible_bulk_outer_cell_sites = sorted(
[site for sites in candidate_struct2_pool_species_sites.values() for site in sites],
key=lambda x: struct2_pool_site_min_dist_dict[x],
reverse=True,
)
bulk_outer_cell_sites = possible_bulk_outer_cell_sites[
: int(len(struct2_pool) * (1 - 1 / num_super_supercells))
]
Expand Down

0 comments on commit b394087

Please sign in to comment.