Skip to content

Commit

Permalink
Use concurrent batch pipeline for ~30x speed up (#236)
Browse files Browse the repository at this point in the history
* bug fix

* use concurrent datapipe

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean up

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update comment

* update comment

* save as tensor

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfulu and pre-commit-ci[bot] authored Oct 7, 2024
1 parent b5fa2d9 commit c93996c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 116 deletions.
73 changes: 19 additions & 54 deletions scripts/backtest_uk_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,11 @@
NumpyBatch,
batch_to_tensor,
copy_batch_to_device,
stack_np_examples_into_batch,
)
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.common import create_t0_and_loc_datapipes
from ocf_datapipes.training.pvnet import (
_get_datapipes_dict,
construct_sliced_data_pipeline,
)
from ocf_datapipes.training.common import _get_datapipes_dict
from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from omegaconf import DictConfig

Expand All @@ -58,20 +54,19 @@
from tqdm import tqdm

from pvnet.load_model import get_model_from_checkpoints
from pvnet.utils import GSPLocationLookup

# ------------------------------------------------------------------
# USER CONFIGURED VARIABLES
output_dir = "/mnt/disks/backtest/test_backtest"
output_dir = "/mnt/disks/extra_batches/test_backtest"

# Local directory to load the PVNet checkpoint from. By default this should pull the best performing
# checkpoint on the val set
model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/kqaknmuc"
model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/q911tei5"

# Local directory to load the summation model checkpoint from. By default this should pull the best
# performing checkpoint on the val set. If set to None a simple sum is used instead
summation_chckpoint_dir = (
"/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2"
"/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/73oa4w9t"
)

# Forecasts will be made for all available init times between these
Expand Down Expand Up @@ -144,7 +139,7 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
# Pop out the config file
config = datapipes_dict.pop("config")

# We are going to abuse the `create_t0_and_loc_datapipes()` function to find the init-times in
# We are going to abuse the `create_t0_datapipe()` function to find the init-times in
# potential_init_times which we have input data for. To do this, we will feed in some fake GSP
# data which has the potential_init_times as timestamps. This is a bit hacky but works for now

Expand Down Expand Up @@ -172,18 +167,15 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
# Overwrite the GSP data which is already in the datapipes dict
datapipes_dict["gsp"] = IterableWrapper([ds_fake_gsp])

# Use create_t0_and_loc_datapipes to get datapipe of init-times
location_pipe, t0_datapipe = create_t0_and_loc_datapipes(
# Use create_t0_datapipe to get datapipe of init-times
t0_datapipe = create_t0_datapipe(
datapipes_dict,
configuration=config,
key_for_t0="gsp",
shuffle=False,
)

# Create a full list of available init-times. Note that we need to loop over the t0s AND
# locations to avoid the torch datapipes buffer overflow but we don't actually use the location
available_init_times = [t0 for _, t0 in zip(location_pipe, t0_datapipe)]
available_init_times = pd.to_datetime(available_init_times)
# Create a full list of available init-times
available_init_times = pd.to_datetime([t0 for t0 in t0_datapipe])

logger.info(
f"{len(available_init_times)} out of {len(potential_init_times)} "
Expand All @@ -193,22 +185,16 @@ def get_available_t0_times(start_datetime, end_datetime, config_path):
return available_init_times


def get_loctimes_datapipes(config_path):
"""Create location and init-time datapipes
def get_times_datapipe(config_path):
"""Create init-time datapipe
Args:
config_path: Path to data config file
Returns:
tuple: A tuple of datapipes
- Datapipe yielding locations
- Datapipe yielding init-times
Datapipe: A Datapipe yielding init-times
"""

# Set up ID location query object
ds_gsp = get_gsp_ds(config_path)
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Filter the init-times to times we have all input data for
available_target_times = get_available_t0_times(
start_datetime,
Expand All @@ -222,25 +208,13 @@ def get_loctimes_datapipes(config_path):
# the backtest will end up producing
available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv")

# Cycle the GSP locations
location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in ALL_GSP_IDS]]).repeat(
num_t0s
)

# Shard and then unbatch the locations so that each worker will generate all samples for all
# GSPs and for a single init-time
location_pipe = location_pipe.sharding_filter()
location_pipe = location_pipe.unbatch(unbatch_level=1)

# Create times datapipe so each worker receives 317 copies of the same datetime for its batch
t0_datapipe = IterableWrapper([[t0 for gsp_id in ALL_GSP_IDS] for t0 in available_target_times])
t0_datapipe = IterableWrapper(available_target_times)
t0_datapipe = t0_datapipe.sharding_filter()
t0_datapipe = t0_datapipe.unbatch(unbatch_level=1)

t0_datapipe = t0_datapipe.set_length(num_t0s * len(ALL_GSP_IDS))
location_pipe = location_pipe.set_length(num_t0s * len(ALL_GSP_IDS))
t0_datapipe = t0_datapipe.set_length(num_t0s)

return location_pipe, t0_datapipe
return t0_datapipe


class ModelPipe:
Expand Down Expand Up @@ -375,25 +349,16 @@ def get_datapipe(config_path: str) -> NumpyBatch:
"""

# Construct location and init-time datapipes
location_pipe, t0_datapipe = get_loctimes_datapipes(config_path)

# Get the number of init-times
num_batches = len(t0_datapipe) // len(ALL_GSP_IDS)
t0_datapipe = get_times_datapipe(config_path)

# Construct sample datapipes
data_pipeline = construct_sliced_data_pipeline(
config_path,
location_pipe,
t0_datapipe,
)

# Batch so that each worker returns a batch of all locations for a single init-time
# Also convert to tensor for model
data_pipeline = (
data_pipeline.batch(len(ALL_GSP_IDS)).map(stack_np_examples_into_batch).map(batch_to_tensor)
)

data_pipeline = data_pipeline.set_length(num_batches)
# Convert to tensor for model
data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe))

