Skip to content

Commit

Permalink
added visualize_last functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Jan 2, 2024
1 parent 3345b01 commit 9257242
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 1,043 deletions.
203 changes: 0 additions & 203 deletions py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,204 +1241,6 @@ def reconstruct(

return self

def _visualize_last_iteration(
self,
fig,
cbar: bool,
plot_convergence: bool,
plot_probe: bool,
plot_fourier_probe: bool,
remove_initial_probe_aberrations: bool,
padding: int,
**kwargs,
):
"""
Displays last reconstructed object and probe iterations.
Parameters
--------
fig: Figure
Matplotlib figure to place Gridspec in
plot_convergence: bool, optional
If true, the normalized mean squared error (NMSE) plot is displayed
cbar: bool, optional
If true, displays a colorbar
plot_probe: bool
If true, the reconstructed probe intensity is also displayed
plot_fourier_probe: bool, optional
If true, the reconstructed complex Fourier probe is displayed
remove_initial_probe_aberrations: bool, optional
If true, when plotting fourier probe, removes initial probe
padding : int, optional
Pixels to pad by post rotating-cropping object
"""
figsize = kwargs.pop("figsize", (8, 5))
cmap = kwargs.pop("cmap", "magma")

chroma_boost = kwargs.pop("chroma_boost", 1)

if self._object_type == "complex":
obj = np.angle(self.object)
else:
obj = self.object

rotated_object = self._crop_rotate_object_fov(
np.sum(obj, axis=0), padding=padding
)
rotated_shape = rotated_object.shape

extent = [
0,
self.sampling[1] * rotated_shape[1],
self.sampling[0] * rotated_shape[0],
0,
]

if plot_fourier_probe:
probe_extent = [
-self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
-self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
]
elif plot_probe:
probe_extent = [
0,
self.sampling[1] * self._region_of_interest_shape[1],
self.sampling[0] * self._region_of_interest_shape[0],
0,
]

if plot_convergence:
if plot_probe or plot_fourier_probe:
spec = GridSpec(
ncols=2,
nrows=2,
height_ratios=[4, 1],
hspace=0.15,
width_ratios=[
(extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
1,
],
wspace=0.35,
)
else:
spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
else:
if plot_probe or plot_fourier_probe:
spec = GridSpec(
ncols=2,
nrows=1,
width_ratios=[
(extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
1,
],
wspace=0.35,
)
else:
spec = GridSpec(ncols=1, nrows=1)

if fig is None:
fig = plt.figure(figsize=figsize)

if plot_probe or plot_fourier_probe:
# Object
ax = fig.add_subplot(spec[0, 0])
im = ax.imshow(
rotated_object,
extent=extent,
cmap=cmap,
**kwargs,
)

ax.set_ylabel("x [A]")
ax.set_xlabel("y [A]")
if self._object_type == "potential":
ax.set_title("Reconstructed object potential")
elif self._object_type == "complex":
ax.set_title("Reconstructed object phase")

if cbar:
divider = make_axes_locatable(ax)
ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
fig.add_axes(ax_cb)
fig.colorbar(im, cax=ax_cb)

# Probe
kwargs.pop("vmin", None)
kwargs.pop("vmax", None)

ax = fig.add_subplot(spec[0, 1])
if plot_fourier_probe:
if remove_initial_probe_aberrations:
probe_array = self.probe_fourier_residual[0]
else:
probe_array = self.probe_fourier[0]

probe_array = Complex2RGB(
probe_array,
chroma_boost=chroma_boost,
)

ax.set_title("Reconstructed Fourier probe[0]")
ax.set_ylabel("kx [mrad]")
ax.set_xlabel("ky [mrad]")
else:
probe_array = Complex2RGB(
self.probe[0], power=2, chroma_boost=chroma_boost
)
ax.set_title("Reconstructed probe[0] intensity")
ax.set_ylabel("x [A]")
ax.set_xlabel("y [A]")

im = ax.imshow(
probe_array,
extent=probe_extent,
)

if cbar:
divider = make_axes_locatable(ax)
ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)

else:
ax = fig.add_subplot(spec[0])
im = ax.imshow(
rotated_object,
extent=extent,
cmap=cmap,
**kwargs,
)
ax.set_ylabel("x [A]")
ax.set_xlabel("y [A]")
if self._object_type == "potential":
ax.set_title("Reconstructed object potential")
elif self._object_type == "complex":
ax.set_title("Reconstructed object phase")

if cbar:
divider = make_axes_locatable(ax)
ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
fig.add_axes(ax_cb)
fig.colorbar(im, cax=ax_cb)

if plot_convergence and hasattr(self, "error_iterations"):
kwargs.pop("vmin", None)
kwargs.pop("vmax", None)
errors = np.array(self.error_iterations)
if plot_probe:
ax = fig.add_subplot(spec[1, :])
else:
ax = fig.add_subplot(spec[1])
ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
ax.set_ylabel("NMSE")
ax.set_xlabel("Iteration number")
ax.yaxis.tick_right()

fig.suptitle(f"Normalized mean squared error: {self.error:.3e}")
spec.tight_layout(fig)

def _visualize_all_iterations(
self,
fig,
Expand Down Expand Up @@ -1668,7 +1470,6 @@ def visualize(
plot_fourier_probe: bool = False,
remove_initial_probe_aberrations: bool = False,
cbar: bool = True,
padding: int = 0,
**kwargs,
):
"""
Expand All @@ -1691,8 +1492,6 @@ def visualize(
remove_initial_probe_aberrations: bool, optional
If true, when plotting fourier probe, removes initial probe
to visualize changes
padding : int, optional
Pixels to pad by post rotating-cropping object
Returns
--------
Expand All @@ -1708,7 +1507,6 @@ def visualize(
plot_fourier_probe=plot_fourier_probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
cbar=cbar,
padding=padding,
**kwargs,
)
else:
Expand All @@ -1720,7 +1518,6 @@ def visualize(
plot_fourier_probe=plot_fourier_probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
cbar=cbar,
padding=padding,
**kwargs,
)
return self
Loading

0 comments on commit 9257242

Please sign in to comment.