Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jacob/icon-global' into jacob/ic…
Browse files Browse the repository at this point in the history
…on-global
  • Loading branch information
jacobbieker committed Aug 16, 2023
2 parents 4f36f69 + c780b1f commit de92609
Showing 1 changed file with 13 additions and 34 deletions.
47 changes: 13 additions & 34 deletions ocf_datapipes/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Select spatial slices"""
import logging
from typing import Union, Optional
from typing import Optional, Union

import numpy as np
import xarray as xr
from scipy.spatial import KDTree
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from scipy.spatial import KDTree

from ocf_datapipes.utils.consts import Location
from ocf_datapipes.utils.geospatial import (
lat_lon_to_osgb,
Expand Down Expand Up @@ -95,9 +96,7 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
# Sanity check!
assert left_idx >= 0, f"{left_idx=} must be >= 0!"
data_width_pixels = len(xr_data[self.x_dim_name])
assert (
right_idx <= data_width_pixels
), f"{right_idx=} must be <= {data_width_pixels=}"
assert right_idx <= data_width_pixels, f"{right_idx=} must be <= {data_width_pixels=}"
assert top_idx >= 0, f"{top_idx=} must be >= 0!"
data_height_pixels = len(xr_data[self.y_dim_name])
assert (
Expand Down Expand Up @@ -175,15 +174,11 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
location, xr_data, left, bottom, right, top
)

x_mask = (left <= xr_data.x_geostationary) & (
xr_data.x_geostationary <= right
)
x_mask = (left <= xr_data.x_geostationary) & (xr_data.x_geostationary <= right)
y_mask = (xr_data.y_geostationary <= top) & ( # Y is flipped
bottom <= xr_data.y_geostationary
)
selected = xr_data.isel(
x_geostationary=x_mask, y_geostationary=y_mask
)
selected = xr_data.isel(x_geostationary=x_mask, y_geostationary=y_mask)
elif "longitude" == self.x_dim_name:
if location.coordinate_system == "osgb":
# Convert to geostationary edges
Expand Down Expand Up @@ -219,17 +214,11 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
# Select data in the region of interest and ID:
# This also works for unstructured grids
# Need to check coordinate systems match
if (
location.coordinate_system == "osgb"
and "longitude" in self.x_dim_name
):
if location.coordinate_system == "osgb" and "longitude" in self.x_dim_name:
# Convert to lat_lon edges
left, bottom = osgb_to_lat_lon(x=left, y=bottom)
right, top = osgb_to_lat_lon(x=right, y=top)
elif (
location.coordinate_system == "lat_lon"
and "osgb" in self.x_dim_name
):
elif location.coordinate_system == "lat_lon" and "osgb" in self.x_dim_name:
left, bottom = lat_lon_to_osgb(longitude=left, latitude=bottom)
right, top = lat_lon_to_osgb(longitude=right, latitude=top)
id_mask = (
Expand All @@ -246,16 +235,12 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
def _convert_to_geostationary(location, xr_data, left, bottom, right, top):
if location.coordinate_system == "osgb":
# Convert to geostationary edges
_osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb(
xr_data
)
_osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb(xr_data)
left, bottom = _osgb_to_geostationary(xx=left, yy=bottom)
right, top = _osgb_to_geostationary(xx=right, yy=top)
elif location.coordinate_system == "lat_lon":
# Convert to geostationary edges
_lat_lon_to_geostationary = (
load_geostationary_area_definition_and_transform_latlon(xr_data)
)
_lat_lon_to_geostationary = load_geostationary_area_definition_and_transform_latlon(xr_data)
left, bottom = _lat_lon_to_geostationary(xx=left, yy=bottom)
right, top = _lat_lon_to_geostationary(xx=right, yy=top)
return left, bottom, right, top
Expand Down Expand Up @@ -333,22 +318,16 @@ def _get_idx_of_pixel_closest_to_poi_geostationary(
Returns:
Location for the center pixel in geostationary coordinates
"""
_osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb(
xr_data
)
center_geostationary_tuple = _osgb_to_geostationary(
xx=center_osgb.x, yy=center_osgb.y
)
_osgb_to_geostationary = load_geostationary_area_definition_and_transform_osgb(xr_data)
center_geostationary_tuple = _osgb_to_geostationary(xx=center_osgb.x, yy=center_osgb.y)
center_geostationary = Location(
x=center_geostationary_tuple[0],
y=center_geostationary_tuple[1],
coordinate_system="geostationary",
)

# Get the index into x and y nearest to x_center_geostationary and y_center_geostationary:
x_index_at_center = (
np.searchsorted(xr_data[x_dim_name].values, center_geostationary.x) - 1
)
x_index_at_center = np.searchsorted(xr_data[x_dim_name].values, center_geostationary.x) - 1
# y_geostationary is in descending order:
y_index_at_center = len(xr_data[y_dim_name]) - (
np.searchsorted(xr_data[y_dim_name].values[::-1], center_geostationary.y) - 1
Expand Down

0 comments on commit de92609

Please sign in to comment.