Skip to content

Commit

Permalink
minimal updates
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jun 18, 2024
1 parent 085586c commit 944787f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
7 changes: 7 additions & 0 deletions pvnet_summation/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
)
else:
self.pvnet_output_shape = (317, self.pvnet_model.forecast_len)

self.use_weighted_loss = False

def predict_pvnet_batch(self, batch):
"""Use PVNet model to create predictions for batch"""
Expand Down Expand Up @@ -184,6 +186,11 @@ def validation_step(self, batch: dict, batch_idx):

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))

# Store these to make horizon accuracy plot
self._horizon_maes.append(
{i: losses[f"MAE_horizon/step_{i:03}"].cpu().numpy() for i in range(self.forecast_len)}
)

logged_losses = {f"{k}/val": v for k, v in losses.items()}

Expand Down
2 changes: 2 additions & 0 deletions pvnet_summation/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def train(config: DictConfig) -> Optional[float]:
for callback in callbacks:
log.info(f"{callback}")
if isinstance(callback, ModelCheckpoint):
# Need to call the .experiment property to initialise the logger
wandb_logger.experiment
callback.dirpath = "/".join(
callback.dirpath.split("/")[:-1] + [wandb_logger.version]
)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ocf_datapipes>=3.3.19
pvnet>=3.0.25
ocf_datapipes>=3.3.33
pvnet>=3.0.45
numpy
pandas
matplotlib
Expand Down
9 changes: 5 additions & 4 deletions scripts/checkpoint_to_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Command line tool to push locally save model checkpoints to huggingface
use:
python checkpoint_to_huggingface.py "path/to/model/checkpoints" \
--local-path="~/tmp/this_model" \
Expand Down Expand Up @@ -56,9 +56,9 @@ def push_to_huggingface(
# Only one epoch (best) saved per model
files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
assert len(files) == 1
checkpoint = torch.load(files[0])
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt")
checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu")

model.load_state_dict(state_dict=checkpoint["state_dict"])

Expand All @@ -72,7 +72,8 @@ def push_to_huggingface(
model.save_pretrained(
model_output_dir,
config=model_config,
wandb_model_code=wandb_id,
data_config=None,
wandb_ids=wandb_id,
push_to_hub=push_to_hub,
repo_id="openclimatefix/pvnet_v2_summation" if push_to_hub else None,
card_template_path=(
Expand Down

0 comments on commit 944787f

Please sign in to comment.