Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a few features #3

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ torchvision = {version = ">=0.19.0", optional = true}
packaging = {version = ">=24.1", optional = true}
pycolmap = {version = ">=3.10.0", optional = true}


[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
pytest-cov = "^4.0.0"
Expand Down
2 changes: 1 addition & 1 deletion satalign/ecc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ def find_warp(
warnings.warn(f"Could not calculate the warp matrix: {cv2err}")
warp_matrix = np.eye(*self.warp_matrix_size, dtype=np.float32)
self.warning_status = True
return warp_matrix
return warp_matrix, None
2 changes: 1 addition & 1 deletion satalign/lgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def find_warp(
warp_matrix = np.eye(*self.warp_matrix_size, dtype=np.float32)
warp_matrix[:2, 2] = [translation_x.item(), translation_y.item()]

return warp_matrix
return warp_matrix, None

def spatial_setup_model(
self,
Expand Down
66 changes: 44 additions & 22 deletions satalign/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
# Set the class attributes (REQUIRED)
self.datacube = datacube
self.reference = reference
self.reference_np = self.reference if isinstance(self.reference, np.ndarray) else np.array(self.reference.values)

# Set the class attributes (OPTIONAL)
self.channel = channel
Expand All @@ -103,7 +104,7 @@ def find_warp(
self,
reference_image: np.ndarray,
moving_image: np.ndarray,
) -> np.ndarray:
) -> Union[np.ndarray, float]:
"""
Find the warp matrix that aligns the source and
destination image.
Expand Down Expand Up @@ -230,7 +231,7 @@ def get_warped_image(
"""

# Obtain the warp matrix
warp_matrix = self.find_warp(
warp_matrix, warp_error = self.find_warp(
reference_image=reference_image_feature,
moving_image=self.create_layer(moving_image),
)
Expand All @@ -244,14 +245,14 @@ def get_warped_image(
# Warp the image using the estimated warp matrix
warped_image = self.warp_feature(img=moving_image, warp_matrix=warp_matrix)

return warped_image, warp_matrix
return warped_image, warp_matrix, warp_error

def run_xarray(self) -> xr.Dataset:
"""
Run sequantially the get_warped_image method; input is xarray
"""
# Create the reference feature using the reference image
reference_layer = self.create_layer(self.reference.values)
reference_layer = self.create_layer(self.reference_np)

# Run iteratively the get_warped_image method
warp_matrices = []
Expand Down Expand Up @@ -283,7 +284,7 @@ def run_numpy(self) -> np.ndarray:
"""

# Create the reference feature using the reference image
reference_layer = self.create_layer(self.reference)
reference_layer = self.create_layer(self.reference_np)

# Run iteratively the get_warped_image method
warp_matrices = []
Expand All @@ -308,7 +309,7 @@ def run_multicore_numpy(self) -> np.ndarray:
"""

# Create the reference feature using the reference image
reference_layer = self.create_layer(self.reference)
reference_layer = self.create_layer(self.reference_np)

# Create the executor
with concurrent.futures.ThreadPoolExecutor(
Expand All @@ -327,21 +328,22 @@ def run_multicore_numpy(self) -> np.ndarray:

# Save the results in the final list
warped_cube = np.zeros_like(self.datacube, dtype=self.datacube.dtype)
warp_matrices = []
warp_matrices, warp_errors = [], []
for index, future in enumerate(futures):
warped_image, warp_matrix = future.result()
warped_image, warp_matrix, warp_error = future.result()
warped_cube[index] = warped_image
warp_matrices.append(warp_matrix)
warp_errors.append(warp_error)

return warped_cube, warp_matrices
return warped_cube, warp_matrices, warp_error

def run_multicore_xarray(self) -> xr.Dataset:
"""
Run the get_warped_image method using multiple threads
"""

# Create the reference feature using the reference image
reference_layer = self.create_layer(self.reference.values)
reference_layer = self.create_layer(self.reference_np)

# Create the executor
with concurrent.futures.ThreadPoolExecutor(
Expand All @@ -360,36 +362,56 @@ def run_multicore_xarray(self) -> xr.Dataset:

# Save the results in the final list
warped_cube = np.zeros_like(self.datacube.values, dtype=self.datacube.dtype)
warp_matrices = []
warp_matrices, warp_errors = [], []
for index, future in enumerate(futures):
warped_image, warp_matrix = future.result()
warped_image, warp_matrix, warp_error = future.result()
warped_cube[index] = warped_image
warp_matrices.append(warp_matrix)
warp_errors.append(warp_error)

# Create the xarray dataset
return xr.DataArray(
data=warped_cube,
coords=self.datacube.coords,
dims=self.datacube.dims,
attrs=self.datacube.attrs,
), warp_matrices
), warp_matrices, warp_errors

def run(self) -> Union[xr.Dataset, np.ndarray]:
def run(self, multicore=False) -> Union[xr.Dataset, np.ndarray]:
"""
Run the alignment method
"""

if isinstance(self.datacube, xr.DataArray):
return self.run_xarray()
error_before= get_error(self.datacube.values, self.reference_np)
if multicore:
data_warped, warps, errors= self.run_multicore_xarray()
else:
data_warped, warps, errors= self.run_xarray()
error_after= get_error(data_warped.values, self.reference_np)
else:
return self.run_numpy()
error_before= get_error(self.datacube, self.reference_np)
if multicore:
data_warped, warps, errors= self.run_multicore_numpy()
else:
data_warped, warps, errors= self.run_numpy()
error_after= get_error(data_warped, self.reference_np)

return data_warped, warps, errors, error_before, error_after

def run_multicore(self) -> Union[xr.Dataset, np.ndarray]:
"""
Run the alignment method using multiple threads
"""

if isinstance(self.datacube, xr.DataArray):
return self.run_multicore_xarray()
else:
return self.run_multicore_numpy()
return self.run(multicore=True)

def get_error(warped, reference, error='correlation'):
assert error in ['correlation'], f'error must be in ["correlation"]'

warped, reference= np.array(warped), np.array(reference)

reference_repeated = np.repeat(reference[None], warped.shape[0], axis=0)
warped_flat = warped.flatten()
reference_flat = reference_repeated.flatten()

if error == 'correlation':
return 1. - np.corrcoef(warped_flat, reference_flat)[0,1] # Get off-diagonal cross-corr
2 changes: 1 addition & 1 deletion satalign/pcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ def find_warp(
warp_matrix = np.eye(*self.warp_matrix_size, dtype=np.float32)
self.warning_status = True

return warp_matrix
return warp_matrix, error
Loading