Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: add najeeb-kazmi as a contributor for question #53

Merged
merged 3 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@
"contributions": [
"question"
]
},
{
"login": "najeeb-kazmi",
"name": "Najeeb Kazmi",
"avatar_url": "https://avatars.githubusercontent.com/u/14131235?v=4",
"profile": "https://github.com/najeeb-kazmi",
"contributions": [
"question"
]
}
],
"contributorsPerLine": 7,
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Skillful Nowcasting with Deep Generative Model of Radar (DGMR)
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)
[![All Contributors](https://img.shields.io/badge/all_contributors-8-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->
Implementation of DeepMind's Skillful Nowcasting GAN Deep Generative Model of Radar (DGMR) (https://arxiv.org/abs/2104.00954) in PyTorch Lightning.

Expand Down Expand Up @@ -117,6 +117,9 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center" valign="top" width="14.28%"><a href="https://github.com/primeoc"><img src="https://avatars.githubusercontent.com/u/75205487?v=4?s=100" width="100px;" alt="cameron"/><br /><sub><b>cameron</b></sub></a><br /><a href="#question-primeoc" title="Answering Questions">💬</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/zhrli"><img src="https://avatars.githubusercontent.com/u/11074703?v=4?s=100" width="100px;" alt="zhrli"/><br /><sub><b>zhrli</b></sub></a><br /><a href="#question-zhrli" title="Answering Questions">💬</a></td>
</tr>
<tr>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/najeeb-kazmi"><img src="https://avatars.githubusercontent.com/u/14131235?v=4?s=100" width="100px;" alt="Najeeb Kazmi"/><br /><sub><b>Najeeb Kazmi</b></sub></a><br /><a href="#question-najeeb-kazmi" title="Answering Questions">💬</a></td>
</tr>
</tbody>
</table>

Expand Down
6 changes: 3 additions & 3 deletions dgmr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .common import ContextConditioningStack, LatentConditioningStack
from .dgmr import DGMR
from .generators import Sampler, Generator
from .discriminators import SpatialDiscriminator, TemporalDiscriminator, Discriminator
from .common import LatentConditioningStack, ContextConditioningStack
from .discriminators import Discriminator, SpatialDiscriminator, TemporalDiscriminator
from .generators import Generator, Sampler
7 changes: 4 additions & 3 deletions dgmr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import einops
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch.distributions import normal
from torch.nn.utils.parametrizations import spectral_norm
from torch.nn.modules.pixelshuffle import PixelUnshuffle
from dgmr.layers.utils import get_conv_layer
from torch.nn.utils.parametrizations import spectral_norm

from dgmr.layers import AttentionLayer
from huggingface_hub import PyTorchModelHubMixin
from dgmr.layers.utils import get_conv_layer


class GBlock(torch.nn.Module):
Expand Down
17 changes: 9 additions & 8 deletions dgmr/dgmr.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import pytorch_lightning as pl
import torch
import torchvision

from dgmr.common import ContextConditioningStack, LatentConditioningStack
from dgmr.discriminators import Discriminator
from dgmr.generators import Generator, Sampler
from dgmr.hub import NowcastingModelHubMixin
from dgmr.losses import (
NowcastingLoss,
GridCellLoss,
NowcastingLoss,
grid_cell_regularizer,
loss_hinge_disc,
loss_hinge_gen,
grid_cell_regularizer,
)
import pytorch_lightning as pl
import torchvision
from dgmr.common import LatentConditioningStack, ContextConditioningStack
from dgmr.generators import Sampler, Generator
from dgmr.discriminators import Discriminator
from dgmr.hub import NowcastingModelHubMixin


class DGMR(pl.LightningModule, NowcastingModelHubMixin):
Expand Down
5 changes: 3 additions & 2 deletions dgmr/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch.nn.modules.pixelshuffle import PixelUnshuffle
from torch.nn.utils.parametrizations import spectral_norm
import torch.nn.functional as F

from dgmr.common import DBlock
from huggingface_hub import PyTorchModelHubMixin


class Discriminator(torch.nn.Module, PyTorchModelHubMixin):
Expand Down
9 changes: 6 additions & 3 deletions dgmr/generators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
from typing import List

import einops
import torch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch.nn.modules.pixelshuffle import PixelShuffle
from torch.nn.utils.parametrizations import spectral_norm
from typing import List

from dgmr.common import GBlock, UpsampleGBlock
from dgmr.layers import ConvGRU
from huggingface_hub import PyTorchModelHubMixin
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARN)
Expand All @@ -27,6 +29,7 @@ def __init__(

The sampler takes the output from the Latent and Context conditioning stacks and
creates one stack of ConvGRU layers per future timestep.

Args:
forecast_steps: Number of forecast steps
latent_channels: Number of input channels to the lowest ConvGRU layer
Expand Down
1 change: 0 additions & 1 deletion dgmr/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import torch


try:
from huggingface_hub import cached_download, hf_hub_url

Expand Down
2 changes: 1 addition & 1 deletion dgmr/layers/Attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import einops
import torch
import torch.nn as nn
from torch.nn import functional as F
import einops


def attention_einsum(q, k, v):
Expand Down
1 change: 1 addition & 0 deletions dgmr/layers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from dgmr.layers import CoordConv


Expand Down
6 changes: 5 additions & 1 deletion dgmr/losses.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import numpy as np
import torch
import torch.nn as nn
from pytorch_msssim import SSIM, MS_SSIM
from pytorch_msssim import MS_SSIM, SSIM
from torch.nn import functional as F


class SSIMLoss(nn.Module):
def __init__(self, convert_range: bool = False, **kwargs):
"""
SSIM Loss, optionally converting input range from [-1,1] to [0,1]

Args:
convert_range:
**kwargs:
Expand All @@ -28,6 +29,7 @@ class MS_SSIMLoss(nn.Module):
def __init__(self, convert_range: bool = False, **kwargs):
"""
Multi-Scale SSIM Loss, optionally converting input range from [-1,1] to [0,1]

Args:
convert_range:
**kwargs:
Expand Down Expand Up @@ -78,6 +80,7 @@ def tv_loss(img, tv_weight):
Inputs:
- img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
- tv_weight: Scalar giving the weight w_t to use for the TV loss.

Returns:
- loss: PyTorch Variable holding a scalar giving the total variation loss
for img weighted by tv_weight.
Expand Down Expand Up @@ -138,6 +141,7 @@ def forward(self, generated_images, targets):
"""
Calculates the grid cell regularizer value, assumes generated images are the mean predictions from
6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable)

Args:
generated_images: Mean generated images from the generator
targets: Ground truth future frames
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from setuptools import setup, find_packages
from pathlib import Path

from setuptools import find_packages, setup

this_directory = Path(__file__).parent
install_requires = (this_directory / "requirements.txt").read_text().splitlines()
long_description = (this_directory / "README.md").read_text()
Expand Down
16 changes: 8 additions & 8 deletions train/run.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import torch.utils.data.dataset
from dgmr import DGMR
import wandb
from datasets import load_dataset
from torch.utils.data import DataLoader
from pytorch_lightning import (
LightningDataModule,
)
from pytorch_lightning.callbacks import ModelCheckpoint
import wandb
from torch.utils.data import DataLoader

from dgmr import DGMR

wandb.init(project="dgmr")
from numpy.random import default_rng
import os
import numpy as np
from pathlib import Path
import tensorflow as tf

import numpy as np
from numpy.random import default_rng
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
Expand Down Expand Up @@ -136,7 +136,7 @@ def __len__(self):
def __getitem__(self, item):
try:
row = next(self.iter_reader)
except Exception as e:
except Exception:
rng = default_rng()
self.iter_reader = iter(
self.reader.shuffle(seed=rng.integers(low=0, high=100000), buffer_size=1000)
Expand Down