Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6 code improvements #50

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rs_tools/_src/geoprocessing/msg/geoprocessor_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def preprocess_files(self):
logger.error(f"Skipping {itime} due to error loading")
continue

# remove crs from dataset
ds = ds.drop_vars('msg_seviri_fes_3km')
# remove crs from dataset
# ds = ds.drop_vars('msg_seviri_fes_3km') # NOTE: Uncommented to keep coordinate reference system

# remove attrs that cause netcdf error
for var in ds.data_vars:
Expand Down
152 changes: 113 additions & 39 deletions rs_tools/_src/preprocessing/prepatcher.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import autoroot
import numpy as np
from xrpatcher._src.base import XRDAPatcher
import rioxarray
from __future__ import annotations

import gc
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List, Union, Tuple
from tqdm import tqdm
from rs_tools._src.utils.io import get_list_filenames
from pathlib import Path

import numpy as np
import typer
from loguru import logger
import xarray as xr
from loguru import logger
from rs_tools._src.utils.io import get_list_filenames
from tqdm import tqdm
from xrpatcher._src.base import XRDAPatcher


def _check_filetype(file_type: str) -> bool:
"""checks instrument for GOES data."""
if file_type in ["nc", "np"]:
if file_type in ["nc", "np", "tif"]:
return True
else:
msg = "Unrecognized file type"
msg += f"\nNeeds to be 'nc' or 'np'. Others are not yet tested"
msg += f"\nNeeds to be 'nc', 'np', or 'tif'. Others are not yet tested"
raise ValueError(msg)



def _check_nan_count(arr: np.array, nan_cutoff: float) -> bool:
"""
Check if the number of NaN values in the given array is below a specified cutoff.
Expand All @@ -37,11 +40,15 @@ def _check_nan_count(arr: np.array, nan_cutoff: float) -> bool:
# get total pixel count
total_count = int(arr.size)
# check if nan_count is within allowed cutoff
if nan_count/total_count <= nan_cutoff:

pct_nan = nan_count / total_count

if pct_nan <= nan_cutoff:
return True
else:
return False


@dataclass(frozen=True)
class PrePatcher:
"""
Expand All @@ -53,22 +60,22 @@ class PrePatcher:
patch_size (int): The size of each patch.
stride_size (int): The stride size for generating patches.
nan_cutoff (float): The cutoff value for allowed NaN count in a patch.
save_filetype (str): The file type to save patches as. Options are [nc, np].
save_filetype (str): The file type to save patches as. Options are [nc, np, tif].

Methods:
nc_files(self) -> List[str]: Returns a list of all NetCDF filenames in the read_path directory.
save_patches(self): Preprocesses and saves patches from the NetCDF files.
"""

read_path: str
save_path: str
save_path: str
patch_size: int
stride_size: int
stride_size: int
nan_cutoff: float
save_filetype: str

@property
def nc_files(self) -> List[str]:
def nc_files(self) -> list[str]:
"""
Returns a list of all NetCDF filenames in the read_path directory.

Expand All @@ -91,6 +98,25 @@ def save_patches(self):
pbar.set_description(f"Processing: {itime}")
# open dataset
ds = xr.open_dataset(ifile, engine="netcdf4")

if self.save_filetype == "tif":
# concatenate variables
ds_temp = xr.concat(
[ds.cloud_mask, ds.latitude, ds.longitude], dim="band"
)
# name data variables "Rad"
ds_temp = ds_temp.to_dataset(name="Rad")
ds_temp = ds_temp.drop(["cloud_mask", "latitude", "longitude"])
ds_temp = ds_temp.assign_coords(
band=["cloud_mask", "latitude", "longitude"]
)
# merge with original dataset
ds = xr.merge([ds_temp.Rad, ds.Rad])
# store band names to be attached to da later
band_names = [str(i) for i in ds.band.values]
del ds_temp
gc.collect()

# extract radiance data array
da = ds.Rad
# define patch parameters
Expand All @@ -104,36 +130,82 @@ def save_patches(self):
os.makedirs(self.save_path)

