diff --git a/.gitignore b/.gitignore index b1385cdf2..446a883ad 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ -# Byte-compiled / optimized / DLL files +# Custom - ocf_datapipes +ocf_datapipes/utils/eso_metadata.csv + +## Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class diff --git a/ocf_datapipes/select/select_live_t0_time.py b/ocf_datapipes/select/select_live_t0_time.py index d10e5b96b..726d255a7 100644 --- a/ocf_datapipes/select/select_live_t0_time.py +++ b/ocf_datapipes/select/select_live_t0_time.py @@ -24,8 +24,8 @@ def __iter__(self) -> pd.Timestamp: for xr_data in self.source_datapipe: # Get most recent time in data # Select the history that goes back that far - latest_time_idx = pd.DatetimeIndex(xr_data[self.dim_name].values).get_loc( - pd.Timestamp.utcnow(), method="pad" - ) + latest_time_idx = pd.DatetimeIndex(xr_data[self.dim_name].values).get_indexer( + [pd.Timestamp.now(tz=None)], method="pad" + )[0] latest_time = xr_data[self.dim_name].values[latest_time_idx] yield latest_time diff --git a/ocf_datapipes/select/select_loc_and_t0.py b/ocf_datapipes/select/select_loc_and_t0.py index 379c9dac9..434ac0c55 100644 --- a/ocf_datapipes/select/select_loc_and_t0.py +++ b/ocf_datapipes/select/select_loc_and_t0.py @@ -14,18 +14,7 @@ @functional_datapipe("select_loc_and_t0") class LocationT0PickerIterDataPipe(IterDataPipe): - """Datapipe to yield location-time pairs from the input data source. - - Args: - source_datapipe: Datapipe emitting Xarray Dataset - return_all: Whether to return all t0-location pairs, - if True, also returns them in structured order - shuffle: If `return_all` sets whether the pairs are - shuffled before being returned. - x_dim_name: x dimension name, defaulted to 'x_osgb' - y_dim_name: y dimension name, defaulted to 'y_osgb' - time_dim_name: time dimension name, defaulted to 'time_utc' - """ + """Datapipe to yield location-time pairs from the input data source.""" def __init__( self, @@ -36,6 +25,19 @@ def __init__( y_dim_name: Optional[str] = "y_osgb", time_dim_name: Optional[str] = "time_utc", ): + """ + Datapipe to yield location-time pairs from the input data source. + + Args: + source_datapipe: Datapipe emitting Xarray Dataset + return_all: Whether to return all t0-location pairs, + if True, also returns them in structured order + shuffle: If `return_all` sets whether the pairs are + shuffled before being returned. + x_dim_name: x dimension name, defaulted to 'x_osgb' + y_dim_name: y dimension name, defaulted to 'y_osgb' + time_dim_name: time dimension name, defaulted to 'time_utc' + """ super().__init__() self.source_datapipe = source_datapipe self.return_all = return_all diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index ce8c8c4f2..d890f1231 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -319,9 +319,9 @@ def create_t0_and_loc_datapipes( shuffle: bool = True, ): """ - Takes datapipes and returns datapipes of appropriate locations and times for which samples can - be constructed from the the input datapipe sources. The (location, t0) pairs are sampled without - replacement. + Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times. + + The (location, t0) pairs are sampled without replacement. Args: datapipes_dict: Dictionary of datapipes of input sources for which we want to select diff --git a/ocf_datapipes/transform/xarray/nwp_dropout.py b/ocf_datapipes/transform/xarray/nwp_dropout.py index cc13e5187..fd4740320 100644 --- a/ocf_datapipes/transform/xarray/nwp_dropout.py +++ b/ocf_datapipes/transform/xarray/nwp_dropout.py @@ -14,19 +14,7 @@ @functional_datapipe("convert_to_nwp_target_time_with_dropout") class ConvertToNWPTargetTimeWithDropoutIterDataPipe(IterDataPipe): - """Convert NWP Xarray dataset to use target time as indexer - - Args: - source_datapipe: Datapipe emitting a Xarray Dataset with step and init_time_utc indexers. - t0_datapipe: Datapipe emitting t0 times for indexing off of choosing the closest previous - init_time_utc. - sample_period_duration: How long the sampling period is. - history_duration: How long the history time should cover. - forecast_duration: How long the forecast time should cover. - dropout_timedeltas: List of timedeltas. We randonly select the delay for each NWP forecast - from this list. These should be negative timedeltas w.r.t time t0. - dropout_frac: Fraction of samples subject to dropout - """ + """Convert NWP Xarray dataset to use target time as indexer""" def __init__( self, @@ -38,6 +26,20 @@ def __init__( dropout_timedeltas: List[timedelta], dropout_frac: Optional[float] = 1, ): + """Convert NWP Xarray dataset to use target time as indexer + + Args: + source_datapipe: Datapipe emitting an Xarray Dataset with step and init_time_utc + indexers. + t0_datapipe: Datapipe emitting t0 times for indexing off of choosing the closest + previous init_time_utc. + sample_period_duration: How long the sampling period is. + history_duration: How long the history time should cover. + forecast_duration: How long the forecast time should cover. + dropout_timedeltas: List of timedeltas. We randonly select the delay for each NWP + forecast from this list. These should be negative timedeltas w.r.t time t0. + dropout_frac: Fraction of samples subject to dropout + """ self.source_datapipe = source_datapipe self.t0_datapipe = t0_datapipe self.sample_period_duration = sample_period_duration diff --git a/ocf_datapipes/transform/xarray/standard_dropout.py b/ocf_datapipes/transform/xarray/standard_dropout.py index 96be097e6..d65dde80e 100644 --- a/ocf_datapipes/transform/xarray/standard_dropout.py +++ b/ocf_datapipes/transform/xarray/standard_dropout.py @@ -11,14 +11,7 @@ @functional_datapipe("select_dropout_time") class SelectDropoutTimeIterDataPipe(IterDataPipe): - """Generates dropout times. The times are absolute values, not timedeltas. - - Args: - source_datapipe: Datapipe of t0 times - dropout_timedeltas: List of timedeltas. We randonly select the delay for each time from this - list. These should be negative timedeltas w.r.t time t0. - dropout_frac: Fraction of samples subject to dropout - """ + """Generates dropout times. The times are absolute values, not timedeltas.""" def __init__( self, @@ -26,6 +19,14 @@ def __init__( dropout_timedeltas: List[timedelta], dropout_frac: Optional[float] = 0, ): + """Generates dropout times. The times are absolute values, not timedeltas. + + Args: + source_datapipe: Datapipe of t0 times + dropout_timedeltas: List of timedeltas. We randonly select the delay for each time from + this list. These should be negative timedeltas w.r.t time t0. + dropout_frac: Fraction of samples subject to dropout + """ self.source_datapipe = source_datapipe self.dropout_timedeltas = dropout_timedeltas self.dropout_frac = dropout_frac @@ -51,18 +52,19 @@ def __iter__(self): @functional_datapipe("apply_dropout_time") class ApplyDropoutTimeIterDataPipe(IterDataPipe): - """Masks an xarray object to replace values that come after the dropout time with NaN. - - Args: - source_datapipe: Datapipe of Xarray objects - dropout_time_datapipe: Datapipe of dropout times - """ + """Masks an xarray object to replace values that come after the dropout time with NaN.""" def __init__( self, source_datapipe: IterDataPipe, dropout_time_datapipe: IterDataPipe, ): + """Masks an xarray object to replace values that come after the dropout time with NaN. + + Args: + source_datapipe: Datapipe of Xarray objects + dropout_time_datapipe: Datapipe of dropout times + """ self.source_datapipe = source_datapipe self.dropout_time_datapipe = dropout_time_datapipe diff --git a/tests/config/test.yaml b/tests/config/test.yaml index d4ae5799d..5589ff402 100644 --- a/tests/config/test.yaml +++ b/tests/config/test.yaml @@ -28,8 +28,8 @@ input_data: pv_image_size_meters_height: 10000000 pv_image_size_meters_width: 10000000 n_pv_systems_per_example: 32 - start_datetime: "2010-01-01 00:00:00+00:00" - end_datetime: "2030-01-01 00:00:00+00:00" + start_datetime: "2010-01-01 00:00:00" + end_datetime: "2030-01-01 00:00:00" satellite: satellite_channels: - IR_016 diff --git a/tests/conftest.py b/tests/conftest.py index 8093c9eeb..b3db2784d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,8 +77,8 @@ def passiv_datapipe(): ) pv = PV( - start_datetime=datetime(2018, 1, 1, tzinfo=timezone.utc), - end_datetime=datetime(2023, 1, 1, tzinfo=timezone.utc), + start_datetime=datetime(2018, 1, 1), + end_datetime=datetime(2023, 1, 1), ) pv_file = PVFiles( pv_filename=str(filename), @@ -110,8 +110,8 @@ def pvoutput_datapipe(): ) pv = PV( - start_datetime=datetime(2018, 1, 1, tzinfo=timezone.utc), - end_datetime=datetime(2023, 1, 1, tzinfo=timezone.utc), + start_datetime=datetime(2018, 1, 1), + end_datetime=datetime(2023, 1, 1), ) pv_file = PVFiles( pv_filename=str(filename), @@ -213,7 +213,7 @@ def pv_yields_and_systems(db_session): for hour in range(4, 10): for minute in range(0, 60, 5): pv_yield_1 = PVYield( - datetime_utc=datetime(2022, 1, 1, hour, minute, tzinfo=timezone.utc), + datetime_utc=datetime(2022, 1, 1, hour, minute), solar_generation_kw=hour + minute / 100, ).to_orm() pv_yield_1.pv_system = pv_system_sql_1 @@ -229,7 +229,7 @@ def pv_yields_and_systems(db_session): # pv system with gaps every 5 mins for minutes in [0, 10, 20, 30]: pv_yield_4 = PVYield( - datetime_utc=datetime(2022, 1, 1, 4, tzinfo=timezone.utc) + timedelta(minutes=minutes), + datetime_utc=datetime(2022, 1, 1, 4) + timedelta(minutes=minutes), solar_generation_kw=4, ).to_orm() pv_yield_4.pv_system = pv_system_sql_2 @@ -237,7 +237,7 @@ def pv_yields_and_systems(db_session): # add a system with only on pv yield pv_yield_5 = PVYield( - datetime_utc=datetime(2022, 1, 1, 4, tzinfo=timezone.utc) + timedelta(minutes=minutes), + datetime_utc=datetime(2022, 1, 1, 4) + timedelta(minutes=minutes), solar_generation_kw=4, ).to_orm() pv_yield_5.pv_system = pv_system_sql_3 @@ -271,7 +271,7 @@ def gsp_yields(db_session): for hour in range(0, 8): for minute in range(0, 60, 30): gsp_yield_1 = GSPYield( - datetime_utc=datetime(2022, 1, 1, hour, minute, tzinfo=timezone.utc), + datetime_utc=datetime(2022, 1, 1, hour, minute), solar_generation_kw=hour + minute, ) gsp_yield_1_sql = gsp_yield_1.to_orm() @@ -298,7 +298,7 @@ def pv_parquet_file(): - generation_wh """ - date = datetime(2022, 9, 1, tzinfo=timezone.utc) + date = datetime(2022, 9, 1) ids = range(0, 10) days = 7 diff --git a/tests/load/pv/test_load_pv.py b/tests/load/pv/test_load_pv.py index 2b8ea9df2..8d8f1dc8b 100644 --- a/tests/load/pv/test_load_pv.py +++ b/tests/load/pv/test_load_pv.py @@ -70,8 +70,8 @@ def test_open_both_from_nc(): def test_load_parquet_file(pv_parquet_file): pv = PV( - start_datetime=datetime(2018, 1, 1, tzinfo=timezone.utc), - end_datetime=datetime(2023, 1, 1, tzinfo=timezone.utc), + start_datetime=datetime(2018, 1, 1), + end_datetime=datetime(2023, 1, 1), ) pv_file = PVFiles( pv_filename=pv_parquet_file, diff --git a/tests/training/test_common.py b/tests/training/test_common.py index 6aaa4026b..9b53a0348 100644 --- a/tests/training/test_common.py +++ b/tests/training/test_common.py @@ -15,6 +15,7 @@ from pyaml_env import parse_config import pandas as pd +import numpy as np def test_open_and_return_datapipes(): diff --git a/tests/transform/xarray/gsp/test_remove_northern_gsp.py b/tests/transform/xarray/gsp/test_remove_northern_gsp.py index f38c56628..83d15565f 100644 --- a/tests/transform/xarray/gsp/test_remove_northern_gsp.py +++ b/tests/transform/xarray/gsp/test_remove_northern_gsp.py @@ -16,10 +16,10 @@ def test_remove_northern_gsp_all(gsp_datapipe): def test_remove_northern_gsp_some(gsp_datapipe): - northern_y_osgb_limit = 180000 + northern_y_osgb_limit = 180_000 gsp_datapipe = RemoveNorthernGSP(gsp_datapipe, northern_y_osgb_limit=northern_y_osgb_limit) data = next(iter(gsp_datapipe)) - assert len(data.gsp_id) == 5 + assert len(data.gsp_id) == 6 assert (data.y_osgb < northern_y_osgb_limit).all()