diff --git a/pyproject.toml b/pyproject.toml index 810cd7a..d80e83d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ torchvision = {version = ">=0.19.0", optional = true} packaging = {version = ">=24.1", optional = true} pycolmap = {version = ">=3.10.0", optional = true} + [tool.poetry.group.dev.dependencies] pytest = "^7.2.0" pytest-cov = "^4.0.0" diff --git a/satalign/ecc.py b/satalign/ecc.py index d7e8819..1ace6f8 100644 --- a/satalign/ecc.py +++ b/satalign/ecc.py @@ -83,4 +83,4 @@ def find_warp( warnings.warn(f"Could not calculate the warp matrix: {cv2err}") warp_matrix = np.eye(*self.warp_matrix_size, dtype=np.float32) self.warning_status = True - return warp_matrix + return warp_matrix, None diff --git a/satalign/lgm.py b/satalign/lgm.py index 26f70fc..a3670b7 100644 --- a/satalign/lgm.py +++ b/satalign/lgm.py @@ -123,7 +123,7 @@ def find_warp( warp_matrix = np.eye(*self.warp_matrix_size, dtype=np.float32) warp_matrix[:2, 2] = [translation_x.item(), translation_y.item()] - return warp_matrix + return warp_matrix, None def spatial_setup_model( self, diff --git a/satalign/main.py b/satalign/main.py index 072968d..fbfc6db 100644 --- a/satalign/main.py +++ b/satalign/main.py @@ -82,6 +82,7 @@ def __init__( # Set the class attributes (REQUIRED) self.datacube = datacube self.reference = reference + self.reference_np = self.reference if isinstance(self.reference, np.ndarray) else np.array(self.reference.values) # Set the class attributes (OPTIONAL) self.channel = channel @@ -103,7 +104,7 @@ def find_warp( self, reference_image: np.ndarray, moving_image: np.ndarray, - ) -> np.ndarray: + ) -> Union[np.ndarray, float]: """ Find the warp matrix that aligns the source and destination image. @@ -230,7 +231,7 @@ def get_warped_image( """ # Obtain the warp matrix - warp_matrix = self.find_warp( + warp_matrix, warp_error = self.find_warp( reference_image=reference_image_feature, moving_image=self.create_layer(moving_image), ) @@ -244,14 +245,14 @@ def get_warped_image( # Warp the image using the estimated warp matrix warped_image = self.warp_feature(img=moving_image, warp_matrix=warp_matrix) - return warped_image, warp_matrix + return warped_image, warp_matrix, warp_error def run_xarray(self) -> xr.Dataset: """ Run sequantially the get_warped_image method; input is xarray """ # Create the reference feature using the reference image - reference_layer = self.create_layer(self.reference.values) + reference_layer = self.create_layer(self.reference_np) # Run iteratively the get_warped_image method warp_matrices = [] @@ -283,7 +284,7 @@ def run_numpy(self) -> np.ndarray: """ # Create the reference feature using the reference image - reference_layer = self.create_layer(self.reference) + reference_layer = self.create_layer(self.reference_np) # Run iteratively the get_warped_image method warp_matrices = [] @@ -308,7 +309,7 @@ def run_multicore_numpy(self) -> np.ndarray: """ # Create the reference feature using the reference image - reference_layer = self.create_layer(self.reference) + reference_layer = self.create_layer(self.reference_np) # Create the executor with concurrent.futures.ThreadPoolExecutor( @@ -327,13 +328,14 @@ def run_multicore_numpy(self) -> np.ndarray: # Save the results in the final list warped_cube = np.zeros_like(self.datacube, dtype=self.datacube.dtype) - warp_matrices = [] + warp_matrices, warp_errors = [], [] for index, future in enumerate(futures): - warped_image, warp_matrix = future.result() + warped_image, warp_matrix, warp_error = future.result() warped_cube[index] = warped_image warp_matrices.append(warp_matrix) + warp_errors.append(warp_error) - return warped_cube, warp_matrices + return warped_cube, warp_matrices, warp_error def run_multicore_xarray(self) -> xr.Dataset: """ @@ -341,7 +343,7 @@ def run_multicore_xarray(self) -> xr.Dataset: """ # Create the reference feature using the reference image - reference_layer = self.create_layer(self.reference.values) + reference_layer = self.create_layer(self.reference_np) # Create the executor with concurrent.futures.ThreadPoolExecutor( @@ -360,11 +362,12 @@ def run_multicore_xarray(self) -> xr.Dataset: # Save the results in the final list warped_cube = np.zeros_like(self.datacube.values, dtype=self.datacube.dtype) - warp_matrices = [] + warp_matrices, warp_errors = [], [] for index, future in enumerate(futures): - warped_image, warp_matrix = future.result() + warped_image, warp_matrix, warp_error = future.result() warped_cube[index] = warped_image warp_matrices.append(warp_matrix) + warp_errors.append(warp_error) # Create the xarray dataset return xr.DataArray( @@ -372,24 +375,43 @@ def run_multicore_xarray(self) -> xr.Dataset: coords=self.datacube.coords, dims=self.datacube.dims, attrs=self.datacube.attrs, - ), warp_matrices + ), warp_matrices, warp_errors - def run(self) -> Union[xr.Dataset, np.ndarray]: + def run(self, multicore=False) -> Union[xr.Dataset, np.ndarray]: """ Run the alignment method """ - if isinstance(self.datacube, xr.DataArray): - return self.run_xarray() + error_before= get_error(self.datacube.values, self.reference_np) + if multicore: + data_warped, warps, errors= self.run_multicore_xarray() + else: + data_warped, warps, errors= self.run_xarray() + error_after= get_error(data_warped.values, self.reference_np) else: - return self.run_numpy() + error_before= get_error(self.datacube, self.reference_np) + if multicore: + data_warped, warps, errors= self.run_multicore_numpy() + else: + data_warped, warps, errors= self.run_numpy() + error_after= get_error(data_warped, self.reference_np) + + return data_warped, warps, errors, error_before, error_after def run_multicore(self) -> Union[xr.Dataset, np.ndarray]: """ Run the alignment method using multiple threads """ - - if isinstance(self.datacube, xr.DataArray): - return self.run_multicore_xarray() - else: - return self.run_multicore_numpy() \ No newline at end of file + return self.run(multicore=True) + +def get_error(warped, reference, error='correlation'): + assert error in ['correlation'], f'error must be in ["correlation"]' + + warped, reference= np.array(warped), np.array(reference) + + reference_repeated = np.repeat(reference[None], warped.shape[0], axis=0) + warped_flat = warped.flatten() + reference_flat = reference_repeated.flatten() + + if error == 'correlation': + return 1. - np.corrcoef(warped_flat, reference_flat)[0,1] # Get off-diagonal cross-corr \ No newline at end of file diff --git a/satalign/pcc.py b/satalign/pcc.py index 1d8332d..de52642 100644 --- a/satalign/pcc.py +++ b/satalign/pcc.py @@ -96,4 +96,4 @@ def find_warp( warp_matrix = np.eye(*self.warp_matrix_size, dtype=np.float32) self.warning_status = True - return warp_matrix + return warp_matrix, error