From 386108cde46e8577958e04b92dcba813e37c3ef8 Mon Sep 17 00:00:00 2001 From: Ian Czekala Date: Wed, 27 Dec 2023 15:01:19 +0000 Subject: [PATCH] moved GriddedDataset export to from_numpy and tests pass. --- src/mpol/gridding.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mpol/gridding.py b/src/mpol/gridding.py index 0c7498fd..4b157ea7 100644 --- a/src/mpol/gridding.py +++ b/src/mpol/gridding.py @@ -12,6 +12,8 @@ import numpy.typing as npt from fast_histogram import histogram as fast_hist +import torch + from mpol.coordinates import GridCoords from mpol.exceptions import DataError, ThresholdExceededError, WrongDimensionError from mpol.datasets import GriddedDataset @@ -657,9 +659,9 @@ def to_pytorch_dataset( return GriddedDataset( coords=self.coords, nchan=self.nchan, - vis_gridded=self.vis_gridded, - weight_gridded=self.weight_gridded, - mask=self.mask, + vis_gridded=torch.from_numpy(self.vis_gridded), + weight_gridded=torch.from_numpy(self.weight_gridded), + mask=torch.from_numpy(self.mask), )