diff --git a/dgmr/common.py b/dgmr/common.py index 2cf5e70..13a66f9 100644 --- a/dgmr/common.py +++ b/dgmr/common.py @@ -291,7 +291,6 @@ def __init__( output_channels: int = 768, num_context_steps: int = 4, conv_type: str = "standard", - **kwargs ): """ Conditioning Stack using the context images from Skillful Nowcasting, , see https://arxiv.org/pdf/2104.00954.pdf @@ -302,14 +301,6 @@ def __init__( conv_type: Type of 2D convolution to use, see satflow/models/utils.py for options """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - output_channels = self.config["output_channels"] - num_context_steps = self.config["num_context_steps"] - conv_type = self.config["conv_type"] conv2d = get_conv_layer(conv_type) self.space2depth = PixelUnshuffle(downscale_factor=2) @@ -417,7 +408,6 @@ def __init__( shape: (int, int, int) = (8, 8, 8), output_channels: int = 768, use_attention: bool = True, - **kwargs ): """ Latent conditioning stack from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -428,13 +418,6 @@ def __init__( use_attention: Whether to have a self-attention block or not """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - shape = self.config["shape"] - output_channels = self.config["output_channels"] - use_attention = self.config["use_attention"] self.shape = shape self.use_attention = use_attention diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py index d1a71af..def45e3 100644 --- a/dgmr/dgmr.py +++ b/dgmr/dgmr.py @@ -1,11 +1,11 @@ import pytorch_lightning as pl import torch import torchvision +from huggingface_hub import PyTorchModelHubMixin from dgmr.common import ContextConditioningStack, LatentConditioningStack from dgmr.discriminators import Discriminator from dgmr.generators import Generator, Sampler -from dgmr.hub import NowcastingModelHubMixin from dgmr.losses import ( GridCellLoss, NowcastingLoss, @@ -15,7 +15,13 @@ ) -class DGMR(pl.LightningModule, NowcastingModelHubMixin): +class DGMR( + pl.LightningModule, + PyTorchModelHubMixin, + library_name="DGMR", + tags=["nowcasting", "forecasting", "timeseries", "remote-sensing", "gan"], + repo_url="https://github.com/openclimatefix/skillful_nowcasting", +): """Deep Generative Model of Radar""" def __init__( @@ -34,7 +40,6 @@ def __init__( latent_channels: int = 768, context_channels: int = 384, generation_steps: int = 6, - **kwargs, ): """ Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954 @@ -59,23 +64,6 @@ def __init__( pretrained: """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - forecast_steps = self.config["forecast_steps"] - output_shape = self.config["output_shape"] - gen_lr = self.config["gen_lr"] - disc_lr = self.config["disc_lr"] - conv_type = self.config["conv_type"] - num_samples = self.config["num_samples"] - grid_lambda = self.config["grid_lambda"] - beta1 = self.config["beta1"] - beta2 = self.config["beta2"] - latent_channels = self.config["latent_channels"] - context_channels = self.config["context_channels"] - visualize = self.config["visualize"] self.gen_lr = gen_lr self.disc_lr = disc_lr self.beta1 = beta1 diff --git a/dgmr/discriminators.py b/dgmr/discriminators.py index a2711e3..cb18b3e 100644 --- a/dgmr/discriminators.py +++ b/dgmr/discriminators.py @@ -13,16 +13,8 @@ def __init__( input_channels: int = 12, num_spatial_frames: int = 8, conv_type: str = "standard", - **kwargs ): super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - num_spatial_frames = self.config["num_spatial_frames"] - conv_type = self.config["conv_type"] self.spatial_discriminator = SpatialDiscriminator( input_channels=input_channels, num_timesteps=num_spatial_frames, conv_type=conv_type @@ -40,7 +32,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TemporalDiscriminator(torch.nn.Module, PyTorchModelHubMixin): def __init__( - self, input_channels: int = 12, num_layers: int = 3, conv_type: str = "standard", **kwargs + self, + input_channels: int = 12, + num_layers: int = 3, + conv_type: str = "standard", ): """ Temporal Discriminator from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -52,13 +47,6 @@ def __init__( conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - num_layers = self.config["num_layers"] - conv_type = self.config["conv_type"] self.downsample = torch.nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) self.space2depth = PixelUnshuffle(downscale_factor=2) @@ -139,7 +127,6 @@ def __init__( num_timesteps: int = 8, num_layers: int = 4, conv_type: str = "standard", - **kwargs ): """ Spatial discriminator from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -151,14 +138,6 @@ def __init__( conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - num_timesteps = self.config["num_timesteps"] - num_layers = self.config["num_layers"] - conv_type = self.config["conv_type"] # Randomly, uniformly, select 8 timesteps to do this on from the input self.num_timesteps = num_timesteps # First step is mean pooling 2x2 to reduce input by half diff --git a/dgmr/generators.py b/dgmr/generators.py index c4b744e..cc92a12 100644 --- a/dgmr/generators.py +++ b/dgmr/generators.py @@ -22,7 +22,6 @@ def __init__( latent_channels: int = 768, context_channels: int = 384, output_channels: int = 1, - **kwargs ): """ Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -35,14 +34,8 @@ def __init__( latent_channels: Number of input channels to the lowest ConvGRU layer """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - self.forecast_steps = self.config["forecast_steps"] - latent_channels = self.config["latent_channels"] - context_channels = self.config["context_channels"] - output_channels = self.config["output_channels"] + + self.forecast_steps = forecast_steps self.convGRU1 = ConvGRU( input_channels=latent_channels + context_channels, diff --git a/dgmr/hub.py b/dgmr/hub.py deleted file mode 100644 index 821fbf2..0000000 --- a/dgmr/hub.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Originally Taken from https://github.com/rwightman/ - -https://github.com/rwightman/pytorch-image-models/ -blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py -""" - -import json -import logging -import os -from functools import partial - -import torch - -try: - from huggingface_hub import cached_download, hf_hub_url - - cached_download = partial(cached_download, library_name="dgmr") -except ImportError: - hf_hub_url = None - cached_download = None - -from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download - -MODEL_CARD_MARKDOWN = """--- -license: mit -tags: -- nowcasting -- forecasting -- timeseries -- remote-sensing -- gan ---- - -# {model_name} - -## Model description - -[More information needed] - -## Intended uses & limitations - -[More information needed] - -## How to use - -[More information needed] - -## Limitations and bias - -[More information needed] - -## Training data - -[More information needed] - -## Training procedure - -[More information needed] - -## Evaluation results - -[More information needed] - -""" - -_logger = logging.getLogger(__name__) - - -class NowcastingModelHubMixin(ModelHubMixin): - """ - HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models - """ - - def __init__(self, *args, **kwargs): - """ - Mixin for pl.LightningModule and Hugging Face - - Mix this class with your pl.LightningModule class to easily push / download - the model via the Hugging Face Hub - - Example:: - - >>> from dgmr.hub import NowcastingModelHubMixin - - >>> class MyModel(nn.Module, NowcastingModelHubMixin): - ... def __init__(self, **kwargs): - ... super().__init__() - ... self.layer = ... - ... def forward(self, ...) - ... return ... - - >>> model = MyModel() - >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub - - >>> # Downloading weights from hf-hub & model will be initialized from those weights - >>> model = MyModel.from_pretrained("username/mymodel") - """ - - def _create_model_card(self, path): - model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) - with open(os.path.join(path, "README.md"), "w") as f: - f.write(model_card) - - def _save_config(self, module, save_directory): - config = dict(module.hparams) - path = os.path.join(save_directory, CONFIG_NAME) - with open(path, "w") as f: - json.dump(config, f) - - def _save_pretrained(self, save_directory: str, save_config: bool = True): - # Save model weights - path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) - model_to_save = self.module if hasattr(self, "module") else self - torch.save(model_to_save.state_dict(), path) - # Save model config - if save_config and model_to_save.hparams: - self._save_config(model_to_save, save_directory) - # Save model card - self._create_model_card(save_directory) - - @classmethod - def _from_pretrained( - cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token=False, - map_location="cpu", - strict=False, - **model_kwargs, - ): - map_location = torch.device(map_location) - - if os.path.isdir(model_id): - print("Loading weights from local directory") - model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) - else: - model_file = hf_hub_download( - repo_id=model_id, - filename=PYTORCH_WEIGHTS_NAME, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - token=use_auth_token, - local_files_only=local_files_only, - ) - model = cls(**model_kwargs["config"]) - - state_dict = torch.load(model_file, map_location=map_location) - model.load_state_dict(state_dict, strict=strict) - model.eval() - - return model diff --git a/requirements.txt b/requirements.txt index bcf1180..fa713f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ numpy torchvision>=0.11.0 pytorch_lightning einops -huggingface_hub==0.21.4 +huggingface_hub>=0.23.3 +safetensors diff --git a/tests/test_model.py b/tests/test_model.py index 34b8e3e..b7c956e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -16,6 +16,14 @@ import einops import pytest from pytorch_lightning import Trainer +from torch.testing import assert_close + + +def assert_model_equal(actual, expected): + assert actual.state_dict().keys() == expected.state_dict().keys() + + for x, y in zip(actual.state_dict().values(), expected.state_dict().values()): + assert_close(x, y) def test_dblock(): @@ -328,3 +336,63 @@ def __getitem__(self, idx): model = DGMR(forecast_steps=forecast_steps) trainer.fit(model, train_loader, val_loader) + + +def test_model_serialization(tmp_path): + model = DGMR( + forecast_steps=1, + input_channels=1, + output_shape=128, + gen_lr=1e-5, + disc_lr=1e-4, + visualize=True, + conv_type="standard", + num_samples=1, + grid_lambda=16.0, + beta1=1.0, + beta2=0.995, + latent_channels=512, + context_channels=256, + generation_steps=1, + ) + + model.save_pretrained(tmp_path / "dgmr") + model_copy = DGMR.from_pretrained(tmp_path / "dgmr") + assert model.hparams == model_copy.hparams + assert_model_equal(model, model_copy) + + +def test_discriminator_serialization(tmp_path): + discriminator = Discriminator(input_channels=1, num_spatial_frames=1, conv_type="standard") + + discriminator.save_pretrained(tmp_path / "discriminator") + discriminator_copy = Discriminator.from_pretrained(tmp_path / "discriminator") + assert_model_equal(discriminator, discriminator_copy) + + +def test_sampler_serialization(tmp_path): + sampler = Sampler( + forecast_steps=1, latent_channels=256, context_channels=256, output_channels=1 + ) + + sampler.save_pretrained(tmp_path / "sampler") + sampler_copy = Sampler.from_pretrained(tmp_path / "sampler") + assert_model_equal(sampler, sampler_copy) + + +def test_context_conditioning_stack_serialization(tmp_path): + ctz = ContextConditioningStack( + input_channels=2, output_channels=256, num_context_steps=1, conv_type="standard" + ) + + ctz.save_pretrained(tmp_path / "context-conditioning-stack") + ctz_copy = ContextConditioningStack.from_pretrained(tmp_path / "context-conditioning-stack") + assert_model_equal(ctz, ctz_copy) + + +def test_latent_conditioning_stack_serialization(tmp_path): + lat = LatentConditioningStack(shape=(4, 4, 4), output_channels=256, use_attention=True) + + lat.save_pretrained(tmp_path / "latent-conditioning-stack") + lat_copy = LatentConditioningStack.from_pretrained(tmp_path / "latent-conditioning-stack") + assert_model_equal(lat, lat_copy)