diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index c3a7519d0..aa1eb8555 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -6,6 +6,7 @@ from fractions import Fraction from typing import Union, Optional import sys +import warnings from emdfile import PointList from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom @@ -66,6 +67,7 @@ def __init__( positions, numbers, cell, + occupancy=None, ): """ Args: @@ -76,7 +78,7 @@ def __init__( 3 numbers: the three lattice parameters for an orthorhombic cell 6 numbers: the a,b,c lattice parameters and ɑ,β,ɣ angles for any cell 3x3 array: row vectors containing the (u,v,w) lattice vectors. - + occupancy (np.array): Partial occupancy values for each atomic site. Must match the length of positions """ # Initialize Crystal self.positions = np.asarray(positions) #: fractional atomic coordinates @@ -131,6 +133,17 @@ def __init__( else: raise Exception("Cell cannot contain " + np.size(cell) + " entries") + # occupancy + if occupancy is not None: + self.occupancy = np.array(occupancy) + # check the occupancy shape makes sense + if self.occupancy.shape[0] != self.positions.shape[0]: + raise Warning( + f"Number of occupancies ({self.occupancy.shape[0]}) and atomic positions ({self.positions.shape[0]}) do not match" + ) + else: + self.occupancy = np.ones(self.positions.shape[0], dtype=np.float32) + # pymatgen flag if "pymatgen" in sys.modules: self.pymatgen_available = True @@ -257,7 +270,70 @@ def get_strained_crystal( else: return crystal_strained - def from_CIF(CIF, conventional_standard_structure=True): + @staticmethod + def from_ase( + atoms, + ): + """ + Create a py4DSTEM Crystal object from an ASE atoms object + + Args: + atoms (ase.Atoms): an ASE atoms object + + """ + # get the occupancies from the atoms object + occupancies = ( + atoms.arrays["occupancies"] + if "occupancies" in atoms.arrays.keys() + else None + ) + + if "occupancy" in atoms.info.keys(): + warnings.warn( + "This Atoms object contains occupancy information but it will be ignored." + ) + + xtal = Crystal( + positions=atoms.get_scaled_positions(), # fractional coords + numbers=atoms.numbers, + cell=atoms.cell.array, + occupancy=occupancies, + ) + return xtal + + @staticmethod + def from_prismatic(filepath): + """ + Create a py4DSTEM Crystal object from an prismatic style xyz co-ordinate file + + Args: + filepath (str|Pathlib.Path): path to the prismatic format xyz file + + """ + + from ase import io + + # read the atoms using ase + atoms = io.read(filepath, format="prismatic") + + # get the occupancies from the atoms object + occupancies = ( + atoms.arrays["occupancies"] + if "occupancies" in atoms.arrays.keys() + else None + ) + xtal = Crystal( + positions=atoms.get_scaled_positions(), # fractional coords + numbers=atoms.numbers, + cell=atoms.cell.array, + occupancy=occupancies, + ) + return xtal + + @staticmethod + def from_CIF( + CIF, primitive: bool = True, conventional_standard_structure: bool = True + ): """ Create a Crystal object from a CIF file, using pymatgen to import the CIF @@ -273,12 +349,13 @@ def from_CIF(CIF, conventional_standard_structure=True): parser = CifParser(CIF) - structure = parser.get_structures(False)[0] + structure = parser.get_structures(primitive=primitive)[0] return Crystal.from_pymatgen_structure( structure, conventional_standard_structure=conventional_standard_structure ) + @staticmethod def from_pymatgen_structure( structure=None, formula=None, @@ -375,8 +452,6 @@ def from_pymatgen_structure( else selected["structure"] ) - positions = structure.frac_coords #: fractional atomic coordinates - cell = np.array( [ structure.lattice.a, @@ -388,10 +463,22 @@ def from_pymatgen_structure( ] ) - numbers = np.array([s.species.elements[0].Z for s in structure]) + site_data = np.array( + [ + (*site.frac_coords, elem.number, comp) + for site in structure + for elem, comp in site.species.items() + ] + ) + positions = site_data[:, :3] + numbers = site_data[:, 3] + occupancies = site_data[:, 4] - return Crystal(positions, numbers, cell) + return Crystal( + positions=positions, numbers=numbers, cell=cell, occupancy=occupancies + ) + @staticmethod def from_unitcell_parameters( latt_params, elements, @@ -575,10 +662,14 @@ def calculate_structure_factors( # Calculate structure factors self.struct_factors = np.zeros(np.size(self.g_vec_leng, 0), dtype="complex64") for a0 in range(self.positions.shape[0]): - self.struct_factors += f_all[:, a0] * np.exp( - (2j * np.pi) - * np.sum( - self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0 + self.struct_factors += ( + f_all[:, a0] + * self.occupancy[a0] + * np.exp( + (2j * np.pi) + * np.sum( + self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0 + ) ) ) diff --git a/py4DSTEM/process/diffraction/crystal_bloch.py b/py4DSTEM/process/diffraction/crystal_bloch.py index 6a3c9b1ac..ce8bb8622 100644 --- a/py4DSTEM/process/diffraction/crystal_bloch.py +++ b/py4DSTEM/process/diffraction/crystal_bloch.py @@ -27,7 +27,6 @@ def calculate_dynamical_structure_factors( tol_structure_factor: float = 0.0, recompute_kinematic_structure_factors=True, g_vec_precision=None, - verbose=True, ): """ Calculate and store the relativistic corrected structure factors used for Bloch computations @@ -92,7 +91,7 @@ def calculate_dynamical_structure_factors( # Calculate the reciprocal lattice points to include based on k_max - k_max = np.asarray(k_max) + k_max: np.ndarray = np.asarray(k_max) if recompute_kinematic_structure_factors: if hasattr(self, "struct_factors"): @@ -215,7 +214,9 @@ def get_f_e(q, Z, thermal_sigma, method): # Calculate structure factors struct_factors = np.sum( - f_e * np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)), + f_e + * self.occupancy[:, None] + * np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)), axis=0, ) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 94cf75b8c..da016b3ed 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -3,6 +3,8 @@ from matplotlib.axes import Axes import matplotlib.tri as mtri from mpl_toolkits.mplot3d import Axes3D, art3d +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + from scipy.signal import medfilt from scipy.ndimage import gaussian_filter from scipy.ndimage import distance_transform_edt @@ -91,18 +93,26 @@ def plot_structure( # Fractional atomic coordinates pos = self.positions + occ = self.occupancy + # x tile sub = pos[:, 0] < tol_distance pos = np.vstack([pos, pos[sub, :] + np.array([1, 0, 0])]) ID = np.hstack([ID, ID[sub]]) + if occ is not None: + occ = np.hstack([occ, occ[sub]]) # y tile sub = pos[:, 1] < tol_distance pos = np.vstack([pos, pos[sub, :] + np.array([0, 1, 0])]) ID = np.hstack([ID, ID[sub]]) + if occ is not None: + occ = np.hstack([occ, occ[sub]]) # z tile sub = pos[:, 2] < tol_distance pos = np.vstack([pos, pos[sub, :] + np.array([0, 0, 1])]) ID = np.hstack([ID, ID[sub]]) + if occ is not None: + occ = np.hstack([occ, occ[sub]]) # Cartesian atomic positions xyz = pos @ self.lat_real @@ -141,17 +151,109 @@ def plot_structure( # atoms ID_all = np.unique(ID) - for ID_plot in ID_all: - sub = ID == ID_plot - ax.scatter( - xs=xyz[sub, 1], # + d[0], - ys=xyz[sub, 0], # + d[1], - zs=xyz[sub, 2], # + d[2], - s=size_marker, - linewidth=2, - facecolors=atomic_colors(ID_plot), - edgecolor=[0, 0, 0], - ) + if occ is None: + for ID_plot in ID_all: + sub = ID == ID_plot + ax.scatter( + xs=xyz[sub, 1], # + d[0], + ys=xyz[sub, 0], # + d[1], + zs=xyz[sub, 2], # + d[2], + s=size_marker, + linewidth=2, + facecolors=atomic_colors(ID_plot), + edgecolor=[0, 0, 0], + ) + else: + # init + tol = 1e-4 + num_seg = 180 + radius = 0.7 + zp = np.zeros(num_seg + 1) + + mark = np.ones(xyz.shape[0], dtype="bool") + for a0 in range(xyz.shape[0]): + if mark[a0]: + xyz_plot = xyz[a0, :] + inds = np.argwhere(np.sum((xyz - xyz_plot) ** 2, axis=1) < tol) + occ_plot = occ[inds] + mark[inds] = False + ID_plot = ID[inds] + + if np.sum(occ_plot) < 1.0: + occ_plot = np.append(occ_plot, 1 - np.sum(occ_plot)) + ID_plot = np.append(ID_plot, -1) + else: + occ_plot = occ_plot[0] + ID_plot = ID_plot[0] + + # Plot site as series of filled arcs + theta0 = 0 + for a1 in range(occ_plot.shape[0]): + theta1 = theta0 + occ_plot[a1] * 2.0 * np.pi + theta = np.linspace(theta0, theta1, num_seg + 1) + xp = np.cos(theta) * radius + yp = np.sin(theta) * radius + + # Rotate towards camera + xyz_rot = np.vstack((xp.ravel(), yp.ravel(), zp.ravel())) + if occ_plot[a1] < 1.0: + xyz_rot = np.append( + xyz_rot, np.array((0, 0, 0))[:, None], axis=1 + ) + xyz_rot = orientation_matrix @ xyz_rot + + # add to plot + verts = [ + list( + zip( + xyz_rot[1, :] + xyz_plot[1], + xyz_rot[0, :] + xyz_plot[0], + xyz_rot[2, :] + xyz_plot[2], + ) + ) + ] + # ax.add_collection3d( + # Poly3DCollection( + # verts + # ) + # ) + collection = Poly3DCollection( + verts, + linewidths=2.0, + alpha=1.0, + edgecolors="k", + ) + face_color = [ + 0.5, + 0.5, + 1, + ] # alternative: matplotlib.colors.rgb2hex([0.5, 0.5, 1]) + if ID_plot[a1] == -1: + collection.set_facecolor((1.0, 1.0, 1.0)) + else: + collection.set_facecolor(atomic_colors(ID_plot[a1])) + ax.add_collection3d(collection) + + # update start point + if a1 < occ_plot.size: + theta0 = theta1 + + # for ID_plot in ID_all: + # sub = ID == ID_plot + # ax.scatter( + # xs=xyz[sub, 1], # + d[0], + # ys=xyz[sub, 0], # + d[1], + # zs=xyz[sub, 2], # + d[2], + # s=size_marker, + # linewidth=2, + # facecolors='none', + # edgecolor=[0, 0, 0], + # ) + # poly = PolyCollection( + # verts, + # facecolors=['r', 'g', 'b', 'y'], + # alpha = 0.6) + # ax.add_collection3d(poly, zs=zs, zdir='y') # plot limit if plot_limit is None: