Skip to content

Commit

Permalink
Merge pull request #50 from spaceml-org/6-code-improvements
Browse files Browse the repository at this point in the history
6 code improvements
  • Loading branch information
annajungbluth authored Jul 10, 2024
2 parents 17bffb3 + 186aa04 commit c9ca32e
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 41 deletions.
4 changes: 2 additions & 2 deletions rs_tools/_src/geoprocessing/msg/geoprocessor_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def preprocess_files(self):
logger.error(f"Skipping {itime} due to error loading")
continue

# remove crs from dataset
ds = ds.drop_vars('msg_seviri_fes_3km')
# remove crs from dataset
# ds = ds.drop_vars('msg_seviri_fes_3km') # NOTE: Uncommented to keep coordinate reference system

# remove attrs that cause netcdf error
for var in ds.data_vars:
Expand Down
152 changes: 113 additions & 39 deletions rs_tools/_src/preprocessing/prepatcher.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import autoroot
import numpy as np
from xrpatcher._src.base import XRDAPatcher
import rioxarray
from __future__ import annotations

import gc
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List, Union, Tuple
from tqdm import tqdm
from rs_tools._src.utils.io import get_list_filenames
from pathlib import Path

import numpy as np
import typer
from loguru import logger
import xarray as xr
from loguru import logger
from rs_tools._src.utils.io import get_list_filenames
from tqdm import tqdm
from xrpatcher._src.base import XRDAPatcher


def _check_filetype(file_type: str) -> bool:
"""checks instrument for GOES data."""
if file_type in ["nc", "np"]:
if file_type in ["nc", "np", "tif"]:
return True
else:
msg = "Unrecognized file type"
msg += f"\nNeeds to be 'nc' or 'np'. Others are not yet tested"
msg += f"\nNeeds to be 'nc', 'np', or 'tif'. Others are not yet tested"
raise ValueError(msg)



def _check_nan_count(arr: np.array, nan_cutoff: float) -> bool:
"""
Check if the number of NaN values in the given array is below a specified cutoff.
Expand All @@ -37,11 +40,15 @@ def _check_nan_count(arr: np.array, nan_cutoff: float) -> bool:
# get total pixel count
total_count = int(arr.size)
# check if nan_count is within allowed cutoff
if nan_count/total_count <= nan_cutoff:

pct_nan = nan_count / total_count

if pct_nan <= nan_cutoff:
return True
else:
return False


@dataclass(frozen=True)
class PrePatcher:
"""
Expand All @@ -53,22 +60,22 @@ class PrePatcher:
patch_size (int): The size of each patch.
stride_size (int): The stride size for generating patches.
nan_cutoff (float): The cutoff value for allowed NaN count in a patch.
save_filetype (str): The file type to save patches as. Options are [nc, np].
save_filetype (str): The file type to save patches as. Options are [nc, np, tif].
Methods:
nc_files(self) -> List[str]: Returns a list of all NetCDF filenames in the read_path directory.
save_patches(self): Preprocesses and saves patches from the NetCDF files.
"""

read_path: str
save_path: str
save_path: str
patch_size: int
stride_size: int
stride_size: int
nan_cutoff: float
save_filetype: str

@property
def nc_files(self) -> List[str]:
def nc_files(self) -> list[str]:
"""
Returns a list of all NetCDF filenames in the read_path directory.
Expand All @@ -91,6 +98,25 @@ def save_patches(self):
pbar.set_description(f"Processing: {itime}")
# open dataset
ds = xr.open_dataset(ifile, engine="netcdf4")

if self.save_filetype == "tif":
# concatenate variables
ds_temp = xr.concat(
[ds.cloud_mask, ds.latitude, ds.longitude], dim="band"
)
# name data variables "Rad"
ds_temp = ds_temp.to_dataset(name="Rad")
ds_temp = ds_temp.drop(["cloud_mask", "latitude", "longitude"])
ds_temp = ds_temp.assign_coords(
band=["cloud_mask", "latitude", "longitude"]
)
# merge with original dataset
ds = xr.merge([ds_temp.Rad, ds.Rad])
# store band names to be attached to da later
band_names = [str(i) for i in ds.band.values]
del ds_temp
gc.collect()

# extract radiance data array
da = ds.Rad
# define patch parameters
Expand All @@ -104,36 +130,82 @@ def save_patches(self):
os.makedirs(self.save_path)

