From f5625787f3790d18d8dab6143e13a8bd6f36f4da Mon Sep 17 00:00:00 2001 From: Brandon Ly Date: Mon, 28 Nov 2022 12:42:36 -0600 Subject: [PATCH] Update gitignore and add constrant layer --- .gitignore | 1 + graph_weather/models/forecast.py | 3 -- graph_weather/models/layers/constraint.py | 55 +++++++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 graph_weather/models/layers/constraint.py diff --git a/.gitignore b/.gitignore index e50a6610..3f97390d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea/ *.nc +/venv/ \ No newline at end of file diff --git a/graph_weather/models/forecast.py b/graph_weather/models/forecast.py index 62024c97..89f9506e 100644 --- a/graph_weather/models/forecast.py +++ b/graph_weather/models/forecast.py @@ -2,7 +2,6 @@ import torch from typing import Optional from huggingface_hub import PyTorchModelHubMixin - from graph_weather.models import Decoder, Encoder, Processor @@ -98,10 +97,8 @@ def __init__( def forward(self, features: torch.Tensor) -> torch.Tensor: """ Compute the new state of the forecast - Args: features: The input features, aligned with the order of lat_lons_heights - Returns: The next state in the forecast """ diff --git a/graph_weather/models/layers/constraint.py b/graph_weather/models/layers/constraint.py new file mode 100644 index 00000000..34178e85 --- /dev/null +++ b/graph_weather/models/layers/constraint.py @@ -0,0 +1,55 @@ +""" + +""" +import torch + + +class Constraint(torch.nn.Module): + def __init__( + self, + lat_lons, + resolution: int = 2, + input_dim: int = 256, + output_dim: int = 78, + output_edge_dim: int = 256, + hidden_dim_processor_node: int = 256, + hidden_dim_processor_edge: int = 256, + hidden_layers_processor_node: int = 2, + hidden_layers_processor_edge: int = 2, + mlp_norm_type: str = "LayerNorm", + hidden_dim_decoder: int = 128, + hidden_layers_decoder: int = 2, + use_checkpointing: bool = False, + ): + super().__init__( + lat_lons, + resolution, + input_dim, + output_dim, + output_edge_dim, + hidden_dim_processor_node, + hidden_dim_processor_edge, + hidden_layers_processor_node, + hidden_layers_processor_edge, + mlp_norm_type, + hidden_dim_decoder, + hidden_layers_decoder, + use_checkpointing, + ) + + def forward( + self, processor_features: torch.Tensor, start_features: torch.Tensor + ) -> torch.Tensor: + """ + Constrains output from previous layer + + Args: + processor_features: Processed features in shape [B*Nodes, Features] + start_features: Original input features to the encoder, with shape [B, Nodes, Features] + + Returns: + Updated features for model + """ + out = super().forward(processor_features, start_features.shape[0]) + out = out + start_features # residual connection + return out