Skip to content

Commit

Permalink
Merge pull request #193 from openclimatefix/timezone_fix
Browse files Browse the repository at this point in the history
Test fixes - timezones
  • Loading branch information
dfulu authored May 10, 2023
2 parents 03175b1 + c82257f commit 21b148d
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 61 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Byte-compiled / optimized / DLL files
# Custom - ocf_datapipes
ocf_datapipes/utils/eso_metadata.csv

## Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
Expand Down
6 changes: 3 additions & 3 deletions ocf_datapipes/select/select_live_t0_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __iter__(self) -> pd.Timestamp:
for xr_data in self.source_datapipe:
# Get most recent time in data
# Select the history that goes back that far
latest_time_idx = pd.DatetimeIndex(xr_data[self.dim_name].values).get_loc(
pd.Timestamp.utcnow(), method="pad"
)
latest_time_idx = pd.DatetimeIndex(xr_data[self.dim_name].values).get_indexer(
[pd.Timestamp.now(tz=None)], method="pad"
)[0]
latest_time = xr_data[self.dim_name].values[latest_time_idx]
yield latest_time
26 changes: 14 additions & 12 deletions ocf_datapipes/select/select_loc_and_t0.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,7 @@

@functional_datapipe("select_loc_and_t0")
class LocationT0PickerIterDataPipe(IterDataPipe):
"""Datapipe to yield location-time pairs from the input data source.
Args:
source_datapipe: Datapipe emitting Xarray Dataset
return_all: Whether to return all t0-location pairs,
if True, also returns them in structured order
shuffle: If `return_all` sets whether the pairs are
shuffled before being returned.
x_dim_name: x dimension name, defaulted to 'x_osgb'
y_dim_name: y dimension name, defaulted to 'y_osgb'
time_dim_name: time dimension name, defaulted to 'time_utc'
"""
"""Datapipe to yield location-time pairs from the input data source."""

