Skip to content

Commit

Permalink
cleaned up mixed-multi slice preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 1, 2024
1 parent d8e4d52 commit db8e4a5
Showing 1 changed file with 50 additions and 202 deletions.
252 changes: 50 additions & 202 deletions py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@
generate_batches,
polar_aliases,
polar_symbols,
spatial_frequencies,
)
from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar

warnings.simplefilter(action="always", category=UserWarning)

Expand Down Expand Up @@ -267,92 +265,6 @@ def __init__(
self._theta_x = theta_x
self._theta_y = theta_y

def _precompute_propagator_arrays(
self,
gpts: Tuple[int, int],
sampling: Tuple[float, float],
energy: float,
slice_thicknesses: Sequence[float],
theta_x: float,
theta_y: float,
):
"""
Precomputes propagator arrays complex wave-function will be convolved by,
for all slice thicknesses.
Parameters
----------
gpts: Tuple[int,int]
Wavefunction pixel dimensions
sampling: Tuple[float,float]
Wavefunction sampling in A
energy: float
The electron energy of the wave functions in eV
slice_thicknesses: Sequence[float]
Array of slice thicknesses in A
theta_x: float
x tilt of propagator (in degrees)
theta_y: float
y tilt of propagator (in degrees)
Returns
-------
propagator_arrays: np.ndarray
(T,Sx,Sy) shape array storing propagator arrays
"""
xp = self._xp

# Frequencies
kx, ky = spatial_frequencies(gpts, sampling)
kx = xp.asarray(kx, dtype=xp.float32)
ky = xp.asarray(ky, dtype=xp.float32)

# Propagators
wavelength = electron_wavelength_angstrom(energy)
num_slices = slice_thicknesses.shape[0]
propagators = xp.empty(
(num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64
)

theta_x = np.deg2rad(theta_x)
theta_y = np.deg2rad(theta_y)

for i, dz in enumerate(slice_thicknesses):
propagators[i] = xp.exp(
1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(
1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(
1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x))
)
propagators[i] *= xp.exp(
1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))
)

return propagators

def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray):
"""
Propagates array by Fourier convolving array with propagator_array.
Parameters
----------
array: np.ndarray
Wavefunction array to be convolved
propagator_array: np.ndarray
Propagator array to convolve array with
Returns
-------
propagated_array: np.ndarray
Fourier-convolved array
"""
xp = self._xp

return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array)

def preprocess(
self,
diffraction_intensities_shape: Tuple[int, int] = None,
Expand All @@ -363,7 +275,7 @@ def preprocess(
plot_center_of_mass: str = "default",
plot_rotation: bool = True,
maximize_divergence: bool = False,
rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
rotation_angles_deg: np.ndarray = None,
plot_probe_overlaps: bool = True,
force_com_rotation: float = None,
force_com_transpose: float = None,
Expand Down Expand Up @@ -455,13 +367,10 @@ def preprocess(
)
)

if self._positions_mask is not None and self._positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
)
if self._positions_mask is not None:
self._positions_mask = np.asarray(self._positions_mask, dtype="bool")

# preprocess datacube
(
self._datacube,
self._vacuum_probe_intensity,
Expand All @@ -477,6 +386,7 @@ def preprocess(
com_shifts=force_com_shifts,
)

# calibrations
self._intensities = self._extract_intensities_and_calibrations_from_datacube(
self._datacube,
require_calibrations=True,
Expand All @@ -485,6 +395,13 @@ def preprocess(
force_reciprocal_sampling=force_reciprocal_sampling,
)

# handle semiangle specified in pixels
if self._semiangle_cutoff_pixels:
self._semiangle_cutoff = (
self._semiangle_cutoff_pixels * self._angular_sampling[0]
)

# calculate CoM
(
self._com_measured_x,
self._com_measured_y,
Expand All @@ -499,6 +416,7 @@ def preprocess(
com_shifts=force_com_shifts,
)

# estimate rotation / transpose
(
self._rotation_best_rad,
self._rotation_best_transpose,
Expand All @@ -520,57 +438,47 @@ def preprocess(
**kwargs,
)

# corner-center amplitudes
(
self._amplitudes,
self._mean_diffraction_intensity,
self._crop_mask,
) = self._normalize_diffraction_intensities(
self._intensities,
self._com_fitted_x,
self._com_fitted_y,
crop_patterns,
self._positions_mask,
crop_patterns,
)

# explicitly delete namespace
self._num_diffraction_patterns = self._amplitudes.shape[0]
self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:])
del self._intensities

self._positions_px = self._calculate_scan_positions_in_pixels(
self._scan_positions, self._positions_mask
# initialize probe positions
(
self._positions_px,
self._object_padding_px,
) = self._calculate_scan_positions_in_pixels(
self._scan_positions,
self._positions_mask,
self._object_padding_px,
)

# handle semiangle specified in pixels
if self._semiangle_cutoff_pixels:
self._semiangle_cutoff = (
self._semiangle_cutoff_pixels * self._angular_sampling[0]
)

# Object Initialization
if self._object is None:
pad_x = self._object_padding_px[0][1]
pad_y = self._object_padding_px[1][1]
p, q = np.round(np.max(self._positions_px, axis=0))
p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
"int"
)
q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
"int"
)
if self._object_type == "potential":
self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32)
elif self._object_type == "complex":
self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64)
else:
if self._object_type == "potential":
self._object = xp.asarray(self._object, dtype=xp.float32)
elif self._object_type == "complex":
self._object = xp.asarray(self._object, dtype=xp.complex64)
# initialize object
self._object = self._initialize_object(
self._object,
self._num_slices,
self._positions_px,
self._object_type,
)