for i, ipatch in tqdm(enumerate(patcher), total=len(patcher)):
data = ipatch.data # extract data patch
data = ipatch.data # extract data
# logger.info(f'stride size {self.stride_size} ')
if _check_nan_count(data, self.nan_cutoff):
if self.save_filetype == "nc":
# reconvert to dataset to attach band_wavelength and time
ipatch = ipatch.to_dataset(name='Rad')
ipatch = ipatch.assign_coords({'time': ds.time.values})
ipatch = ipatch.assign_coords({'band_wavelength': ds.band_wavelength.values})
ipatch = ipatch.to_dataset(name="Rad")
ipatch = ipatch.assign_coords({"time": ds.time.values})
ipatch = ipatch.assign_coords(
{"band_wavelength": ds.band_wavelength.values}
)
# compile filename
file_path = Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc")
file_path = Path(self.save_path).joinpath(
f"{itime}_patch_{i}.nc"
)
# remove file if it already exists
if os.path.exists(file_path):
os.remove(file_path)
# save patch to netcdf
ipatch.to_netcdf(Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc"), engine="netcdf4")
# save patch to netcdf
ipatch.to_netcdf(
Path(self.save_path).joinpath(f"{itime}_patch_{i}.nc"),
engine="netcdf4",
)
elif self.save_filetype == "tif":
# reconvert to dataset to attach band_wavelength and time
# ds.attrs['band_names'] = [str(i) for i in ds.band.values]
# compile filename
file_path = Path(self.save_path).joinpath(
f"{itime}_patch_{i}.nc"
)
# remove file if it already exists
if os.path.exists(file_path):
os.remove(file_path)
# add band names as attribute
ipatch.attrs["band_names"] = band_names
# save patch to tiff
ipatch.rio.to_raster(
Path(self.save_path).joinpath(f"{itime}_patch_{i}.tif")
)
elif self.save_filetype == "np":
# save as numpy files
np.save(Path(self.save_path).joinpath(f"{itime}_radiance_patch_{i}"), data)
np.save(Path(self.save_path).joinpath(f"{itime}_latitude_patch_{i}"), ipatch.latitude.values)
np.save(Path(self.save_path).joinpath(f"{itime}_longitude_patch_{i}"), ipatch.longitude.values)
np.save(Path(self.save_path).joinpath(f"{itime}_cloudmask_patch_{i}"), ipatch.cloud_mask.values)
np.save(
Path(self.save_path).joinpath(
f"{itime}_radiance_patch_{i}"
),
data,
)
np.save(
Path(self.save_path).joinpath(
f"{itime}_latitude_patch_{i}"
),
ipatch.latitude.values,
)
np.save(
Path(self.save_path).joinpath(
f"{itime}_longitude_patch_{i}"
),
ipatch.longitude.values,
)
np.save(
Path(self.save_path).joinpath(
f"{itime}_cloudmask_patch_{i}"
),
ipatch.cloud_mask.values,
)
else:
logger.info(f'NaN count exceeded for patch {i} of timestamp {itime}.')
pass
# logger.info(f'NaN count exceeded for patch {i} of timestamp {itime}.')


def prepatch(
read_path: str = "./",
save_path: str = "./",
patch_size: int = 256,
stride_size: int = 256,
nan_cutoff: float = 0.5,
save_filetype: str = 'nc'
read_path: str = "./",
save_path: str = "./",
patch_size: int = 256,
stride_size: int = 256,
nan_cutoff: float = 0.5,
save_filetype: str = "nc",
):
"""
Patches satellite data into smaller patches for training.
Expand All @@ -151,21 +223,23 @@ def prepatch(
_check_filetype(file_type=save_filetype)

# Initialize Prepatcher
logger.info(f"Patching Files...: {read_path}")
logger.info(f"Initializing Prepatcher...")
prepatcher = PrePatcher(
read_path=read_path,
read_path=read_path,
save_path=save_path,
patch_size=patch_size,
stride_size=stride_size,
nan_cutoff=nan_cutoff,
save_filetype=save_filetype
)
save_filetype=save_filetype,
)
logger.info(f"Patching Files...: {save_path}")
prepatcher.save_patches()

logger.info(f"Finished Prepatching Script...!")

if __name__ == '__main__':

if __name__ == "__main__":
"""
python scripts/pipeline/prepatch.py --read-path "/path/to/netcdf/file" --save-path /path/to/save/patches
"""
Expand Down
Loading