diff --git a/src/darsia/corrections/shape/curvature.py b/src/darsia/corrections/shape/curvature.py index 2d52d73c..4373d544 100644 --- a/src/darsia/corrections/shape/curvature.py +++ b/src/darsia/corrections/shape/curvature.py @@ -779,8 +779,13 @@ def _transform_image( im_array_as_vector = map_coordinates( in_data, grid, order=self.interpolation_order ) - # Convert to correct shape and data type - corrected_img[:, :, i] = im_array_as_vector.reshape(shape).astype(img.dtype) + # Convert to correct shape and data type (if necessary) + if im_array_as_vector.dtype == img.dtype: + corrected_img[:, :, i] = im_array_as_vector.reshape(shape) + else: + corrected_img[:, :, i] = im_array_as_vector.reshape(shape).astype( + img.dtype + ) return np.squeeze(corrected_img)