def __init__(
self,
Expand All @@ -36,6 +25,19 @@ def __init__(
y_dim_name: Optional[str] = "y_osgb",
time_dim_name: Optional[str] = "time_utc",
):
"""
Datapipe to yield location-time pairs from the input data source.
Args:
source_datapipe: Datapipe emitting Xarray Dataset
return_all: Whether to return all t0-location pairs,
if True, also returns them in structured order
shuffle: If `return_all` sets whether the pairs are
shuffled before being returned.
x_dim_name: x dimension name, defaulted to 'x_osgb'
y_dim_name: y dimension name, defaulted to 'y_osgb'
time_dim_name: time dimension name, defaulted to 'time_utc'
"""
super().__init__()
self.source_datapipe = source_datapipe
self.return_all = return_all
Expand Down
6 changes: 3 additions & 3 deletions ocf_datapipes/training/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ def create_t0_and_loc_datapipes(
shuffle: bool = True,
):
"""
Takes datapipes and returns datapipes of appropriate locations and times for which samples can
be constructed from the the input datapipe sources. The (location, t0) pairs are sampled without
replacement.
Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times.
The (location, t0) pairs are sampled without replacement.
Args:
datapipes_dict: Dictionary of datapipes of input sources for which we want to select
Expand Down
28 changes: 15 additions & 13 deletions ocf_datapipes/transform/xarray/nwp_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,7 @@

@functional_datapipe("convert_to_nwp_target_time_with_dropout")
class ConvertToNWPTargetTimeWithDropoutIterDataPipe(IterDataPipe):
"""Convert NWP Xarray dataset to use target time as indexer
Args:
source_datapipe: Datapipe emitting a Xarray Dataset with step and init_time_utc indexers.
t0_datapipe: Datapipe emitting t0 times for indexing off of choosing the closest previous
init_time_utc.
sample_period_duration: How long the sampling period is.
history_duration: How long the history time should cover.
forecast_duration: How long the forecast time should cover.
dropout_timedeltas: List of timedeltas. We randonly select the delay for each NWP forecast
from this list. These should be negative timedeltas w.r.t time t0.
dropout_frac: Fraction of samples subject to dropout
"""
"""Convert NWP Xarray dataset to use target time as indexer"""

def __init__(
self,
Expand All @@ -38,6 +26,20 @@ def __init__(
dropout_timedeltas: List[timedelta],
dropout_frac: Optional[float] = 1,
):
"""Convert NWP Xarray dataset to use target time as indexer
Args:
source_datapipe: Datapipe emitting an Xarray Dataset with step and init_time_utc
indexers.
t0_datapipe: Datapipe emitting t0 times for indexing off of choosing the closest
previous init_time_utc.
sample_period_duration: How long the sampling period is.
history_duration: How long the history time should cover.
forecast_duration: How long the forecast time should cover.
dropout_timedeltas: List of timedeltas. We randonly select the delay for each NWP
forecast from this list. These should be negative timedeltas w.r.t time t0.
dropout_frac: Fraction of samples subject to dropout
"""
self.source_datapipe = source_datapipe
self.t0_datapipe = t0_datapipe
self.sample_period_duration = sample_period_duration
Expand Down
30 changes: 16 additions & 14 deletions ocf_datapipes/transform/xarray/standard_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@

@functional_datapipe("select_dropout_time")
class SelectDropoutTimeIterDataPipe(IterDataPipe):
"""Generates dropout times. The times are absolute values, not timedeltas.
Args:
source_datapipe: Datapipe of t0 times
dropout_timedeltas: List of timedeltas. We randonly select the delay for each time from this
list. These should be negative timedeltas w.r.t time t0.
dropout_frac: Fraction of samples subject to dropout
"""
"""Generates dropout times. The times are absolute values, not timedeltas."""

def __init__(
self,
source_datapipe: IterDataPipe,
dropout_timedeltas: List[timedelta],
dropout_frac: Optional[float] = 0,
):
"""Generates dropout times. The times are absolute values, not timedeltas.
Args:
source_datapipe: Datapipe of t0 times
dropout_timedeltas: List of timedeltas. We randonly select the delay for each time from
this list. These should be negative timedeltas w.r.t time t0.
dropout_frac: Fraction of samples subject to dropout
"""
self.source_datapipe = source_datapipe
self.dropout_timedeltas = dropout_timedeltas
self.dropout_frac = dropout_frac
Expand All @@ -51,18 +52,19 @@ def __iter__(self):

@functional_datapipe("apply_dropout_time")
class ApplyDropoutTimeIterDataPipe(IterDataPipe):
"""Masks an xarray object to replace values that come after the dropout time with NaN.
Args:
source_datapipe: Datapipe of Xarray objects
dropout_time_datapipe: Datapipe of dropout times
"""
"""Masks an xarray object to replace values that come after the dropout time with NaN."""

def __init__(
self,
source_datapipe: IterDataPipe,
dropout_time_datapipe: IterDataPipe,
):
"""Masks an xarray object to replace values that come after the dropout time with NaN.
Args:
source_datapipe: Datapipe of Xarray objects
dropout_time_datapipe: Datapipe of dropout times
"""
self.source_datapipe = source_datapipe
self.dropout_time_datapipe = dropout_time_datapipe

Expand Down
4 changes: 2 additions & 2 deletions tests/config/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ input_data:
pv_image_size_meters_height: 10000000
pv_image_size_meters_width: 10000000
n_pv_systems_per_example: 32
start_datetime: "2010-01-01 00:00:00+00:00"
end_datetime: "2030-01-01 00:00:00+00:00"
start_datetime: "2010-01-01 00:00:00"
end_datetime: "2030-01-01 00:00:00"
satellite:
satellite_channels:
- IR_016
Expand Down
18 changes: 9 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def passiv_datapipe():
)

pv = PV(
start_datetime=datetime(2018, 1, 1, tzinfo=timezone.utc),
end_datetime=datetime(2023, 1, 1, tzinfo=timezone.utc),
start_datetime=datetime(2018, 1, 1),
end_datetime=datetime(2023, 1, 1),
)
pv_file = PVFiles(
pv_filename=str(filename),
Expand Down Expand Up @@ -110,8 +110,8 @@ def pvoutput_datapipe():
)

pv = PV(
start_datetime=datetime(2018, 1, 1, tzinfo=timezone.utc),
end_datetime=datetime(2023, 1, 1, tzinfo=timezone.utc),
start_datetime=datetime(2018, 1, 1),
end_datetime=datetime(2023, 1, 1),
)
pv_file = PVFiles(
pv_filename=str(filename),
Expand Down Expand Up @@ -213,7 +213,7 @@ def pv_yields_and_systems(db_session):
for hour in range(4, 10):
for minute in range(0, 60, 5):
pv_yield_1 = PVYield(
datetime_utc=datetime(2022, 1, 1, hour, minute, tzinfo=timezone.utc),
datetime_utc=datetime(2022, 1, 1, hour, minute),
solar_generation_kw=hour + minute / 100,
).to_orm()
pv_yield_1.pv_system = pv_system_sql_1
Expand All @@ -229,15 +229,15 @@ def pv_yields_and_systems(db_session):
# pv system with gaps every 5 mins
for minutes in [0, 10, 20, 30]:
pv_yield_4 = PVYield(
datetime_utc=datetime(2022, 1, 1, 4, tzinfo=timezone.utc) + timedelta(minutes=minutes),
datetime_utc=datetime(2022, 1, 1, 4) + timedelta(minutes=minutes),
solar_generation_kw=4,
).to_orm()
pv_yield_4.pv_system = pv_system_sql_2
pv_yield_sqls.append(pv_yield_4)

# add a system with only on pv yield
pv_yield_5 = PVYield(
datetime_utc=datetime(2022, 1, 1, 4, tzinfo=timezone.utc) + timedelta(minutes=minutes),
datetime_utc=datetime(2022, 1, 1, 4) + timedelta(minutes=minutes),
solar_generation_kw=4,
).to_orm()
pv_yield_5.pv_system = pv_system_sql_3
Expand Down Expand Up @@ -271,7 +271,7 @@ def gsp_yields(db_session):
for hour in range(0, 8):
for minute in range(0, 60, 30):
gsp_yield_1 = GSPYield(
datetime_utc=datetime(2022, 1, 1, hour, minute, tzinfo=timezone.utc),
datetime_utc=datetime(2022, 1, 1, hour, minute),
solar_generation_kw=hour + minute,
)
gsp_yield_1_sql = gsp_yield_1.to_orm()
Expand All @@ -298,7 +298,7 @@ def pv_parquet_file():
- generation_wh
"""

date = datetime(2022, 9, 1, tzinfo=timezone.utc)
date = datetime(2022, 9, 1)
ids = range(0, 10)
days = 7

Expand Down
4 changes: 2 additions & 2 deletions tests/load/pv/test_load_pv.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_open_both_from_nc():

def test_load_parquet_file(pv_parquet_file):
pv = PV(
start_datetime=datetime(2018, 1, 1, tzinfo=timezone.utc),
end_datetime=datetime(2023, 1, 1, tzinfo=timezone.utc),
start_datetime=datetime(2018, 1, 1),
end_datetime=datetime(2023, 1, 1),
)
pv_file = PVFiles(
pv_filename=pv_parquet_file,
Expand Down
1 change: 1 addition & 0 deletions tests/training/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyaml_env import parse_config

import pandas as pd
import numpy as np


def test_open_and_return_datapipes():
Expand Down
4 changes: 2 additions & 2 deletions tests/transform/xarray/gsp/test_remove_northern_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def test_remove_northern_gsp_all(gsp_datapipe):


def test_remove_northern_gsp_some(gsp_datapipe):
northern_y_osgb_limit = 180000
northern_y_osgb_limit = 180_000

gsp_datapipe = RemoveNorthernGSP(gsp_datapipe, northern_y_osgb_limit=northern_y_osgb_limit)
data = next(iter(gsp_datapipe))

assert len(data.gsp_id) == 5
assert len(data.gsp_id) == 6
assert (data.y_osgb < northern_y_osgb_limit).all()

0 comments on commit 21b148d

Please sign in to comment.