Skip to content

Commit

Permalink
moved show_transmitted probe
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 2, 2024
1 parent 14c1e66 commit 4fcc778
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex
from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg

try:
import cupy as cp
Expand Down Expand Up @@ -1944,79 +1944,6 @@ def visualize(
)
return self

def show_transmitted_probe(
self,
plot_fourier_probe: bool = False,
remove_initial_probe_aberrations=False,
**kwargs,
):
"""
Plots the min, max, and mean transmitted probe after propagation and transmission.
Parameters
----------
plot_fourier_probe: boolean, optional
If True, the transmitted probes are also plotted in Fourier space
kwargs:
Passed to show_complex
"""

xp = self._xp
asnumpy = self._asnumpy

transmitted_probe_intensities = xp.sum(
xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1)
)
min_intensity_transmitted = self._transmitted_probes[
xp.argmin(transmitted_probe_intensities), 0
]
max_intensity_transmitted = self._transmitted_probes[
xp.argmax(transmitted_probe_intensities), 0
]
mean_transmitted = self._transmitted_probes[:, 0].mean(0)
probes = [
asnumpy(self._return_centered_probe(probe))
for probe in [
mean_transmitted,
min_intensity_transmitted,
max_intensity_transmitted,
]
]
title = [
"Mean Transmitted Probe",
"Min Intensity Transmitted Probe",
"Max Intensity Transmitted Probe",
]

if plot_fourier_probe:
bottom_row = [
asnumpy(
self._return_fourier_probe(
probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
for probe in [
mean_transmitted,
min_intensity_transmitted,
max_intensity_transmitted,
]
]
probes = [probes, bottom_row]

title += [
"Mean Transmitted Fourier Probe",
"Min Intensity Transmitted Fourier Probe",
"Max Intensity Transmitted Fourier Probe",
]

title = kwargs.get("title", title)
show_complex(
probes,
title=title,
**kwargs,
)

def _return_self_consistency_errors(
self,
max_batch_size=None,
Expand Down
75 changes: 1 addition & 74 deletions py4DSTEM/process/phase/iterative_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex
from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg

try:
import cupy as cp
Expand Down Expand Up @@ -1923,76 +1923,3 @@ def visualize(
**kwargs,
)
return self

def show_transmitted_probe(
self,
plot_fourier_probe: bool = False,
remove_initial_probe_aberrations=False,
**kwargs,
):
"""
Plots the min, max, and mean transmitted probe after propagation and transmission.
Parameters
----------
plot_fourier_probe: boolean, optional
If True, the transmitted probes are also plotted in Fourier space
kwargs:
Passed to show_complex
"""

xp = self._xp
asnumpy = self._asnumpy

transmitted_probe_intensities = xp.sum(
xp.abs(self._transmitted_probes) ** 2, axis=(-2, -1)
)
min_intensity_transmitted = self._transmitted_probes[
xp.argmin(transmitted_probe_intensities)
]
max_intensity_transmitted = self._transmitted_probes[
xp.argmax(transmitted_probe_intensities)
]
mean_transmitted = self._transmitted_probes.mean(0)
probes = [
asnumpy(self._return_centered_probe(probe))
for probe in [
mean_transmitted,
min_intensity_transmitted,
max_intensity_transmitted,
]
]
title = [
"Mean Transmitted Probe",
"Min Intensity Transmitted Probe",
"Max Intensity Transmitted Probe",
]

if plot_fourier_probe:
bottom_row = [
asnumpy(
self._return_fourier_probe(
probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
for probe in [
mean_transmitted,
min_intensity_transmitted,
max_intensity_transmitted,
]
]
probes = [probes, bottom_row]

title += [
"Mean Transmitted Fourier Probe",
"Min Intensity Transmitted Fourier Probe",
"Max Intensity Transmitted Fourier Probe",
]

title = kwargs.get("title", title)
show_complex(
probes,
title=title,
**kwargs,
)
121 changes: 121 additions & 0 deletions py4DSTEM/process/phase/iterative_ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AffineTransform,
ComplexProbe,
fft_shift,
generate_batches,
rotate_point,
spatial_frequencies,
)
Expand Down Expand Up @@ -2106,6 +2107,120 @@ def _projection_sets_adjoint(

return current_object, current_probe

def show_transmitted_probe(
self,
max_batch_size=None,
plot_fourier_probe: bool = False,
remove_initial_probe_aberrations=False,
**kwargs,
):
"""
Plots the min, max, and mean transmitted probe after propagation and transmission.
Parameters
----------
max_batch_size: int, optional
Max number of probes to calculate at once
plot_fourier_probe: boolean, optional
If True, the transmitted probes are also plotted in Fourier space
remove_initial_probe_aberrations: bool, optional
If true, when plotting fourier probe, removes initial probe
kwargs:
Passed to show_complex
"""

xp = self._xp
asnumpy = self._asnumpy

if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

positions_px = self._positions_px.copy()

mean_transmitted = xp.zeros_like(self._probe)
intensities_compare = [np.inf, 0]

for start, end in generate_batches(
self._num_diffraction_patterns, max_batch=max_batch_size
):
# batch indices
self._positions_px = positions_px[start:end]
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()

# overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)

# store relevant arrays
mean_transmitted += overlap.sum(0)

intensities = xp.sum(xp.abs(overlap) ** 2, axis=(-2, -1))
min_intensity = intensities.min()
max_intensity = intensities.max()

if min_intensity < intensities_compare[0]:
min_intensity_transmitted = overlap[xp.argmin(intensities)]
intensities_compare[0] = min_intensity

if max_intensity > intensities_compare[1]:
max_intensity_transmitted = overlap[xp.argmax(intensities)]
intensities_compare[1] = max_intensity

mean_transmitted /= self._num_diffraction_patterns

probes = [
asnumpy(self._return_centered_probe(probe))
for probe in [
mean_transmitted,
min_intensity_transmitted,
max_intensity_transmitted,
]
]
title = [
"Mean Transmitted Probe",
"Min Intensity Transmitted Probe",
"Max Intensity Transmitted Probe",
]

if plot_fourier_probe:
bottom_row = [
asnumpy(
self._return_fourier_probe(
probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
for probe in [
mean_transmitted,
min_intensity_transmitted,
max_intensity_transmitted,
]
]
probes = [probes, bottom_row]

title += [
"Mean Transmitted Fourier Probe",
"Min Intensity Transmitted Fourier Probe",
"Max Intensity Transmitted Fourier Probe",
]

title = kwargs.get("title", title)
ticks = kwargs.get("ticks", False)
axsize = kwargs.get("axsize", (4.5, 4.5))

show_complex(
probes,
title=title,
ticks=ticks,
axsize=axsize,
**kwargs,
)


class ObjectNDProbeMixedMethodsMixin:
"""
Expand Down Expand Up @@ -2726,3 +2841,9 @@ def _projection_sets_adjoint(
)

return current_object, current_probe

def show_transmitted_probe(
self,
**kwargs,
):
raise NotImplementedError()

0 comments on commit 4fcc778

Please sign in to comment.