return data_pipeline

Expand Down
71 changes: 9 additions & 62 deletions scripts/save_concurrent_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,17 @@
import hydra
import numpy as np
import torch
from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch
from ocf_datapipes.training.common import (
open_and_return_datapipes,
from ocf_datapipes.batch import BatchKey, batch_to_tensor
from ocf_datapipes.training.pvnet_all_gsp import (
construct_sliced_data_pipeline,
construct_time_pipeline,
)
from ocf_datapipes.training.pvnet import construct_loctime_pipelines, construct_sliced_data_pipeline
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper
from tqdm import tqdm

from pvnet.utils import GSPLocationLookup

warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

logger = logging.getLogger(__name__)
Expand All @@ -61,73 +59,22 @@ def __call__(self, input):
torch.save(batch, f"{self.batch_dir}/{i:06}.pt")


def select_first(x):
"""Select zeroth element from indexable object"""
return x[0]


def _get_loctimes_datapipes(config_path, start_time, end_time, n_batches):
# Set up ID location query object
ds_gsp = next(
iter(
open_and_return_datapipes(
config_path,
use_gsp=True,
use_nwp=False,
use_pv=False,
use_sat=False,
use_hrv=False,
use_topo=False,
)["gsp"]
)
)
gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb)

# Cycle the GSP locations
location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in range(1, 318)]]).repeat(
n_batches
)

# Shard and unbatch so each worker goes through GSP 1-317 for each batch
location_pipe = location_pipe.sharding_filter()
location_pipe = location_pipe.unbatch(unbatch_level=1)

# These two datapipes come from an earlier fork and must be iterated through together
# despite the fact that we don't want these random locations here
random_location_datapipe, t0_datapipe = construct_loctime_pipelines(
def _get_datapipe(config_path, start_time, end_time, n_batches):
t0_datapipe = construct_time_pipeline(
config_path,
start_time,
end_time,
)

# Iterate through both but select only time
t0_datapipe = t0_datapipe.zip(random_location_datapipe).map(select_first)

# Create times datapipe so we'll get the same time over each batch
t0_datapipe = t0_datapipe.header(n_batches)
t0_datapipe = IterableWrapper([[t0 for gsp_id in range(1, 318)] for t0 in t0_datapipe])
t0_datapipe = t0_datapipe.sharding_filter()
t0_datapipe = t0_datapipe.unbatch(unbatch_level=1)

return location_pipe, t0_datapipe


def _get_datapipe(config_path, start_time, end_time, n_batches):
# Open datasets from the config and filter to useable location-time pairs

location_pipe, t0_datapipe = _get_loctimes_datapipes(
config_path, start_time, end_time, n_batches
)

data_pipeline = construct_sliced_data_pipeline(
datapipe = construct_sliced_data_pipeline(
config_path,
location_pipe,
t0_datapipe,
)

data_pipeline = data_pipeline.batch(317).map(stack_np_examples_into_batch).map(batch_to_tensor)
).map(batch_to_tensor)

return data_pipeline
return datapipe


def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
Expand Down

0 comments on commit c93996c

Please sign in to comment.