Skip to content

Commit

Permalink
added transpose and upsampling options
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Aug 11, 2024
1 parent e85ac13 commit 9288d1a
Showing 1 changed file with 74 additions and 23 deletions.
97 changes: 74 additions & 23 deletions py4DSTEM/process/phase/parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,14 @@ def preprocess(

return self

def guess_common_aberrations_and_rotation(
def guess_common_aberrations(
self,
rotation_angle_deg=0,
transpose=False,
kde_upsample_factor=None,
kde_sigma_px=0.125,
kde_lowpass_filter=False,
lanczos_interpolation_order=None,
defocus=0,
astigmatism=0,
astigmatism_angle_deg=0,
Expand Down Expand Up @@ -825,6 +830,10 @@ def guess_common_aberrations_and_rotation(
]
)

# transpose rotation matrix
if transpose:
rotation_angle_deg *= -1

# aberrations_basis
sampling = 1 / (
np.array(self._reciprocal_sampling) * self._region_of_interest_shape
Expand Down Expand Up @@ -852,33 +861,75 @@ def guess_common_aberrations_and_rotation(
)
)
shifts_ang = xp.tensordot(gradients, aberrations_coefs, axes=1).T

# transpose predicted shifts
if transpose:
shifts_ang = xp.flip(shifts_ang, axis=1)

shifts_px = shifts_ang / xp.array(self._scan_sampling)

# shifted stack
aligned_stack = xp.zeros_like(self._stack_BF_shifted_initial[0])
if max_batch_size is None:
max_batch_size = self._num_bf_images
# upsampled stack
if kde_upsample_factor is not None:
BF_size = np.array(self._stack_BF_unshifted.shape[-2:])
pixel_output_shape = np.round(BF_size * kde_upsample_factor).astype("int")

for start, end in generate_batches(
self._num_bf_images, max_batch=max_batch_size
):
shifted_BFs = self._stack_BF_shifted_initial[start:end]
x = xp.arange(BF_size[0], dtype=xp.float32)
y = xp.arange(BF_size[1], dtype=xp.float32)
xa_init, ya_init = xp.meshgrid(x, y, indexing="ij")

Gs = xp.fft.fft2(shifted_BFs)
# kernel density output the upsampled BF image
xa = (xa_init + shifts_px[:, 0, None, None]) * kde_upsample_factor
ya = (ya_init + shifts_px[:, 1, None, None]) * kde_upsample_factor

dx = shifts_px[start:end, 0]
dy = shifts_px[start:end, 1]
pix_output = self._kernel_density_estimate(
xa,
ya,
self._stack_BF_unshifted,
pixel_output_shape,
kde_sigma_px * kde_upsample_factor,
lanczos_alpha=lanczos_interpolation_order,
lowpass_filter=kde_lowpass_filter,
)

shift_op = xp.exp(
self._qx_shift[None] * dx[:, None, None]
+ self._qy_shift[None] * dy[:, None, None]
# hack since cropping requires "_kde_upsample_factor"
old_upsample_factor = getattr(self, "_kde_upsample_factor", None)
self._kde_upsample_factor = kde_upsample_factor
cropped_image = asnumpy(
self._crop_padded_object(pix_output, upsampled=True)
)
stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op))
aligned_stack += stack_BF_shifted.sum(0)
if old_upsample_factor is not None:
self._kde_upsample_factor = old_upsample_factor
else:
del self._kde_upsample_factor

cropped_stack = asnumpy(
self._crop_padded_object(aligned_stack, upsampled=False)
)
# shifted stack
else:
kde_upsample_factor = 1
aligned_stack = xp.zeros_like(self._stack_BF_shifted_initial[0])

if max_batch_size is None:
max_batch_size = self._num_bf_images

for start, end in generate_batches(
self._num_bf_images, max_batch=max_batch_size
):
shifted_BFs = self._stack_BF_shifted_initial[start:end]

Gs = xp.fft.fft2(shifted_BFs)

dx = shifts_px[start:end, 0]
dy = shifts_px[start:end, 1]

shift_op = xp.exp(
self._qx_shift[None] * dx[:, None, None]
+ self._qy_shift[None] * dy[:, None, None]
)
stack_BF_shifted = xp.real(xp.fft.ifft2(Gs * shift_op))
aligned_stack += stack_BF_shifted.sum(0)

cropped_image = asnumpy(
self._crop_padded_object(aligned_stack, upsampled=False)
)

figsize = kwargs.pop("figsize", (8, 4))
color = kwargs.pop("color", (1, 0, 0, 1))
Expand All @@ -899,12 +950,12 @@ def guess_common_aberrations_and_rotation(

extent = [
0,
self._scan_sampling[1] * cropped_stack.shape[1],
self._scan_sampling[0] * cropped_stack.shape[0],
self._scan_sampling[1] * cropped_image.shape[1] / kde_upsample_factor,
self._scan_sampling[0] * cropped_image.shape[0] / kde_upsample_factor,
0,
]

axs[1].imshow(cropped_stack, cmap=cmap, extent=extent, **kwargs)
axs[1].imshow(cropped_image, cmap=cmap, extent=extent, **kwargs)
axs[1].set_ylabel("x [A]")
axs[1].set_xlabel("y [A]")
axs[1].set_title("Predicted Aligned BF Image")
Expand Down

0 comments on commit 9288d1a

Please sign in to comment.