self._object_initial = self._object.copy()
self._object_type_initial = self._object_type
self._object_shape = self._object.shape[-2:]

# center probe positions
self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32)
self._positions_px_com = xp.mean(self._positions_px, axis=0)
self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2
Expand All @@ -584,88 +492,25 @@ def preprocess(
self._positions_initial[:, 0] *= self.sampling[0]
self._positions_initial[:, 1] *= self.sampling[1]

# Vectorized Patches
# set vectorized patches
(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()

# Probe Initialization
if self._probe is None or isinstance(self._probe, ComplexProbe):
if self._probe is None:
if self._vacuum_probe_intensity is not None:
self._semiangle_cutoff = np.inf
self._vacuum_probe_intensity = xp.asarray(
self._vacuum_probe_intensity, dtype=xp.float32
)
probe_x0, probe_y0 = get_CoM(
self._vacuum_probe_intensity,
device=self._device,
)
self._vacuum_probe_intensity = get_shifted_ar(
self._vacuum_probe_intensity,
-probe_x0,
-probe_y0,
bilinear=True,
device=self._device,
)
if crop_patterns:
self._vacuum_probe_intensity = self._vacuum_probe_intensity[
self._crop_mask
].reshape(self._region_of_interest_shape)
_probe = (
ComplexProbe(
gpts=self._region_of_interest_shape,
sampling=self.sampling,
energy=self._energy,
semiangle_cutoff=self._semiangle_cutoff,
rolloff=self._rolloff,
vacuum_probe_intensity=self._vacuum_probe_intensity,
parameters=self._polar_parameters,
device=self._device,
)
.build()
._array
)

else:
if self._probe._gpts != self._region_of_interest_shape:
raise ValueError()
if hasattr(self._probe, "_array"):
_probe = self._probe._array
else:
self._probe._xp = xp
_probe = self._probe.build()._array

self._probe = xp.zeros(
(self._num_probes,) + tuple(self._region_of_interest_shape),
dtype=xp.complex64,
)
sx, sy = self._region_of_interest_shape
self._probe[0] = _probe

# Randomly shift phase of other probes
for i_probe in range(1, self._num_probes):
shift_x = xp.exp(
-2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx)
)
shift_y = xp.exp(
-2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy)
)
self._probe[i_probe] = (
self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None]
)

# Normalize probe to match mean diffraction intensity
probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2)
self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity)

else:
self._probe = xp.asarray(self._probe, dtype=xp.complex64)
# initialize probe
self._probe, self._semiangle_cutoff = self._initialize_probe(
self._probe,
self._vacuum_probe_intensity,
self._mean_diffraction_intensity,
self._semiangle_cutoff,
crop_patterns,
)

self._probe_initial = self._probe.copy()
self._probe_initial_aperture = None # Doesn't really make sense for mixed-state
self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe))

# initialize aberrations
self._known_aberrations_array = ComplexProbe(
energy=self._energy,
gpts=self._region_of_interest_shape,
Expand All @@ -674,7 +519,7 @@ def preprocess(
device=self._device,
)._evaluate_ctf()

# Precomputed propagator arrays
# precompute propagator arrays
self._propagator_arrays = self._precompute_propagator_arrays(
self._region_of_interest_shape,
self.sampling,
Expand All @@ -688,22 +533,25 @@ def preprocess(
shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp)
probe_intensities = xp.abs(shifted_probes) ** 2
probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
probe_overlap = self._gaussian_filter(probe_overlap, 1.0)

if object_fov_mask is None:
self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max())
probe_overlap_blurred = self._gaussian_filter(probe_overlap, 1.0)
self._object_fov_mask = asnumpy(
probe_overlap_blurred > 0.25 * probe_overlap_blurred.max()
)
else:
self._object_fov_mask = np.asarray(object_fov_mask)
self._object_fov_mask_inverse = np.invert(self._object_fov_mask)

if plot_probe_overlaps:
figsize = kwargs.pop("figsize", (13, 4))
chroma_boost = kwargs.pop("chroma_boost", 1)
power = kwargs.pop("power", 2)

# initial probe
complex_probe_rgb = Complex2RGB(
self.probe_centered[0],
power=2,
power=power,
chroma_boost=chroma_boost,
)

Expand All @@ -716,7 +564,7 @@ def preprocess(
)
complex_propagated_rgb = Complex2RGB(
asnumpy(self._return_centered_probe(propagated_probe)),
power=2,
power=power,
chroma_boost=chroma_boost,
)

Expand Down

0 comments on commit db8e4a5

Please sign in to comment.