diff --git a/.github/scripts/update_version.py b/.github/scripts/update_version.py index 2aaaa07af..635cf8268 100644 --- a/.github/scripts/update_version.py +++ b/.github/scripts/update_version.py @@ -8,7 +8,7 @@ lines = f.readlines() line_split = lines[0].split(".") -patch_number = line_split[2].split("'")[0] +patch_number = line_split[2].split("'")[0].split('"')[0] # Increment patch number patch_number = str(int(patch_number) + 1) + "'" diff --git a/.github/workflows/check_install_dev.yml b/.github/workflows/check_install_dev.yml index 6e44a6334..82701d50d 100644 --- a/.github/workflows/check_install_dev.yml +++ b/.github/workflows/check_install_dev.yml @@ -17,10 +17,10 @@ jobs: runs-on: [ubuntu-latest] architecture: [x86_64] python-version: ["3.9", "3.10", "3.11",] - include: - - python-version: "3.12.0-beta.4" - runs-on: ubuntu-latest - allow_failure: true + # include: + # - python-version: "3.12.0-beta.4" + # runs-on: ubuntu-latest + # allow_failure: true # Currently no public runners available for this but this or arm64 should work next time # include: # - python-version: "3.10" diff --git a/README.md b/README.md index 0561f098a..aa102542a 100644 --- a/README.md +++ b/README.md @@ -46,42 +46,50 @@ First, download and install Anaconda: www.anaconda.com/download. If you prefer a more lightweight conda client, you can instead install Miniconda: https://docs.conda.io/en/latest/miniconda.html. Then open a conda terminal and run one of the following sets of commands to ensure everything is up-to-date and create a new environment for your py4DSTEM installation: - ``` conda update conda conda create -n py4dstem conda activate py4dstem +conda install -c conda-forge py4dstem pymatgen jupyterlab ``` -Next, install py4DSTEM. To simultaneously install py4DSTEM with `pymatgen` (used in some crystal structure workflows) and `jupyterlab` (providing an interface for running Python notebooks like those provided in the [py4DSTEM tutorials repository](https://github.com/py4dstem/py4DSTEM_tutorials)) run: +In order, these commands +- ensure your installation of anaconda is up-to-date +- make a virtual environment (see below) +- enter the environment +- install py4DSTEM, as well as pymatgen (used for crystal structure calculations) and JupyterLab (an interface for running Python notebooks like those in the [py4DSTEM tutorials repository](https://github.com/py4dstem/py4DSTEM_tutorials)) + + +We've had some recent reports install of `conda` getting stuck trying to solve the environment using the above installation. If you run into this problem, you can install py4DSTEM using `pip` instead of `conda` by running: ``` -conda install -c conda-forge py4dstem pymatgen jupyterlab +conda update conda +conda create -n py4dstem python=3.10 +conda activate py4dstem +pip install py4dstem pymatgen ``` -Or if you would prefer to install only the base modules of **py4DSTEM**, you can instead run: +Both `conda` and `pip` are programs which manage package installations, i.e. make sure different codes you're installing which depend on one another are using mutually compatible versions. Each has advantages and disadvantages; `pip` is a little more bare-bones, and we've seen this install work when `conda` doesn't. If you also want to use Jupyterlab you can then use either `pip install jupyterlab` or `conda install jupyterlab`. + +If you would prefer to install only the base modules of **py4DSTEM**, and skip pymategen and Jupterlab, you can instead run: ``` conda install -c conda-forge py4dstem ``` -In Windows you should then also run: +Finally, regardless of which of the above approaches you used, in Windows you should then also run: ``` conda install pywin32 ``` -In order, these commands -- ensure your installation of anaconda is up-to-date -- make a virtual environment (see below) -- enter the environment -- install py4DSTEM, and optionally also pymatgen and JupyterLab -- on Windows, enable python to talk to the windows API +which enables Python to talk to the Windows API. Please note that virtual environments are used in the instructions above in order to make sure packages that have different dependencies don't conflict with one another. Because these directions install py4DSTEM to its own virtual environment, each time you want to use py4DSTEM you'll need to activate this environment. You can do this in the command line by running `conda activate py4dstem`, or, if you're using the Anaconda Navigator, by clicking on the Environments tab and then clicking on `py4dstem`. +Last - as of the version 0.14.4 update, we've had a few reports of problems upgrading to the newest version. We're not sure what's causing the issue yet, but have found the new version can be installed successfully in these cases using a fresh Anaconda installation. diff --git a/docs/requirements.txt b/docs/requirements.txt index 43dbc0817..03ecc7e26 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,3 @@ emdfile -# py4dstem \ No newline at end of file +sphinx_rtd_theme +# py4dstem diff --git a/docs/source/conf.py b/docs/source/conf.py index 30ee084fe..6da66611e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -36,7 +36,12 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.intersphinx"] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx_rtd_theme", +] # Other useful extensions # sphinx_copybutton diff --git a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py index d0f550dcc..c5f89b9fd 100644 --- a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py @@ -17,8 +17,8 @@ try: import cupy as cp -except: - raise ImportError("Import Error: Please install cupy before proceeding") +except ModuleNotFoundError: + raise ImportError("AIML CUDA Requires cupy") try: import tensorflow as tf diff --git a/py4DSTEM/preprocess/utils.py b/py4DSTEM/preprocess/utils.py index 0c76f35a7..752e2f81c 100644 --- a/py4DSTEM/preprocess/utils.py +++ b/py4DSTEM/preprocess/utils.py @@ -5,8 +5,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def bin2D(array, factor, dtype=np.float64): diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index fa735438b..b508d589e 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -2,23 +2,15 @@ import numpy as np import matplotlib.pyplot as plt +from matplotlib.patches import Circle from fractions import Fraction from typing import Union, Optional -from scipy.optimize import curve_fit import sys -from emdfile import tqdmnd, PointList, PointListArray +from emdfile import PointList from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom -from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -from py4DSTEM.process.diffraction.crystal_viz import plot_ring_pattern -from py4DSTEM.process.diffraction.utils import Orientation, calc_1D_profile - -try: - from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - from pymatgen.core.structure import Structure -except ImportError: - pass +from py4DSTEM.process.diffraction.utils import Orientation class Crystal: @@ -36,6 +28,8 @@ class Crystal: orientation_plan, match_orientations, match_single_pattern, + cluster_grains, + cluster_orientation_map, calculate_strain, save_ang_file, symmetry_reduce_directions, @@ -51,6 +45,8 @@ class Crystal: plot_orientation_plan, plot_orientation_maps, plot_fiber_orientation_maps, + plot_clusters, + plot_cluster_size, ) from py4DSTEM.process.diffraction.crystal_calibrate import ( @@ -1074,3 +1070,517 @@ def calculate_bragg_peak_histogram( int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power) int_exp /= np.max(int_exp) return k, int_exp + + +def generate_moire_diffraction_pattern( + bragg_peaks_0, + bragg_peaks_1, + thresh_0=0.0002, + thresh_1=0.0002, + exx_1=0.0, + eyy_1=0.0, + exy_1=0.0, + phi_1=0.0, + power=2.0, +): + """ + Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated + and strained with respect to the original lattice. Note that this strain is applied in real space, + and so the inverse of the calculated infinitestimal strain tensor is applied. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + thresh_0: float + Intensity threshold for structure factors from lattice 0. + thresh_1: float + Intensity threshold for structure factors from lattice 1. + exx_1: float + Strain of lattice 1 in x direction (vertical) in real space. + eyy_1: float + Strain of lattice 1 in y direction (horizontal) in real space. + exy_1: float + Shear strain of lattice 1 in (x,y) direction (diagonal) in real space. + phi_1: float + Rotation of lattice 1 in real space. + power: float + Plotting power law (default is amplitude**2.0, i.e. intensity). + + Returns + -------- + parent_peaks_0, parent_peaks_1, moire_peaks: BraggVectors + Bragg vectors for the rotated & strained parent lattices + and the moire lattice + + """ + + # get intenties of all peaks + int0 = bragg_peaks_0["intensity"] ** (power / 2.0) + int1 = bragg_peaks_1["intensity"] ** (power / 2.0) + + # peaks above threshold + sub0 = int0 >= thresh_0 + sub1 = int1 >= thresh_1 + + # Remove origin (assuming brightest peak) + ind0_or = np.argmax(bragg_peaks_0["intensity"]) + ind1_or = np.argmax(bragg_peaks_1["intensity"]) + sub0[ind0_or] = False + sub1[ind1_or] = False + int0_sub = int0[sub0] + int1_sub = int1[sub1] + + # Get peaks + qx0 = bragg_peaks_0["qx"][sub0] + qy0 = bragg_peaks_0["qy"][sub0] + qx1_init = bragg_peaks_1["qx"][sub1] + qy1_init = bragg_peaks_1["qy"][sub1] + + # peak labels + h0 = bragg_peaks_0["h"][sub0] + k0 = bragg_peaks_0["k"][sub0] + l0 = bragg_peaks_0["l"][sub0] + h1 = bragg_peaks_1["h"][sub1] + k1 = bragg_peaks_1["k"][sub1] + l1 = bragg_peaks_1["l"][sub1] + + # apply strain tensor to lattice 1 + m = np.array( + [ + [np.cos(phi_1), -np.sin(phi_1)], + [np.sin(phi_1), np.cos(phi_1)], + ] + ) @ np.linalg.inv( + np.array( + [ + [1 + exx_1, exy_1 * 0.5], + [exy_1 * 0.5, 1 + eyy_1], + ] + ) + ) + qx1 = m[0, 0] * qx1_init + m[0, 1] * qy1_init + qy1 = m[1, 0] * qx1_init + m[1, 1] * qy1_init + + # Generate moire lattice + ind0, ind1 = np.meshgrid( + np.arange(np.sum(sub0)), + np.arange(np.sum(sub1)), + indexing="ij", + ) + qx = qx0[ind0] + qx1[ind1] + qy = qy0[ind0] + qy1[ind1] + int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5 + + # moire labels + m_h0 = h0[ind0] + m_k0 = k0[ind0] + m_l0 = l0[ind0] + m_h1 = h1[ind1] + m_k1 = k1[ind1] + m_l1 = l1[ind1] + + # Convert thresholded and moire peaks to BraggVector class + + pl_dtype_parent = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h", "int"), + ("k", "int"), + ("l", "int"), + ] + ) + + bragg_parent_0 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_0.add_data_by_field( + [ + qx0.ravel(), + qy0.ravel(), + int0_sub.ravel(), + h0.ravel(), + k0.ravel(), + l0.ravel(), + ] + ) + + bragg_parent_1 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_1.add_data_by_field( + [ + qx1.ravel(), + qy1.ravel(), + int1_sub.ravel(), + h1.ravel(), + k1.ravel(), + l1.ravel(), + ] + ) + + pl_dtype = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h0", "int"), + ("k0", "int"), + ("l0", "int"), + ("h1", "int"), + ("k1", "int"), + ("l1", "int"), + ] + ) + bragg_moire = PointList(np.array([], dtype=pl_dtype)) + bragg_moire.add_data_by_field( + [ + qx.ravel(), + qy.ravel(), + int_moire.ravel(), + m_h0.ravel(), + m_k0.ravel(), + m_l0.ravel(), + m_h1.ravel(), + m_k1.ravel(), + m_l1.ravel(), + ] + ) + + return bragg_parent_0, bragg_parent_1, bragg_moire + + +def plot_moire_diffraction_pattern( + bragg_parent_0, + bragg_parent_1, + bragg_moire, + int_range=(0, 5e-3), + k_max=1.0, + plot_subpixel=True, + labels=None, + marker_size_parent=16, + marker_size_moire=4, + text_size_parent=10, + text_size_moire=6, + add_labels_parent=False, + add_labels_moire=False, + dist_labels=0.03, + dist_check=0.06, + sep_labels=0.03, + figsize=(8, 6), + returnfig=False, +): + """ + Plot Moire lattice and parent lattices. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + bragg_moire: BraggVector + Bragg vectors for moire lattice. + int_range: (float, float) + Plotting intensity range for the Moire peaks. + k_max: float + Max k value of the plotted Moire lattice. + plot_subpixel: bool + Apply subpixel corrections to the Bragg spot positions. + Matplotlib default scatter plot rounds to the nearest pixel. + labels: list + List of text labels for parent lattices + marker_size_parent: float + Size of plot markers for the two parent lattices. + marker_size_moire: float + Size of plot markers for the Moire lattice. + text_size_parent: float + Label text size for parent lattice. + text_size_moire: float + Label text size for Moire lattice. + add_labels_parent: bool + Plot the parent lattice index labels. + add_labels_moire: bool + Plot the parent lattice index labels for the Moire spots. + dist_labels: float + Distance to move the labels off the spots. + dist_check: float + Set to some distance to "push" the labels away from each other if they are within this distance. + sep_labels: float + Separation distance for labels which are "pushed" apart. + figsize: (float,float) + Size of output figure. + returnfig: bool + Return the (fix,ax) handles of the plot. + + Returns + -------- + fig, ax: matplotlib handles (optional) + Figure and axes handles for the moire plot. + """ + + # peak labels + + if labels is None: + labels = ("crystal 0", "crystal 1") + + def overline(x): + return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") + + # parent 1 + qx0 = bragg_parent_0["qx"] + qy0 = bragg_parent_0["qy"] + h0 = bragg_parent_0["h"] + k0 = bragg_parent_0["k"] + l0 = bragg_parent_0["l"] + + # parent 2 + qx1 = bragg_parent_1["qx"] + qy1 = bragg_parent_1["qy"] + h1 = bragg_parent_1["h"] + k1 = bragg_parent_1["k"] + l1 = bragg_parent_1["l"] + + # moire + qx = bragg_moire["qx"] + qy = bragg_moire["qy"] + m_h0 = bragg_moire["h0"] + m_k0 = bragg_moire["k0"] + m_l0 = bragg_moire["l0"] + m_h1 = bragg_moire["h1"] + m_k1 = bragg_moire["k1"] + m_l1 = bragg_moire["l1"] + int_moire = bragg_moire["intensity"] + + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0.09, 0.09, 0.65, 0.9]) + ax_labels = fig.add_axes([0.75, 0, 0.25, 1]) + + text_params_parent = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_parent, + } + text_params_moire = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_moire, + } + + if plot_subpixel is False: + # moire + ax.scatter( + qy, + qx, + # color = (0,0,0,1), + c=int_moire, + s=marker_size_moire, + cmap="gray_r", + vmin=int_range[0], + vmax=int_range[1], + antialiased=True, + ) + + # parent lattices + ax.scatter( + qy0, + qx0, + color=(1, 0, 0, 1), + s=marker_size_parent, + antialiased=True, + ) + ax.scatter( + qy1, + qx1, + color=(0, 0.7, 1, 1), + s=marker_size_parent, + antialiased=True, + ) + + # origin + ax.scatter( + 0, + 0, + color=(0, 0, 0, 1), + s=marker_size_parent, + antialiased=True, + ) + + else: + # moire peaks + int_all = np.clip( + (int_moire - int_range[0]) / (int_range[1] - int_range[0]), 0, 1 + ) + keep = np.logical_and.reduce( + (qx >= -k_max, qx <= k_max, qy >= -k_max, qy <= k_max) + ) + for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_moire) / 800.0, + color=(1 - int_marker, 1 - int_marker, 1 - int_marker), + ) + ) + if add_labels_moire: + for a0 in range(qx.size): + if keep.ravel()[a0]: + x0 = qx.ravel()[a0] + y0 = qy.ravel()[a0] + d2 = (qx.ravel() - x0) ** 2 + (qy.ravel() - y0) ** 2 + sub = d2 < dist_check**2 + xc = np.mean(qx.ravel()[sub]) + yc = np.mean(qy.ravel()[sub]) + xp = x0 - xc + yp = y0 - yc + if xp == 0 and yp == 0.0: + xp = x0 - dist_labels + yp = y0 + else: + leng = np.linalg.norm((xp, yp)) + xp = x0 + xp * dist_labels / leng + yp = y0 + yp * dist_labels / leng + + ax.text( + yp, + xp - sep_labels, + "$" + + overline(m_h0.ravel()[a0]) + + overline(m_k0.ravel()[a0]) + + overline(m_l0.ravel()[a0]) + + "$", + c="r", + **text_params_moire, + ) + ax.text( + yp, + xp, + "$" + + overline(m_h1.ravel()[a0]) + + overline(m_k1.ravel()[a0]) + + overline(m_l1.ravel()[a0]) + + "$", + c=(0, 0.7, 1.0), + **text_params_moire, + ) + + keep = np.logical_and.reduce( + (qx0 >= -k_max, qx0 <= k_max, qy0 >= -k_max, qy0 <= k_max) + ) + for x, y in zip(qx0[keep], qy0[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(1, 0, 0), + ) + ) + if add_labels_parent: + for a0 in range(qx0.size): + if keep.ravel()[a0]: + xp = qx0.ravel()[a0] - dist_labels + yp = qy0.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h0.ravel()[a0]) + + overline(k0.ravel()[a0]) + + overline(l0.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + keep = np.logical_and.reduce( + (qx1 >= -k_max, qx1 <= k_max, qy1 >= -k_max, qy1 <= k_max) + ) + for x, y in zip(qx1[keep], qy1[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0.7, 1), + ) + ) + if add_labels_parent: + for a0 in range(qx1.size): + if keep.ravel()[a0]: + xp = qx1.ravel()[a0] - dist_labels + yp = qy1.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h1.ravel()[a0]) + + overline(k1.ravel()[a0]) + + overline(l1.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + # origin + ax.add_artist( + Circle( + xy=(0, 0), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0, 0), + ) + ) + + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + ax.set_ylabel("$q_x$ (1/A)") + ax.set_xlabel("$q_y$ (1/A)") + ax.invert_yaxis() + + # labels + ax_labels.scatter( + 0, + 0, + color=(1, 0, 0, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -1, + color=(0, 0.7, 1, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -2, + color=(0, 0, 0, 1), + s=marker_size_moire, + ) + ax_labels.text( + 0.4, + -0.2, + labels[0], + fontsize=14, + ) + ax_labels.text( + 0.4, + -1.2, + labels[1], + fontsize=14, + ) + ax_labels.text( + 0.4, + -2.2, + "Moiré lattice", + fontsize=14, + ) + + ax_labels.set_xlim((-1, 4)) + ax_labels.set_ylim((-21, 1)) + + ax_labels.axis("off") + + if returnfig: + return fig, ax diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 2346db598..5722f3f38 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1,7 +1,7 @@ import numpy as np import matplotlib.pyplot as plt -import os from typing import Union, Optional +from tqdm import tqdm from emdfile import tqdmnd, PointList, PointListArray from py4DSTEM.data import RealSlice @@ -14,7 +14,7 @@ try: import cupy as cp -except: +except ModuleNotFoundError: cp = None @@ -762,14 +762,34 @@ def match_orientations( self, bragg_peaks_array: PointListArray, num_matches_return: int = 1, - min_number_peaks=3, - inversion_symmetry=True, - multiple_corr_reset=True, - progress_bar: bool = True, + min_angle_between_matches_deg=None, + min_number_peaks: int = 3, + inversion_symmetry: bool = True, + multiple_corr_reset: bool = True, return_orientation: bool = True, + progress_bar: bool = True, ): """ - This function computes the orientation of any number of PointLists stored in a PointListArray, and returns an OrienationMap. + Parameters + -------- + bragg_peaks_array: PointListArray + PointListArray containing the Bragg peaks and intensities, with calibrations applied + num_matches_return: int + return these many matches as 3th dim of orient (matrix) + min_angle_between_matches_deg: int + Minimum angle between zone axis of multiple matches, in degrees. + Note that I haven't thought how to handle in-plane rotations, since multiple matches are possible. + min_number_peaks: int + Minimum number of peaks required to perform ACOM matching + inversion_symmetry: bool + check for inversion symmetry in the matches + multiple_corr_reset: bool + keep original correlation score for multiple matches + return_orientation: bool + Return orientation map from function for inspection. + The map is always stored in the Crystal object. + progress_bar: bool + Show or hide the progress bar """ orientation_map = OrientationMap( @@ -808,6 +828,7 @@ def match_orientations( orientation = self.match_single_pattern( bragg_peaks=vectors, num_matches_return=num_matches_return, + min_angle_between_matches_deg=min_angle_between_matches_deg, min_number_peaks=min_number_peaks, inversion_symmetry=inversion_symmetry, multiple_corr_reset=multiple_corr_reset, @@ -816,6 +837,8 @@ def match_orientations( ) orientation_map.set_orientation(orientation, rx, ry) + + # assign and return self.orientation_map = orientation_map if return_orientation: @@ -828,6 +851,7 @@ def match_single_pattern( self, bragg_peaks: PointList, num_matches_return: int = 1, + min_angle_between_matches_deg=None, min_number_peaks=3, inversion_symmetry=True, multiple_corr_reset=True, @@ -841,23 +865,42 @@ def match_single_pattern( """ Solve for the best fit orientation of a single diffraction pattern. - Args: - bragg_peaks (PointList): numpy array containing the Bragg positions and intensities ('qx', 'qy', 'intensity') - num_matches_return (int): return these many matches as 3th dim of orient (matrix) - min_number_peaks (int): Minimum number of peaks required to perform ACOM matching - inversion_symmetry (bool): check for inversion symmetry in the matches - multiple_corr_reset (bool): keep original correlation score for multiple matches - subpixel_tilt (bool): set to false for faster matching, returning the nearest corr point - plot_polar (bool): set to true to plot the polar transform of the diffraction pattern - plot_corr (bool): set to true to plot the resulting correlogram - returnfig (bool): Return figure handles - figsize (list): size of figure - verbose (bool): Print the fitted zone axes, correlation scores - CUDA (bool): Enable CUDA for the FFT steps - - Returns: - orientation (Orientation): Orientation class containing all outputs - fig, ax (handles): Figure handles for the plotting output + Parameters + -------- + bragg_peaks: PointList + numpy array containing the Bragg positions and intensities ('qx', 'qy', 'intensity') + num_matches_return: int + return these many matches as 3th dim of orient (matrix) + min_angle_between_matches_deg: int + Minimum angle between zone axis of multiple matches, in degrees. + Note that I haven't thought how to handle in-plane rotations, since multiple matches are possible. + min_number_peaks: int + Minimum number of peaks required to perform ACOM matching + inversion_symmetry bool + check for inversion symmetry in the matches + multiple_corr_reset bool + keep original correlation score for multiple matches + subpixel_tilt: bool + set to false for faster matching, returning the nearest corr point + plot_polar: bool + set to true to plot the polar transform of the diffraction pattern + plot_corr: bool + set to true to plot the resulting correlogram + returnfig: bool + return figure handles + figsize: list + size of figure + verbose: bool + Print the fitted zone axes, correlation scores + CUDA: bool + Enable CUDA for the FFT steps + + Returns + -------- + orientation: Orientation + Orientation class containing all outputs + fig, ax: handles + Figure handles for the plotting output """ # init orientation output @@ -1028,6 +1071,25 @@ def match_single_pattern( 0, ) + # If minimum angle is specified and we're on a match later than the first, + # we zero correlation values within the given range. + if min_angle_between_matches_deg is not None: + if match_ind > 0: + inds_previous = orientation.inds[:match_ind, 0] + for a0 in range(inds_previous.size): + mask_zero = np.arccos( + np.clip( + np.sum( + self.orientation_vecs + * self.orientation_vecs[inds_previous[a0], :], + axis=1, + ), + -1, + 1, + ) + ) < np.deg2rad(min_angle_between_matches_deg) + corr_full[mask_zero, :] = 0.0 + # Get maximum (non inverted) correlation value ind_phi = np.argmax(corr_full, axis=1) @@ -1095,6 +1157,26 @@ def match_single_pattern( ), 0, ) + + # If minimum angle is specified and we're on a match later than the first, + # we zero correlation values within the given range. + if min_angle_between_matches_deg is not None: + if match_ind > 0: + inds_previous = orientation.inds[:match_ind, 0] + for a0 in range(inds_previous.size): + mask_zero = np.arccos( + np.clip( + np.sum( + self.orientation_vecs + * self.orientation_vecs[inds_previous[a0], :], + axis=1, + ), + -1, + 1, + ) + ) < np.deg2rad(min_angle_between_matches_deg) + corr_full_inv[mask_zero, :] = 0.0 + ind_phi_inv = np.argmax(corr_full_inv, axis=1) corr_inv = np.zeros(self.orientation_num_zones, dtype="bool") @@ -1682,6 +1764,250 @@ def match_single_pattern( return orientation +def cluster_grains( + self, + threshold_add=1.0, + threshold_grow=0.1, + angle_tolerance_deg=5.0, + progress_bar=True, +): + """ + Cluster grains using rotation criterion, and correlation values. + + Parameters + -------- + threshold_add: float + Minimum signal required for a probe position to initialize a cluster. + threshold_grow: float + Minimum signal required for a probe position to be added to a cluster. + angle_tolerance_deg: float + Rotation rolerance for clustering grains. + progress_bar: bool + Turns on the progress bar for the polar transformation + + """ + + # symmetry operators + sym = self.symmetry_operators + + # Get data + # Correlation data = signal to cluster with + sig = self.orientation_map.corr.copy() + sig_init = sig.copy() + mark = sig >= threshold_grow + sig[np.logical_not(mark)] = 0 + # orientation matrix used for angle tolerance + matrix = self.orientation_map.matrix.copy() + + # init + self.cluster_sizes = np.array((), dtype="int") + self.cluster_sig = np.array(()) + self.cluster_inds = [] + self.cluster_orientation = [] + inds_all = np.zeros_like(sig, dtype="int") + inds_all.ravel()[:] = np.arange(inds_all.size) + + # Tolerance + tol = np.deg2rad(angle_tolerance_deg) + + # Main loop + search = True + comp = 0.0 + mark_total = np.sum(np.max(mark, axis=2)) + pbar = tqdm(total=mark_total, disable=not progress_bar) + while search is True: + inds_grain = np.argmax(sig) + + val = sig.ravel()[inds_grain] + + if val < threshold_add: + search = False + + else: + # Start cluster + x, y, z = np.unravel_index(inds_grain, sig.shape) + mark[x, y, z] = False + sig[x, y, z] = 0 + matrix_cluster = matrix[x, y, z] + orientation_cluster = self.orientation_map.get_orientation_single(x, y, z) + + # Neighbors to search + xr = np.clip(x + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1) + yr = np.clip(y + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1) + inds_cand = inds_all[xr[:, None], yr[None], :].ravel() + inds_cand = np.delete(inds_cand, mark.ravel()[inds_cand] == False) + + if inds_cand.size == 0: + grow = False + else: + grow = True + + # grow the cluster + while grow is True: + inds_new = np.array((), dtype="int") + + keep = np.zeros(inds_cand.size, dtype="bool") + for a0 in range(inds_cand.size): + xc, yc, zc = np.unravel_index(inds_cand[a0], sig.shape) + + # Angle test between orientation matrices + dphi = np.min( + np.arccos( + np.clip( + ( + np.trace( + self.symmetry_operators + @ matrix[xc, yc, zc] + @ np.transpose(matrix_cluster), + axis1=1, + axis2=2, + ) + - 1 + ) + / 2, + -1, + 1, + ) + ) + ) + + if np.abs(dphi) < tol: + keep[a0] = True + + sig[xc, yc, zc] = 0 + mark[xc, yc, zc] = False + + xr = np.clip( + xc + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1 + ) + yr = np.clip( + yc + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1 + ) + inds_add = inds_all[xr[:, None], yr[None], :].ravel() + inds_new = np.append(inds_new, inds_add) + + inds_grain = np.append(inds_grain, inds_cand[keep]) + inds_cand = np.unique( + np.delete(inds_new, mark.ravel()[inds_new] == False) + ) + + if inds_cand.size == 0: + grow = False + + # convert grain to x,y coordinates, add = list + xg, yg, zg = np.unravel_index(inds_grain, sig.shape) + xyg = np.unique(np.vstack((xg, yg)), axis=1) + sig_mean = np.mean(sig_init.ravel()[inds_grain]) + self.cluster_sizes = np.append(self.cluster_sizes, xyg.shape[1]) + self.cluster_sig = np.append(self.cluster_sig, sig_mean) + self.cluster_orientation.append(orientation_cluster) + self.cluster_inds.append(xyg) + + # update progressbar + new_marks = mark_total - np.sum(np.max(mark, axis=2)) + pbar.update(new_marks) + mark_total -= new_marks + + pbar.close() + + +def cluster_orientation_map( + self, + stripe_width=(2, 2), + area_min=2, +): + """ + Produce a new orientation map from the clustered grains. + Use a stripe pattern for the overlapping grains. + + Parameters + -------- + stripe_width: (int,int) + Width of stripes for plotting maps with overlapping grains + area_min: (int) + Minimum size of grains to include + + Returns + -------- + + orientation_map + The clustered orientation map + + """ + + # init + orientation_map = OrientationMap( + num_x=self.orientation_map.num_x, + num_y=self.orientation_map.num_y, + num_matches=1, + ) + im_grain = np.zeros( + (self.orientation_map.num_x, self.orientation_map.num_y), dtype="bool" + ) + im_count = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y)) + im_mark = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y)) + + # Loop over grains to determine number in each pixel + for a0 in range(self.cluster_sizes.shape[0]): + if self.cluster_sizes[a0] >= area_min: + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + im_count += im_grain + im_stripe = im_count >= 2 + im_single = np.logical_not(im_stripe) + + # prefactor for stripes + if stripe_width[0] == 0: + dx = 0 + else: + dx = 1 / stripe_width[0] + if stripe_width[1] == 0: + dy = 0 + else: + dy = 1 / stripe_width[1] + + # loop over grains + for a0 in range(self.cluster_sizes.shape[0]): + if self.cluster_sizes[a0] >= area_min: + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + + # non-overlapping grains + sub = np.logical_and(im_grain, im_single) + x, y = np.unravel_index(np.where(sub.ravel()), im_grain.shape) + x = np.atleast_1d(np.squeeze(x)) + y = np.atleast_1d(np.squeeze(y)) + for a1 in range(x.size): + orientation_map.set_orientation( + self.cluster_orientation[a0], x[a1], y[a1] + ) + + # overlapping grains + sub = np.logical_and(im_grain, im_stripe) + x, y = np.unravel_index(np.where(sub.ravel()), im_grain.shape) + x = np.atleast_1d(np.squeeze(x)) + y = np.atleast_1d(np.squeeze(y)) + for a1 in range(x.size): + d = np.mod( + x[a1] * dx + y[a1] * dy + im_mark[x[a1], y[a1]] + +0.5, + im_count[x[a1], y[a1]], + ) + + if d < 1.0: + orientation_map.set_orientation( + self.cluster_orientation[a0], x[a1], y[a1] + ) + im_mark[x[a1], y[a1]] += 1 + + return orientation_map + + def calculate_strain( self, bragg_peaks_array: PointListArray, diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 8ffd558e9..9f9336155 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -5,6 +5,8 @@ from mpl_toolkits.mplot3d import Axes3D, art3d from scipy.signal import medfilt from scipy.ndimage import gaussian_filter +from scipy.ndimage.morphology import distance_transform_edt +from skimage.morphology import dilation, erosion import warnings import numpy as np @@ -147,7 +149,7 @@ def plot_structure( zs=xyz[sub, 2], # + d[2], s=size_marker, linewidth=2, - color=atomic_colors(ID_plot), + facecolors=atomic_colors(ID_plot), edgecolor=[0, 0, 0], ) @@ -989,7 +991,7 @@ def overline(x): def plot_orientation_maps( self, - orientation_map, + orientation_map=None, orientation_ind: int = 0, dir_in_plane_degrees: float = 0.0, corr_range: np.ndarray = np.array([0, 5]), @@ -1010,6 +1012,7 @@ def plot_orientation_maps( Args: orientation_map (OrientationMap): Class containing orientation matrices, correlation values, etc. + Optional - can reference internally stored OrientationMap. orientation_ind (int): Which orientation match to plot if num_matches > 1 dir_in_plane_degrees (float): In-plane angle to plot in degrees. Default is 0 / x-axis / vertical down. corr_range (np.ndarray): Correlation intensity range for the plot @@ -1037,6 +1040,9 @@ def plot_orientation_maps( """ # Inputs + if orientation_map is None: + orientation_map = self.orientation_map + # Legend size leg_size = np.array([300, 300], dtype="int") @@ -1720,6 +1726,205 @@ def plot_fiber_orientation_maps( return images_orientation +def plot_clusters( + self, + area_min=2, + outline_grains=True, + outline_thickness=1, + fill_grains=0.25, + smooth_grains=1.0, + cmap="viridis", + figsize=(8, 8), + returnfig=False, +): + """ + Plot the clusters as an image. + + Parameters + -------- + area_min: int (optional) + Min cluster size to include, in units of probe positions. + outline_grains: bool (optional) + Set to True to draw grains with outlines + outline_thickness: int (optional) + Thickenss of the grain outline + fill_grains: float (optional) + Outlined grains are filled with this value in pixels. + smooth_grains: float (optional) + Grain boundaries are smoothed by this value in pixels. + figsize: tuple + Size of the figure panel + returnfig: bool + Setting this to true returns the figure and axis handles + + Returns + -------- + fig, ax (optional) + Figure and axes handles + + """ + + # init + im_plot = np.zeros( + ( + self.orientation_map.num_x, + self.orientation_map.num_y, + ) + ) + im_grain = np.zeros( + ( + self.orientation_map.num_x, + self.orientation_map.num_y, + ), + dtype="bool", + ) + + # make plotting image + + for a0 in range(self.cluster_sizes.shape[0]): + if self.cluster_sizes[a0] >= area_min: + if outline_grains: + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + + im_dist = distance_transform_edt( + erosion( + np.invert(im_grain), footprint=np.ones((3, 3), dtype="bool") + ) + ) - distance_transform_edt(im_grain) + im_dist = gaussian_filter(im_dist, sigma=smooth_grains, mode="nearest") + im_add = np.exp(im_dist**2 / (-0.5 * outline_thickness**2)) + + if fill_grains > 0: + im_dist = distance_transform_edt( + erosion( + np.invert(im_grain), footprint=np.ones((3, 3), dtype="bool") + ) + ) + im_dist = gaussian_filter( + im_dist, sigma=smooth_grains, mode="nearest" + ) + im_add += fill_grains * np.exp( + im_dist**2 / (-0.5 * outline_thickness**2) + ) + + # im_add = 1 - np.exp( + # distance_transform_edt(im_grain)**2 \ + # / (-2*outline_thickness**2)) + im_plot += im_add + # im_plot = np.minimum(im_plot, im_add) + else: + # xg,yg = np.unravel_index(self.cluster_inds[a0], im_plot.shape) + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + im_plot += gaussian_filter( + im_grain.astype("float"), sigma=smooth_grains, mode="nearest" + ) + + # im_plot[ + # self.cluster_inds[a0][0,:], + # self.cluster_inds[a0][1,:], + # ] += 1 + + if outline_grains: + im_plot = np.clip(im_plot, 0, 2) + + # plotting + fig, ax = plt.subplots(figsize=figsize) + ax.imshow( + im_plot, + # vmin = -3, + # vmax = 3, + cmap=cmap, + ) + + +def plot_cluster_size( + self, + area_min=None, + area_max=None, + area_step=1, + weight_intensity=False, + pixel_area=1.0, + pixel_area_units="px^2", + figsize=(8, 6), + returnfig=False, +): + """ + Plot the cluster sizes + + Parameters + -------- + area_min: int (optional) + Min area to include in pixels^2 + area_max: int (optional) + Max area bin in pixels^2 + area_step: int (optional) + Step size of the histogram bin in pixels^2 + weight_intensity: bool + Weight histogram by the peak intensity. + pixel_area: float + Size of pixel area unit square + pixel_area_units: string + Units of the pixel area + figsize: tuple + Size of the figure panel + returnfig: bool + Setting this to true returns the figure and axis handles + + Returns + -------- + fig, ax (optional) + Figure and axes handles + + """ + + if area_max is None: + area_max = np.max(self.cluster_sizes) + area = np.arange(0, area_max, area_step) + if area_min is None: + sub = self.cluster_sizes.astype("int") < area_max + else: + sub = np.logical_and( + self.cluster_sizes.astype("int") >= area_min, + self.cluster_sizes.astype("int") < area_max, + ) + if weight_intensity: + hist = np.bincount( + self.cluster_sizes[sub] // area_step, + weights=self.cluster_sig[sub], + minlength=area.shape[0], + ) + else: + hist = np.bincount( + self.cluster_sizes[sub] // area_step, + minlength=area.shape[0], + ) + + # plotting + fig, ax = plt.subplots(figsize=figsize) + ax.bar( + area * pixel_area, + hist, + width=0.8 * pixel_area * area_step, + ) + ax.set_xlim((0, area_max * pixel_area)) + ax.set_xlabel("Grain Area [" + pixel_area_units + "]") + if weight_intensity: + ax.set_ylabel("Total Signal [arb. units]") + else: + ax.set_ylabel("Number of Grains") + + if returnfig: + return fig, ax + + def axisEqual3D(ax): extents = np.array([getattr(ax, "get_{}lim".format(dim))() for dim in "xyz"]) sz = extents[:, 1] - extents[:, 0] diff --git a/py4DSTEM/process/diffraction/utils.py b/py4DSTEM/process/diffraction/utils.py index 09bd09f7c..cfb11f044 100644 --- a/py4DSTEM/process/diffraction/utils.py +++ b/py4DSTEM/process/diffraction/utils.py @@ -67,6 +67,16 @@ def get_orientation(self, ind_x, ind_y): orientation.angles = self.angles[ind_x, ind_y] return orientation + def get_orientation_single(self, ind_x, ind_y, ind_match): + orientation = Orientation(num_matches=1) + orientation.matrix = self.matrix[ind_x, ind_y, ind_match] + orientation.family = self.family[ind_x, ind_y, ind_match] + orientation.corr = self.corr[ind_x, ind_y, ind_match] + orientation.inds = self.inds[ind_x, ind_y, ind_match] + orientation.mirror = self.mirror[ind_x, ind_y, ind_match] + orientation.angles = self.angles[ind_x, ind_y, ind_match] + return orientation + # def __copy__(self): # return OrientationMap(self.name) # def __deepcopy__(self, memo): diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 178079349..1005a619d 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,28 +3,14 @@ _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import ( - MixedstatePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_multislice_ptychography import ( - MultislicePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import ( - OverlapMagneticTomographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_tomography import ( - OverlapTomographicReconstruction, -) +from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction +from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import ( - SimultaneousPtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_singleslice_ptychography import ( - SingleslicePtychographicReconstruction, -) -from py4DSTEM.process.phase.parameter_optimize import ( - OptimizationParameter, - PtychographyOptimizer, -) +from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction +from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index ae4c92d4b..04cfd6a60 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -13,8 +13,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration @@ -56,6 +56,53 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self + def reinitialize_parameters(self, device: str = None, verbose: bool = None): + """ + Reinitializes common parameters. This is useful when loading a previously-saved + reconstruction (which set device='cpu' and verbose=True for compatibility) , + using different initialization parameters. + + Parameters + ---------- + device: str, optional + If not None, imports and assigns appropriate device modules + verbose: bool, optional + If not None, sets the verbosity to verbose + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if device is not None: + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self._device = device + + if verbose is not None: + self._verbose = verbose + + return self + def set_save_defaults( self, save_datacube: bool = False, @@ -278,7 +325,9 @@ def _extract_intensities_and_calibrations_from_datacube( """ # Copies intensities to device casting to float32 - intensities = datacube.data + xp = self._xp + + intensities = xp.asarray(datacube.data, dtype=xp.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -295,13 +344,14 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "real-space calibrations in 'A'" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "real-space calibrations in 'A'" + ), + UserWarning, + ) self._scan_sampling = (1.0, 1.0) self._scan_units = ("pixels",) * 2 @@ -359,13 +409,14 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "appropriate reciprocal-space calibrations" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "appropriate reciprocal-space calibrations" + ), + UserWarning, + ) self._angular_sampling = (1.0, 1.0) self._angular_units = ("pixels",) * 2 @@ -448,8 +499,6 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - intensities = xp.asarray(intensities, dtype=xp.float32) - # for ptycho if com_measured: com_measured_x, com_measured_y = com_measured @@ -484,9 +533,14 @@ def _calculate_intensities_center_of_mass( ) if com_shifts is None: + com_measured_x_np = asnumpy(com_measured_x) + com_measured_y_np = asnumpy(com_measured_y) + finite_mask = np.isfinite(com_measured_x_np) + com_shifts = fit_origin( - (asnumpy(com_measured_x), asnumpy(com_measured_y)), + (com_measured_x_np, com_measured_y_np), fitfunction=fit_function, + mask=finite_mask, ) # Fit function to center of mass @@ -494,12 +548,12 @@ def _calculate_intensities_center_of_mass( com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # fix CoM units - com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[ - 0 - ] - com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[ - 1 - ] + com_normalized_x = ( + xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + ) + com_normalized_y = ( + xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + ) return ( com_measured_x, @@ -1077,6 +1131,7 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, + crop_patterns, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1089,6 +1144,9 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns + when centering Returns ------- @@ -1101,13 +1159,46 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros_like(diffraction_intensities) - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities = self._asnumpy(diffraction_intensities) + if crop_patterns: + crop_x = int( + np.minimum( + diffraction_intensities.shape[2] - com_fitted_x.max(), + com_fitted_x.min(), + ) + ) + crop_y = int( + np.minimum( + diffraction_intensities.shape[3] - com_fitted_y.max(), + com_fitted_y.min(), + ) + ) + + crop_w = np.minimum(crop_y, crop_x) + region_of_interest_shape = (crop_w * 2, crop_w * 2) + amplitudes = np.zeros( + ( + diffraction_intensities.shape[0], + diffraction_intensities.shape[1], + crop_w * 2, + crop_w * 2, + ), + dtype=np.float32, + ) + + crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask[:crop_w, :crop_w] = True + crop_mask[-crop_w:, :crop_w] = True + crop_mask[:crop_w:, -crop_w:] = True + crop_mask[-crop_w:, -crop_w:] = True + self._crop_mask = crop_mask + + else: + region_of_interest_shape = diffraction_intensities.shape[-2:] + amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) - diffraction_intensities = self._asnumpy(diffraction_intensities) - amplitudes = self._asnumpy(amplitudes) for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): @@ -1119,16 +1210,71 @@ def _normalize_diffraction_intensities( device="cpu", ) + if crop_patterns: + intensities = intensities[crop_mask].reshape( + region_of_interest_shape + ) + mean_intensity += np.sum(intensities) amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) - amplitudes = xp.asarray(amplitudes, dtype=xp.float32) - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) + amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity + def show_complex_CoM( + self, + com=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot complex-valued CoM image + + Parameters + ---------- + + com = (CoM_x, CoM_y) tuple + If None is specified, uses (self.com_x, self.com_y) instead + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A + pixelsize: float, optional + default is scan sampling + """ + + if com is None: + com = (self.com_x, self.com_y) + + if pixelsize is None: + pixelsize = self._scan_sampling[0] + if pixelunits is None: + pixelunits = r"$\AA$" + + figsize = kwargs.pop("figsize", (6, 6)) + fig, ax = plt.subplots(figsize=figsize) + + complex_com = com[0] + 1j * com[1] + + show_complex( + complex_com, + cbar=cbar, + figax=(fig, ax), + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): """ @@ -1309,10 +1455,10 @@ def _get_constructor_args(cls, group): "object_type": instance_md["object_type"], "semiangle_cutoff": instance_md["semiangle_cutoff"], "rolloff": instance_md["rolloff"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], "polar_parameters": polar_params, + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } class_specific_kwargs = {} @@ -2109,6 +2255,7 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" figsize = kwargs.pop("figsize", (6, 6)) + chroma_boost = kwargs.pop("chroma_boost", 2) fig, ax = plt.subplots(figsize=figsize) show_complex( @@ -2119,6 +2266,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) @@ -2142,7 +2290,7 @@ def show_object_fft(self, obj=None, **kwargs): vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 4c80ed177..af3cbbb45 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -13,8 +13,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration @@ -195,9 +195,9 @@ def _get_constructor_args(cls, group): "datacube": dc, "initial_object_guess": np.asarray(obj), "energy": instance_md["energy"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs @@ -718,24 +718,26 @@ def reconstruct( xp = self._xp asnumpy = self._asnumpy - if reset is None and hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - self.error_iterations = [] if reset: self.error = np.inf + self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] self.error = getattr(self, "error", np.inf) @@ -770,7 +772,8 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - print(f"Iteration {a0}, step reduced to {self._step_size}") + if self._verbose: + print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -807,10 +810,11 @@ def reconstruct( self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: - warnings.warn( - f"Step-size has decreased below stopping criterion {stopping_criterion}.", - UserWarning, - ) + if self._verbose: + warnings.warn( + f"Step-size has decreased below stopping criterion {stopping_criterion}.", + UserWarning, + ) # crop result self._object_phase = self._padded_object_phase[ @@ -840,7 +844,7 @@ def _visualize_last_iteration( If true, the NMSE error plot is displayed """ - figsize = kwargs.pop("figsize", (8, 8)) + figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") if plot_convergence: @@ -862,7 +866,7 @@ def _visualize_last_iteration( im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC Phase Reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") if cbar: divider = make_axes_locatable(ax1) @@ -870,11 +874,11 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "_error_iterations"): - errors = self._error_iterations + if plot_convergence: + errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -979,7 +983,7 @@ def _visualize_all_iterations( if plot_convergence: ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -990,7 +994,7 @@ def visualize( fig=None, iterations_grid: Tuple[int, int] = None, plot_convergence: bool = True, - cbar: bool = False, + cbar: bool = True, **kwargs, ): """ diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..3eeb07814 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -0,0 +1,3511 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pylops +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex + +try: + import cupy as cp +except ImportError: + cp = None + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate + +warnings.simplefilter(action="always", category=UserWarning) + + +class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._verbose = verbose + self._device = device + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + probe_roi_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + probe_roi_shape, (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + xp = self._xp + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._probe_roi_shape = probe_roi_shape + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + self._intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + ) + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + self.com_x, + self.com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + ( + self._amplitudes, + self._mean_diffraction_intensity, + ) = self._normalize_diffraction_intensities( + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + ) + + # explicitly delete namespace + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + del self._intensities + + self._positions_px = self._calculate_scan_positions_in_pixels( + self._scan_positions + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + if self._object is None: + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) + p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( + "int" + ) + if self._object_type == "potential": + self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) + else: + if self._object_type == "potential": + self._object = xp.asarray(self._object, dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.asarray(self._object, dtype=xp.complex64) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # Vectorized Patches + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Probe Initialization + if self._probe is None or isinstance(self._probe, ComplexProbe): + if self._probe is None: + if self._vacuum_probe_intensity is not None: + self._semiangle_cutoff = np.inf + self._vacuum_probe_intensity = xp.asarray( + self._vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + self._vacuum_probe_intensity, + device=self._device, + ) + self._vacuum_probe_intensity = get_shifted_ar( + self._vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=self._device, + ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + _probe = ( + ComplexProbe( + gpts=self._region_of_interest_shape, + sampling=self.sampling, + energy=self._energy, + semiangle_cutoff=self._semiangle_cutoff, + rolloff=self._rolloff, + vacuum_probe_intensity=self._vacuum_probe_intensity, + parameters=self._polar_parameters, + device=self._device, + ) + .build() + ._array + ) + + else: + if self._probe._gpts != self._region_of_interest_shape: + raise ValueError() + if hasattr(self._probe, "_array"): + _probe = self._probe._array + else: + self._probe._xp = xp + _probe = self._probe.build()._array + + self._probe = xp.zeros( + (self._num_probes,) + tuple(self._region_of_interest_shape), + dtype=xp.complex64, + ) + sx, sy = self._region_of_interest_shape + self._probe[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, self._num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + self._probe[i_probe] = ( + self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) + self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) + + else: + self._probe = xp.asarray(self._probe, dtype=xp.complex64) + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + # overlaps + shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) + probe_intensities = xp.abs(shifted_probes) ** 2 + probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + + if object_fov_mask is None: + self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=2, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=2, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe[0] intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax2, chroma_boost=chroma_boost) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe[0] intensity") + + ax3.imshow( + asnumpy(probe_overlap), + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + num_probe_positions = object_patches.shape[1] + + propagated_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) + propagated_probes[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes = ( + xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes[s + 1] = self._propagate_array( + transmitted_probes, self._propagator_arrays[s] + ) + + return propagated_probes, object_patches, transmitted_probes + + def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + + Returns + -------- + exit_waves:np.ndarray + Exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + intensity_norm[intensity_norm == 0.0] = np.inf + amplitude_modification = amplitudes / intensity_norm + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves + modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_exit_wave - transmitted_probes + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = transmitted_probes.copy() + + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + factor_to_be_projected = ( + projection_c * transmitted_probes + projection_y * exit_waves + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + intensity_norm_projected = xp.sqrt( + xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) + ) + intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + + amplitude_modification = amplitudes / intensity_norm_projected + fourier_projected_factor *= amplitude_modification[:, None] + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * transmitted_probes + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + ( + propagated_probes, + object_patches, + transmitted_probes, + ) = self._overlap_projection(current_object, current_probe) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, transmitted_probes + ) + + return propagated_probes, object_patches, transmitted_probes, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ) + ) + else: + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims( + xp.conj(obj), axis=1 + ) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + current_probe, + transmitted_probes, + amplitudes, + current_positions, + positions_step_size, + constrain_position_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe:np.ndarray + fractionally-shifted probes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + constrain_position_distance: float + Distance to constrain position correction within original + field of view in A + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + + # Intensity gradient + exit_waves_fft = xp.fft.fft2(transmitted_probes) + exit_waves_fft_conj = xp.conj(exit_waves_fft) + estimated_intensity = xp.abs(exit_waves_fft) ** 2 + measured_intensity = amplitudes**2 + + flat_shape = (transmitted_probes.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # Computing perturbed exit waves one at a time to save on memory + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + # dx + obj_rolled_patches = complex_object[ + :, + (self._vectorized_patch_indices_row + 1) % self._object_shape[0], + self._vectorized_patch_indices_col, + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + # dy + obj_rolled_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + (self._vectorized_patch_indices_col + 1) % self._object_shape[1], + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + partial_intensity_dx = 2 * xp.real( + exit_waves_dx_fft * exit_waves_fft_conj + ).reshape(flat_shape) + partial_intensity_dy = 2 * xp.real( + exit_waves_dy_fft * exit_waves_fft_conj + ).reshape(flat_shape) + + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + + # positions_update = xp.einsum( + # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity + # ) + + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + if constrain_position_distance is not None: + constrain_position_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + x1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 0 + ] + y1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 1 + ] + x0 = self._positions_px_initial[:, 0] + y0 = self._positions_px_initial[:, 1] + if self._rotation_best_transpose: + x0, y0 = xp.array([y0, x0]) + x1, y1 = xp.array([y1, x1]) + + if self._rotation_best_rad is not None: + rotation_angle = self._rotation_best_rad + x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( + -rotation_angle + ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) + x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( + -rotation_angle + ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) + + outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( + x1 < (xp.min(x0) - constrain_position_distance) + ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( + y1 < (xp.min(y0) - constrain_position_distance) + ) > 0 + + positions_update[..., 0][outlier_ind] = 0 + + current_positions -= positions_step_size * positions_update[..., 0] + + return current_positions + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + 2D Butterworth filter + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + current_object = xp.pad( + current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" + ) + + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[1:] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + + def _constraints( + self, + current_object, + current_probe, + current_positions, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, + fix_positions, + global_affine_transformation, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + q_lowpass, + q_highpass, + butterworth_order, + kz_regularization_filter, + kz_regularization_gamma, + identical_slices, + object_positivity, + shrinkage_rad, + object_mask, + pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + orthogonalize_probe, + ): + """ + Ptychographic constraints operator. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + current_positions: np.ndarray + Current positions estimate + fix_com: bool + If True, probe CoM is fixed to the center + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool + If True, probe amplitude is constrained by top hat function + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool + If True, probe Fourier amplitude is replaced by initial_probe_aperture + initial_probe_aperture: np.ndarray + Initial probe aperture to use in replacing probe Fourier amplitude + fix_positions: bool + If True, positions are not updated + gaussian_filter: bool + If True, applies real-space gaussian filter in A + gaussian_filter_sigma: float + Standard deviation of gaussian kernel + butterworth_filter: bool + If True, applies fourier-space butterworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool + If True, applies fourier-space arctan regularization filter + kz_regularization_gamma: float + Slice regularization strength + identical_slices: bool + If True, forces all object slices to be identical + object_positivity: bool + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + pure_phase_object: bool + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True, performs TV denoising along z + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + orthogonalize_probe: bool + If True, probe will be orthogonalized + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + constrained_probe: np.ndarray + Constrained probe estimate + constrained_positions: np.ndarray + Constrained positions estimate + """ + + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, kz_regularization_gamma + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + pad_object=tv_denoise_pad_chambolle, + ) + + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # These constraints don't _really_ make sense for mixed-state + if fix_probe_aperture: + raise NotImplementedError() + elif constrain_probe_fourier_amplitude: + raise NotImplementedError() + if fit_probe_aberrations: + raise NotImplementedError() + if constrain_probe_amplitude: + raise NotImplementedError() + + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + max_iter: int = 64, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe_iter: int = 0, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions_iter: int = np.inf, + constrain_position_distance: float = None, + global_affine_transformation: bool = True, + gaussian_filter_sigma: float = None, + gaussian_filter_iter: int = np.inf, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + butterworth_filter_iter: int = np.inf, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter_iter: int = np.inf, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices_iter: int = 0, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + switch_object_iter: int = np.inf, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + max_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe_iter: int, optional + Number of iterations to run with a fixed probe before updating probe estimate + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions_iter: int, optional + Number of iterations to run with fixed positions before updating positions estimate + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter_iter: int, optional + Number of iterations to run using object smoothness constraint + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + butterworth_filter_iter: int, optional + Number of iterations to run using high-pass butteworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter_iter: int, optional + Number of iterations to run using kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices_iter: int, optional + Number of iterations to run using identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object_iter: int, optional + Number of iterations where object amplitude is set to unity + tv_denoise_iter_chambolle: bool + Number of iterations with TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + switch_object_iter: int, optional + Iteration to switch object type between 'complex' and 'potential' or between + 'potential' and 'complex' + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + asnumpy = self._asnumpy + xp = self._xp + + # Reconstruction method + + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) + + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) + + if self._verbose: + if switch_object_iter > max_iter: + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " + else: + switch_object_type = ( + "complex" if self._object_type == "potential" else "potential" + ) + first_line = ( + f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " + f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " + ) + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) + ) + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ) + ) + + else: + if reconstruction_parameter is not None: + if np.array(reconstruction_parameter).shape == (3,): + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ) + ) + else: + if step_size is not None: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ) + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) + + if max_batch_size is not None: + xp.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + if reset: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + self._exit_waves = None + self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] + self._exit_waves = None + + # main loop + for a0 in tqdmnd( + max_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if a0 == switch_object_iter: + if self._object_type == "potential": + self._object_type = "complex" + self._object = xp.exp(1j * self._object) + elif self._object_type == "complex": + self._object_type = "potential" + self._object = xp.angle(self._object) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( + self._num_diffraction_patterns + ) + positions_px = self._positions_px.copy()[shuffled_indices] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[shuffled_indices[start:end]] + + # forward operator + ( + propagated_probes, + object_patches, + self._transmitted_probes, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + self._probe, + amplitudes, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + propagated_probes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=a0 < fix_probe_iter, + ) + + # position correction + if a0 >= fix_positions_iter: + positions_px[start:end] = self._position_correction( + self._object, + self._probe[0], + self._transmitted_probes[:, 0], + amplitudes, + self._positions_px, + positions_step_size, + constrain_position_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._positions_px = positions_px.copy()[unshuffled_indices] + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=a0 < kz_regularization_filter_iter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma, + identical_slices=a0 < identical_slices_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _visualize_last_iteration_figax( + self, + fig, + object_ax, + convergence_ax, + cbar: bool, + padding: int = 0, + **kwargs, + ): + """ + Displays last reconstructed object on a given fig/ax. + + Parameters + -------- + fig: Figure + Matplotlib figure object_ax lives in + object_ax: Axes + Matplotlib axes to plot reconstructed object in + convergence_ax: Axes, optional + Matplotlib axes to plot convergence plot in + cbar: bool, optional + If true, displays a colorbar + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + cmap = kwargs.pop("cmap", "magma") + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + im = object_ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(object_ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if convergence_ax is not None and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = self.error_iterations + + convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + padding: int, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + """ + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe_array = Complex2RGB( + self.probe_fourier[0], chroma_boost=chroma_boost + ) + ax.set_title("Reconstructed Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + self.probe[0], power=2, chroma_boost=chroma_boost + ) + ax.set_title("Reconstructed probe[0] intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = np.array(self.error_iterations) + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + iterations_grid: Tuple[int, int], + padding: int, + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + if iterations_grid == "auto": + num_iter = len(self.error_iterations) + + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + + errors = np.array(self.error_iterations) + + objects = [] + object_type = [] + + for obj in self.object_iterations: + if np.iscomplexobj(obj): + obj = np.angle(obj) + object_type.append("phase") + else: + object_type.append("potential") + objects.append( + self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) + ) + + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + probes = self.probe_iterations + else: + total_grids = np.prod(iterations_grid) + max_iter = len(objects) - 1 + grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + im = ax.imshow( + objects[grid_range[n]], + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = Complex2RGB( + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0] + ) + ), + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], + chroma_boost=chroma_boost, + ) + + if plot_convergence: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + cbar: bool = True, + padding: int = 0, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + return self + + def show_fourier_probe( + self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if probe is None: + probe = list(self.probe_fourier) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + chroma_boost = kwargs.pop("chroma_boost", 2) + + show_complex( + probe if len(probe) > 1 else probe[0], + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + chroma_boost=chroma_boost, + **kwargs, + ) + + def show_transmitted_probe( + self, + plot_fourier_probe: bool = False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + kwargs: + Passed to show_complex + """ + + xp = self._xp + asnumpy = self._asnumpy + + transmitted_probe_intensities = xp.sum( + xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) + ) + min_intensity_transmitted = self._transmitted_probes[ + xp.argmin(transmitted_probe_intensities), 0 + ] + max_intensity_transmitted = self._transmitted_probes[ + xp.argmax(transmitted_probe_intensities), 0 + ] + mean_transmitted = self._transmitted_probes[:, 0].mean(0) + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean Transmitted Probe", + "Min Intensity Transmitted Probe", + "Max Intensity Transmitted Probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy(self._return_fourier_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean Transmitted Fourier Probe", + "Min Intensity Transmitted Fourier Probe", + "Max Intensity Transmitted Fourier Probe", + ] + + title = kwargs.get("title", title) + show_complex( + probes, + title=title, + **kwargs, + ) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + vmin = np.min(rotated_object) if common_color_scale else None + vmax = np.max(rotated_object) if common_color_scale else None + vmin = kwargs.pop("vmin", vmin) + vmax = kwargs.pop("vmax", vmax) + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + + def tune_num_slices_and_thicknesses( + self, + num_slices_guess=None, + thicknesses_guess=None, + num_slices_step_size=1, + thicknesses_step_size=20, + num_slices_values=3, + num_thicknesses_values=3, + update_defocus=False, + max_iter=5, + plot_reconstructions=True, + plot_convergence=True, + return_values=False, + **kwargs, + ): + """ + Run reconstructions over a parameters space of number of slices + and slice thicknesses. Should be run after the preprocess step. + + Parameters + ---------- + num_slices_guess: float, optional + initial starting guess for number of slices, rounds to nearest integer + if None, uses current initialized values + thicknesses_guess: float (A), optional + initial starting guess for thicknesses of slices assuming same + thickness for each slice + if None, uses current initialized values + num_slices_step_size: float, optional + size of change of number of slices for each step in parameter space + thicknesses_step_size: float (A), optional + size of change of slice thicknesses for each step in parameter space + num_slices_values: int, optional + number of number of slice values to test, must be >= 1 + num_thicknesses_values: int,optional + number of thicknesses values to test, must be >= 1 + update_defocus: bool, optional + if True, updates defocus based on estimated total thickness + max_iter: int, optional + number of iterations to run in ptychographic reconstruction + plot_reconstructions: bool, optional + if True, plot phase of reconstructed objects + plot_convergence: bool, optional + if True, plots error for each iteration for each reconstruction + return_values: bool, optional + if True, returns objects, convergence + + Returns + ------- + objects: list + reconstructed objects + convergence: np.ndarray + array of convergence values from reconstructions + """ + + # calculate number of slices and thicknesses values to test + if num_slices_guess is None: + num_slices_guess = self._num_slices + if thicknesses_guess is None: + thicknesses_guess = np.mean(self._slice_thicknesses) + + if num_slices_values == 1: + num_slices_step_size = 0 + + if num_thicknesses_values == 1: + thicknesses_step_size = 0 + + num_slices = np.linspace( + num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_values, + ) + + thicknesses = np.linspace( + thicknesses_guess + - thicknesses_step_size * (num_thicknesses_values - 1) / 2, + thicknesses_guess + + thicknesses_step_size * (num_thicknesses_values - 1) / 2, + num_thicknesses_values, + ) + + if return_values: + convergence = [] + objects = [] + + # current initialized values + current_verbose = self._verbose + current_num_slices = self._num_slices + current_thicknesses = self._slice_thicknesses + current_rotation_deg = self._rotation_best_rad * 180 / np.pi + current_transpose = self._rotation_best_transpose + current_defocus = -self._polar_parameters["C10"] + + # Gridspec to plot on + if plot_reconstructions: + if plot_convergence: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values * 2, + height_ratios=[1, 1 / 4] * num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) + ) + else: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) + ) + + fig = plt.figure(figsize=figsize) + + progress_bar = kwargs.pop("progress_bar", False) + # run loop and plot along the way + self._verbose = False + for flat_index, (slices, thickness) in enumerate( + tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") + ): + slices = int(slices) + self._num_slices = slices + self._slice_thicknesses = np.tile(thickness, slices - 1) + self._probe = None + self._object = None + if update_defocus: + defocus = current_defocus + slices / 2 * thickness + self._polar_parameters["C10"] = -defocus + + self.preprocess( + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + ) + self.reconstruct( + reset=True, + store_iterations=True if plot_convergence else False, + max_iter=max_iter, + progress_bar=progress_bar, + **kwargs, + ) + + if plot_reconstructions: + row_index, col_index = np.unravel_index( + flat_index, (num_slices_values, num_thicknesses_values) + ) + + if plot_convergence: + object_ax = fig.add_subplot(spec[row_index * 2, col_index]) + convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=convergence_ax, + cbar=True, + ) + convergence_ax.yaxis.tick_right() + else: + object_ax = fig.add_subplot(spec[row_index, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=None, + cbar=True, + ) + + object_ax.set_title( + f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" + ) + object_ax.set_xticks([]) + object_ax.set_yticks([]) + + if return_values: + objects.append(self.object) + convergence.append(self.error_iterations.copy()) + + # initialize back to pre-tuning values + self._probe = None + self._object = None + self._num_slices = current_num_slices + self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) + self._polar_parameters["C10"] = -current_defocus + self.preprocess( + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + ) + self._verbose = current_verbose + + if plot_reconstructions: + spec.tight_layout(fig) + + if return_values: + return objects, convergence + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + """ + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + obj = asnumpy(obj) + if np.iscomplexobj(obj): + obj = np.angle(obj) + + obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 56fec1004..2e9fbd076 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -204,6 +204,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -261,6 +262,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -346,9 +349,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -429,6 +430,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( @@ -505,19 +510,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -540,23 +539,19 @@ def preprocess( axs[i].imshow( complex_probe_rgb[i], extent=probe_extent, - **kwargs, ) axs[i].set_ylabel("x [A]") axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial Probe[{i}]") + axs[i].set_title(f"Initial probe[{i}] intensity") divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax, chroma_boost=chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) axs[-1].scatter( self.positions[:, 1], @@ -568,7 +563,7 @@ def preprocess( axs[-1].set_xlabel("y [A]") axs[-1].set_xlim((extent[0], extent[1])) axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object Field of View") + axs[-1].set_title("Object field of view") fig.tight_layout() @@ -1125,6 +1120,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, orthogonalize_probe, object_positivity, shrinkage_rad, @@ -1183,6 +1181,12 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter orthogonalize_probe: bool If True, probe will be orthogonalized + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1213,6 +1217,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1281,6 +1290,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, + constrain_position_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -1290,6 +1300,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1353,6 +1366,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1373,6 +1388,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1575,6 +1596,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1667,6 +1690,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, + constrain_position_distance, ) error += batch_error @@ -1707,6 +1731,9 @@ def reconstruct( q_highpass=q_highpass, butterworth_order=butterworth_order, orthogonalize_probe=orthogonalize_probe, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1821,8 +1848,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1915,29 +1945,31 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert + self.probe_fourier[0], + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert + self.probe[0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe[0]") + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1970,10 +2002,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2045,8 +2077,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2149,29 +2184,30 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2183,7 +2219,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2280,11 +2316,14 @@ def show_fourier_probe( if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 2) + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index a352502d0..4515590fe 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -8,14 +8,15 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -29,6 +30,7 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -78,6 +80,10 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -109,6 +115,8 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, object_type: str = "complex", verbose: bool = True, device: str = "cpu", @@ -189,6 +197,8 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -196,6 +206,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -211,6 +223,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) Returns ------- @@ -230,6 +246,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -237,6 +257,12 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators @@ -279,6 +305,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -336,6 +363,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -421,9 +450,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -503,6 +530,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -559,6 +590,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps @@ -575,19 +608,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -599,10 +626,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -624,38 +649,34 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -667,7 +688,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1449,6 +1470,111 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1481,9 +1607,12 @@ def _constraints( shrinkage_rad, object_mask, pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, tv_denoise, - tv_denoise_weight, - tv_denoise_pad, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1548,12 +1677,19 @@ def _constraints( If not None, used to calculate additional shrinkage using masked-mean of object pure_phase_object: bool If True, object amplitude is set to unity - tv_denoise: bool + tv_denoise_chambolle: bool If True, performs TV denoising along z - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1585,13 +1721,17 @@ def _constraints( current_object, kz_regularization_gamma ) elif tv_denoise: - if self._object_type == "complex": - raise NotImplementedError() + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, - tv_denoise_weight, + tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad, + pad_object=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1690,9 +1830,12 @@ def reconstruct( shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, - tv_denoise_weight=None, - tv_denoise_pad=True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1751,6 +1894,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1785,12 +1930,19 @@ def reconstruct( If true, the potential mean outside the FOV is forced to zero at each iteration pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - tv_denoise_iter: bool + tv_denoise_iter_chambolle: bool Number of iterations with TV denoisining - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -1987,6 +2139,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2123,7 +2277,7 @@ def reconstruct( and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma[a0] if kz_regularization_gamma is not None - and type(kz_regularization_gamma) == np.ndarray + and isinstance(kz_regularization_gamma, np.ndarray) else kz_regularization_gamma, identical_slices=a0 < identical_slices_iter, object_positivity=object_positivity, @@ -2133,9 +2287,13 @@ def reconstruct( else None, pure_phase_object=a0 < pure_phase_object_iter and self._object_type == "complex", - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_pad=tv_denoise_pad, + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2250,8 +2408,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2347,29 +2508,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2402,10 +2563,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2477,8 +2638,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2583,29 +2747,28 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], power=2, chroma_boost=chroma_boost ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2617,7 +2780,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2841,6 +3004,143 @@ def show_slices( spec.tight_layout(fig) + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + def tune_num_slices_and_thicknesses( self, num_slices_guess=None, diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 8691a121d..32b0f6fd4 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize import show @@ -16,8 +17,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -430,6 +431,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -474,6 +476,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -591,9 +595,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -684,6 +686,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -807,19 +813,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -831,10 +831,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -856,38 +854,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -899,7 +896,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1679,6 +1676,111 @@ def _divergence_free_constraint(self, vector_field): return vector_field + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1710,6 +1812,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1771,6 +1876,15 @@ def _constraints( If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1822,6 +1936,31 @@ def _constraints( butterworth_order, ) + elif tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[1] = self._object_denoise_tv_pylops( + current_object[1], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[2] = self._object_denoise_tv_pylops( + current_object[2], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[3] = self._object_denoise_tv_pylops( + current_object[3], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object[0] = self._object_shrinkage_constraint( current_object[0], @@ -1913,6 +2052,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1998,6 +2140,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2171,12 +2322,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2477,6 +2629,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2530,6 +2686,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2929,7 +3088,7 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[-1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() spec.tight_layout(fig) @@ -3133,7 +3292,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index d6bee12fd..66cf46487 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize import show @@ -16,8 +17,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -55,8 +56,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): The electron energy of the wave functions in eV num_slices: int Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of tilt angles in degrees, + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -94,13 +95,14 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") + _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) def __init__( self, energy: float, num_slices: int, - tilt_angles_deg: Sequence[float], + tilt_orientation_matrices: Sequence[np.ndarray], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, @@ -122,22 +124,29 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom + from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from scipy.special import erf self._erf = erf elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom + from cupyx.scipy.ndimage import ( + affine_transform, + gaussian_filter, + rotate, + zoom, + ) self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from cupyx.scipy.special import erf self._erf = erf @@ -156,7 +165,7 @@ def __init__( polar_parameters.update(kwargs) self._set_polar_parameters(polar_parameters) - num_tilts = len(tilt_angles_deg) + num_tilts = len(tilt_orientation_matrices) if initial_scan_positions is None: initial_scan_positions = [None] * num_tilts @@ -185,7 +194,7 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_tilts = num_tilts def _precompute_propagator_arrays( @@ -323,6 +332,29 @@ def _expand_sliced_object(self, array: np.ndarray, output_z): normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + ): + """ """ + + xp = self._xp + affine_transform = self._affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume, tf, offset=offset, order=3) + + return volume + def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -340,6 +372,7 @@ def preprocess( force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -384,6 +417,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -500,9 +535,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -593,6 +626,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -663,15 +700,14 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - current_angle_deg = self._tilt_angles_deg[tilt_index] - probe_overlap_3D = self._rotate( + rot_matrix = self._tilt_orientation_matrices[tilt_index] + + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, + rot_matrix @ old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -691,14 +727,12 @@ def preprocess( ) probe_overlap_3D += probe_overlap[None] + old_rot_matrix = rot_matrix - probe_overlap_3D = self._rotate( - probe_overlap_3D, - -current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, - ) + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( @@ -719,19 +753,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -743,10 +771,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -768,38 +794,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -811,7 +836,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1527,6 +1552,111 @@ def _object_butterworth_constraint( current_object += current_object_mean return xp.real(current_object) + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1555,6 +1685,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1611,6 +1744,13 @@ def _constraints( Phase shift in radians to be subtracted from the potential at each iteration object_mask: np.ndarray (boolean) If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1634,6 +1774,12 @@ def _constraints( q_highpass, butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( @@ -1723,6 +1869,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1806,6 +1955,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -1981,12 +2139,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2018,17 +2177,17 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index tilt_error = 0.0 - self._object = self._rotate( + rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] + self._object = self._rotate_zxy_volume( self._object, - self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix @ old_rot_matrix.T, ) object_sliced = self._project_sliced_object( @@ -2132,23 +2291,13 @@ def reconstruct( ) if collective_tilt_updates: - collective_object += self._rotate( - object_update, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + collective_object += self._rotate_zxy_volume( + object_update, rot_matrix.T ) else: self._object += object_update - self._object = self._rotate( - self._object, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, - ) + old_rot_matrix = rot_matrix # Normalize Error tilt_error /= ( @@ -2203,8 +2352,14 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) + # Normalize Error Over Tilts error /= self._num_tilts @@ -2251,6 +2406,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2431,8 +2589,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) asnumpy = self._asnumpy @@ -2534,16 +2695,19 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2556,7 +2720,10 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg( + ax_cb, + chroma_boost=chroma_boost, + ) else: ax = fig.add_subplot(spec[0]) im = ax.imshow( @@ -2585,10 +2752,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2672,8 +2839,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2788,29 +2958,30 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2822,7 +2993,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -3001,7 +3172,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 80cdd8cd8..74688fa0b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -8,22 +8,44 @@ import matplotlib.pyplot as plt import numpy as np -from emdfile import Custom, tqdmnd +from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec -from py4DSTEM import DataCube +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable +from py4DSTEM import Calibration, DataCube +from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.visualize import show from scipy.linalg import polar +from scipy.optimize import minimize from scipy.special import comb try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np warnings.simplefilter(action="always", category=UserWarning) +_aberration_names = { + (1, 0): "C1 ", + (1, 2): "stig ", + (2, 1): "coma ", + (2, 3): "trefoil ", + (3, 0): "C3 ", + (3, 2): "stig2 ", + (3, 4): "quadfoil ", + (4, 1): "coma2 ", + (4, 3): "trefoil2 ", + (4, 5): "pentafoil ", + (5, 0): "C5 ", + (5, 2): "stig3 ", + (5, 4): "quadfoil2 ", + (5, 6): "hexafoil ", +} + class ParallaxReconstruction(PhaseReconstruction): """ @@ -35,9 +57,6 @@ class ParallaxReconstruction(PhaseReconstruction): Input 4D diffraction pattern intensities energy: float The electron energy of the wave functions in eV - dp_mean: ndarray, optional - Mean diffraction pattern - If None, get_dp_mean() is used verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -73,6 +92,8 @@ def __init__( else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_save_defaults() + # Data self._datacube = datacube @@ -86,9 +107,78 @@ def __init__( def to_h5(self, group): """ Wraps datasets and metadata to write in emdfile classes, - notably ... + notably the (subpixel-)aligned BF. """ - raise NotImplementedError() + # instantiation metadata + self.metadata = Metadata( + name="instantiation_metadata", + data={ + "energy": self._energy, + "verbose": self._verbose, + "device": self._device, + "object_padding_px": self._object_padding_px, + "name": self.name, + }, + ) + + # preprocessing metadata + self.metadata = Metadata( + name="preprocess_metadata", + data={ + "scan_sampling": self._scan_sampling, + "wavelength": self._wavelength, + }, + ) + + # reconstruction metadata + recon_metadata = {"reconstruction_error": float(self._recon_error)} + + if hasattr(self, "aberration_C1"): + recon_metadata |= { + "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_transpose": self.transpose_detected, + "aberration_C1": self.aberration_C1, + "aberration_A1x": self.aberration_A1x, + "aberration_A1y": self.aberration_A1y, + } + + if hasattr(self, "_kde_upsample_factor"): + recon_metadata |= { + "kde_upsample_factor": self._kde_upsample_factor, + } + self._subpixel_aligned_BF_emd = Array( + name="subpixel_aligned_BF", + data=self._asnumpy(self._recon_BF_subpixel_aligned), + ) + + if hasattr(self, "aberration_dict"): + self.metadata = Metadata( + name="aberrations_metadata", + data={ + v["aberration name"]: v["value [Ang]"] + for k, v in self.aberration_dict.items() + }, + ) + + self.metadata = Metadata( + name="reconstruction_metadata", + data=recon_metadata, + ) + + self._aligned_BF_emd = Array( + name="aligned_BF", + data=self._asnumpy(self._recon_BF), + ) + + # datacube + if self._save_datacube: + self.metadata = self._datacube.calibration + Custom.to_h5(self, group) + else: + dc = self._datacube + self._datacube = None + Custom.to_h5(self, group) + self._datacube = dc @classmethod def _get_constructor_args(cls, group): @@ -96,14 +186,68 @@ def _get_constructor_args(cls, group): Returns a dictionary of arguments/values to pass to the class' __init__ function """ - raise NotImplementedError() + # Get data + dict_data = cls._get_emd_attr_data(cls, group) + + # Get metadata dictionaries + instance_md = _read_metadata(group, "instantiation_metadata") + + # Fix calibrations bug + if "_datacube" in dict_data: + calibrations_dict = _read_metadata(group, "calibration")._params + cal = Calibration() + cal._params.update(calibrations_dict) + dc = dict_data["_datacube"] + dc.calibration = cal + else: + dc = None + + # Populate args and return + kwargs = { + "datacube": dc, + "energy": instance_md["energy"], + "object_padding_px": instance_md["object_padding_px"], + "name": instance_md["name"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility + } + + return kwargs def _populate_instance(self, group): """ Sets post-initialization properties, notably some preprocessing meta optional; during read, this method is run after object instantiation. """ - raise NotImplementedError() + + xp = self._xp + + # Preprocess metadata + preprocess_md = _read_metadata(group, "preprocess_metadata") + self._scan_sampling = preprocess_md["scan_sampling"] + self._wavelength = preprocess_md["wavelength"] + + # Reconstruction metadata + reconstruction_md = _read_metadata(group, "reconstruction_metadata") + self._recon_error = reconstruction_md["reconstruction_error"] + + # Data + dict_data = Custom._get_emd_attr_data(Custom, group) + + if "aberration_C1" in reconstruction_md.keys: + self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.transpose_detected = reconstruction_md["aberration_transpose"] + self.aberration_C1 = reconstruction_md["aberration_C1"] + self.aberration_A1x = reconstruction_md["aberration_A1x"] + self.aberration_A1y = reconstruction_md["aberration_A1y"] + + if "kde_upsample_factor" in reconstruction_md.keys: + self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"] + self._recon_BF_subpixel_aligned = xp.asarray( + dict_data["_subpixel_aligned_BF_emd"].data, dtype=xp.float32 + ) + + self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32) def preprocess( self, @@ -111,6 +255,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, + descan_correct: bool = True, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, @@ -133,6 +278,8 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + descan_correct: float, optional + If True, aligns bright field stack based on measured descan rotation_guess: float, optional Initial guess of defocus value in degrees If None, first iteration assumed to be 0 @@ -171,7 +318,10 @@ def preprocess( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities, dtype=xp.float32) + + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) + self._scan_shape = np.array(self._intensities.shape[:2]) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -180,6 +330,45 @@ def preprocess( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct + if descan_correct: + ( + _, + _, + com_fitted_x, + com_fitted_y, + _, + _, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=None, + fit_function="plane", + com_shifts=None, + com_measured=None, + ) + + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + intensities = asnumpy(self._intensities) + intensities_shifted = np.zeros_like(intensities) + + center_x, center_y = self._region_of_interest_shape / 2 + + for rx in range(intensities_shifted.shape[0]): + for ry in range(intensities_shifted.shape[1]): + intensity_shifted = get_shifted_ar( + intensities[rx, ry], + -com_fitted_x[rx, ry] + center_x, + -com_fitted_y[rx, ry] + center_y, + bilinear=True, + device="cpu", + ) + + intensities_shifted[rx, ry] = intensity_shifted + + self._intensities = xp.asarray(intensities_shifted, xp.float32) + self._dp_mean = self._intensities.mean((0, 1)) + # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) self._num_bf_images = int(xp.count_nonzero(self._dp_mask)) @@ -187,14 +376,16 @@ def preprocess( # diffraction space coordinates self._xy_inds = np.argwhere(self._dp_mask) - self._kxy = (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) * xp.array( - self._reciprocal_sampling - )[None] + self._kxy = xp.asarray( + (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) + * xp.array(self._reciprocal_sampling)[None], + dtype=xp.float32, + ) self._probe_angles = self._kxy * self._wavelength self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1)) # Window function - x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1)[1:] + x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:] x -= (x[1] - x[0]) / 2 wx = ( xp.sin( @@ -205,7 +396,7 @@ def preprocess( ) ** 2 ) - y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1)[1:] + y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1, dtype=xp.float32)[1:] y -= (y[1] - y[0]) / 2 wy = ( xp.sin( @@ -222,7 +413,8 @@ def preprocess( ( self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], - ) + ), + dtype=xp.float32, ) self._window_pad[ self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -245,7 +437,8 @@ def preprocess( self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: - self._stack_BF = xp.ones(stack_shape) + self._stack_BF = xp.ones(stack_shape, dtype=xp.float32) + self._stack_BF_no_window = xp.ones(stack_shape, xp.float32) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -259,13 +452,21 @@ def preprocess( self._window_inv[None] + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + elif normalize_order == 1: - x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) - y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) + x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) + y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32) ya, xa = xp.meshgrid(y, x) basis = np.vstack( ( - xp.ones(xa.size), + xp.ones_like(xa), xa.ravel(), ya.ravel(), ) @@ -285,9 +486,18 @@ def preprocess( basis @ coefs[0], all_bfs.shape[1:3] ) + self._stack_BF_no_window[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3]) + else: all_means = xp.mean(all_bfs, axis=(1, 2)) self._stack_BF = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None]) self._stack_BF[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -299,9 +509,21 @@ def preprocess( + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) + qx = xp.asarray(qx, dtype=xp.float32) + qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) + qy = xp.asarray(qy, dtype=xp.float32) + qxa, qya = xp.meshgrid(qx, qy, indexing="ij") self._qx_shift = -2j * xp.pi * qxa self._qy_shift = -2j * xp.pi * qya @@ -336,7 +558,7 @@ def preprocess( del Gs else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2)) + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) self._stack_mean = xp.mean(self._stack_BF) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images @@ -506,8 +728,6 @@ def tune_angle_and_defocus( convergence.append(asnumpy(self._recon_error[0])) if plot_convergence: - from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable - fig, ax = plt.subplots() ax.set_title("convergence") im = ax.imshow( @@ -533,9 +753,9 @@ def tune_angle_and_defocus( divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) + fig.colorbar(im, cax=cax) - plt.tight_layout() + fig.tight_layout() if return_values: convergence = np.array(convergence).reshape( @@ -548,7 +768,7 @@ def reconstruct( max_alignment_bin: int = None, min_alignment_bin: int = 1, max_iter_at_min_bin: int = 2, - upsample_factor: int = 8, + cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = True, running_average: bool = True, @@ -570,7 +790,7 @@ def reconstruct( Minimum bin size for bright field alignment max_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size - upsample_factor: int, optional + cross_correlation_upsample_factor: int, optional DFT upsample factor for subpixel alignment regularizer_matrix_size: Tuple[int,int], optional Bernstein basis degree used for regularizing shifts @@ -623,7 +843,8 @@ def reconstruct( ( self._num_bf_images, (regularizer_matrix_size[0] + 1) * (regularizer_matrix_size[1] + 1), - ) + ), + dtype=xp.float32, ) for ii in np.arange(regularizer_matrix_size[0] + 1): Bi = ( @@ -708,7 +929,7 @@ def reconstruct( # Sort by radial order, from center to outer edge inds_order = xp.argsort(xp.sum(xy_vals**2, axis=1)) - shifts_update = xp.zeros((self._num_bf_images, 2)) + shifts_update = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) for a1 in tqdmnd( xy_vals.shape[0], @@ -730,7 +951,7 @@ def reconstruct( xy_shift = align_images_fourier( G_ref, G, - upsample_factor=upsample_factor, + upsample_factor=cross_correlation_upsample_factor, device=self._device, ) @@ -777,11 +998,19 @@ def reconstruct( self._qx_shift[None] * dx[:, None, None] + self._qy_shift[None] * dy[:, None, None] ) + self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) + self._stack_BF = xp.asarray( + self._stack_BF, dtype=xp.float32 + ) # numpy fft upcasts? + self._stack_mask = xp.asarray( + self._stack_mask, dtype=xp.float32 + ) # numpy fft upcasts? + del Gs # Center the shifts @@ -837,31 +1066,293 @@ def reconstruct( return self + def subpixel_alignment( + self, + kde_upsample_factor=None, + kde_sigma=0.125, + plot_upsampled_BF_comparison: bool = True, + plot_upsampled_FFT_comparison: bool = False, + **kwargs, + ): + """ + Upsample and subpixel-align BFs using the measured image shifts. + Uses kernel density estimation (KDE) to align upsampled BFs. + + Parameters + ---------- + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma: float, optional + KDE gaussian kernel bandwidth + plot_upsampled_BF_comparison: bool, optional + If True, the pre/post alignment BF images are plotted for comparison + plot_upsampled_FFT_comparison: bool, optional + If True, the pre/post alignment BF FFTs are plotted for comparison + + """ + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + xy_shifts = self._xy_shifts + BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + + self._DF_upsample_limit = np.max( + self._region_of_interest_shape / self._scan_shape + ) + self._BF_upsample_limit = ( + 2 * self._kr.max() / self._reciprocal_sampling[0] + ) / self._scan_shape.max() + if self._device == "gpu": + self._BF_upsample_limit = self._BF_upsample_limit.item() + + if kde_upsample_factor is None: + kde_upsample_factor = np.minimum( + self._BF_upsample_limit * 3 / 2, self._DF_upsample_limit + ) + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + + if kde_upsample_factor < 1: + raise ValueError("kde_upsample_factor must be larger than 1") + + if kde_upsample_factor > self._DF_upsample_limit: + warnings.warn( + ( + "Requested upsampling factor exceeds " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f}." + ), + UserWarning, + ) + + self._kde_upsample_factor = kde_upsample_factor + pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int") + pixel_size = pixel_output.prod() + + # shifted coordinates + x = xp.arange(BF_size[0]) + y = xp.arange(BF_size[1]) + + xa, ya = xp.meshgrid(x, y, indexing="ij") + xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel() + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + sigma = kde_sigma * self._kde_upsample_factor + pix_count = gaussian_filter(pix_count, sigma) + pix_count[pix_output == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, sigma) + pix_output /= pix_count + + self._recon_BF_subpixel_aligned = pix_output + self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) + + # plotting + if plot_upsampled_BF_comparison: + if plot_upsampled_FFT_comparison: + figsize = kwargs.pop("figsize", (8, 8)) + fig, axs = plt.subplots(2, 2, figsize=figsize) + else: + figsize = kwargs.pop("figsize", (8, 4)) + fig, axs = plt.subplots(1, 2, figsize=figsize) + + axs = axs.flat + cmap = kwargs.pop("cmap", "magma") + + cropped_object = self._crop_padded_object(self._recon_BF) + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[0].set_title("Aligned Bright Field") + + axs[1].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[1].set_title("Upsampled Bright Field") + + for ax in axs[:2]: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if plot_upsampled_FFT_comparison: + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + pad_x = np.round( + BF_size[0] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_y = np.round( + BF_size[1] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_recon_fft = asnumpy( + xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + ) + + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + + show( + pad_recon_fft, + figax=(fig, axs[2]), + extent=reciprocal_extent, + cmap="gray", + title="Aligned Bright Field FFT", + **kwargs, + ) + + show( + upsampled_fft, + figax=(fig, axs[3]), + extent=reciprocal_extent, + cmap="gray", + title="Upsampled Bright Field FFT", + **kwargs, + ) + + for ax in axs[2:]: + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.xaxis.set_ticks_position("bottom") + + fig.tight_layout() + def aberration_fit( self, - plot_CTF_compare: bool = False, - plot_dk: float = 0.005, - plot_k_sigma: float = 0.02, + fit_BF_shifts: bool = False, + fit_CTF_FFT: bool = False, + fit_aberrations_max_radial_order: int = 3, + fit_aberrations_max_angular_order: int = 4, + fit_aberrations_min_radial_order: int = 2, + fit_aberrations_min_angular_order: int = 0, + fit_max_thon_rings: int = 6, + fit_power_alpha: float = 2.0, + plot_CTF_comparison: bool = None, + plot_BF_shifts_comparison: bool = None, + upsampled: bool = True, + force_transpose: bool = None, ): """ Fit aberrations to the measured image shifts. Parameters ---------- - plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies - plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT - plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by + fit_BF_shifts: bool + Set to True to fit aberrations to the measured BF shifts directly. + fit_CTF_FFT: bool + Set to True to fit aberrations in the FFT of the (upsampled) BF + image. Note that this method relies on visible zero crossings in the FFT. + fit_aberrations_max_radial_order: int + Max radial order for fitting of aberrations. + fit_aberrations_max_angular_order: int + Max angular order for fitting of aberrations. + fit_aberrations_min_radial_order: int + Min radial order for fitting of aberrations. + fit_aberrations_min_angular_order: int + Min angular order for fitting of aberrations. + fit_max_thon_rings: int + Max number of Thon rings to search for during CTF FFT fitting. + fit_power_alpha: int + Power to raise FFT alpha weighting during CTF FFT fitting. + plot_CTF_comparison: bool, optional + If True, the fitted CTF is plotted against the reconstructed frequencies. + plot_BF_shifts_comparison: bool, optional + If True, the measured vs fitted BF shifts are plotted. + upsampled: bool + If True, and upsampled BF is available, uses that for CTF FFT fitting. + force_transpose: bool + If True, and fit_BF_shifts is True, flips the measured x and y shifts """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + + ### First pass # Convert real space shifts to Angstroms - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + if force_transpose is None: + self.transpose_detected = False + else: + self.transpose_detected = force_transpose + + if force_transpose is True: + self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( + self._scan_sampling + ) + else: + self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) # Solve affine transformation m = asnumpy( @@ -879,123 +1370,518 @@ def aberration_fit( ) m_aberration = -1.0 * m_aberration self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = ( - m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 # factor /2 for A1 astigmatism? /4? + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + ### Second pass - # Print results - if self._verbose: - print( - ( - "Rotation of Q w.r.t. R = " - f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" + # Aberration coefs + mn = [] + + for m in range( + fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order + ): + n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) + for n in range(fit_aberrations_min_angular_order, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + self._aberrations_mn = np.array(mn) + self._aberrations_mn = self._aberrations_mn[ + np.argsort(self._aberrations_mn[:, 1]), : + ] + + sub = self._aberrations_mn[:, 1] > 0 + self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ + np.argsort(self._aberrations_mn[sub, 0]), : + ] + self._aberrations_mn[~sub, :] = self._aberrations_mn[~sub, :][ + np.argsort(self._aberrations_mn[~sub, 0]), : + ] + self._aberrations_num = self._aberrations_mn.shape[0] + + if plot_CTF_comparison is None: + if fit_CTF_FFT: + plot_CTF_comparison = True + + if plot_BF_shifts_comparison is None: + if fit_BF_shifts: + plot_BF_shifts_comparison = True + + # Thon Rings Fitting + if fit_CTF_FFT or plot_CTF_comparison: + if upsampled and hasattr(self, "_kde_upsample_factor"): + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + + else: + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + upsampled = False + + # FFT coordinates + qx = xp.fft.fftfreq(im_FFT.shape[0], sx) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha_FFT = xp.sqrt(qr2) * self._wavelength + theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + self._aberrations_basis_FFT = xp.zeros( + (alpha_FFT.size, self._aberrations_num) + ) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) / (m + 1) + ).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) + ).ravel() + else: + # sin coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape_FFT = alpha_FFT.shape + plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 + + # CTF function + def calculate_CTF_FFT(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis_FFT[:, a0] + return xp.reshape(chi, alpha_shape) + + # Direct Shifts Fitting + if fit_BF_shifts: + # FFT coordinates + sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) + sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + u = qx[:, None] * self._wavelength + v = qy[None, :] * self._wavelength + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + + if n == 0: + # Radially symmetric basis + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() + + else: + # sin coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape = alpha.shape + + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:, a0] + return xp.reshape(chi, alpha_shape) + + # initial coefficients and plotting intensity range mask + self._aberrations_coefs = np.zeros(self._aberrations_num) + ind = np.argmin( + np.abs(self._aberrations_mn[:, 0] - 1.0) + self._aberrations_mn[:, 1] + ) + self._aberrations_coefs[ind] = self.aberration_C1 + + # Refinement using CTF fitting / Thon rings + if fit_CTF_FFT: + # scoring function to minimize - mean value of zero crossing regions of FFT + def score_CTF(coefs): + im_CTF = xp.abs( + calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs) + ) + mask = xp.logical_and( + im_CTF > 0.5 * np.pi, + im_CTF < (max_num_rings + 0.5) * np.pi, ) + if np.any(mask): + weights = xp.cos(im_CTF[mask]) ** 4 + return asnumpy( + xp.sum( + weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha + ) + / xp.sum(weights) + ) + else: + return np.inf + + for max_num_rings in range(1, fit_max_thon_rings + 1): + # minimization + res = minimize( + score_CTF, + self._aberrations_coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method="BFGS", + tol=1e-8, + ) + self._aberrations_coefs = res.x + + # Refinement using CTF fitting / Thon rings + elif fit_BF_shifts: + # Gradient basis + corner_indices = self._xy_inds - xp.asarray( + self._region_of_interest_shape // 2 ) - print( + raveled_indices = np.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.vstack( ( - "Astigmatism (A1x,A1y) = (" - f"{self.aberration_A1x:.0f}," - f"{self.aberration_A1y:.0f}) Ang" + self._aberrations_basis_du[raveled_indices, :], + self._aberrations_basis_dv[raveled_indices, :], ) ) - if self.aberration_C1 > 0: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + + # (Relative) untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] + + if force_transpose is None: + # (Relative) transposed fit + transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) + m_T = asnumpy( + xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ + 0 + ] + ) + m_rotation_T, _ = polar(m_T, side="right") + rotation_Q_to_R_rads_T = -1 * np.arctan2( + m_rotation_T[1, 0], m_rotation_T[0, 0] + ) + if np.abs( + np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi + ) > (np.pi * 0.5): + rotation_Q_to_R_rads_T = ( + np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + ) + + tf_T = AffineTransform(angle=rotation_Q_to_R_rads_T) + rotated_shifts_T = tf_T(transposed_shifts, xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq( + gradients, rotated_shifts_T, rcond=None + )[:2] + + # Compare fits + if res_T.sum() < res.sum(): + self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T + self.transpose_detected = not self.transpose_detected + self._aberrations_coefs = asnumpy(aberrations_coefs_T) + self._rotated_shifts = rotated_shifts_T + + warnings.warn( + ( + "Data transpose detected. " + f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" + ), + UserWarning, + ) else: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: - # Get polar mean from FFT of BF reconstruction - im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - k_max = xp.max(kra) / np.sqrt(2.0) - k_num_bins = int(xp.ceil(k_max / plot_dk)) - k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # histogram - k_ind = kra / plot_dk - kf = np.floor(k_ind).astype("int") - dk = k_ind - kf - sub = kf <= k_num_bins - hist_exp = xp.bincount( - kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins + if plot_CTF_comparison: + # Generate FFT plotting image + im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha) + int_vals = np.sort(im_scale.ravel()) + int_range = ( + int_vals[np.round(0.02 * im_scale.size).astype("int")], + int_vals[np.round(0.98 * im_scale.size).astype("int")], ) - hist_norm = xp.bincount( - kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins + int_range = ( + int_range[0], + (int_range[1] - int_range[0]) * 1.0 + int_range[0], + ) + im_scale = np.clip( + (np.fft.fftshift(im_scale) - int_range[0]) + / (int_range[1] - int_range[0]), + 0, + 1, ) - sub = kf <= k_num_bins - 1 + im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) - hist_exp += xp.bincount( - kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins + # Add CTF zero crossings + im_CTF = calculate_CTF_FFT( + self._aberrations_surface_shape_FFT, *self._aberrations_coefs ) - hist_norm += xp.bincount( - kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins + im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 + im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 + im_CTF[xp.logical_not(plot_mask)] = 0 + + im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) + im_plot[:, :, 0] += im_CTF + im_plot[:, :, 1] -= im_CTF + im_plot[:, :, 2] -= im_CTF + im_plot = np.clip(im_plot, 0, 1) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + ax1.imshow(im_plot, vmin=int_range[0], vmax=int_range[1]) + + ax2.imshow(np.fft.fftshift(asnumpy(im_CTF_cos)), cmap="gray") + + fig.tight_layout() + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + if not fit_BF_shifts: + raise ValueError() + + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[: self._xy_inds.shape[0]] - # KDE and normalizing - k_sigma = plot_dk / plot_k_sigma - hist_exp[0] = 0.0 - hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - hist_exp /= hist_norm + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[self._xy_inds.shape[0] :] - # CTF comparison - CTF_fit = xp.sin( - (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 + fitted_shifts = xp.tensordot( + gradients, xp.array(self._aberrations_coefs), axes=1 ) - # plotting input - log scale - min_hist_val = xp.max(hist_exp) * 1e-3 - hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - hist_plot -= xp.min(hist_plot) - hist_plot /= xp.max(hist_plot) + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + : self._xy_inds.shape[0] + ] - hist_plot = asnumpy(hist_plot) - k_bins = asnumpy(k_bins) - CTF_fit = asnumpy(CTF_fit) + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + self._xy_inds.shape[0] : + ] - fig, ax = plt.subplots(figsize=(8, 4)) + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] + ) + ) - ax.fill_between( - k_bins, - hist_plot, - color=(0.7, 0.7, 0.7, 1), + show( + [ + [asnumpy(measured_shifts_sx), asnumpy(measured_shifts_sy)], + [asnumpy(fitted_shifts_sx), asnumpy(fitted_shifts_sy)], + ], + cmap="PiYG", + vmin=-max_shift, + vmax=max_shift, + intensity_range="absolute", + axsize=(4, 4), + ticks=False, + title=[ + "Measured Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Vertical Shifts", + "Fitted Horizontal Shifts", + ], ) - ax.plot( - k_bins, - np.clip(CTF_fit, 0.0, np.inf), - color=(1, 0, 0, 1), - linewidth=2, + self.aberration_dict = { + tuple(self._aberrations_mn[a0]): { + "aberration name": _aberration_names.get( + tuple(self._aberrations_mn[a0, :2]), "-" + ).strip(), + "value [Ang]": self._aberrations_coefs[a0], + } + for a0 in range(self._aberrations_num) + } + + # Print results + if self._verbose: + if fit_CTF_FFT or fit_BF_shifts: + print("Initial Aberration coefficients") + print("-------------------------------") + print( + ( + "Rotation of Q w.r.t. R = " + f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" + ) ) - ax.plot( - k_bins, - np.clip(-CTF_fit, 0.0, np.inf), - color=(0, 0.5, 1, 1), - linewidth=2, + print( + ( + "Astigmatism (A1x,A1y) = (" + f"{self.aberration_A1x:.0f}," + f"{self.aberration_A1y:.0f}) Ang" + ) ) - ax.set_xlim([0, k_bins[-1]]) - ax.set_ylim([0, 1.05]) + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + + if fit_CTF_FFT or fit_BF_shifts: + print() + print("Refined Aberration coefficients") + print("-------------------------------") + print("aberration radial angular dir. coefs") + print("name order order Ang ") + print("---------- ------- ------- ---- -----") + + for a0 in range(self._aberrations_mn.shape[0]): + m, n, a = self._aberrations_mn[a0] + name = _aberration_names.get((m, n), " -- ") + if n == 0: + print( + name + + " " + + str(m + 1) + + " 0 - " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + elif a == 0: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " x " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + else: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " y " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + def _calculate_CTF(self, alpha_shape, sampling, *coefs): + xp = self._xp + + # FFT coordinates + sx, sy = sampling + qx = xp.fft.fftfreq(alpha_shape[0], sx) + qy = xp.fft.fftfreq(alpha_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + + # global scaling + aberrations_basis *= 2 * np.pi / self._wavelength + + chi = xp.zeros_like(aberrations_basis[:, 0]) + + for a0 in range(len(coefs)): + chi += coefs[a0] * aberrations_basis[:, a0] + + return xp.reshape(chi, alpha_shape) def aberration_correct( self, + use_CTF_fit=None, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, + Wiener_signal_noise_ratio: float = 1.0, + Wiener_filter_low_only: bool = False, + upsampled: bool = True, **kwargs, ): """ @@ -1003,6 +1889,9 @@ def aberration_correct( Parameters ---------- + use_FFT_fit: bool + Use the CTF fitted to the zero crossings of the FFT. + Default is True plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional @@ -1028,46 +1917,79 @@ def aberration_correct( ) ) + if upsampled and hasattr(self, "_kde_upsample_factor"): + im = self._recon_BF_subpixel_aligned + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + else: + upsampled = False + im = self._recon_BF + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + # Fourier coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + kx = xp.fft.fftfreq(im.shape[0], sx) + ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + if use_CTF_fit is None: + if hasattr(self, "_aberrations_surface_shape"): + use_CTF_fit = True - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio + if use_CTF_fit: + sin_chi = np.sin( + self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs) ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr - - else: - # CTF without tilt correction (beyond the parallax operator) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr # if needed, add low pass filter output image if k_info_limit is not None: im_fft_corr /= 1 + (kra2**k_info_power) / ( (k_info_limit) ** (2 * k_info_power) ) + else: + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + + (kra2**k_info_power) + / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) @@ -1084,12 +2006,14 @@ def aberration_correct( fig, ax = plt.subplots(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_phase_corrected) + cropped_object = self._crop_padded_object( + self._recon_phase_corrected, upsampled=upsampled + ) extent = [ 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], + sy * cropped_object.shape[1], + sx * cropped_object.shape[0], 0, ] @@ -1246,6 +2170,7 @@ def _crop_padded_object( self, padded_object: np.ndarray, remaining_padding: int = 0, + upsampled: bool = False, ): """ Utility function to crop padded object @@ -1266,8 +2191,19 @@ def _crop_padded_object( asnumpy = self._asnumpy - pad_x = self._object_padding_px[0] // 2 - remaining_padding - pad_y = self._object_padding_px[1] // 2 - remaining_padding + if upsampled: + pad_x = np.round( + self._object_padding_px[0] / 2 * self._kde_upsample_factor + ).astype("int") + pad_y = np.round( + self._object_padding_px[1] / 2 * self._kde_upsample_factor + ).astype("int") + else: + pad_x = self._object_padding_px[0] // 2 + pad_y = self._object_padding_px[1] // 2 + + pad_x -= remaining_padding + pad_y -= remaining_padding return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) @@ -1276,6 +2212,7 @@ def _visualize_figax( fig, ax, remaining_padding: int = 0, + upsampled: bool = False, **kwargs, ): """ @@ -1294,14 +2231,31 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + if upsampled: + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, remaining_padding, upsampled + ) - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + extent = [ + 0, + self._scan_sampling[1] + * cropped_object.shape[1] + / self._kde_upsample_factor, + self._scan_sampling[0] + * cropped_object.shape[0] + / self._kde_upsample_factor, + 0, + ] + + else: + cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] ax.imshow( cropped_object, @@ -1310,7 +2264,7 @@ def _visualize_figax( **kwargs, ) - def _visualize_shifts( + def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 67dba6115..3eebdb068 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,4 +1,7 @@ +import warnings + import numpy as np +import pylops from py4DSTEM.process.phase.utils import ( array_slice, estimate_global_transformation_ransac, @@ -183,6 +186,63 @@ def _object_butterworth_constraint( return current_object + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + def _object_denoise_tv_chambolle( self, current_object, @@ -229,90 +289,100 @@ def _object_denoise_tv_chambolle( Adapted skimage.restoration.denoise_tv_chambolle. """ xp = self._xp - - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) + if xp.iscomplexobj(current_object): + updated_object = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" + current_object_sum = xp.sum(current_object) + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if pad_object: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (1, 1) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() - p = xp.zeros( - (current_object.ndim,) + current_object.shape, dtype=current_object.dtype - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ slice(None), ] * (current_object.ndim + 1) for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E E_previous = E - i += 1 + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] + if pad_object: + for ax in range(len(ndim)): + slices = array_slice(ndim[ax], current_object.ndim, 1, -1) + updated_object = updated_object[slices] + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) - return updated_object / xp.sum(updated_object) * current_object_sum + return updated_object def _probe_center_of_mass_constraint(self, current_probe): """ @@ -363,7 +433,7 @@ def _probe_amplitude_constraint( xp = self._xp erf = self._erf - probe_intensity = xp.abs(current_probe) ** 2 + # probe_intensity = xp.abs(current_probe) ** 2 # current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] @@ -485,10 +555,12 @@ def _probe_aberration_fitting_constraint( fourier_probe = xp.fft.fft2(current_probe) fourier_probe_abs = xp.abs(fourier_probe) sampling = self.sampling + energy = self._energy fitted_angle, _ = fit_aberration_surface( fourier_probe, sampling, + energy, max_angular_order, max_radial_order, xp=xp, diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 8881d021c..37438852f 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -192,6 +192,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -246,6 +247,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -401,9 +404,7 @@ def preprocess( amplitudes_0, mean_diffraction_intensity_0, ) = self._normalize_diffraction_intensities( - intensities_0, - com_fitted_x_0, - com_fitted_y_0, + intensities_0, com_fitted_x_0, com_fitted_y_0, crop_patterns ) # explicitly delete namescapes @@ -484,9 +485,7 @@ def preprocess( amplitudes_1, mean_diffraction_intensity_1, ) = self._normalize_diffraction_intensities( - intensities_1, - com_fitted_x_1, - com_fitted_y_1, + intensities_1, com_fitted_x_1, com_fitted_y_1, crop_patterns ) # explicitly delete namescapes @@ -568,9 +567,7 @@ def preprocess( amplitudes_2, mean_diffraction_intensity_2, ) = self._normalize_diffraction_intensities( - intensities_2, - com_fitted_x_2, - com_fitted_y_2, + intensities_2, com_fitted_x_2, com_fitted_y_2, crop_patterns ) # explicitly delete namescapes @@ -683,6 +680,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -746,19 +747,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -780,23 +775,22 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax2.scatter( self.positions[:, 1], @@ -808,7 +802,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -2232,6 +2226,9 @@ def _constraints( q_highpass_e, q_highpass_m, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, warmup_iteration, object_positivity, shrinkage_rad, @@ -2300,6 +2297,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising warmup_iteration: bool If True, constraints electrostatic object only object_positivity: bool @@ -2349,6 +2352,15 @@ def _constraints( if self._object_type == "complex": magnetic_obj = magnetic_obj.real + if tv_denoise: + electrostatic_obj = self._object_denoise_tv_pylops( + electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) + + if not warmup_iteration: + magnetic_obj = self._object_denoise_tv_pylops( + magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) if shrinkage_rad > 0.0 or object_mask is not None: electrostatic_obj = self._object_shrinkage_constraint( @@ -2446,6 +2458,9 @@ def reconstruct( q_highpass_e: float = None, q_highpass_m: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -2538,6 +2553,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -2748,6 +2769,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = (None,) * self._num_sim_measurements self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2899,6 +2922,9 @@ def reconstruct( q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -3029,8 +3055,6 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj_e = np.angle(self.object[0]) @@ -3052,6 +3076,11 @@ def _visualize_last_iteration( vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + extent = [ 0, self.sampling[1] * rotated_shape[1], @@ -3156,29 +3185,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: # Electrostatic Object @@ -3229,10 +3258,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0480bae8a..5dd19d7bd 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM.datacube import DataCube @@ -188,6 +188,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -245,6 +246,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -330,9 +333,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -412,6 +413,11 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + self._probe = ( ComplexProbe( gpts=self._region_of_interest_shape, @@ -474,19 +480,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -508,23 +508,19 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="gray", ) ax2.scatter( self.positions[:, 1], @@ -536,7 +532,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -1023,6 +1019,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, @@ -1078,6 +1077,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1108,6 +1113,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1198,6 +1208,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1284,6 +1297,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1486,6 +1505,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1618,6 +1639,9 @@ def reconstruct( q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1734,8 +1758,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1828,29 +1855,31 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1883,10 +1912,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -1957,9 +1986,12 @@ def _visualize_all_iterations( else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "inferno") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2063,8 +2095,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2072,21 +2103,23 @@ def _visualize_all_iterations( else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2098,7 +2131,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index c2e1d3b77..d29765d04 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1543,39 +1543,77 @@ def step_model(radius, sig_0, rad_0, width): def aberrations_basis_function( probe_size, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ + + # Add constant phase shift in basis + mn = [[-1, 0, 0]] + + for m in range(1, max_radial_order): + n_max = np.minimum(max_angular_order, m + 1) + for n in range(0, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + aberrations_mn = np.array(mn) + aberrations_mn = aberrations_mn[np.argsort(aberrations_mn[:, 1]), :] + + sub = aberrations_mn[:, 1] > 0 + aberrations_mn[sub, :] = aberrations_mn[sub, :][ + np.argsort(aberrations_mn[sub, 0]), : + ] + aberrations_mn[~sub, :] = aberrations_mn[~sub, :][ + np.argsort(aberrations_mn[~sub, 0]), : + ] + aberrations_num = aberrations_mn.shape[0] + sx, sy = probe_size dx, dy = probe_sampling + wavelength = electron_wavelength_angstrom(energy) + qx = xp.fft.fftfreq(sx, dx) qy = xp.fft.fftfreq(sy, dy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.ones((alpha.size, aberrations_num)) + + # Skip constant to avoid dividing by zero in normalization + for a0 in range(1, aberrations_num): + m, n, a = aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - qxa, qya = xp.meshgrid(qx, qy, indexing="ij") - q2 = qxa**2 + qya**2 - theta = xp.arctan2(qya, qxa) - - basis = [] - index = [] - - for n in range(max_angular_order + 1): - for m in range((max_radial_order - n) // 2 + 1): - basis.append((q2 ** (m + n / 2) * np.cos(n * theta))) - index.append((m, n, 0)) - if n > 0: - basis.append((q2 ** (m + n / 2) * np.sin(n * theta))) - index.append((m, n, 1)) - - basis = xp.array(basis) + # global scaling + aberrations_basis *= 2 * np.pi / wavelength - return basis, index + return aberrations_basis, aberrations_mn def fit_aberration_surface( complex_probe, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, @@ -1592,21 +1630,47 @@ def fit_aberration_surface( unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) - basis, _ = aberrations_basis_function( + raveled_basis, _ = aberrations_basis_function( complex_probe.shape, probe_sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - raveled_basis = basis.reshape((basis.shape[0], -1)) raveled_weights = probe_amp.ravel() - Aw = raveled_basis.T * raveled_weights[:, None] + Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] - fitted_angle = xp.tensordot(coeff, basis, axes=1) + fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) return fitted_angle, coeff + + +def rotate_point(origin, point, angle): + """ + Rotate a point (x1, y1) counterclockwise by a given angle around + a given origin (x0, y0). + + Parameters + -------- + origin: 2-tuple of floats + (x0, y0) + point: 2-tuple of floats + (x1, y1) + angle: float (radians) + + Returns + -------- + rotated points (2-tuple) + + """ + ox, oy = origin + px, py = point + + qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy) + qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy) + return qx, qy diff --git a/py4DSTEM/process/utils/cross_correlate.py b/py4DSTEM/process/utils/cross_correlate.py index f9aac1312..50de91e33 100644 --- a/py4DSTEM/process/utils/cross_correlate.py +++ b/py4DSTEM/process/utils/cross_correlate.py @@ -6,8 +6,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def get_cross_correlation(ar, template, corrPower=1, _returnval="real"): diff --git a/py4DSTEM/process/utils/multicorr.py b/py4DSTEM/process/utils/multicorr.py index 8523c8e62..bc07390bb 100644 --- a/py4DSTEM/process/utils/multicorr.py +++ b/py4DSTEM/process/utils/multicorr.py @@ -15,8 +15,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def upsampled_correlation(imageCorr, upsampleFactor, xyShift, device="cpu"): diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 03d3d07a0..4ef2e1d8a 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -24,8 +24,8 @@ def clear_output(wait=True): try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def radial_reduction(ar, x0, y0, binsize=1, fn=np.mean, coords=None): diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 224f1fb74..141826d55 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.4" +__version__ = "0.14.5" diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 89f09606a..cfa017299 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -1,6 +1,5 @@ from matplotlib import cm, colors as mcolors, pyplot as plt import numpy as np -from matplotlib.colors import hsv_to_rgb from matplotlib.patches import Wedge from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.spatial import Voronoi @@ -17,6 +16,7 @@ ) from py4DSTEM.visualize.vis_grid import show_image_grid from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR +from colorspacious import cspace_convert def show_elliptical_fit( @@ -937,15 +937,20 @@ def show_selected_dps( ) -def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float) : power to raise amplitude to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.5) """ amp = np.abs(complex_data) + phase = np.angle(complex_data) + + if power is not None: + amp = amp**power + if np.isclose(np.max(amp), np.min(amp)): if vmin is None: vmin = 0 @@ -966,36 +971,40 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) + amp = ((amp - vmin) / vmax).clip(1e-16, 1) + + J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff + C = np.minimum(chroma_boost * 98 * J / 123, 110) + h = np.rad2deg(phase) + 180 - phase = np.angle(complex_data) + np.deg2rad(hue_start) - amp /= np.max(amp) - rgb = np.zeros(phase.shape + (3,)) - rgb[..., 0] = 0.5 * (np.sin(phase) + 1) * amp - rgb[..., 1] = 0.5 * (np.sin(phase + np.pi / 2) + 1) * amp - rgb[..., 2] = 0.5 * (-np.sin(phase) + 1) * amp + JCh = np.stack((J, C, h), axis=-1) + rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) - return 1 - rgb if invert else rgb + return rgb -def add_colorbar_arg(cax, vmin=None, vmax=None, hue_start=0, invert=False): +def add_colorbar_arg(cax, chroma_boost=1, c=49, j=61.5): """ - cax : axis to add cbar too - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + cax : axis to add cbar to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.25) + c (float) : constant chroma value + j (float) : constant luminance value """ - z = np.exp(1j * np.linspace(-np.pi, np.pi, 200)) - rgb_vals = Complex2RGB(z, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert) + + h = np.linspace(0, 360, 256, endpoint=False) + J = np.full_like(h, j) + C = np.full_like(h, np.minimum(c * chroma_boost, 110)) + JCh = np.stack((J, C, h), axis=-1) + rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) - cb1 = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) - cb1.set_label("arg", rotation=0, ha="center", va="bottom") - cb1.ax.yaxis.set_label_coords(0.5, 1.01) - cb1.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) - cb1.set_ticklabels( + cb.set_label("arg", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) + cb.set_ticklabels( [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) @@ -1004,13 +1013,13 @@ def show_complex( ar_complex, vmin=None, vmax=None, + power=None, + chroma_boost=1, cbar=True, scalebar=False, pixelunits="pixels", pixelsize=1, returnfig=False, - hue_start=0, - invert=False, **kwargs ): """ @@ -1023,13 +1032,13 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels - cbar (bool, optional) : if True, include color wheel + power (float,optional) : power to raise amplitude to + chroma_boost (float) : boosts chroma for higher-contrast (~1-2.25) + cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - hue_start (float, optional) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -1044,7 +1053,7 @@ def show_complex( if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for sublist in ar_complex for ar in sublist ] @@ -1053,7 +1062,7 @@ def show_complex( else: rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for ar in ar_complex ] if len(rgb[0].shape) == 4: @@ -1064,7 +1073,9 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, hue_start=hue_start, invert=invert) + rgb = Complex2RGB( + ar_complex, vmin, vmax, power=power, chroma_boost=chroma_boost + ) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -1115,37 +1126,18 @@ def show_complex( add_scalebar(ax, scalebar) # add color bar - if cbar == True: - ax0 = fig.add_axes([1, 0.35, 0.3, 0.3]) - - # create wheel - AA = 1000 - kx = np.fft.fftshift(np.fft.fftfreq(AA)) - ky = np.fft.fftshift(np.fft.fftfreq(AA)) - kya, kxa = np.meshgrid(ky, kx) - kra = (kya**2 + kxa**2) ** 0.5 - ktheta = np.arctan2(-kxa, kya) - ktheta = kra * np.exp(1j * ktheta) - - # convert to hsv - rgb = Complex2RGB(ktheta, 0, 0.4, hue_start=hue_start, invert=invert) - ind = kra > 0.4 - rgb[ind] = [1, 1, 1] - - # plot - ax0.imshow(rgb) - - # add axes - ax0.axhline(AA / 2, 0, AA, color="k") - ax0.axvline(AA / 2, 0, AA, color="k") - ax0.axis("off") - - label_size = 16 - - ax0.text(AA, AA / 2, 1, fontsize=label_size) - ax0.text(AA / 2, 0, "i", fontsize=label_size) - ax0.text(AA / 2, AA, "-i", fontsize=label_size) - ax0.text(0, AA / 2, -1, fontsize=label_size) - - if returnfig == True: + if cbar: + if is_grid: + for ax_flat in ax.flatten(): + divider = make_axes_locatable(ax_flat) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb) + else: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + fig.tight_layout() + + if returnfig: return fig, ax diff --git a/setup.py b/setup.py index 069bf1600..fad5913ea 100644 --- a/setup.py +++ b/setup.py @@ -37,9 +37,11 @@ "gdown >= 4.7.1", "dask >= 2.3.0", "distributed >= 2.3.0", - "emdfile >= 0.0.13", + "emdfile >= 0.0.14", "mpire >= 2.7.1", "threadpoolctl >= 3.1.0", + "pylops >= 2.1.0", + "colorspacious >= 1.1.2", ], extras_require={ "ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"], @@ -57,8 +59,8 @@ package_data={ "py4DSTEM": [ "process/utils/scattering_factors.txt", - "process/diskdetection/multicorr_row_kernel.cu", - "process/diskdetection/multicorr_col_kernel.cu", + "braggvectors/multicorr_row_kernel.cu", + "braggvectors/multicorr_col_kernel.cu", ] }, )