Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor serialization via Huggingface Hub #75

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions dgmr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def __init__(
output_channels: int = 768,
num_context_steps: int = 4,
conv_type: str = "standard",
**kwargs
):
"""
Conditioning Stack using the context images from Skillful Nowcasting, , see https://arxiv.org/pdf/2104.00954.pdf
Expand All @@ -302,14 +301,6 @@ def __init__(
conv_type: Type of 2D convolution to use, see satflow/models/utils.py for options
"""
super().__init__()
config = locals()
config.pop("__class__")
config.pop("self")
self.config = kwargs.get("config", config)
input_channels = self.config["input_channels"]
output_channels = self.config["output_channels"]
num_context_steps = self.config["num_context_steps"]
conv_type = self.config["conv_type"]

conv2d = get_conv_layer(conv_type)
self.space2depth = PixelUnshuffle(downscale_factor=2)
Expand Down Expand Up @@ -417,7 +408,6 @@ def __init__(
shape: (int, int, int) = (8, 8, 8),
output_channels: int = 768,
use_attention: bool = True,
**kwargs
):
"""
Latent conditioning stack from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
Expand All @@ -428,13 +418,6 @@ def __init__(
use_attention: Whether to have a self-attention block or not
"""
super().__init__()
config = locals()
config.pop("__class__")
config.pop("self")
self.config = kwargs.get("config", config)
shape = self.config["shape"]
output_channels = self.config["output_channels"]
use_attention = self.config["use_attention"]

self.shape = shape
self.use_attention = use_attention
Expand Down
28 changes: 8 additions & 20 deletions dgmr/dgmr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytorch_lightning as pl
import torch
import torchvision
from huggingface_hub import PyTorchModelHubMixin

from dgmr.common import ContextConditioningStack, LatentConditioningStack
from dgmr.discriminators import Discriminator
from dgmr.generators import Generator, Sampler
from dgmr.hub import NowcastingModelHubMixin
from dgmr.losses import (
GridCellLoss,
NowcastingLoss,
Expand All @@ -15,7 +15,13 @@
)


class DGMR(pl.LightningModule, NowcastingModelHubMixin):
class DGMR(
pl.LightningModule,
PyTorchModelHubMixin,
library_name="DGMR",
tags=["nowcasting", "forecasting", "timeseries", "remote-sensing", "gan"],
repo_url="https://github.com/openclimatefix/skillful_nowcasting",
):
"""Deep Generative Model of Radar"""

def __init__(
Expand All @@ -34,7 +40,6 @@ def __init__(
latent_channels: int = 768,
context_channels: int = 384,
generation_steps: int = 6,
**kwargs,
):
"""
Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954
Expand All @@ -59,23 +64,6 @@ def __init__(
pretrained:
"""
super().__init__()
config = locals()
config.pop("__class__")
config.pop("self")
self.config = kwargs.get("config", config)
input_channels = self.config["input_channels"]
forecast_steps = self.config["forecast_steps"]
output_shape = self.config["output_shape"]
gen_lr = self.config["gen_lr"]
disc_lr = self.config["disc_lr"]
conv_type = self.config["conv_type"]
num_samples = self.config["num_samples"]
grid_lambda = self.config["grid_lambda"]
beta1 = self.config["beta1"]
beta2 = self.config["beta2"]
latent_channels = self.config["latent_channels"]
context_channels = self.config["context_channels"]
visualize = self.config["visualize"]
self.gen_lr = gen_lr
self.disc_lr = disc_lr
self.beta1 = beta1
Expand Down
29 changes: 4 additions & 25 deletions dgmr/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,8 @@ def __init__(
input_channels: int = 12,
num_spatial_frames: int = 8,
conv_type: str = "standard",
**kwargs
):
super().__init__()
config = locals()
config.pop("__class__")
config.pop("self")
self.config = kwargs.get("config", config)
input_channels = self.config["input_channels"]
num_spatial_frames = self.config["num_spatial_frames"]
conv_type = self.config["conv_type"]

self.spatial_discriminator = SpatialDiscriminator(
input_channels=input_channels, num_timesteps=num_spatial_frames, conv_type=conv_type
Expand All @@ -40,7 +32,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class TemporalDiscriminator(torch.nn.Module, PyTorchModelHubMixin):
def __init__(
self, input_channels: int = 12, num_layers: int = 3, conv_type: str = "standard", **kwargs
self,
input_channels: int = 12,
num_layers: int = 3,
conv_type: str = "standard",
):
"""
Temporal Discriminator from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
Expand All @@ -52,13 +47,6 @@ def __init__(
conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options
"""
super().__init__()
config = locals()
config.pop("__class__")
config.pop("self")
self.config = kwargs.get("config", config)
input_channels = self.config["input_channels"]
num_layers = self.config["num_layers"]
conv_type = self.config["conv_type"]

self.downsample = torch.nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.space2depth = PixelUnshuffle(downscale_factor=2)
Expand Down Expand Up @@ -139,7 +127,6 @@ def __init__(
num_timesteps: int = 8,
num_layers: int = 4,
conv_type: str = "standard",
**kwargs
):
"""
Spatial discriminator from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
Expand All @@ -151,14 +138,6 @@ def __init__(
conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options
"""
super().__init__()
config = locals()
config.pop("__class__")
config.pop("self")
self.config = kwargs.get("config", config)
input_channels = self.config["input_channels"]
num_timesteps = self.config["num_timesteps"]
num_layers = self.config["num_layers"]
conv_type = self.config["conv_type"]
# Randomly, uniformly, select 8 timesteps to do this on from the input
self.num_timesteps = num_timesteps
# First step is mean pooling 2x2 to reduce input by half
Expand Down
11 changes: 2 additions & 9 deletions dgmr/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
latent_channels: int = 768,
context_channels: int = 384,
output_channels: int = 1,
**kwargs
):
"""
Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
Expand All @@ -35,14 +34,8 @@ def __init__(
latent_channels: Number of input channels to the lowest ConvGRU layer
"""
super().__init__()
config = locals()
config.pop("__class__")
config.pop("self")
self.config = kwargs.get("config", config)
self.forecast_steps = self.config["forecast_steps"]
latent_channels = self.config["latent_channels"]
context_channels = self.config["context_channels"]
output_channels = self.config["output_channels"]

self.forecast_steps = forecast_steps

self.convGRU1 = ConvGRU(
input_channels=latent_channels + context_channels,
Expand Down
160 changes: 0 additions & 160 deletions dgmr/hub.py

This file was deleted.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ numpy
torchvision>=0.11.0
pytorch_lightning
einops
huggingface_hub==0.21.4
huggingface_hub>=0.23.3
safetensors
Loading