Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save batch #141

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ This example runs the application and writes the results to stdout
DB_URL={DB_URL} NWP_ZARR_PATH={NWP_ZARR_PATH} poetry run app
```

To save batches, you need to set the `SAVE_BATCHES_DIR` environment variable to directory.
```

### Starting a local database using docker

```bash
Expand Down
4 changes: 4 additions & 0 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
download_satellite_data,
populate_data_config_sources,
process_and_cache_nwp,
save_batch,
set_night_time_zeros,
worker_init_fn,
)
Expand Down Expand Up @@ -93,6 +94,9 @@ def predict(self, site_id: str, timestamp: dt.datetime):
for i, batch in enumerate(self.dataloader):
log.info(f"Predicting for batch: {i}")

# save batch
save_batch(batch=batch, i=i, model_name=self.name, site_uuid=self.site_uuid)

# Run batch through model
device_batch = copy_batch_to_device(batch_to_tensor(batch), DEVICE)
preds = self.model(device_batch).detach().cpu().numpy()
Expand Down
35 changes: 32 additions & 3 deletions india_forecast_app/models/pvnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Useful functions for setting up PVNet model"""
import logging
import os
from typing import Optional

import fsspec
import numpy as np
import torch
import xarray as xr
import yaml
from ocf_datapipes.batch import BatchKey
Expand Down Expand Up @@ -53,7 +55,7 @@ def populate_data_config_sources(input_path, output_path):
"wind": {"filename": wind_netcdf_path, "metadata_filename": wind_metadata_path},
"pv": {"filename": pv_netcdf_path, "metadata_filename": pv_metadata_path},
"nwp": {"ecmwf": nwp_ecmwf_path, "gfs": nwp_gfs_path, "mo_global": nwp_mo_global_path},
"satellite": {"filepath": satellite_path}
"satellite": {"filepath": satellite_path},
}

if "nwp" in config["input_data"]:
Expand Down Expand Up @@ -93,8 +95,9 @@ def populate_data_config_sources(input_path, output_path):
def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
"""Reads zarr file, renames t variable to t2m and saves zarr to new destination"""

log.info(f"Processing and caching NWP data for {source_nwp_path}, "
f"and saving to {dest_nwp_path}")
log.info(
f"Processing and caching NWP data for {source_nwp_path}, " f"and saving to {dest_nwp_path}"
)

if os.path.exists(dest_nwp_path):
log.info(f"File already exists at {dest_nwp_path}")
Expand Down Expand Up @@ -191,3 +194,29 @@ def set_night_time_zeros(batch, preds, sun_elevation_limit=0.0):
preds[sun_elevation < sun_elevation_limit] = 0

return preds


def save_batch(batch, i: int, model_name, site_uuid, save_batches_dir: Optional[str] = None):
"""
Save batch to SAVE_BATCHES_DIR if set

Args:
batch: The batch to save
i: The index of the batch
model_name: The name of the
site_uuid: The site_uuid of the site
save_batches_dir: The directory to save the batch to,
defaults to environment variable SAVE_BATCHES_DIR
"""

if save_batches_dir is None:
save_batches_dir = os.getenv("SAVE_BATCHES_DIR", None)

if save_batches_dir is not None:
log.info(f"Saving batch {i} to {save_batches_dir}")

local_filename = f"batch_{i}_{model_name}_{site_uuid}.pt"
torch.save(batch, local_filename)

fs = fsspec.open(save_batches_dir).fs
fs.put(local_filename, f"{save_batches_dir}/{local_filename}")
23 changes: 21 additions & 2 deletions tests/models/pvnet/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
""" Tests for utils for pvnet"""
import os
import tempfile

import numpy as np
from ocf_datapipes.batch import BatchKey

from india_forecast_app.models.pvnet.utils import set_night_time_zeros
from india_forecast_app.models.pvnet.utils import save_batch, set_night_time_zeros


def test_set_night_time_zeros():
""" Test for setting night time zeros"""
"""Test for setting night time zeros"""
# set up preds (1,5,7) {example, time, plevels}
preds = np.random.rand(1, 5, 7)

Expand All @@ -26,3 +29,19 @@ def test_set_night_time_zeros():
assert np.all(preds[:, 2:, :] == 0)
# check that all values are positive
assert np.all(preds[:, :2, :] > 0)


def test_save_batch():
"""Test to check batches are saved"""

# set up batch
batch = {"key": "value"}
i = 1
model_name = "test_model_name"

# create temp folder
with tempfile.TemporaryDirectory() as temp_dir:
save_batch(batch, i, model_name, save_batches_dir=temp_dir, site_uuid="fff-fff")

# check that batch is saved
assert os.path.exists(f"{temp_dir}/batch_{i}_{model_name}_fff-fff.pt")
Loading