diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index dcb6a861d..adf757d1b 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -53,7 +53,7 @@ ) # strain -from py4DSTEM.process import StrainMap +from py4DSTEM.process.strain.strain import StrainMap # TODO - crystal # TODO - ptycho diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index e711e907d..0509d181e 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,11 +1,9 @@ from py4DSTEM.process.polar import PolarDatacube -from py4DSTEM.process.strain import StrainMap +from py4DSTEM.process.strain.strain import StrainMap -from py4DSTEM.process import latticevectors from py4DSTEM.process import phase from py4DSTEM.process import calibration from py4DSTEM.process import utils from py4DSTEM.process import classification -from py4DSTEM.process import latticevectors from py4DSTEM.process import diffraction from py4DSTEM.process import wholepatternfit diff --git a/py4DSTEM/process/latticevectors/__init__.py b/py4DSTEM/process/latticevectors/__init__.py deleted file mode 100644 index 560a3b7e6..000000000 --- a/py4DSTEM/process/latticevectors/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from py4DSTEM.process.latticevectors.initialguess import * -from py4DSTEM.process.latticevectors.index import * -from py4DSTEM.process.latticevectors.fit import * -from py4DSTEM.process.latticevectors.strain import * diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py deleted file mode 100644 index 659bc8940..000000000 --- a/py4DSTEM/process/latticevectors/fit.py +++ /dev/null @@ -1,200 +0,0 @@ -# Functions for fitting lattice vectors to measured Bragg peak positions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray -from py4DSTEM.data import RealSlice - - -def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (7-tuple) A 7-tuple containing: - - * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. - * **y0**: *(float)* the y-coord of the origin - * **g1x**: *(float)* x-coord of the first lattice vector - * **g1y**: *(float)* y-coord of the first lattice vector - * **g2x**: *(float)* x-coord of the second lattice vector - * **g2y**: *(float)* y-coord of the second lattice vector - * **error**: *(float)* the fit error - """ - assert isinstance(braggpeaks, PointList) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - braggpeaks = braggpeaks.copy() - - # Remove unindexed peaks - if "index_mask" in braggpeaks.dtype.names: - deletemask = braggpeaks.data["index_mask"] == False - braggpeaks.remove(deletemask) - - # Check to ensure enough peaks are present - if braggpeaks.length < minNumPeaks: - return None, None, None, None, None, None, None - - # Get M, the matrix of (h,k) indices - h, k = braggpeaks.data["h"], braggpeaks.data["k"] - M = np.vstack((np.ones_like(h, dtype=int), h, k)).T - - # Get alpha, the matrix of measured Bragg peak positions - alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T - - # Get weighted matrices - weights = braggpeaks.data["intensity"] - weighted_M = M * weights[:, np.newaxis] - weighted_alpha = alpha * weights[:, np.newaxis] - - # Solve for lattice vectors - beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] - x0, y0 = beta[0, 0], beta[0, 1] - g1x, g1y = beta[1, 0], beta[1, 1] - g2x, g2y = beta[2, 0], beta[2, 1] - - # Calculate the error - alpha_calculated = np.matmul(M, beta) - error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) - error = np.sum(error * weights) / np.sum(weights) - - return x0, y0, g1x, g1y, g2x, g2y, error - - -def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some - known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - - return g1g2_map - - -def fit_lattice_vectors_masked(braggpeaks, mask, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks corresponding - to a scan position for which mask==True. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - mask (boolean array): real space shaped (R_Nx,R_Ny); fit lattice vectors where - mask is True - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((braggpeaks.shape[0], braggpeaks.shape[1], 8)), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - if mask[Rx, Ry]: - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - return g1g2_map diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py deleted file mode 100644 index 4ac7939e7..000000000 --- a/py4DSTEM/process/latticevectors/index.py +++ /dev/null @@ -1,280 +0,0 @@ -# Functions for indexing the Bragg directions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray - - -def get_selected_lattice_vectors(gx, gy, i0, i1, i2): - """ - From a set of reciprocal lattice points (gx,gy), and indices in those arrays which - specify the center beam, the first basis lattice vector, and the second basis lattice - vector, computes and returns the lattice vectors g1 and g2. - - Args: - gx (1d array): the reciprocal lattice points x-coords - gy (1d array): the reciprocal lattice points y-coords - i0 (int): index in the (gx,gy) arrays specifying the center beam - i1 (int): index in the (gx,gy) arrays specifying the first basis lattice vector - i2 (int): index in the (gx,gy) arrays specifying the second basis lattice vector - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing - - * **g1**: *(2-tuple)* the first lattice vector, (g1x,g1y) - * **g2**: *(2-tuple)* the second lattice vector, (g2x,g2y) - """ - for i in (i0, i1, i2): - assert isinstance(i, (int, np.integer)) - g1x = gx[i1] - gx[i0] - g1y = gy[i1] - gy[i0] - g2x = gx[i2] - gx[i0] - g2y = gy[i2] - gy[i0] - return (g1x, g1y), (g2x, g2y) - - -def index_bragg_directions(x0, y0, gx, gy, g1, g2): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - The approach is to solve the matrix equation - ``alpha = beta * M`` - where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, - beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the - h,k indices. - - Args: - x0 (float): x-coord of origin - y0 (float): y-coord of origin - gx (1d array): x-coord of the reciprocal lattice vectors - gy (1d array): y-coord of the reciprocal lattice vectors - g1 (2-tuple of floats): g1x,g1y - g2 (2-tuple of floats): g2x,g2y - - Returns: - (3-tuple) A 3-tuple containing: - - * **h**: *(ndarray of ints)* first index of the bragg directions - * **k**: *(ndarray of ints)* second index of the bragg directions - * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the - indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y - coords 'h' and 'k' contain h and k. - """ - # Get beta, the matrix of lattice vectors - beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) - - # Get alpha, the matrix of measured bragg angles - alpha = np.vstack([gx - x0, gy - y0]) - - # Calculate M, the matrix of peak positions - M = lstsq(beta, alpha, rcond=None)[0].T - M = np.round(M).astype(int) - - # Get h,k - h = M[:, 0] - k = M[:, 1] - - # Store in a PointList - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - temp_array = np.zeros([], dtype=coords) - bragg_directions = PointList(data=temp_array) - bragg_directions.add_data_by_field((gx, gy, h, k)) - mask = np.zeros(bragg_directions["qx"].shape[0]) - mask[0] = 1 - bragg_directions.remove(mask) - - return h, k, bragg_directions - - -def generate_lattice(ux, uy, vx, vy, x0, y0, Q_Nx, Q_Ny, h_max=None, k_max=None): - """ - Returns a full reciprocal lattice stretching to the limits of the diffraction pattern - by making linear combinations of the lattice vectors up to (±h_max,±k_max). - - This can be useful when there are false peaks or missing peaks in the braggvectormap, - which can cause errors in the strain finding routines that rely on those peaks for - indexing. This allows us to create a reference lattice that has all combinations of - the lattice vectors all the way out to the edges of the frame, and excluding any - erroneous intermediate peaks. - - Args: - ux (float): x-coord of the u lattice vector - uy (float): y-coord of the u lattice vector - vx (float): x-coord of the v lattice vector - vy (float): y-coord of the v lattice vector - x0 (float): x-coord of the lattice origin - y0 (float): y-coord of the lattice origin - Q_Nx (int): diffraction pattern size in the x-direction - Q_Ny (int): diffraction pattern size in the y-direction - h_max, k_max (int): maximal indices for generating the lattice (the lattive is - always trimmed to fit inside the pattern so you can overestimate these, or - leave unspecified and they will be automatically found) - - Returns: - (PointList): A 4-coordinate PointList, ('qx','qy','h','k'), containing points - corresponding to linear combinations of the u and v vectors, with associated - indices - """ - - # Matrix of lattice vectors - beta = np.array([[ux, uy], [vx, vy]]) - - # If no max index is specified, (over)estimate based on image size - if (h_max is None) or (k_max is None): - (y, x) = np.mgrid[0:Q_Ny, 0:Q_Nx] - x = x - x0 - y = y - y0 - h_max = np.max(np.ceil(np.abs((x / ux, y / uy)))) - k_max = np.max(np.ceil(np.abs((x / vx, y / vy)))) - - (hlist, klist) = np.meshgrid( - np.arange(-h_max, h_max + 1), np.arange(-k_max, k_max + 1) - ) - - M_ideal = np.vstack((hlist.ravel(), klist.ravel())).T - ideal_peaks = np.matmul(M_ideal, beta) - - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - - ideal_data = np.zeros(len(ideal_peaks[:, 0]), dtype=coords) - ideal_data["qx"] = ideal_peaks[:, 0] - ideal_data["qy"] = ideal_peaks[:, 1] - ideal_data["h"] = M_ideal[:, 0] - ideal_data["k"] = M_ideal[:, 1] - - ideal_lattice = PointList(data=ideal_data) - - # shift to the DP center - ideal_lattice.data["qx"] += x0 - ideal_lattice.data["qy"] += y0 - - # trim peaks outside the image - deletePeaks = ( - (ideal_lattice.data["qx"] > Q_Nx) - | (ideal_lattice.data["qx"] < 0) - | (ideal_lattice.data["qy"] > Q_Ny) - | (ideal_lattice.data["qy"] < 0) - ) - ideal_lattice.remove(deletePeaks) - - return ideal_lattice - - -def add_indices_to_braggvectors( - braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None -): - """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, - identify the indices for each peak in the PointListArray braggpeaks. - Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed - or not with the bool index_mask. If `mask` is specified, only the locations where - mask is True are indexed. - - Args: - braggpeaks (PointListArray): the braggpeaks to index. Must contain - the coordinates 'qx', 'qy', and 'intensity' - lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. - Must contain the coordinates 'qx', 'qy', 'h', and 'k' - maxPeakSpacing (float): Maximum distance from the ideal lattice points - to include a peak for indexing - qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList - relative to the `braggpeaks` PointListArray - mask (bool): Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - - Returns: - (PointListArray): The original braggpeaks pointlistarray, with new coordinates - 'h', 'k', containing the indices of each indexable peak. - """ - - # assert isinstance(braggpeaks,BraggVectors) - # assert isinstance(lattice, PointList) - # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) - - if mask is None: - mask = np.ones(braggpeaks.Rshape, dtype=bool) - - assert ( - mask.shape == braggpeaks.Rshape - ), "mask must have same shape as pointlistarray" - assert mask.dtype == bool, "mask must be boolean" - - coords = [ - ("qx", float), - ("qy", float), - ("intensity", float), - ("h", int), - ("k", int), - ] - - indexed_braggpeaks = PointListArray( - dtype=coords, - shape=braggpeaks.Rshape, - ) - - # loop over all the scan positions - for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): - if mask[Rx, Ry]: - pl = braggpeaks.cal[Rx, Ry] - for i in range(pl.data.shape[0]): - r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( - pl.data["qy"][i] - lattice.data["qy"] + qy_shift - ) ** 2 - ind = np.argmin(r2) - if r2[ind] <= maxPeakSpacing**2: - indexed_braggpeaks[Rx, Ry].add_data_by_field( - ( - pl.data["qx"][i], - pl.data["qy"][i], - pl.data["intensity"][i], - lattice.data["h"][ind], - lattice.data["k"][ind], - ) - ) - - return indexed_braggpeaks - - -def bragg_vector_intensity_map_by_index(braggpeaks, h, k, symmetric=False): - """ - Returns a correlation intensity map for an indexed (h,k) Bragg vector - Used to obtain a darkfield image corresponding to the (h,k) reflection - or a bightfield image when h=k=0 - - Args: - braggpeaks (PointListArray): must contain the coordinates 'h','k', and - 'intensity' - h, k (int): indices for the reflection to generate an intensity map from - symmetric (bool): if set to true, returns sum of intensity of (h,k), (-h,k), - (h,-k), (-h,-k) - - Returns: - (numpy array): a map of the intensity of the (h,k) Bragg vector correlation. - Same shape as the pointlistarray. - """ - assert isinstance(braggpeaks, PointListArray), "braggpeaks must be a PointListArray" - assert np.all([name in braggpeaks.dtype.names for name in ("h", "k", "intensity")]) - intensity_map = np.zeros(braggpeaks.shape, dtype=float) - - for Rx in range(braggpeaks.shape[0]): - for Ry in range(braggpeaks.shape[1]): - pl = braggpeaks.get_pointlist(Rx, Ry) - if pl.length > 0: - if symmetric: - matches = np.logical_and( - np.abs(pl.data["h"]) == np.abs(h), - np.abs(pl.data["k"]) == np.abs(k), - ) - else: - matches = np.logical_and(pl.data["h"] == h, pl.data["k"] == k) - - if len(matches) > 0: - intensity_map[Rx, Ry] = np.sum(pl.data["intensity"][matches]) - - return intensity_map diff --git a/py4DSTEM/process/latticevectors/initialguess.py b/py4DSTEM/process/latticevectors/initialguess.py deleted file mode 100644 index d8054143f..000000000 --- a/py4DSTEM/process/latticevectors/initialguess.py +++ /dev/null @@ -1,229 +0,0 @@ -# Obtain an initial guess at the lattice vectors - -import numpy as np -from scipy.ndimage import gaussian_filter -from skimage.transform import radon - -from py4DSTEM.process.utils import get_maxima_1D - - -def get_radon_scores( - braggvectormap, - mask=None, - N_angles=200, - sigma=2, - minSpacing=2, - minRelativeIntensity=0.05, -): - """ - Calculates a score function, score(angle), representing the likelihood that angle is - a principle lattice direction of the lattice in braggvectormap. - - The procedure is as follows: - If mask is not None, ignore any data in braggvectormap where mask is False. Useful - for removing the unscattered beam, which can dominate the results. - Take the Radon transform of the (masked) Bragg vector map. - For each angle, get the corresponding slice of the sinogram, and calculate its score. - If we let R_theta(r) be the sinogram slice at angle theta, and where r is the - sinogram position coordinate, then the score of the slice is given by - score(theta) = sum_i(R_theta(r_i)) / N_i - Here, r_i are the positions r of all local maxima in R_theta(r), and N_i is the - number of such maxima. Thus the score is large when there are few maxima which are - high intensity. - - Args: - braggvectormap (ndarray): the Bragg vector map - mask (ndarray of bools): ignore data in braggvectormap wherever mask==False - N_angles (int): the number of angles at which to calculate the score - sigma (float): smoothing parameter for local maximum identification - minSpacing (float): if two maxima are found in a radon slice closer than - minSpacing, the dimmer of the two is removed - minRelativeIntensity (float): maxima in each radon slice dimmer than - minRelativeIntensity compared to the most intense maximum are removed - - Returns: - (3-tuple) A 3-tuple containing: - - * **scores**: *(ndarray, len N_angles, floats)* the scores for each angle - * **thetas**: *(ndarray, len N_angles, floats)* the angles, in radians - * **sinogram**: *(ndarray)* the radon transform of braggvectormap*mask - """ - # Get sinogram - thetas = np.linspace(0, 180, N_angles) - if mask is not None: - sinogram = radon(braggvectormap * mask, theta=thetas, circle=False) - else: - sinogram = radon(braggvectormap, theta=thetas, circle=False) - - # Get scores - N_maxima = np.empty_like(thetas) - total_intensity = np.empty_like(thetas) - for i in range(len(thetas)): - theta = thetas[i] - - # Get radon transform slice - ind = np.argmin(np.abs(thetas - theta)) - sinogram_theta = sinogram[:, ind] - sinogram_theta = gaussian_filter(sinogram_theta, 2) - - # Get maxima - maxima = get_maxima_1D(sinogram_theta, sigma, minSpacing, minRelativeIntensity) - - # Calculate metrics - N_maxima[i] = len(maxima) - total_intensity[i] = np.sum(sinogram_theta[maxima]) - scores = total_intensity / N_maxima - - return scores, np.radians(thetas), sinogram - - -def get_lattice_directions_from_scores( - thetas, scores, sigma=2, minSpacing=2, minRelativeIntensity=0.05, index1=0, index2=0 -): - """ - Get the lattice directions from the scores of the radon transform slices. - - Args: - thetas (ndarray): the angles, in radians - scores (ndarray): the scores - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - index1 (int): specifies which local maximum to use for the first lattice - direction, in order of maximum intensity - index2 (int): specifies the local maximum for the second lattice direction - - Returns: - (2-tuple) A 2-tuple containing: - - * **theta1**: *(float)* the first lattice direction, in radians - * **theta2**: *(float)* the second lattice direction, in radians - """ - assert len(thetas) == len(scores), "Size of thetas and scores must match" - - # Get first lattice direction - maxima1 = get_maxima_1D( - scores, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max1 = thetas[maxima1] - scores_max1 = scores[maxima1] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max1), dtype=dtype) - ar_structured["thetas"] = thetas_max1 - ar_structured["scores"] = scores_max1 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta1 = ar_structured["thetas"][index1] # Get direction 1 - - # Apply sin**2 damping - scores_damped = scores * np.sin(thetas - theta1) ** 2 - - # Get second lattice direction - maxima2 = get_maxima_1D( - scores_damped, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max2 = thetas[maxima2] - scores_max2 = scores[maxima2] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max2), dtype=dtype) - ar_structured["thetas"] = thetas_max2 - ar_structured["scores"] = scores_max2 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta2 = ar_structured["thetas"][index2] # Get direction 2 - - return theta1, theta2 - - -def get_lattice_vector_lengths( - u_theta, - v_theta, - thetas, - sinogram, - spacing_thresh=1.5, - sigma=1, - minSpacing=2, - minRelativeIntensity=0.1, -): - """ - Gets the lengths of the two lattice vectors from their angles and the sinogram. - - First, finds the spacing between peaks in the sinogram slices projected down the u- - and v- directions, u_proj and v_proj. Then, finds the lengths by taking:: - - |u| = v_proj/sin(u_theta-v_theta) - |v| = u_proj/sin(u_theta-v_theta) - - The most important thresholds for this function are spacing_thresh, which discards - any detected spacing between adjacent radon projection peaks which deviate from the - median spacing by more than this fraction, and minRelativeIntensity, which discards - detected maxima (from which spacings are then calculated) below this threshold - relative to the brightest maximum. - - Args: - u_theta (float): the angle of u, in radians - v_theta (float): the angle of v, in radians - thetas (ndarray): the angles corresponding to the sinogram - sinogram (ndarray): the sinogram - spacing_thresh (float): ignores spacings which are greater than spacing_thresh - times the median spacing - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - - Returns: - (2-tuple) A 2-tuple containing: - - * **u_length**: *(float)* the length of u, in pixels - * **v_length**: *(float)* the length of v, in pixels - """ - assert ( - len(thetas) == sinogram.shape[1] - ), "thetas must corresponding to the number of sinogram projection directions." - - # Get u projected spacing - ind = np.argmin(np.abs(thetas - u_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - u_projected_spacing = np.mean(spacings) - - # Get v projected spacing - ind = np.argmin(np.abs(thetas - v_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - v_projected_spacing = np.mean(spacings) - - # Get u and v lengths - sin_uv = np.sin(np.abs(u_theta - v_theta)) - u_length = v_projected_spacing / sin_uv - v_length = u_projected_spacing / sin_uv - - return u_length, v_length diff --git a/py4DSTEM/process/latticevectors/strain.py b/py4DSTEM/process/latticevectors/strain.py deleted file mode 100644 index 6f4000449..000000000 --- a/py4DSTEM/process/latticevectors/strain.py +++ /dev/null @@ -1,231 +0,0 @@ -# Functions for calculating strain from lattice vector maps - -import numpy as np -from numpy.linalg import lstsq - -from py4DSTEM.data import RealSlice - - -def get_reference_g1g2(g1g2_map, mask): - """ - Gets a pair of reference lattice vectors from a region of real space specified by - mask. Takes the median of the lattice vectors in g1g2_map within the specified - region. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever - mask==True - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing: - - * **g1**: *(2-tuple)* first reference lattice vector (x,y) - * **g2**: *(2-tuple)* second reference lattice vector (x,y) - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] - ) - assert mask.dtype == bool - g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) - g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) - g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) - g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) - return (g1x, g1y), (g2x, g2y) - - -def get_strain_from_reference_g1g2(g1g2_map, g1, g2): - """ - Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map - g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real and - diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - g1 (2-tuple): first reference lattice vector (x,y) - g2 (2-tuple): second reference lattice vector (x,y) - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - - # Get RealSlice for output storage - R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape - strain_map = RealSlice( - data=np.zeros((5, R_Nx, R_Ny)), - slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), - name="strain_map", - ) - - # Get reference lattice matrix - g1x, g1y = g1 - g2x, g2y = g2 - M = np.array([[g1x, g1y], [g2x, g2y]]) - - for Rx in range(R_Nx): - for Ry in range(R_Ny): - # Get lattice vectors for DP at Rx,Ry - alpha = np.array( - [ - [ - g1g2_map.get_slice("g1x").data[Rx, Ry], - g1g2_map.get_slice("g1y").data[Rx, Ry], - ], - [ - g1g2_map.get_slice("g2x").data[Rx, Ry], - g1g2_map.get_slice("g2y").data[Rx, Ry], - ], - ] - ) - # Get transformation matrix - beta = lstsq(M, alpha, rcond=None)[0].T - - # Get the infinitesimal strain matrix - strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] - strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] - strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 - strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 - strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ - Rx, Ry - ] - return strain_map - - -def get_strain_from_reference_region(g1g2_map, mask): - """ - Gets a strain map from the reference region of real space specified by mask and the - lattice vector map g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real - and diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - assert mask.dtype == bool - - g1, g2 = get_reference_g1g2(g1g2_map, mask) - strain_map = get_strain_from_reference_g1g2(g1g2_map, g1, g2) - return strain_map - - -def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): - """ - Starting from a strain map defined with respect to the xy coordinate system of - diffraction space, i.e. where exx and eyy are the compression/tension along the Qx - and Qy directions, respectively, get a strain map defined with respect to some other - right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, - xaxis_y). - - Args: - xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector - along the new x-axis - unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the - infinitessimal strain matrix elements, stored at - * unrotated_strain_map.get_slice('e_xx') - * unrotated_strain_map.get_slice('e_xy') - * unrotated_strain_map.get_slice('e_yy') - * unrotated_strain_map.get_slice('theta') - - Returns: - (RealSlice) the rotated counterpart to unrotated_strain_map, with the - rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate - system - """ - assert isinstance(unrotated_strain_map, RealSlice) - assert np.all( - [ - key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] - for key in unrotated_strain_map.slicelabels - ] - ) - theta = -np.arctan2(xaxis_y, xaxis_x) - cost = np.cos(theta) - sint = np.sin(theta) - cost2 = cost**2 - sint2 = sint**2 - - Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape - rotated_strain_map = RealSlice( - data=np.zeros((5, Rx, Ry)), - slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], - name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), - ) - - rotated_strain_map.data[0, :, :] = ( - cost2 * unrotated_strain_map.get_slice("e_xx").data - - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + sint2 * unrotated_strain_map.get_slice("e_yy").data - ) - rotated_strain_map.data[1, :, :] = ( - cost - * sint - * ( - unrotated_strain_map.get_slice("e_xx").data - - unrotated_strain_map.get_slice("e_yy").data - ) - + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data - ) - rotated_strain_map.data[2, :, :] = ( - sint2 * unrotated_strain_map.get_slice("e_xx").data - + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + cost2 * unrotated_strain_map.get_slice("e_yy").data - ) - if flip_theta == True: - rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data - else: - rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data - rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data - return rotated_strain_map diff --git a/py4DSTEM/process/strain.py b/py4DSTEM/process/strain.py deleted file mode 100644 index db252f75b..000000000 --- a/py4DSTEM/process/strain.py +++ /dev/null @@ -1,601 +0,0 @@ -# Defines the Strain class - -from typing import Optional - -import matplotlib.pyplot as plt -import numpy as np -from py4DSTEM import PointList -from py4DSTEM.braggvectors import BraggVectors -from py4DSTEM.data import Data, RealSlice -from py4DSTEM.preprocess.utils import get_maxima_2D -from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show - - -class StrainMap(RealSlice, Data): - """ - Stores strain map. - - TODO add docs - - """ - - def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): - """ - TODO - """ - assert isinstance( - braggvectors, BraggVectors - ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" - - # initialize as a RealSlice - RealSlice.__init__( - self, - name=name, - data=np.empty( - ( - 6, - braggvectors.Rshape[0], - braggvectors.Rshape[1], - ) - ), - slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], - ) - - # set up braggvectors - # this assigns the bvs, ensures the origin is calibrated, - # and adds the strainmap to the bvs' tree - self.braggvectors = braggvectors - - # initialize as Data - Data.__init__(self) - - # set calstate - # this property is used only to check to make sure that - # the braggvectors being used throughout a workflow are - # the same. The state of calibration of the vectors is noted - # here, and then checked each time the vectors are used - - # if they differ, an error message and instructions for - # re-calibration are issued - self.calstate = self.braggvectors.calstate - assert self.calstate["center"], "braggvectors must be centered" - # get the BVM - # a new BVM using the current calstate is computed - self.bvm = self.braggvectors.histogram(mode="cal") - - # braggvector properties - - @property - def braggvectors(self): - return self._braggvectors - - @braggvectors.setter - def braggvectors(self, x): - assert isinstance( - x, BraggVectors - ), f".braggvectors must be BraggVectors, not type {type(x)}" - assert ( - x.calibration.origin is not None - ), f"braggvectors must have a calibrated origin" - self._braggvectors = x - self._braggvectors.tree(self, force=True) - - def reset_calstate(self): - """ - Resets the calibration state. This recomputes the BVM, and removes any computations - this StrainMap instance has stored, which will need to be recomputed. - """ - for attr in ( - "g0", - "g1", - "g2", - ): - if hasattr(self, attr): - delattr(self, attr) - self.calstate = self.braggvectors.calstate - pass - - # Class methods - - def choose_lattice_vectors( - self, - index_g0, - index_g1, - index_g2, - subpixel="multicorr", - upsample_factor=16, - sigma=0, - minAbsoluteIntensity=0, - minRelativeIntensity=0, - relativeToPeak=0, - minSpacing=0, - edgeBoundary=1, - maxNumPeaks=10, - figsize=(12, 6), - c_indices="lightblue", - c0="g", - c1="r", - c2="r", - c_vectors="r", - c_vectorlabels="w", - size_indices=20, - width_vectors=1, - size_vectorlabels=20, - vis_params={}, - returncalc=False, - returnfig=False, - ): - """ - Choose which lattice vectors to use for strain mapping. - - Overlays the bvm with the points detected via local 2D - maxima detection, plus an index for each point. User selects - 3 points using the overlaid indices, which are identified as - the origin and the termini of the lattice vectors g1 and g2. - - Parameters - ---------- - index_g0 : int - selected index for the origin - index_g1 : int - selected index for g1 - index_g2 :int - selected index for g2 - subpixel : str in ('pixel','poly','multicorr') - See the docstring for py4DSTEM.preprocess.get_maxima_2D - upsample_factor : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - sigma : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - minAbsoluteIntensity : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - minRelativeIntensity : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - relativeToPeak : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - minSpacing : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - edgeBoundary : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - maxNumPeaks : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - figsize : 2-tuple - the size of the figure - c_indices : color - color of the maxima - c0 : color - color of the origin - c1 : color - color of g1 point - c2 : color - color of g2 point - c_vectors : color - color of the g1/g2 vectors - c_vectorlabels : color - color of the vector labels - size_indices : number - size of the indices - width_vectors : number - width of the vectors - size_vectorlabels : number - size of the vector labels - vis_params : dict - additional visualization parameters passed to `show` - returncalc : bool - toggles returning the answer - returnfig : bool - toggles returning the figure - - Returns - ------- - (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter - """ - # validate inputs - for i in (index_g0, index_g1, index_g2): - assert isinstance(i, (int, np.integer)), "indices must be integers!" - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - # find the maxima - g = get_maxima_2D( - self.bvm.data, - subpixel=subpixel, - upsample_factor=upsample_factor, - sigma=sigma, - minAbsoluteIntensity=minAbsoluteIntensity, - minRelativeIntensity=minRelativeIntensity, - relativeToPeak=relativeToPeak, - minSpacing=minSpacing, - edgeBoundary=edgeBoundary, - maxNumPeaks=maxNumPeaks, - ) - - # get the lattice vectors - gx, gy = g["x"], g["y"] - g0 = gx[index_g0], gy[index_g0] - g1x = gx[index_g1] - g0[0] - g1y = gy[index_g1] - g0[1] - g2x = gx[index_g2] - g0[0] - g2y = gy[index_g2] - g0[1] - g1, g2 = (g1x, g1y), (g2x, g2y) - - # make the figure - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - show(self.bvm.data, figax=(fig, ax1), **vis_params) - show(self.bvm.data, figax=(fig, ax2), **vis_params) - - # Add indices to left panel - d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} - d0 = { - "x": gx[index_g0], - "y": gy[index_g0], - "size": size_indices, - "color": c0, - "fontweight": "bold", - "labels": [str(index_g0)], - } - d1 = { - "x": gx[index_g1], - "y": gy[index_g1], - "size": size_indices, - "color": c1, - "fontweight": "bold", - "labels": [str(index_g1)], - } - d2 = { - "x": gx[index_g2], - "y": gy[index_g2], - "size": size_indices, - "color": c2, - "fontweight": "bold", - "labels": [str(index_g2)], - } - add_pointlabels(ax1, d) - add_pointlabels(ax1, d0) - add_pointlabels(ax1, d1) - add_pointlabels(ax1, d2) - - # Add vectors to right panel - dg1 = { - "x0": gx[index_g0], - "y0": gy[index_g0], - "vx": g1[0], - "vy": g1[1], - "width": width_vectors, - "color": c_vectors, - "label": r"$g_1$", - "labelsize": size_vectorlabels, - "labelcolor": c_vectorlabels, - } - dg2 = { - "x0": gx[index_g0], - "y0": gy[index_g0], - "vx": g2[0], - "vy": g2[1], - "width": width_vectors, - "color": c_vectors, - "label": r"$g_2$", - "labelsize": size_vectorlabels, - "labelcolor": c_vectorlabels, - } - add_vector(ax2, dg1) - add_vector(ax2, dg2) - - # store vectors - self.g = g - self.g0 = g0 - self.g1 = g1 - self.g2 = g2 - - # return - if returncalc and returnfig: - return (g0, g1, g2), (fig, (ax1, ax2)) - elif returncalc: - return (g0, g1, g2) - elif returnfig: - return (fig, (ax1, ax2)) - else: - return - - def fit_lattice_vectors( - self, - x0=None, - y0=None, - max_peak_spacing=2, - mask=None, - plot=True, - vis_params={}, - returncalc=False, - ): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - Args: - x0 : floagt - x-coord of origin - y0 : float - y-coord of origin - max_peak_spacing: float - Maximum distance from the ideal lattice points - to include a peak for indexing - mask: bool - Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - plot:bool - plot results if tru - vis_params : dict - additional visualization parameters passed to `show` - returncalc : bool - if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map - """ - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - if x0 is None: - x0 = self.braggvectors.Qshape[0] / 2 - if y0 is None: - y0 = self.braggvectors.Qshape[0] / 2 - - # index braggvectors - from py4DSTEM.process.latticevectors import index_bragg_directions - - _, _, braggdirections = index_bragg_directions( - x0, y0, self.g["x"], self.g["y"], self.g1, self.g2 - ) - - self.braggdirections = braggdirections - - if plot: - self.show_bragg_indexing( - self.bvm, - bragg_directions=braggdirections, - points=True, - **vis_params, - ) - - # add indicies to braggvectors - from py4DSTEM.process.latticevectors import add_indices_to_braggvectors - - bragg_vectors_indexed = add_indices_to_braggvectors( - self.braggvectors, - self.braggdirections, - maxPeakSpacing=max_peak_spacing, - qx_shift=self.braggvectors.Qshape[0] / 2, - qy_shift=self.braggvectors.Qshape[1] / 2, - mask=mask, - ) - - self.bragg_vectors_indexed = bragg_vectors_indexed - - # fit bragg vectors - from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs - - g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) - self.g1g2_map = g1g2_map - - if returncalc: - braggdirections, bragg_vectors_indexed, g1g2_map - - def get_strain( - self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs - ): - """ - mask: nd.array (bool) - Use lattice vectors from g1g2_map scan positions - wherever mask==True. If mask is None gets median strain - map from entire field of view. If mask is not None, gets - reference g1 and g2 from region and then calculates strain. - g_reference: nd.array of form [x,y] - G_reference (tupe): reference coordinate system for - xaxis_x and xaxis_y - flip_theta: bool - If True, flips rotation coordinate system - returncal: bool - It True, returns rotated map - """ - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - if mask is None: - mask = np.ones(self.g1g2_map.shape, dtype="bool") - - from py4DSTEM.process.latticevectors import get_strain_from_reference_region - - strainmap_g1g2 = get_strain_from_reference_region( - self.g1g2_map, - mask=mask, - ) - else: - from py4DSTEM.process.latticevectors import get_reference_g1g2 - - g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) - - from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 - - strainmap_g1g2 = get_strain_from_reference_g1g2( - self.g1g2_map, g1_ref, g2_ref - ) - - self.strainmap_g1g2 = strainmap_g1g2 - - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) - - from py4DSTEM.process.latticevectors import get_rotated_strain_map - - strainmap_rotated = get_rotated_strain_map( - self.strainmap_g1g2, - xaxis_x=g_reference[0], - xaxis_y=g_reference[1], - flip_theta=flip_theta, - ) - - self.strainmap_rotated = strainmap_rotated - - from py4DSTEM.visualize import show_strain - - figsize = kwargs.pop("figsize", (14, 4)) - vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) - vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) - ticknumber = kwargs.pop("ticknumber", 3) - bkgrd = kwargs.pop("bkgrd", False) - axes_plots = kwargs.pop("axes_plots", ()) - - fig, ax = show_strain( - self.strainmap_rotated, - vrange_exx=vrange_exx, - vrange_theta=vrange_theta, - ticknumber=ticknumber, - axes_plots=axes_plots, - bkgrd=bkgrd, - figsize=figsize, - **kwargs, - returnfig=True, - ) - - if not np.all(mask == True): - ax[0][0].imshow(mask, alpha=0.2, cmap="binary") - ax[0][1].imshow(mask, alpha=0.2, cmap="binary") - ax[1][0].imshow(mask, alpha=0.2, cmap="binary") - ax[1][1].imshow(mask, alpha=0.2, cmap="binary") - - if returncalc: - return self.strainmap_rotated - - def show_lattice_vectors( - ar, - x0, - y0, - g1, - g2, - color="r", - width=1, - labelsize=20, - labelcolor="w", - returnfig=False, - **kwargs, - ): - """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" - fig, ax = show(ar, returnfig=True, **kwargs) - - # Add vectors - dg1 = { - "x0": x0, - "y0": y0, - "vx": g1[0], - "vy": g1[1], - "width": width, - "color": color, - "label": r"$g_1$", - "labelsize": labelsize, - "labelcolor": labelcolor, - } - dg2 = { - "x0": x0, - "y0": y0, - "vx": g2[0], - "vy": g2[1], - "width": width, - "color": color, - "label": r"$g_2$", - "labelsize": labelsize, - "labelcolor": labelcolor, - } - add_vector(ax, dg1) - add_vector(ax, dg2) - - if returnfig: - return fig, ax - else: - plt.show() - return - - def show_bragg_indexing( - self, - ar, - bragg_directions, - voffset=5, - hoffset=0, - color="w", - size=20, - points=True, - pointcolor="r", - pointsize=50, - returnfig=False, - **kwargs, - ): - """ - Shows an array with an overlay describing the Bragg directions - - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. - """ - assert isinstance(bragg_directions, PointList) - for k in ("qx", "qy", "h", "k"): - assert k in bragg_directions.data.dtype.fields - - fig, ax = show(ar, returnfig=True, **kwargs) - d = { - "bragg_directions": bragg_directions, - "voffset": voffset, - "hoffset": hoffset, - "color": color, - "size": size, - "points": points, - "pointsize": pointsize, - "pointcolor": pointcolor, - } - add_bragg_index_labels(ax, d) - - if returnfig: - return fig, ax - else: - plt.show() - return - - def copy(self, name=None): - name = name if name is not None else self.name + "_copy" - strainmap_copy = StrainMap(self.braggvectors) - for attr in ( - "g", - "g0", - "g1", - "g2", - "calstate", - "bragg_directions", - "bragg_vectors_indexed", - "g1g2_map", - "strainmap_g1g2", - "strainmap_rotated", - ): - if hasattr(self, attr): - setattr(strainmap_copy, attr, getattr(self, attr)) - - for k in self.metadata.keys(): - strainmap_copy.metadata = self.metadata[k].copy() - return strainmap_copy - - # IO methods - - # read - @classmethod - def _get_constructor_args(cls, group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = RealSlice._get_constructor_args(group) - args = { - "data": ar_constr_args["data"], - "name": ar_constr_args["name"], - } - return args diff --git a/py4DSTEM/process/strain/__init__.py b/py4DSTEM/process/strain/__init__.py new file mode 100644 index 000000000..b487c916b --- /dev/null +++ b/py4DSTEM/process/strain/__init__.py @@ -0,0 +1,10 @@ +from py4DSTEM.process.strain.strain import StrainMap +from py4DSTEM.process.strain.latticevectors import ( + index_bragg_directions, + add_indices_to_braggvectors, + fit_lattice_vectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_strain_from_reference_g1g2, + get_rotated_strain_map, +) diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py new file mode 100644 index 000000000..26c8d66a5 --- /dev/null +++ b/py4DSTEM/process/strain/latticevectors.py @@ -0,0 +1,458 @@ +# Functions for indexing the Bragg directions + +import numpy as np +from emdfile import PointList, PointListArray, tqdmnd +from numpy.linalg import lstsq +from py4DSTEM.data import RealSlice + + +def index_bragg_directions(x0, y0, gx, gy, g1, g2): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + The approach is to solve the matrix equation + ``alpha = beta * M`` + where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, + beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the + h,k indices. + + Args: + x0 (float): x-coord of origin + y0 (float): y-coord of origin + gx (1d array): x-coord of the reciprocal lattice vectors + gy (1d array): y-coord of the reciprocal lattice vectors + g1 (2-tuple of floats): g1x,g1y + g2 (2-tuple of floats): g2x,g2y + + Returns: + (3-tuple) A 3-tuple containing: + + * **h**: *(ndarray of ints)* first index of the bragg directions + * **k**: *(ndarray of ints)* second index of the bragg directions + * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the + indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y + coords 'h' and 'k' contain h and k. + """ + # Get beta, the matrix of lattice vectors + beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) + + # Get alpha, the matrix of measured bragg angles + alpha = np.vstack([gx - x0, gy - y0]) + + # Calculate M, the matrix of peak positions + M = lstsq(beta, alpha, rcond=None)[0].T + M = np.round(M).astype(int) + + # Get h,k + h = M[:, 0] + k = M[:, 1] + + # Store in a PointList + coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] + temp_array = np.zeros([], dtype=coords) + bragg_directions = PointList(data=temp_array) + bragg_directions.add_data_by_field((gx, gy, h, k)) + mask = np.zeros(bragg_directions["qx"].shape[0]) + mask[0] = 1 + bragg_directions.remove(mask) + + return h, k, bragg_directions + + +def add_indices_to_braggvectors( + braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None +): + """ + Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, + identify the indices for each peak in the PointListArray braggpeaks. + Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus + three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak + indices with the ints (h,k) and indicating whether the peak was successfully indexed + or not with the bool index_mask. If `mask` is specified, only the locations where + mask is True are indexed. + + Args: + braggpeaks (PointListArray): the braggpeaks to index. Must contain + the coordinates 'qx', 'qy', and 'intensity' + lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. + Must contain the coordinates 'qx', 'qy', 'h', and 'k' + maxPeakSpacing (float): Maximum distance from the ideal lattice points + to include a peak for indexing + qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList + relative to the `braggpeaks` PointListArray + mask (bool): Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + + Returns: + (PointListArray): The original braggpeaks pointlistarray, with new coordinates + 'h', 'k', containing the indices of each indexable peak. + """ + + # assert isinstance(braggpeaks,BraggVectors) + # assert isinstance(lattice, PointList) + # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) + + if mask is None: + mask = np.ones(braggpeaks.Rshape, dtype=bool) + + assert ( + mask.shape == braggpeaks.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + + coords = [ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ] + + indexed_braggpeaks = PointListArray( + dtype=coords, + shape=braggpeaks.Rshape, + ) + + calstate = braggpeaks.calstate + + # loop over all the scan positions + for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + if mask[Rx, Ry]: + pl = braggpeaks.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) + for i in range(pl.data.shape[0]): + r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( + pl.data["qy"][i] - lattice.data["qy"] + qy_shift + ) ** 2 + ind = np.argmin(r2) + if r2[ind] <= maxPeakSpacing**2: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + lattice.data["h"][ind], + lattice.data["k"][ind], + ) + ) + + return indexed_braggpeaks + + +def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (7-tuple) A 7-tuple containing: + + * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. + * **y0**: *(float)* the y-coord of the origin + * **g1x**: *(float)* x-coord of the first lattice vector + * **g1y**: *(float)* y-coord of the first lattice vector + * **g2x**: *(float)* x-coord of the second lattice vector + * **g2y**: *(float)* y-coord of the second lattice vector + * **error**: *(float)* the fit error + """ + assert isinstance(braggpeaks, PointList) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + braggpeaks = braggpeaks.copy() + + # Remove unindexed peaks + if "index_mask" in braggpeaks.dtype.names: + deletemask = braggpeaks.data["index_mask"] == False + braggpeaks.remove(deletemask) + + # Check to ensure enough peaks are present + if braggpeaks.length < minNumPeaks: + return None, None, None, None, None, None, None + + # Get M, the matrix of (h,k) indices + h, k = braggpeaks.data["h"], braggpeaks.data["k"] + M = np.vstack((np.ones_like(h, dtype=int), h, k)).T + + # Get alpha, the matrix of measured Bragg peak positions + alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T + + # Get weighted matrices + weights = braggpeaks.data["intensity"] + weighted_M = M * weights[:, np.newaxis] + weighted_alpha = alpha * weights[:, np.newaxis] + + # Solve for lattice vectors + beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] + x0, y0 = beta[0, 0], beta[0, 1] + g1x, g1y = beta[1, 0], beta[1, 1] + g2x, g2y = beta[2, 0], beta[2, 1] + + # Calculate the error + alpha_calculated = np.matmul(M, beta) + error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) + error = np.sum(error * weights) / np.sum(weights) + + return x0, y0, g1x, g1y, g2x, g2y, error + + +def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some + known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: + + * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice + * ``g1g2_map.get_slice('y0')`` y-coord of the origin + * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector + * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector + * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector + * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector + * ``g1g2_map.get_slice('error')`` the fit error + * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful + fits + """ + assert isinstance(braggpeaks, PointListArray) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + + # Make RealSlice to contain outputs + slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") + g1g2_map = RealSlice( + data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), + slicelabels=slicelabels, + name="g1g2_map", + ) + + # Fit lattice vectors + for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): + braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) + qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( + braggpeaks_curr, x0, y0, minNumPeaks + ) + # Store data + if g1x is not None: + g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x + g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y + g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x + g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y + g1g2_map.get_slice("error").data[Rx, Ry] = error + g1g2_map.get_slice("mask").data[Rx, Ry] = 1 + + return g1g2_map + + +def get_reference_g1g2(g1g2_map, mask): + """ + Gets a pair of reference lattice vectors from a region of real space specified by + mask. Takes the median of the lattice vectors in g1g2_map within the specified + region. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever + mask==True + + Returns: + (2-tuple of 2-tuples) A 2-tuple containing: + + * **g1**: *(2-tuple)* first reference lattice vector (x,y) + * **g2**: *(2-tuple)* second reference lattice vector (x,y) + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] + ) + assert mask.dtype == bool + g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) + g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) + g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) + g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) + return (g1x, g1y), (g2x, g2y) + + +def get_strain_from_reference_g1g2(g1g2_map, g1, g2): + """ + Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map + g1g2_map. + + Note that this function will return the strain map oriented with respect to the x/y + axes of diffraction space - to rotate the coordinate system, use + get_rotated_strain_map(). Calibration of the rotational misalignment between real and + diffraction space may also be necessary. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + g1 (2-tuple): first reference lattice vector (x,y) + g2 (2-tuple): second reference lattice vector (x,y) + + Returns: + (RealSlice) the strain map; contains the elements of the infinitessimal strain + matrix, in the following 5 arrays: + + * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect + to x + * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect + to y + * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect + to y + * ``strain_map.get_slice('theta')``: rotation of lattice with respect to + reference + * ``strain_map.get_slice('mask')``: 0/False indicates unknown values + + Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] + ) + + # Get RealSlice for output storage + R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape + strain_map = RealSlice( + data=np.zeros((5, R_Nx, R_Ny)), + slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), + name="strain_map", + ) + + # Get reference lattice matrix + g1x, g1y = g1 + g2x, g2y = g2 + M = np.array([[g1x, g1y], [g2x, g2y]]) + + for Rx in range(R_Nx): + for Ry in range(R_Ny): + # Get lattice vectors for DP at Rx,Ry + alpha = np.array( + [ + [ + g1g2_map.get_slice("g1x").data[Rx, Ry], + g1g2_map.get_slice("g1y").data[Rx, Ry], + ], + [ + g1g2_map.get_slice("g2x").data[Rx, Ry], + g1g2_map.get_slice("g2y").data[Rx, Ry], + ], + ] + ) + # Get transformation matrix + beta = lstsq(M, alpha, rcond=None)[0].T + + # Get the infinitesimal strain matrix + strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] + strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] + strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 + strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 + strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ + Rx, Ry + ] + return strain_map + + +def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): + """ + Starting from a strain map defined with respect to the xy coordinate system of + diffraction space, i.e. where exx and eyy are the compression/tension along the Qx + and Qy directions, respectively, get a strain map defined with respect to some other + right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, + xaxis_y). + + Args: + xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector + along the new x-axis + unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the + infinitessimal strain matrix elements, stored at + * unrotated_strain_map.get_slice('e_xx') + * unrotated_strain_map.get_slice('e_xy') + * unrotated_strain_map.get_slice('e_yy') + * unrotated_strain_map.get_slice('theta') + + Returns: + (RealSlice) the rotated counterpart to unrotated_strain_map, with the + rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate + system + """ + assert isinstance(unrotated_strain_map, RealSlice) + assert np.all( + [ + key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] + for key in unrotated_strain_map.slicelabels + ] + ) + theta = -np.arctan2(xaxis_y, xaxis_x) + cost = np.cos(theta) + sint = np.sin(theta) + cost2 = cost**2 + sint2 = sint**2 + + Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape + rotated_strain_map = RealSlice( + data=np.zeros((5, Rx, Ry)), + slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], + name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), + ) + + rotated_strain_map.data[0, :, :] = ( + cost2 * unrotated_strain_map.get_slice("e_xx").data + - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + sint2 * unrotated_strain_map.get_slice("e_yy").data + ) + rotated_strain_map.data[1, :, :] = ( + cost + * sint + * ( + unrotated_strain_map.get_slice("e_xx").data + - unrotated_strain_map.get_slice("e_yy").data + ) + + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data + ) + rotated_strain_map.data[2, :, :] = ( + sint2 * unrotated_strain_map.get_slice("e_xx").data + + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + cost2 * unrotated_strain_map.get_slice("e_yy").data + ) + if flip_theta == True: + rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data + else: + rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data + rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data + return rotated_strain_map diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py new file mode 100644 index 000000000..538c90825 --- /dev/null +++ b/py4DSTEM/process/strain/strain.py @@ -0,0 +1,1011 @@ +# Defines the Strain class + +import warnings +from typing import Optional + +import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable +import numpy as np +from py4DSTEM import PointList, PointListArray, tqdmnd +from py4DSTEM.braggvectors import BraggVectors +from py4DSTEM.data import Data, RealSlice +from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.process.strain.latticevectors import ( + add_indices_to_braggvectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_rotated_strain_map, + get_strain_from_reference_g1g2, + index_bragg_directions, +) +from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show +from py4DSTEM.visualize import ax_addaxes, ax_addaxes_QtoR + +warnings.simplefilter(action="always", category=UserWarning) + + +class StrainMap(RealSlice, Data): + """ + Storage and processing methods for 4D-STEM datasets. + + """ + + def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): + """ + Accepts: + braggvectors (BraggVectors): BraggVectors for Strain Map + name (str): the name of the strainmap + Returns: + A new StrainMap instance. + """ + assert isinstance( + braggvectors, BraggVectors + ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" + + # initialize as a RealSlice + RealSlice.__init__( + self, + name=name, + data=np.empty( + ( + 6, + braggvectors.Rshape[0], + braggvectors.Rshape[1], + ) + ), + slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], + ) + + # set up braggvectors + # this assigns the bvs, ensures the origin is calibrated, + # and adds the strainmap to the bvs' tree + self.braggvectors = braggvectors + + # initialize as Data + Data.__init__(self) + + # set calstate + # this property is used only to check to make sure that + # the braggvectors being used throughout a workflow are + # the same. The state of calibration of the vectors is noted + # here, and then checked each time the vectors are used - + # if they differ, an error message and instructions for + # re-calibration are issued + self.calstate = self.braggvectors.calstate + assert self.calstate["center"], "braggvectors must be centered" + if self.calstate["rotate"] == False: + warnings.warn( + ("Real to reciprocal space rotation not calibrated"), + UserWarning, + ) + + # get the BVM + # a new BVM using the current calstate is computed + self.bvm = self.braggvectors.histogram(mode="cal") + + # braggvector properties + + @property + def braggvectors(self): + return self._braggvectors + + @braggvectors.setter + def braggvectors(self, x): + assert isinstance( + x, BraggVectors + ), f".braggvectors must be BraggVectors, not type {type(x)}" + assert ( + x.calibration.origin is not None + ), "braggvectors must have a calibrated origin" + self._braggvectors = x + self._braggvectors.tree(self, force=True) + + @property + def rshape(self): + return self._braggvectors.Rshape + + @property + def qshape(self): + return self._braggvectors.Qshape + + @property + def origin(self): + return self.calibration.get_origin_mean() + + @property + def mask(self): + try: + return self.g1g2_map["mask"].data.astype("bool") + except: + return np.ones(self.rshape, dtype=bool) + + def reset_calstate(self): + """ + Resets the calibration state. This recomputes the BVM, and removes any computations + this StrainMap instance has stored, which will need to be recomputed. + """ + for attr in ( + "g0", + "g1", + "g2", + ): + if hasattr(self, attr): + delattr(self, attr) + self.calstate = self.braggvectors.calstate + pass + + # Class methods + + def choose_lattice_vectors( + self, + index_g1=None, + index_g2=None, + index_origin=None, + subpixel="multicorr", + upsample_factor=16, + sigma=0, + minAbsoluteIntensity=0, + minRelativeIntensity=0, + relativeToPeak=0, + minSpacing=0, + edgeBoundary=1, + maxNumPeaks=10, + x0=None, + y0=None, + figsize=(14, 9), + c_indices="lightblue", + c0="g", + c1="r", + c2="r", + c_vectors="r", + c_vectorlabels="w", + size_indices=15, + width_vectors=1, + size_vectorlabels=15, + vis_params={}, + returncalc=False, + returnfig=False, + ): + """ + Choose which lattice vectors to use for strain mapping. + + Overlays the bvm with the points detected via local 2D + maxima detection, plus an index for each point. User selects + 3 points using the overlaid indices, which are identified as + the origin and the termini of the lattice vectors g1 and g2. + + Parameters + ---------- + index_g1 : int + selected index for g1 + index_g2 :int + selected index for g2 + index_origin : int + selected index for the origin + subpixel : str in ('pixel','poly','multicorr') + See the docstring for py4DSTEM.preprocess.get_maxima_2D + upsample_factor : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + sigma : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minAbsoluteIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minRelativeIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + relativeToPeak : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + minSpacing : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + edgeBoundary : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + maxNumPeaks : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + figsize : 2-tuple + the size of the figure + c_indices : color + color of the maxima + c0 : color + color of the origin + c1 : color + color of g1 point + c2 : color + color of g2 point + c_vectors : color + color of the g1/g2 vectors + c_vectorlabels : color + color of the vector labels + size_indices : number + size of the indices + width_vectors : number + width of the vectors + size_vectorlabels : number + size of the vector labels + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + toggles returning the answer + returnfig : bool + toggles returning the figure + + Returns + ------- + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter + """ + # validate inputs + for i in (index_origin, index_g1, index_g2): + assert isinstance(i, (int, np.integer)) or ( + i is None + ), "indices must be integers!" + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # find the maxima + + g = get_maxima_2D( + self.bvm.data, + subpixel=subpixel, + upsample_factor=upsample_factor, + sigma=sigma, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + ) + + # guess the origin and g1 g2 vectors if indices aren't provided + if np.any([x is None for x in (index_g1, index_g2, index_origin)]): + # get distances and angles from calibrated origin + g_dists = np.hypot(g["x"] - self.origin[0], g["y"] - self.origin[1]) + g_angles = np.angle( + g["x"] - self.origin[0] + 1j * (g["y"] - self.origin[1]) + ) + + # guess the origin + if index_origin is None: + index_origin = np.argmin(g_dists) + g_dists[index_origin] = 2 * np.max(g_dists) + + # guess g1 + if index_g1 is None: + index_g1 = np.argmin(g_dists) + g_dists[index_g1] = 2 * np.max(g_dists) + + # guess g2 + if index_g2 is None: + angle_scaling = np.cos(g_angles - g_angles[index_g1]) ** 2 + index_g2 = np.argmin(g_dists * (angle_scaling + 0.1)) + + # get the lattice vectors + gx, gy = g["x"], g["y"] + g0 = gx[index_origin], gy[index_origin] + g1x = gx[index_g1] - g0[0] + g1y = gy[index_g1] - g0[1] + g2x = gx[index_g2] - g0[0] + g2y = gy[index_g2] - g0[1] + g1, g2 = (g1x, g1y), (g2x, g2y) + + # index the lattice vectors + _, _, braggdirections = index_bragg_directions( + g0[0], g0[1], g["x"], g["y"], g1, g2 + ) + + # make the figure + fig, ax = plt.subplots(1, 3, figsize=figsize) + show(self.bvm.data, figax=(fig, ax[0]), **vis_params) + show(self.bvm.data, figax=(fig, ax[1]), **vis_params) + self.show_bragg_indexing( + self.bvm.data, + bragg_directions=braggdirections, + points=True, + figax=(fig, ax[2]), + size=size_indices, + **vis_params, + ) + + # Add indices to left panel + d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} + d0 = { + "x": gx[index_origin], + "y": gy[index_origin], + "size": size_indices, + "color": c0, + "fontweight": "bold", + "labels": [str(index_origin)], + } + d1 = { + "x": gx[index_g1], + "y": gy[index_g1], + "size": size_indices, + "color": c1, + "fontweight": "bold", + "labels": [str(index_g1)], + } + d2 = { + "x": gx[index_g2], + "y": gy[index_g2], + "size": size_indices, + "color": c2, + "fontweight": "bold", + "labels": [str(index_g2)], + } + add_pointlabels(ax[0], d) + add_pointlabels(ax[0], d0) + add_pointlabels(ax[0], d1) + add_pointlabels(ax[0], d2) + + # Add vectors to right panel + dg1 = { + "x0": gx[index_origin], + "y0": gy[index_origin], + "vx": g1[0], + "vy": g1[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_1$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + dg2 = { + "x0": gx[index_origin], + "y0": gy[index_origin], + "vx": g2[0], + "vy": g2[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_2$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + add_vector(ax[1], dg1) + add_vector(ax[1], dg2) + + # store vectors + self.g = g + self.g0 = g0 + self.g1 = g1 + self.g2 = g2 + + # center the bragg directions and store + braggdirections.data["qx"] -= self.origin[0] + braggdirections.data["qy"] -= self.origin[1] + self.braggdirections = braggdirections + + # return + if returncalc and returnfig: + return (self.g0, self.g1, self.g2, self.braggdirections), (fig, ax) + elif returncalc: + return (self.g0, self.g1, self.g2, self.braggdirections) + elif returnfig: + return (fig, ax) + else: + return + + def fit_lattice_vectors( + self, + max_peak_spacing=2, + mask=None, + returncalc=False, + ): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + Args: + max_peak_spacing: float + Maximum distance from the ideal lattice points + to include a peak for indexing + mask: bool + Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + ### add indices to the bragg vectors + + # validate mask + if mask is None: + mask = np.ones(self.braggvectors.Rshape, dtype=bool) + assert ( + mask.shape == self.braggvectors.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + + # set up new braggpeaks PLA + indexed_braggpeaks = PointListArray( + dtype=[ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ], + shape=self.braggvectors.Rshape, + ) + calstate = self.braggvectors.calstate + + # loop over all the scan positions + for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + if mask[Rx, Ry]: + pl = self.braggvectors.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) + for i in range(pl.data.shape[0]): + r = np.hypot( + pl.data["qx"][i] - self.braggdirections.data["qx"], + pl.data["qy"][i] - self.braggdirections.data["qy"], + ) + ind = np.argmin(r) + if r[ind] <= max_peak_spacing: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + self.braggdirections.data["h"][ind], + self.braggdirections.data["k"][ind], + ) + ) + self.bragg_vectors_indexed = indexed_braggpeaks + + ### fit bragg vectors + g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) + self.g1g2_map = g1g2_map + + # return + if returncalc: + return self.bragg_vectors_indexed, self.g1g2_map + + def get_strain( + self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs + ): + """ + mask: nd.array (bool) + Use lattice vectors from g1g2_map scan positions + wherever mask==True. If mask is None gets median strain + map from entire field of view. If mask is not None, gets + reference g1 and g2 from region and then calculates strain. + g_reference: nd.array of form [x,y] + G_reference (tupe): reference coordinate system for + xaxis_x and xaxis_y + flip_theta: bool + If True, flips rotation coordinate system + returncal: bool + It True, returns rotated map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + if mask is None: + mask = self.mask + # mask = np.ones(self.g1g2_map.shape, dtype="bool") + # strainmap_g1g2 = get_strain_from_reference_region( + # self.g1g2_map, + # mask=mask, + # ) + + # g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + # strain_map = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) + # else: + + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + + strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) + + self.strainmap_g1g2 = strainmap_g1g2 + + if g_reference is None: + g_reference = np.subtract(self.g1, self.g2) + + strainmap_rotated = get_rotated_strain_map( + self.strainmap_g1g2, + xaxis_x=g_reference[0], + xaxis_y=g_reference[1], + flip_theta=flip_theta, + ) + + self.data[0] = strainmap_rotated["e_xx"].data + self.data[1] = strainmap_rotated["e_yy"].data + self.data[2] = strainmap_rotated["e_xy"].data + self.data[3] = strainmap_rotated["theta"].data + self.data[4] = strainmap_rotated["mask"].data + self.g_reference = g_reference + + figsize = kwargs.pop("figsize", (14, 4)) + vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) + vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) + ticknumber = kwargs.pop("ticknumber", 3) + bkgrd = kwargs.pop("bkgrd", False) + axes_plots = kwargs.pop("axes_plots", ()) + + fig, ax = self.show_strain( + vrange_exx=vrange_exx, + vrange_theta=vrange_theta, + ticknumber=ticknumber, + axes_plots=axes_plots, + bkgrd=bkgrd, + figsize=figsize, + **kwargs, + returnfig=True, + ) + + if not np.all(mask == True): + ax[0][0].imshow(mask, alpha=0.2, cmap="binary") + ax[0][1].imshow(mask, alpha=0.2, cmap="binary") + ax[1][0].imshow(mask, alpha=0.2, cmap="binary") + ax[1][1].imshow(mask, alpha=0.2, cmap="binary") + + if returncalc: + return self.strainmap + + def show_strain( + self, + vrange_exx, + vrange_theta, + vrange_exy=None, + vrange_eyy=None, + flip_theta=False, + bkgrd=True, + show_cbars=("exx", "eyy", "exy", "theta"), + bordercolor="k", + borderwidth=1, + titlesize=24, + ticklabelsize=16, + ticknumber=5, + unitlabelsize=24, + show_axes=False, + axes_position=(0, 0), + axes_length=10, + axes_width=1, + axes_color="w", + xaxis_space="Q", + labelaxes=True, + QR_rotation=0, + axes_labelsize=12, + axes_labelcolor="r", + axes_plots=("exx"), + cmap="RdBu_r", + mask_color="k", + layout=0, + figsize=(12, 12), + returnfig=False, + ): + """ + Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and + masking each image with strainmap.get_slice('mask') + + Args: + vrange_exx (length 2 list or tuple): + vrange_theta (length 2 list or tuple): + vrange_exy (length 2 list or tuple): + vrange_eyy (length 2 list or tuple): + flip_theta (bool): if True, take negative of angle + bkgrd (bool): + show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a + tuple containing any, all, or none of ('exx','eyy','exy','theta'). + bordercolor (color): + borderwidth (number): + titlesize (number): + ticklabelsize (number): + ticknumber (number): number of ticks on colorbars + unitlabelsize (number): + show_axes (bool): + axes_x0 (number): + axes_y0 (number): + xaxis_x (number): + xaxis_y (number): + axes_length (number): + axes_width (number): + axes_color (color): + xaxis_space (string): must be 'Q' or 'R' + labelaxes (bool): + QR_rotation (number): + axes_labelsize (number): + axes_labelcolor (color): + axes_plots (tuple of strings): controls if coordinate axes showing the + orientation of the strain matrices are overlaid over any of the plots. + Must be a tuple of strings containing any, all, or none of + ('exx','eyy','exy','theta'). + cmap (colormap): + layout=0 (int): determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize (length 2 tuple of numbers): + returnfig (bool): + """ + # Lookup table for different layouts + assert layout in (0, 1, 2) + layout_lookup = { + 0: ["left", "right", "left", "right"], + 1: ["bottom", "bottom", "bottom", "bottom"], + 2: ["right", "right", "right", "right"], + } + layout_p = layout_lookup[layout] + + # Contrast limits + if vrange_exy is None: + vrange_exy = vrange_exx + if vrange_eyy is None: + vrange_eyy = vrange_exx + for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): + assert len(vrange) == 2, "vranges must have length 2" + vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 + vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 + vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 + # theta is plotted in units of degrees + vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( + 180.0 / np.pi + ) + + # Get images + e_xx = np.ma.array( + self.get_slice("exx").data, mask=self.get_slice("mask").data == False + ) + e_yy = np.ma.array( + self.get_slice("eyy").data, mask=self.get_slice("mask").data == False + ) + e_xy = np.ma.array( + self.get_slice("exy").data, mask=self.get_slice("mask").data == False + ) + theta = np.ma.array( + self.get_slice("theta").data, + mask=self.get_slice("mask").data == False, + ) + if flip_theta == True: + theta = -theta + + ## Plot + + # modify the figsize according to the image aspect ratio + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) + figsize_mean = np.mean(figsize) + figsize = (figsize_mean * ratio, figsize_mean / ratio) + + # set up layout + if layout == 0: + fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) + elif layout == 1: + figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) + else: + figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) + + # display images, returning cbar axis references + cax11 = show( + e_xx, + figax=(fig, ax11), + vmin=vmin_exx, + vmax=vmax_exx, + intensity_range="absolute", + cmap=cmap, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + cax12 = show( + e_yy, + figax=(fig, ax12), + vmin=vmin_eyy, + vmax=vmax_eyy, + intensity_range="absolute", + cmap=cmap, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + cax21 = show( + e_xy, + figax=(fig, ax21), + vmin=vmin_exy, + vmax=vmax_exy, + intensity_range="absolute", + cmap=cmap, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + cax22 = show( + theta, + figax=(fig, ax22), + vmin=vmin_theta, + vmax=vmax_theta, + intensity_range="absolute", + cmap=cmap, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) + ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) + ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) + ax22.set_title(r"$\theta$", size=titlesize) + + # Add black background + if bkgrd: + mask = np.ma.masked_where( + self.get_slice("mask").data.astype(bool), + np.zeros_like(self.get_slice("mask").data), + ) + ax11.matshow(mask, cmap="gray") + ax12.matshow(mask, cmap="gray") + ax21.matshow(mask, cmap="gray") + ax22.matshow(mask, cmap="gray") + + # add colorbars + show_cbars = np.array( + [ + "exx" in show_cbars, + "eyy" in show_cbars, + "exy" in show_cbars, + "theta" in show_cbars, + ] + ) + if np.any(show_cbars): + divider11 = make_axes_locatable(ax11) + divider12 = make_axes_locatable(ax12) + divider21 = make_axes_locatable(ax21) + divider22 = make_axes_locatable(ax22) + cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) + cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) + cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) + cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) + for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( + range(4), + show_cbars, + (cax11, cax12, cax21, cax22), + (cbax11, cbax12, cbax21, cbax22), + (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), + (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), + (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), + ("% ", " %", "% ", r" $^\circ$"), + ): + if show_cbar: + ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) + if ind < 3: + ticklabels = np.round( + np.linspace( + 100 * vmin, 100 * vmax, ticknumber, endpoint=True + ), + decimals=2, + ).astype(str) + else: + ticklabels = np.round( + np.linspace( + (180 / np.pi) * vmin, + (180 / np.pi) * vmax, + ticknumber, + endpoint=True, + ), + decimals=2, + ).astype(str) + + if tickside in ("left", "right"): + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="vertical" + ) + cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) + cbax.yaxis.set_ticks_position(tickside) + cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) + cbax.yaxis.set_label_position(tickside) + else: + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="horizontal" + ) + cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) + cbax.xaxis.set_ticks_position(tickside) + cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) + cbax.xaxis.set_label_position(tickside) + else: + cbax.axis("off") + + # Add coordinate axes + if show_axes: + assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" + show_which_axes = np.array( + [ + "exx" in axes_plots, + "eyy" in axes_plots, + "exy" in axes_plots, + "theta" in axes_plots, + ] + ) + for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): + if _show: + if xaxis_space == "R": + ax_addaxes( + _ax, + self.g_reference[0], + self.g_reference[1], + axes_length, + axes_position[0], + axes_position[1], + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) + else: + ax_addaxes_QtoR( + _ax, + self.g_reference[0], + self.g_reference[1], + axes_length, + axes_position[0], + axes_position[1], + QR_rotation, + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) + + # Add borders + if bordercolor is not None: + for ax in (ax11, ax12, ax21, ax22): + for s in ["bottom", "top", "left", "right"]: + ax.spines[s].set_color(bordercolor) + ax.spines[s].set_linewidth(borderwidth) + ax.set_xticks([]) + ax.set_yticks([]) + + if not returnfig: + plt.show() + return + else: + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs + + def show_lattice_vectors( + ar, + x0, + y0, + g1, + g2, + color="r", + width=1, + labelsize=20, + labelcolor="w", + returnfig=False, + **kwargs, + ): + """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" + fig, ax = show(ar, returnfig=True, **kwargs) + + # Add vectors + dg1 = { + "x0": x0, + "y0": y0, + "vx": g1[0], + "vy": g1[1], + "width": width, + "color": color, + "label": r"$g_1$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + dg2 = { + "x0": x0, + "y0": y0, + "vx": g2[0], + "vy": g2[1], + "width": width, + "color": color, + "label": r"$g_2$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + add_vector(ax, dg1) + add_vector(ax, dg2) + + if returnfig: + return fig, ax + else: + plt.show() + return + + def show_bragg_indexing( + self, + ar, + bragg_directions, + voffset=5, + hoffset=0, + color="w", + size=20, + points=True, + pointcolor="r", + pointsize=50, + figax=None, + returnfig=False, + **kwargs, + ): + """ + Shows an array with an overlay describing the Bragg directions + + Accepts: + ar (arrray) the image + bragg_directions (PointList) the bragg scattering directions; must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. + """ + assert isinstance(bragg_directions, PointList) + for k in ("qx", "qy", "h", "k"): + assert k in bragg_directions.data.dtype.fields + + if figax is None: + fig, ax = show(ar, returnfig=True, **kwargs) + else: + fig = figax[0] + ax = figax[1] + show(ar, figax=figax, **kwargs) + + d = { + "bragg_directions": bragg_directions, + "voffset": voffset, + "hoffset": hoffset, + "color": color, + "size": size, + "points": points, + "pointsize": pointsize, + "pointcolor": pointcolor, + } + add_bragg_index_labels(ax, d) + + if returnfig: + return fig, ax + else: + return + + def copy(self, name=None): + name = name if name is not None else self.name + "_copy" + strainmap_copy = StrainMap(self.braggvectors) + for attr in ( + "g", + "g0", + "g1", + "g2", + "calstate", + "bragg_directions", + "bragg_vectors_indexed", + "g1g2_map", + "strainmap_g1g2", + "strainmap_rotated", + ): + if hasattr(self, attr): + setattr(strainmap_copy, attr, getattr(self, attr)) + + for k in self.metadata.keys(): + strainmap_copy.metadata = self.metadata[k].copy() + return strainmap_copy + + # TODO IO methods + + # read + @classmethod + def _get_constructor_args(cls, group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = RealSlice._get_constructor_args(group) + args = { + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], + } + return args diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index cfa017299..d1efbd023 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -156,16 +156,16 @@ def show_amorphous_ring_fit( mask=np.logical_not(mask), mask_color="empty", returnfig=True, - returnclipvals=True, + return_intensity_range=True, **kwargs, ) show( fit, scaling=scaling, figax=(fig, ax), - clipvals="manual", - min=vmin, - max=vmax, + intensity_range="absolute", + vmin=vmin, + vmax=vmax, cmap=cmap_fit, mask=mask, mask_color="empty", @@ -404,308 +404,6 @@ def show_class_BPs_grid( return fig, axs -def show_strain( - strainmap, - vrange_exx, - vrange_theta, - vrange_exy=None, - vrange_eyy=None, - flip_theta=False, - bkgrd=True, - show_cbars=("exx", "eyy", "exy", "theta"), - bordercolor="k", - borderwidth=1, - titlesize=24, - ticklabelsize=16, - ticknumber=5, - unitlabelsize=24, - show_axes=True, - axes_x0=0, - axes_y0=0, - xaxis_x=1, - xaxis_y=0, - axes_length=10, - axes_width=1, - axes_color="r", - xaxis_space="Q", - labelaxes=True, - QR_rotation=0, - axes_labelsize=12, - axes_labelcolor="r", - axes_plots=("exx"), - cmap="RdBu_r", - layout=0, - figsize=(12, 12), - returnfig=False, -): - """ - Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') - - Args: - strainmap (RealSlice): - vrange_exx (length 2 list or tuple): - vrange_theta (length 2 list or tuple): - vrange_exy (length 2 list or tuple): - vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle - bkgrd (bool): - show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a - tuple containing any, all, or none of ('exx','eyy','exy','theta'). - bordercolor (color): - borderwidth (number): - titlesize (number): - ticklabelsize (number): - ticknumber (number): number of ticks on colorbars - unitlabelsize (number): - show_axes (bool): - axes_x0 (number): - axes_y0 (number): - xaxis_x (number): - xaxis_y (number): - axes_length (number): - axes_width (number): - axes_color (color): - xaxis_space (string): must be 'Q' or 'R' - labelaxes (bool): - QR_rotation (number): - axes_labelsize (number): - axes_labelcolor (color): - axes_plots (tuple of strings): controls if coordinate axes showing the - orientation of the strain matrices are overlaid over any of the plots. - Must be a tuple of strings containing any, all, or none of - ('exx','eyy','exy','theta'). - cmap (colormap): - layout=0 (int): determines the layout of the grid which the strain components - will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). - figsize (length 2 tuple of numbers): - returnfig (bool): - """ - # Lookup table for different layouts - assert layout in (0, 1, 2) - layout_lookup = { - 0: ["left", "right", "left", "right"], - 1: ["bottom", "bottom", "bottom", "bottom"], - 2: ["right", "right", "right", "right"], - } - layout_p = layout_lookup[layout] - - # Contrast limits - if vrange_exy is None: - vrange_exy = vrange_exx - if vrange_eyy is None: - vrange_eyy = vrange_exx - for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): - assert len(vrange) == 2, "vranges must have length 2" - vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 - vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 - vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 - # theta is plotted in units of degrees - vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( - 180.0 / np.pi - ) - - # Get images - e_xx = np.ma.array( - strainmap.get_slice("e_xx").data, mask=strainmap.get_slice("mask").data == False - ) - e_yy = np.ma.array( - strainmap.get_slice("e_yy").data, mask=strainmap.get_slice("mask").data == False - ) - e_xy = np.ma.array( - strainmap.get_slice("e_xy").data, mask=strainmap.get_slice("mask").data == False - ) - theta = np.ma.array( - strainmap.get_slice("theta").data, - mask=strainmap.get_slice("mask").data == False, - ) - if flip_theta == True: - theta = -theta - - # Plot - if layout == 0: - fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) - elif layout == 1: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) - else: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) - cax11 = show( - e_xx, - figax=(fig, ax11), - vmin=vmin_exx, - vmax=vmax_exx, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax12 = show( - e_yy, - figax=(fig, ax12), - vmin=vmin_eyy, - vmax=vmax_eyy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax21 = show( - e_xy, - figax=(fig, ax21), - vmin=vmin_exy, - vmax=vmax_exy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax22 = show( - theta, - figax=(fig, ax22), - vmin=vmin_theta, - vmax=vmax_theta, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) - ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) - ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) - ax22.set_title(r"$\theta$", size=titlesize) - - # Add black background - if bkgrd: - mask = np.ma.masked_where( - strainmap.get_slice("mask").data.astype(bool), - np.zeros_like(strainmap.get_slice("mask").data), - ) - ax11.matshow(mask, cmap="gray") - ax12.matshow(mask, cmap="gray") - ax21.matshow(mask, cmap="gray") - ax22.matshow(mask, cmap="gray") - - # Colorbars - show_cbars = np.array( - [ - "exx" in show_cbars, - "eyy" in show_cbars, - "exy" in show_cbars, - "theta" in show_cbars, - ] - ) - if np.any(show_cbars): - divider11 = make_axes_locatable(ax11) - divider12 = make_axes_locatable(ax12) - divider21 = make_axes_locatable(ax21) - divider22 = make_axes_locatable(ax22) - cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) - cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) - cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) - cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) - for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( - range(4), - show_cbars, - (cax11, cax12, cax21, cax22), - (cbax11, cbax12, cbax21, cbax22), - (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), - (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), - (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), - ("% ", " %", "% ", r" $^\circ$"), - ): - if show_cbar: - ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) - if ind < 3: - ticklabels = np.round( - np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), - decimals=2, - ).astype(str) - else: - ticklabels = np.round( - np.linspace( - (180 / np.pi) * vmin, - (180 / np.pi) * vmax, - ticknumber, - endpoint=True, - ), - decimals=2, - ).astype(str) - - if tickside in ("left", "right"): - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="vertical" - ) - cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) - cbax.yaxis.set_ticks_position(tickside) - cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) - cbax.yaxis.set_label_position(tickside) - else: - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="horizontal" - ) - cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) - cbax.xaxis.set_ticks_position(tickside) - cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) - cbax.xaxis.set_label_position(tickside) - else: - cbax.axis("off") - - # Add coordinate axes - if show_axes: - assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array( - [ - "exx" in axes_plots, - "eyy" in axes_plots, - "exy" in axes_plots, - "theta" in axes_plots, - ] - ) - for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): - if _show: - if xaxis_space == "R": - ax_addaxes( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - else: - ax_addaxes_QtoR( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - QR_rotation, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - - # Add borders - if bordercolor is not None: - for ax in (ax11, ax12, ax21, ax22): - for s in ["bottom", "top", "left", "right"]: - ax.spines[s].set_color(bordercolor) - ax.spines[s].set_linewidth(borderwidth) - ax.set_xticks([]) - ax.set_yticks([]) - - if not returnfig: - plt.show() - return - else: - axs = ((ax11, ax12), (ax21, ax22)) - return fig, axs - - def show_pointlabels( ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs ):