Skip to content

Commit

Permalink
【PPSCI Export&Infer No.35】nowcastnet (PaddlePaddle#895)
Browse files Browse the repository at this point in the history
* nowcastnet.py

* nowcastnet.py

* nowcastnet.py

* Update examples/nowcastnet/nowcastnet.py

Co-authored-by: HydrogenSulfate <[email protected]>

* Update examples/nowcastnet/nowcastnet.py

Co-authored-by: HydrogenSulfate <[email protected]>

* nowcastnet

* nowcastnet

---------

Co-authored-by: HydrogenSulfate <[email protected]>
  • Loading branch information
smallpoxscattered and HydrogenSulfate authored May 21, 2024
1 parent a1830a1 commit c57a97a
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 8 deletions.
12 changes: 12 additions & 0 deletions docs/zh/examples/nowcastnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
python nowcastnet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python nowcastnet.py mode=export
```

=== "模型推理命令"

``` sh
python nowcastnet.py mode=infer
```

## 1. 背景简介

近年来,深度学习方法已被应用于天气预报,尤其是雷达观测的降水预报。这些方法利用大量雷达复合观测数据来训练神经网络模型,以端到端的方式进行训练,无需明确参考降水过程的物理定律。
Expand Down
20 changes: 20 additions & 0 deletions examples/nowcastnet/conf/nowcastnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ hydra:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- INFER.pretrained_model_path
- INFER.export_path
- mode
- output_dir
- log_freq
Expand All @@ -22,6 +24,7 @@ hydra:
# general settings
mode: eval # running mode: train/eval
seed: 42
log_freq: 20
output_dir: ${hydra:run.dir}
NORMAL_DATASET_PATH: datasets/mrms/figure
LARGE_DATASET_PATH: datasets/mrms/large_figure
Expand Down Expand Up @@ -55,3 +58,20 @@ MODEL:
# evaluation settings
EVAL:
pretrained_model_path: checkpoints/paddle_mrms_model

INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams
export_path: ./inference/nowcastnet
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 4000
gpu_id: 0
max_batch_size: 16
num_cpu_threads: 4
batch_size: 1
100 changes: 99 additions & 1 deletion examples/nowcastnet/nowcastnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,112 @@ def evaluate(cfg: DictConfig):
solver.visualize(batch_id)


def export(cfg: DictConfig):
from paddle.static import InputSpec

# set models
if cfg.CASE_TYPE == "large":
model_cfg = cfg.MODEL.large
elif cfg.CASE_TYPE == "normal":
model_cfg = cfg.MODEL.normal
else:
raise ValueError(
f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'"
)
model = ppsci.arch.NowcastNet(**model_cfg)

# load pretrained model
solver = ppsci.solver.Solver(
model=model, pretrained_model_path=cfg.INFER.pretrained_model_path
)
# export models
input_spec = [
{
key: InputSpec(
[None, 29, model_cfg.image_width, model_cfg.image_height, 2],
"float32",
name=key,
)
for key in model_cfg.input_keys
},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
import os.path as osp

from deploy.python_infer import pinn_predictor

# set model predictor
predictor = pinn_predictor.PINNPredictor(cfg)

if cfg.CASE_TYPE == "large":
dataset_path = cfg.LARGE_DATASET_PATH
model_cfg = cfg.MODEL.large
output_dir = osp.join(cfg.output_dir, "large")
elif cfg.CASE_TYPE == "normal":
dataset_path = cfg.NORMAL_DATASET_PATH
model_cfg = cfg.MODEL.normal
output_dir = osp.join(cfg.output_dir, "normal")
else:
raise ValueError(
f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'"
)

input_keys = ("radar_frames",)
dataset_param = {
"input_keys": input_keys,
"label_keys": (),
"image_width": model_cfg.image_width,
"image_height": model_cfg.image_height,
"total_length": model_cfg.total_length,
"dataset_path": dataset_path,
}
test_data_loader = paddle.io.DataLoader(
ppsci.data.dataset.RadarDataset(**dataset_param),
batch_size=cfg.INFER.batch_size,
num_workers=cfg.CPU_WORKER,
drop_last=True,
)
for batch_id, test_ims in enumerate(test_data_loader):
if batch_id > cfg.NUM_SAVE_SAMPLES:
break
test_ims = {"input": test_ims[0][input_keys[0]].numpy()}
output_dict = predictor.predict(test_ims, cfg.INFER.batch_size)
# mapping data to model_cfg.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(model_cfg.output_keys, output_dict.keys())
}

visualizer = ppsci.visualize.VisualizerRadar(
test_ims,
{
"output": lambda out: out["output"],
},
prefix="v_nowcastnet",
case_type=cfg.CASE_TYPE,
total_length=model_cfg.total_length,
)
test_ims.update(output_dict)
visualizer.save(osp.join(output_dir, f"epoch_{batch_id}"), test_ims)


@hydra.main(version_base=None, config_path="./conf", config_name="nowcastnet.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
15 changes: 8 additions & 7 deletions ppsci/arch/nowcastnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ def forward_tensor(self, x):
# Generative Network
evo_feature = self.gen_enc(paddle.concat(x=[input_frames, evo_result], axis=1))
noise = paddle.randn(shape=[batch, self.ngf, height // 32, width // 32])
noise = self.proj(noise)
ngf = noise.shape[1]
noise_feature = (
self.proj(noise)
.reshape((batch, -1, 4, 4, 8, 8))
noise.reshape((batch, -1, 4, 4, 8, 8))
.transpose(perm=[0, 1, 4, 5, 2, 3])
.reshape((batch, -1, height // 8, width // 8))
.reshape((batch, ngf // 16, height // 8, width // 8))
)
feature = paddle.concat(x=[evo_feature, noise_feature], axis=1)
gen_result = self.gen_dec(feature, evo_result)
Expand Down Expand Up @@ -461,7 +462,7 @@ class Noise_Projector(paddle.nn.Layer):
def __init__(self, input_length):
super().__init__()
self.input_length = input_length
self.conv_first = spectral_norm(
self.conv_first = paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=self.input_length,
out_channels=self.input_length * 2,
Expand All @@ -486,7 +487,7 @@ def forward(self, x):
class ProjBlock(paddle.nn.Layer):
def __init__(self, in_channel, out_channel):
super().__init__()
self.one_conv = spectral_norm(
self.one_conv = paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel - in_channel,
Expand All @@ -495,7 +496,7 @@ def __init__(self, in_channel, out_channel):
)
)
self.double_conv = paddle.nn.Sequential(
spectral_norm(
paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
Expand All @@ -504,7 +505,7 @@ def __init__(self, in_channel, out_channel):
)
),
paddle.nn.ReLU(),
spectral_norm(
paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
Expand Down

0 comments on commit c57a97a

Please sign in to comment.