diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 73021d8a9..497c7ae1c 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1163,6 +1163,12 @@ def _normalize_diffraction_intensities( mean_intensity = 0 diffraction_intensities = self._asnumpy(diffraction_intensities) + if positions_mask is not None: + number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) + sx, sy = np.where(~self._positions_mask) + else: + number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + if crop_patterns: crop_x = int( np.minimum( @@ -1181,8 +1187,7 @@ def _normalize_diffraction_intensities( region_of_interest_shape = (crop_w * 2, crop_w * 2) amplitudes = np.zeros( ( - diffraction_intensities.shape[0], - diffraction_intensities.shape[1], + number_of_patterns, crop_w * 2, crop_w * 2, ), @@ -1198,13 +1203,19 @@ def _normalize_diffraction_intensities( else: region_of_interest_shape = diffraction_intensities.shape[-2:] - amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) + amplitudes = np.zeros( + (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + ) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) + counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): + if positions_mask is not None: + if rx in sx and ry in sy: + continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], -com_fitted_x[rx, ry], @@ -1219,13 +1230,10 @@ def _normalize_diffraction_intensities( ) mean_intensity += np.sum(intensities) - amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) + amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) + counter += 1 - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) amplitudes = xp.asarray(amplitudes) - if positions_mask is not None: - amplitudes = amplitudes[positions_mask.ravel()] - mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 2915acccb..26b0d8cff 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -189,7 +189,7 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 01d70bf71..ebc40928d 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -164,7 +164,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index be24f067d..73f83558e 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -198,7 +198,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 810352ce8..582eea772 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -166,7 +166,7 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 701267e81..f4dfe5022 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -175,7 +175,7 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index ae1a3ecac..866ff0a89 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -153,7 +153,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index ab16330da..350d0a3cb 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -150,7 +150,7 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning,