diff --git a/ocf_datapipes/select/select_spatial_slice.py b/ocf_datapipes/select/select_spatial_slice.py index cfa7e0bfb..272cb9203 100644 --- a/ocf_datapipes/select/select_spatial_slice.py +++ b/ocf_datapipes/select/select_spatial_slice.py @@ -1,12 +1,13 @@ """Select spatial slices""" import logging -from typing import Union, Optional +from typing import Optional, Union import numpy as np import xarray as xr +from scipy.spatial import KDTree from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe -from scipy.spatial import KDTree + from ocf_datapipes.utils.consts import Location from ocf_datapipes.utils.geospatial import ( lat_lon_to_osgb, @@ -95,9 +96,7 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: # Sanity check! assert left_idx >= 0, f"{left_idx=} must be >= 0!" data_width_pixels = len(xr_data[self.x_dim_name]) - assert ( - right_idx <= data_width_pixels - ), f"{right_idx=} must be <= {data_width_pixels=}" + assert right_idx <= data_width_pixels, f"{right_idx=} must be <= {data_width_pixels=}" assert top_idx >= 0, f"{top_idx=} must be >= 0!" data_height_pixels = len(xr_data[self.y_dim_name]) assert ( @@ -175,15 +174,11 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: location, xr_data, left, bottom, right, top ) - x_mask = (left <= xr_data.x_geostationary) & ( - xr_data.x_geostationary <= right - ) + x_mask = (left <= xr_data.x_geostationary) & (xr_data.x_geostationary <= right) y_mask = (xr_data.y_geostationary <= top) & ( # Y is flipped bottom <= xr_data.y_geostationary ) - selected = xr_data.isel( - x_geostationary=x_mask, y_geostationary=y_mask - ) + selected = xr_data.isel(x_geostationary=x_mask, y_geostationary=y_mask) elif "longitude" == self.x_dim_name: if location.coordinate_system == "osgb": # Convert to geostationary edges @@ -219,17 +214,11 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: # Select data in the region of interest and ID: # This also works for unstructured grids # Need to check coordinate systems match - if ( - location.coordinate_system == "osgb" - and "longitude" in self.x_dim_name - ): + if location.coordinate_system == "osgb" and "longitude" in self.x_dim_name: # Convert to lat_lon edges left, bottom = osgb_to_lat_lon(x=left, y=bottom) right, top = osgb_to_lat_lon(x=right, y=top) - elif ( - location.coordinate_system == "lat_lon" - and "osgb" in self.x_dim_name - ): + elif location.coordinate_system == "lat_lon" and "osgb" in self.x_dim_name: left, bottom = lat_lon_to_osgb(longitude=left, latitude=bottom) right, top = lat_lon_to_osgb(longitude=right, latitude=top) id_mask = ( @@ -246,16 +235,12 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: def _convert_to_geostationary(location, xr_data, left, bottom, right, top): if location.coordinate_system == "osgb": # Convert to geostationary edges - _osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb( - xr_data - ) + _osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb(xr_data) left, bottom = _osgb_to_geostationary(xx=left, yy=bottom) right, top = _osgb_to_geostationary(xx=right, yy=top) elif location.coordinate_system == "lat_lon": # Convert to geostationary edges - _lat_lon_to_geostationary = ( - load_geostationary_area_definition_and_transform_latlon(xr_data) - ) + _lat_lon_to_geostationary = load_geostationary_area_definition_and_transform_latlon(xr_data) left, bottom = _lat_lon_to_geostationary(xx=left, yy=bottom) right, top = _lat_lon_to_geostationary(xx=right, yy=top) return left, bottom, right, top @@ -333,12 +318,8 @@ def _get_idx_of_pixel_closest_to_poi_geostationary( Returns: Location for the center pixel in geostationary coordinates """ - _osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb( - xr_data - ) - center_geostationary_tuple = _osgb_to_geostationary( - xx=center_osgb.x, yy=center_osgb.y - ) + _osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb(xr_data) + center_geostationary_tuple = _osgb_to_geostationary(xx=center_osgb.x, yy=center_osgb.y) center_geostationary = Location( x=center_geostationary_tuple[0], y=center_geostationary_tuple[1], @@ -346,9 +327,7 @@ def _get_idx_of_pixel_closest_to_poi_geostationary( ) # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary: - x_index_at_center = ( - np.searchsorted(xr_data[x_dim_name].values, center_geostationary.x) - 1 - ) + x_index_at_center = np.searchsorted(xr_data[x_dim_name].values, center_geostationary.x) - 1 # y_geostationary is in descending order: y_index_at_center = len(xr_data[y_dim_name]) - ( np.searchsorted(xr_data[y_dim_name].values[::-1], center_geostationary.y) - 1