diff --git a/src/mpol/images.py b/src/mpol/images.py index 0185766e..9dae120c 100644 --- a/src/mpol/images.py +++ b/src/mpol/images.py @@ -316,3 +316,39 @@ def to_FITS(self, fname="cube.fits", overwrite=False, header_kwargs=None): hdul.writeto(fname, overwrite=overwrite) hdul.close() + + +def np_to_imagecube(image, coords, nchan=1, wrap=False): + """Convenience function for converting a numpy image into an MPoL ImageCube + tensor (see mpol.images.ImageCube) + + Parameters + ---------- + image : array + An image in numpy format + coords : `mpol.coordinates.GridCoords` object + Instance of the `mpol.coordinates.GridCoords` class + nchan : int, default=1 + Number of channels in the image. Default assumes a single 2D image + wrap : bool, default=False + Whether to wrap the numpy image so that index 0 is in the image center + (FFT algorithms typically place index 0 in the image corner) + + Returns + ------- + icube : `mpol.images.ImageCube` object + The image cube tensor + """ + if wrap: + # move the 0 index to the image center + image = center_np_image(image) + + # broadcast image to (nchan, npix, npix) + img_packed_cube = np.broadcast_to(image, + (nchan, coords.npix, coords.npix)).copy() + + # convert to pytorch tensor + img_packed_tensor = torch.from_numpy(img_packed_cube) + + # insert into ImageCube layer + return ImageCube(coords=coords, nchan=nchan, cube=img_packed_tensor)