diff --git a/src/mpol/images.py b/src/mpol/images.py index 0185766e..b0f9b2dc 100644 --- a/src/mpol/images.py +++ b/src/mpol/images.py @@ -316,3 +316,39 @@ def to_FITS(self, fname="cube.fits", overwrite=False, header_kwargs=None): hdul.writeto(fname, overwrite=overwrite) hdul.close() + + +def np_to_imagecube(image, coords, nchan=1, wrap=False): + """Convenience function for converting a numpy image into an MPoL ImageCube + tensor (see mpol.images.ImageCube) + + Parameters + ---------- + image : array + An image in numpy format + coords : `mpol.coordinates.GridCoords` object + Instance of the `mpol.coordinates.GridCoords` class + nchan : int, default=1 + Number of channels in the image. Default assumes a single 2D image + wrap : bool, default=False + Whether to wrap the numpy image so that index 0 is in the image center + (FFT algorithms typically place index 0 in the image corner) + + Returns + ------- + icube : `mpol.images.ImageCube` object + The image cube tensor + """ + if wrap: + # move the 0 index to the image center + image = utils.center_np_image(image) + + # broadcast image to (nchan, npix, npix) + img_packed_cube = np.broadcast_to(image, + (nchan, coords.npix, coords.npix)).copy() + + # convert to pytorch tensor + img_packed_tensor = torch.from_numpy(img_packed_cube) + + # insert into ImageCube layer + return ImageCube(coords=coords, nchan=nchan, cube=img_packed_tensor) diff --git a/src/mpol/utils.py b/src/mpol/utils.py index 92096b19..65178bb9 100644 --- a/src/mpol/utils.py +++ b/src/mpol/utils.py @@ -10,6 +10,30 @@ def torch2npy(tensor): return tensor.detach().cpu().numpy() +def center_np_image(image): + """Wrap a numpy image array so that the 0 index is in the image center. + Adapted from frank.utilities.make_image (https://github.com/discsim/frank) + + Parameters + ---------- + image : 2D array + The numpy image with index 0 in the corner + + Returns + ------- + image_wrapped : 2D array + The numpy image with index 0 in the center + """ + tmp, image_wrapped = image.copy(), image.copy() + + Nx, Ny = tmp.shape[0], tmp.shape[1] + + tmp[:Nx//2,], tmp[Nx//2:] = image[Nx//2:], image[:Nx//2] + image_wrapped[:, :Ny//2], image_wrapped[:, Ny//2:] = tmp[:, Ny//2:], tmp[:, :Ny//2] + + return image_wrapped + + def ground_cube_to_packed_cube(ground_cube): r""" Converts a Ground Cube to a Packed Visibility Cube for visibility-plane work. See Units and Conventions for more details.