Skip to content

Commit

Permalink
Merge pull request #539 from py4dstem/phase_contrast
Browse files Browse the repository at this point in the history
Double double, aberrations, bugs, and trouble
  • Loading branch information
smribet authored Oct 22, 2023
2 parents f18aaf9 + 43220d0 commit 6416804
Show file tree
Hide file tree
Showing 15 changed files with 6,150 additions and 686 deletions.
30 changes: 8 additions & 22 deletions py4DSTEM/process/phase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,14 @@
_emd_hook = True

from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction
from py4DSTEM.process.phase.iterative_mixedstate_ptychography import (
MixedstatePtychographicReconstruction,
)
from py4DSTEM.process.phase.iterative_multislice_ptychography import (
MultislicePtychographicReconstruction,
)
from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import (
OverlapMagneticTomographicReconstruction,
)
from py4DSTEM.process.phase.iterative_overlap_tomography import (
OverlapTomographicReconstruction,
)
from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction
from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction
from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction
from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction
from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction
from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction
from py4DSTEM.process.phase.iterative_simultaneous_ptychography import (
SimultaneousPtychographicReconstruction,
)
from py4DSTEM.process.phase.iterative_singleslice_ptychography import (
SingleslicePtychographicReconstruction,
)
from py4DSTEM.process.phase.parameter_optimize import (
OptimizationParameter,
PtychographyOptimizer,
)
from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction
from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction
from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer

# fmt: on
214 changes: 181 additions & 33 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,53 @@ def attach_datacube(self, datacube: DataCube):
self._datacube = datacube
return self

def reinitialize_parameters(self, device: str = None, verbose: bool = None):
"""
Reinitializes common parameters. This is useful when loading a previously-saved
reconstruction (which set device='cpu' and verbose=True for compatibility) ,
using different initialization parameters.
Parameters
----------
device: str, optional
If not None, imports and assigns appropriate device modules
verbose: bool, optional
If not None, sets the verbosity to verbose
Returns
--------
self: PhaseReconstruction
Self to enable chaining
"""

if device is not None:
if device == "cpu":
self._xp = np
self._asnumpy = np.asarray
from scipy.ndimage import gaussian_filter

self._gaussian_filter = gaussian_filter
from scipy.special import erf

self._erf = erf
elif device == "gpu":
self._xp = cp
self._asnumpy = cp.asnumpy
from cupyx.scipy.ndimage import gaussian_filter

self._gaussian_filter = gaussian_filter
from cupyx.scipy.special import erf

self._erf = erf
else:
raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
self._device = device

if verbose is not None:
self._verbose = verbose

return self

def set_save_defaults(
self,
save_datacube: bool = False,
Expand Down Expand Up @@ -278,7 +325,9 @@ def _extract_intensities_and_calibrations_from_datacube(
"""

# Copies intensities to device casting to float32
intensities = datacube.data
xp = self._xp

intensities = xp.asarray(datacube.data, dtype=xp.float32)
self._grid_scan_shape = intensities.shape[:2]

# Extracts calibrations
Expand All @@ -295,13 +344,14 @@ def _extract_intensities_and_calibrations_from_datacube(
if require_calibrations:
raise ValueError("Real-space calibrations must be given in 'A'")

warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"real-space calibrations in 'A'"
),
UserWarning,
)
if self._verbose:
warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"real-space calibrations in 'A'"
),
UserWarning,
)

self._scan_sampling = (1.0, 1.0)
self._scan_units = ("pixels",) * 2
Expand Down Expand Up @@ -359,13 +409,14 @@ def _extract_intensities_and_calibrations_from_datacube(
"Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'"
)

warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"appropriate reciprocal-space calibrations"
),
UserWarning,
)
if self._verbose:
warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"appropriate reciprocal-space calibrations"
),
UserWarning,
)

self._angular_sampling = (1.0, 1.0)
self._angular_units = ("pixels",) * 2
Expand Down Expand Up @@ -448,8 +499,6 @@ def _calculate_intensities_center_of_mass(
xp = self._xp
asnumpy = self._asnumpy

intensities = xp.asarray(intensities, dtype=xp.float32)

# for ptycho
if com_measured:
com_measured_x, com_measured_y = com_measured
Expand Down Expand Up @@ -484,22 +533,27 @@ def _calculate_intensities_center_of_mass(
)

if com_shifts is None:
com_measured_x_np = asnumpy(com_measured_x)
com_measured_y_np = asnumpy(com_measured_y)
finite_mask = np.isfinite(com_measured_x_np)

com_shifts = fit_origin(
(asnumpy(com_measured_x), asnumpy(com_measured_y)),
(com_measured_x_np, com_measured_y_np),
fitfunction=fit_function,
mask=finite_mask,
)

# Fit function to center of mass
com_fitted_x = xp.asarray(com_shifts[0], dtype=xp.float32)
com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32)

# fix CoM units
com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[
0
]
com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[
1
]
com_normalized_x = (
xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0]
)
com_normalized_y = (
xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1]
)

return (
com_measured_x,
Expand Down Expand Up @@ -1077,6 +1131,7 @@ def _normalize_diffraction_intensities(
diffraction_intensities,
com_fitted_x,
com_fitted_y,
crop_patterns,
):
"""
Fix diffraction intensities CoM, shift to origin, and take square root
Expand All @@ -1089,6 +1144,9 @@ def _normalize_diffraction_intensities(
Best fit horizontal center of mass gradient
com_fitted_y: (Rx,Ry) xp.ndarray
Best fit vertical center of mass gradient
crop_patterns: bool
if True, crop patterns to avoid wrap around of patterns
when centering
Returns
-------
Expand All @@ -1101,13 +1159,46 @@ def _normalize_diffraction_intensities(
xp = self._xp
mean_intensity = 0

amplitudes = xp.zeros_like(diffraction_intensities)
region_of_interest_shape = diffraction_intensities.shape[-2:]
diffraction_intensities = self._asnumpy(diffraction_intensities)
if crop_patterns:
crop_x = int(
np.minimum(
diffraction_intensities.shape[2] - com_fitted_x.max(),
com_fitted_x.min(),
)
)
crop_y = int(
np.minimum(
diffraction_intensities.shape[3] - com_fitted_y.max(),
com_fitted_y.min(),
)
)

crop_w = np.minimum(crop_y, crop_x)
region_of_interest_shape = (crop_w * 2, crop_w * 2)
amplitudes = np.zeros(
(
diffraction_intensities.shape[0],
diffraction_intensities.shape[1],
crop_w * 2,
crop_w * 2,
),
dtype=np.float32,
)

crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_)
crop_mask[:crop_w, :crop_w] = True
crop_mask[-crop_w:, :crop_w] = True
crop_mask[:crop_w:, -crop_w:] = True
crop_mask[-crop_w:, -crop_w:] = True
self._crop_mask = crop_mask

else:
region_of_interest_shape = diffraction_intensities.shape[-2:]
amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32)

com_fitted_x = self._asnumpy(com_fitted_x)
com_fitted_y = self._asnumpy(com_fitted_y)
diffraction_intensities = self._asnumpy(diffraction_intensities)
amplitudes = self._asnumpy(amplitudes)

for rx in range(diffraction_intensities.shape[0]):
for ry in range(diffraction_intensities.shape[1]):
Expand All @@ -1119,16 +1210,71 @@ def _normalize_diffraction_intensities(
device="cpu",
)

if crop_patterns:
intensities = intensities[crop_mask].reshape(
region_of_interest_shape
)

mean_intensity += np.sum(intensities)
amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0))