for i, ipatch in tqdm(enumerate(patcher), total=len(patcher)):
data = ipatch.data # extract data patch
data = ipatch.data # extract data
# logger.info(f'stride size {self.stride_size} ')
if _check_nan_count(data, self.nan_cutoff):
if self.save_filetype == "nc":
# reconvert to dataset to attach band_wavelength and time
ipatch = ipatch.to_dataset(name='Rad')
ipatch = ipatch.assign_coords({'time': ds.time.values})
ipatch = ipatch.assign_coords({'band_wavelength': ds.band_wavelength.values})
ipatch = ipatch.to_dataset(name="Rad")
ipatch = ipatch.assign_coords({"time": ds.time.values})
ipatch = ipatch.assign_coords(
{"band_wavelength": ds.band_wavelength.values}
)
# compile filename
file_path = Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc")
file_path = Path(self.save_path).joinpath(
f"{itime}_patch_{i}.nc"
)
# remove file if it already exists
if os.path.exists(file_path):
os.remove(file_path)
# save patch to netcdf
ipatch.to_netcdf(Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc"), engine="netcdf4")
# save patch to netcdf
ipatch.to_netcdf(
Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc"),
engine="netcdf4",
)
elif self.save_filetype == "tif":
# reconvert to dataset to attach band_wavelength and time
# ds.attrs['band_names'] = [str(i) for i in ds.band.values]
# compile filename
file_path = Path(self.save_path).joinpath(
f"{itime}_patch_{i}.nc"
)
# remove file if it already exists
if os.path.exists(file_path):
os.remove(file_path)
# add band names as attribute
ipatch.attrs["band_names"] = band_names
# save patch to tiff
ipatch.rio.to_raster(
Path(self.save_path).joinpath(f"{itime}_patch_{i}.tif")
)
elif self.save_filetype == "np":
# save as numpy files
np.save(Path(self.save_path).joinpath(f"{itime}_radiance_patch_{i}"), data)
np.save(Path(self.save_path).joinpath(f"{itime}_latitude_patch_{i}"), ipatch.latitude.values)
np.save(Path(self.save_path).joinpath(f"{itime}_longitude_patch_{i}"), ipatch.longitude.values)
np.save(Path(self.save_path).joinpath(f"{itime}_cloudmask_patch_{i}"), ipatch.cloud_mask.values)
np.save(
Path(self.save_path).joinpath(
f"{itime}_radiance_patch_{i}"
),
data,
)
np.save(
Path(self.save_path).joinpath(
f"{itime}_latitude_patch_{i}"
),
ipatch.latitude.values,
)
np.save(
Path(self.save_path).joinpath(
f"{itime}_longitude_patch_{i}"
),
ipatch.longitude.values,
)
np.save(
Path(self.save_path).joinpath(
f"{itime}_cloudmask_patch_{i}"
),
ipatch.cloud_mask.values,
)
else:
logger.info(f'NaN count exceeded for patch {i} of timestamp {itime}.')
pass
# logger.info(f'NaN count exceeded for patch {i} of timestamp {itime}.')


def prepatch(
read_path: str = "./",
save_path: str = "./",
patch_size: int = 256,
stride_size: int = 256,
nan_cutoff: float = 0.5,
save_filetype: str = 'nc'
read_path: str = "./",
save_path: str = "./",
patch_size: int = 256,
stride_size: int = 256,
nan_cutoff: float = 0.5,
save_filetype: str = "nc",
):
"""
Patches satellite data into smaller patches for training.
Expand All @@ -151,21 +223,23 @@ def prepatch(
_check_filetype(file_type=save_filetype)

# Initialize Prepatcher
logger.info(f"Patching Files...: {read_path}")
logger.info(f"Initializing Prepatcher...")
prepatcher = PrePatcher(
read_path=read_path,
read_path=read_path,
save_path=save_path,
patch_size=patch_size,
stride_size=stride_size,
nan_cutoff=nan_cutoff,
save_filetype=save_filetype
)
save_filetype=save_filetype,
)
logger.info(f"Patching Files...: {save_path}")
prepatcher.save_patches()

logger.info(f"Finished Prepatching Script...!")

if __name__ == '__main__':

if __name__ == "__main__":
"""
python scripts/pipeline/prepatch.py --read-path "/path/to/netcdf/file" --save-path /path/to/save/patches
"""
Expand Down

0 comments on commit c9ca32e

Please sign in to comment.