From e681508bdac866d011c5953842cc218146b8e865 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Fri, 2 Aug 2024 12:11:59 +0000 Subject: [PATCH] minor tweaks --- ocf_dataset_alpha/datasets/pvnet.py | 107 ++++++++++-------- .../numpy_batch/add_sun_position.py | 1 - 2 files changed, 61 insertions(+), 47 deletions(-) diff --git a/ocf_dataset_alpha/datasets/pvnet.py b/ocf_dataset_alpha/datasets/pvnet.py index 48e9069..b2eed4a 100644 --- a/ocf_dataset_alpha/datasets/pvnet.py +++ b/ocf_dataset_alpha/datasets/pvnet.py @@ -35,8 +35,7 @@ from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.utils.location import Location - -from ocf_datapipes.batch import BatchKey +from ocf_datapipes.batch import BatchKey, NumpyBatch from ocf_datapipes.utils.consts import ( NWP_MEANS, @@ -53,25 +52,23 @@ xr.set_options(keep_attrs=True) -xarray_object = xr.DataArray | xr.Dataset - # TODO: This is messy. Remove the need for this in NumpyBatch conversion -def add_t0_idx_and_sample_period_duration(ds: xarray_object, source_config): +def add_t0_idx_and_sample_period_duration(da: xr.DataArray, source_config) -> xr.DataArray: """Add attributes to the xarray dataset needed for numpy batch Args: - ds: Xarray object + da: xarray DataArray source_config: Input data source config """ - ds.attrs["t0_idx"] = int( + da.attrs["t0_idx"] = int( source_config.history_minutes / source_config.time_resolution_minutes ) - ds.attrs["sample_period_duration"] = minutes(source_config.time_resolution_minutes) - return ds + da.attrs["sample_period_duration"] = minutes(source_config.time_resolution_minutes) + return da -def minutes(m: float): +def minutes(m: float) -> timedelta: """Timedelta minutes Args: @@ -80,7 +77,7 @@ def minutes(m: float): return timedelta(minutes=m) -def get_dataset_dict(config: Configuration): +def get_dataset_dict(config: Configuration) -> dict[xr.DataArray, dict[xr.DataArray]]: """Construct dictionary of all of the input data sources Args: @@ -88,6 +85,8 @@ def get_dataset_dict(config: Configuration): """ # Check which modalities to use conf_in = config.input_data + + # TODO: Clean these up use_nwp = ( (conf_in.nwp is not None) @@ -103,10 +102,10 @@ def get_dataset_dict(config: Configuration): # We always assume GSP will be included gsp_config = config.input_data.gsps - ds_gsp = open_gsp(zarr_path=gsp_config.gsp_zarr_path) - ds_gsp = add_t0_idx_and_sample_period_duration(ds_gsp, gsp_config) + da_gsp = open_gsp(zarr_path=gsp_config.gsp_zarr_path) + da_gsp = add_t0_idx_and_sample_period_duration(da_gsp, gsp_config) - datasets["gsp"] = ds_gsp + datasets["gsp"] = da_gsp # Load NWP data if in config if use_nwp: @@ -114,25 +113,25 @@ def get_dataset_dict(config: Configuration): datasets["nwp"] = {} for nwp_source, nwp_config in conf_in.nwp.items(): - ds_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider) + da_nwp = open_nwp(nwp_config.nwp_zarr_path, provider=nwp_config.nwp_provider) - ds_nwp = ds_nwp.sel(channel=list(nwp_config.nwp_channels)) + da_nwp = da_nwp.sel(channel=list(nwp_config.nwp_channels)) - ds_nwp = add_t0_idx_and_sample_period_duration(ds_nwp, nwp_config) + da_nwp = add_t0_idx_and_sample_period_duration(da_nwp, nwp_config) - datasets["nwp"][nwp_source] = ds_nwp + datasets["nwp"][nwp_source] = da_nwp # Load satellite data if in config if use_sat: sat_config = config.input_data.satellite - ds_sat = open_sat_data(sat_config.satellite_zarr_path) + da_sat = open_sat_data(sat_config.satellite_zarr_path) - ds_sat.sel(channel=list(sat_config.satellite_channels)) + da_sat.sel(channel=list(sat_config.satellite_channels)) - ds_sat = add_t0_idx_and_sample_period_duration(ds_sat, sat_config) + da_sat = add_t0_idx_and_sample_period_duration(da_sat, sat_config) - datasets["sat"] = ds_sat + datasets["sat"] = da_sat return datasets @@ -157,7 +156,7 @@ def find_valid_t0_times( if key == "nwp": for nwp_key, nwp_conf in config.input_data.nwp.items(): - ds = datasets_dict["nwp"][nwp_key] + da = datasets_dict["nwp"][nwp_key] if nwp_conf.dropout_timedeltas_minutes is None: max_dropout = minutes(0) @@ -178,7 +177,7 @@ def find_valid_t0_times( # This is the max staleness we can use considering the max step of the input data max_possible_staleness = ( - pd.Timedelta(ds["step"].max().item()) + pd.Timedelta(da["step"].max().item()) - minutes(nwp_conf.forecast_minutes) - end_buffer ) @@ -192,7 +191,7 @@ def find_valid_t0_times( max_staleness = max_staleness time_periods = find_contiguous_t0_periods_nwp( - datetimes=pd.DatetimeIndex(ds["init_time_utc"]), + datetimes=pd.DatetimeIndex(da["init_time_utc"]), history_duration=minutes(nwp_conf.history_minutes), max_staleness=max_staleness, max_dropout=max_dropout, @@ -299,14 +298,14 @@ def slice_datasets_by_time( sliced_datasets_dict["nwp"] = {} - for nwp_key, ds_nwp in datasets_dict["nwp"].items(): + for nwp_key, da_nwp in datasets_dict["nwp"].items(): dropout_timedeltas = minutes_list_to_timedeltas( conf_in.nwp[nwp_key].dropout_timedeltas_minutes ) sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp( - ds_nwp, + da_nwp, t0, sample_period_duration=minutes(conf_in.nwp[nwp_key].time_resolution_minutes), history_duration=minutes(conf_in.nwp[nwp_key].history_minutes), @@ -377,7 +376,7 @@ def slice_datasets_by_time( return sliced_datasets_dict -def merge_dicts(list_of_dicts: list[dict]): +def merge_dicts(list_of_dicts: list[dict]) -> dict: """Merge a list of dictionaries into a single dictionary""" # TODO: This doesn't account for duplicate keys all_d = {} @@ -387,7 +386,7 @@ def merge_dicts(list_of_dicts: list[dict]): -def process_and_combine_datasets(dataset_dict: dict, config: Configuration): +def process_and_combine_datasets(dataset_dict: dict, config: Configuration) -> NumpyBatch: """Normalize and convert data to numpy arrays""" numpy_modalities = [] @@ -397,27 +396,27 @@ def process_and_combine_datasets(dataset_dict: dict, config: Configuration): conf_nwp = config.input_data.nwp nwp_numpy_modalities = dict() - for nwp_key, ds_nwp in dataset_dict["nwp"].items(): + for nwp_key, da_nwp in dataset_dict["nwp"].items(): # Normalise provider = conf_nwp[nwp_key].nwp_provider - ds_nwp = (ds_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] + da_nwp = (da_nwp - NWP_MEANS[provider]) / NWP_STDS[provider] # Convert to NumpyBatch - nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(ds_nwp) + nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp) # Combine the NWPs into NumpyBatch numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities}) if "sat" in dataset_dict: # Normalise - ds_sat = (dataset_dict["sat"] - RSS_MEAN) / RSS_STD + da_sat = (dataset_dict["sat"] - RSS_MEAN) / RSS_STD # Convert to NumpyBatch - numpy_modalities.append(convert_satellite_to_numpy_batch(ds_sat)) + numpy_modalities.append(convert_satellite_to_numpy_batch(da_sat)) # GSP always assumed to be in data - ds_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]]) - ds_gsp = normalize_gsp(ds_gsp) - numpy_modalities.append(convert_gsp_to_numpy_batch(ds_gsp)) + da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]]) + da_gsp = normalize_gsp(da_gsp) + numpy_modalities.append(convert_gsp_to_numpy_batch(da_gsp)) # Combine all the modalities combined_sample = merge_dicts(numpy_modalities) @@ -428,8 +427,8 @@ def process_and_combine_datasets(dataset_dict: dict, config: Configuration): return combined_sample -def compute(d): - """Eagerly load a nested dictionary of xarray data""" +def compute(xarray_dict: dict) -> dict: + """Eagerly load a nested dictionary of xarray DataArrays""" for k, v in d.items(): if isinstance(v, dict): d[k] = compute(v) @@ -438,16 +437,16 @@ def compute(d): return d -def get_locations(gs_gsp: xr.Dataset) -> list[Location]: +def get_locations(gs_gsp: xr.DataArray) -> list[Location]: """Get list of locations of GSP""" locations = [] for gsp_id in gs_gsp.gsp_id.values: - ds_ = gs_gsp.sel(gsp_id=gsp_id) + da_ = gs_gsp.sel(gsp_id=gsp_id) locations.append( Location( coordinate_system = "osgb", - x=ds_.x_osgb.item(), - y=ds_.y_osgb.item(), + x=da_.x_osgb.item(), + y=da_.y_osgb.item(), id=gsp_id, ) ) @@ -483,6 +482,9 @@ def __init__( # Construct list of locations to sample from locations = get_locations(datasets_dict["gsp"]) + + # Construct a lookup for locations - useful for users to construct sample by GSP ID + location_lookup = {loc.id: loc for loc in locations} # Construct indices for sampling t_index, loc_index = np.meshgrid( @@ -490,11 +492,13 @@ def __init__( np.arange(len(locations)), ) + # Make array of all possible (t0, location) coordinates. Each row is a single coordinate index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T # Assign coords and indices to self self.valid_t0_times = valid_t0_times self.locations = locations + self.location_lookup = location_lookup self.index_pairs = index_pairs # Assign config and input data to self @@ -505,8 +509,8 @@ def __init__( def __len__(self): return len(self.index_pairs) - # TODO: Would this be better if we could pass in GSP ID int instead of Location object? - def get_sample(self, t0: pd.Timestamp, location: Location): + + def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpyBatch: """Generate the PVNet sample for given coordinates Args: @@ -530,7 +534,18 @@ def __getitem__(self, idx): t0 = self.valid_t0_times[t_index] # Generate the sample - return self.get_sample(t0, location) + return self._get_sample(t0, location) + + + def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch: + """Generate the PVNet sample for given coordinates + + Args: + t0: init-time for sample + gsp_id: GSP ID + """ + location = self.location_lookup[gsp_id] + return self._get_sample(t0, location) if __name__=="__main__": diff --git a/ocf_dataset_alpha/numpy_batch/add_sun_position.py b/ocf_dataset_alpha/numpy_batch/add_sun_position.py index 187f204..2939737 100644 --- a/ocf_dataset_alpha/numpy_batch/add_sun_position.py +++ b/ocf_dataset_alpha/numpy_batch/add_sun_position.py @@ -1,7 +1,6 @@ import pvlib import numpy as np -import warnings from ocf_datapipes.batch import BatchKey from ocf_datapipes.utils.consts import ( AZIMUTH_MEAN,