amplitudes = xp.asarray(amplitudes, dtype=xp.float32)

amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape)
amplitudes = xp.asarray(amplitudes)
mean_intensity /= amplitudes.shape[0]

return amplitudes, mean_intensity

def show_complex_CoM(
self,
com=None,
cbar=True,
scalebar=True,
pixelsize=None,
pixelunits=None,
**kwargs,
):
"""
Plot complex-valued CoM image
Parameters
----------
com = (CoM_x, CoM_y) tuple
If None is specified, uses (self.com_x, self.com_y) instead
cbar: bool, optional
if True, adds colorbar
scalebar: bool, optional
if True, adds scalebar to probe
pixelunits: str, optional
units for scalebar, default is A
pixelsize: float, optional
default is scan sampling
"""

if com is None:
com = (self.com_x, self.com_y)

if pixelsize is None:
pixelsize = self._scan_sampling[0]
if pixelunits is None:
pixelunits = r"$\AA$"

figsize = kwargs.pop("figsize", (6, 6))
fig, ax = plt.subplots(figsize=figsize)

complex_com = com[0] + 1j * com[1]

show_complex(
complex_com,
cbar=cbar,
figax=(fig, ax),
scalebar=scalebar,
pixelsize=pixelsize,
pixelunits=pixelunits,
ticks=False,
**kwargs,
)


class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints):
"""
Expand Down Expand Up @@ -1309,10 +1455,10 @@ def _get_constructor_args(cls, group):
"object_type": instance_md["object_type"],
"semiangle_cutoff": instance_md["semiangle_cutoff"],
"rolloff": instance_md["rolloff"],
"verbose": instance_md["verbose"],
"name": instance_md["name"],
"device": instance_md["device"],
"polar_parameters": polar_params,
"verbose": True, # for compatibility
"device": "cpu", # for compatibility
}

class_specific_kwargs = {}
Expand Down Expand Up @@ -2109,6 +2255,7 @@ def show_fourier_probe(
pixelunits = r"$\AA^{-1}$"

figsize = kwargs.pop("figsize", (6, 6))
chroma_boost = kwargs.pop("chroma_boost", 2)

fig, ax = plt.subplots(figsize=figsize)
show_complex(
Expand All @@ -2119,6 +2266,7 @@ def show_fourier_probe(
pixelsize=pixelsize,
pixelunits=pixelunits,
ticks=False,
chroma_boost=chroma_boost,
**kwargs,
)

Expand All @@ -2142,7 +2290,7 @@ def show_object_fft(self, obj=None, **kwargs):
vmax = kwargs.pop("vmax", 1)
power = kwargs.pop("power", 0.2)

pixelsize = 1 / (object_fft.shape[0] * self.sampling[0])
pixelsize = 1 / (object_fft.shape[1] * self.sampling[1])
show(
object_fft,
figsize=figsize,
Expand Down
Loading

0 comments on commit 6416804

Please sign in to comment.