Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PPSCI Export&Infer No.35】nowcastnet #895

Merged
merged 11 commits into from
May 21, 2024
12 changes: 12 additions & 0 deletions docs/zh/examples/nowcastnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
python nowcastnet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python nowcastnet.py mode=export
```

=== "模型推理命令"

``` sh
python nowcastnet.py mode=infer
```

## 1. 背景简介

近年来,深度学习方法已被应用于天气预报,尤其是雷达观测的降水预报。这些方法利用大量雷达复合观测数据来训练神经网络模型,以端到端的方式进行训练,无需明确参考降水过程的物理定律。
Expand Down
19 changes: 19 additions & 0 deletions examples/nowcastnet/conf/nowcastnet.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加 log_freq 字段,否则infer报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的好的

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ hydra:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- INFER.pretrained_model_path
- INFER.export_path
- mode
- output_dir
- log_freq
Expand Down Expand Up @@ -55,3 +57,20 @@ MODEL:
# evaluation settings
EVAL:
pretrained_model_path: checkpoints/paddle_mrms_model

INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams
export_path: ./inference/nowcastnet
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

字段名称麻烦修改下

Suggested change
pdpiparams_path: ${INFER.export_path}.pdiparams
pdiparams_path: ${INFER.export_path}.pdiparams

device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 4000
gpu_id: 0
max_batch_size: 16
num_cpu_threads: 4
batch_size: 1
104 changes: 103 additions & 1 deletion examples/nowcastnet/nowcastnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,116 @@ def evaluate(cfg: DictConfig):
solver.visualize(batch_id)


def export(cfg: DictConfig):
from paddle.static import InputSpec

# set models
if cfg.CASE_TYPE == "large":
model_cfg = cfg.MODEL.large
elif cfg.CASE_TYPE == "normal":
model_cfg = cfg.MODEL.normal
else:
raise ValueError(
f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'"
)
model = ppsci.arch.NowcastNet(**model_cfg)

# load pretrained model
solver = ppsci.solver.Solver(
model=model, pretrained_model_path=cfg.INFER.pretrained_model_path
)
# export models
input_spec = [
{
key: InputSpec(
[None, 29, model_cfg.image_width, model_cfg.image_height, 2],
"float32",
name=key,
)
for key in model_cfg.input_keys
},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
import os.path as osp

import numpy as np

from deploy.python_infer import pinn_predictor

# set model predictor
predictor = pinn_predictor.PINNPredictor(cfg)

if cfg.CASE_TYPE == "large":
dataset_path = cfg.LARGE_DATASET_PATH
model_cfg = cfg.MODEL.large
output_dir = osp.join(cfg.output_dir, "large")
elif cfg.CASE_TYPE == "normal":
dataset_path = cfg.NORMAL_DATASET_PATH
model_cfg = cfg.MODEL.normal
output_dir = osp.join(cfg.output_dir, "normal")
else:
raise ValueError(
f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'"
)

input_keys = ("radar_frames",)
dataset_param = {
"input_keys": input_keys,
"label_keys": (),
"image_width": model_cfg.image_width,
"image_height": model_cfg.image_height,
"total_length": model_cfg.total_length,
"dataset_path": dataset_path,
"data_type": np.float32(),
smallpoxscattered marked this conversation as resolved.
Show resolved Hide resolved
}
test_data_loader = paddle.io.DataLoader(
ppsci.data.dataset.RadarDataset(**dataset_param),
batch_size=cfg.INFER.batch_size,
shuffle=False,
smallpoxscattered marked this conversation as resolved.
Show resolved Hide resolved
num_workers=cfg.CPU_WORKER,
drop_last=True,
)
for batch_id, test_ims in enumerate(test_data_loader):
if batch_id > cfg.NUM_SAVE_SAMPLES:
break
test_ims = {"input": test_ims[0][input_keys[0]].numpy()}
output_dict = predictor.predict(test_ims, cfg.INFER.batch_size)
# mapping data to model_cfg.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(model_cfg.output_keys, output_dict.keys())
}

visualizer = ppsci.visualize.VisualizerRadar(
test_ims,
{
"output": lambda out: out["output"],
},
prefix="v_nowcastnet",
case_type=cfg.CASE_TYPE,
total_length=model_cfg.total_length,
)
test_ims.update(output_dict)
visualizer.save(osp.join(output_dir, f"epoch_{batch_id}"), test_ims)


@hydra.main(version_base=None, config_path="./conf", config_name="nowcastnet.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
15 changes: 8 additions & 7 deletions ppsci/arch/nowcastnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ def forward_tensor(self, x):
# Generative Network
evo_feature = self.gen_enc(paddle.concat(x=[input_frames, evo_result], axis=1))
noise = paddle.randn(shape=[batch, self.ngf, height // 32, width // 32])
noise = self.proj(noise)
ngf = noise.shape[1]
noise_feature = (
self.proj(noise)
.reshape((batch, -1, 4, 4, 8, 8))
noise.reshape((batch, -1, 4, 4, 8, 8))
.transpose(perm=[0, 1, 4, 5, 2, 3])
.reshape((batch, -1, height // 8, width // 8))
.reshape((batch, ngf // 16, height // 8, width // 8))
)
feature = paddle.concat(x=[evo_feature, noise_feature], axis=1)
gen_result = self.gen_dec(feature, evo_result)
Expand Down Expand Up @@ -461,7 +462,7 @@ class Noise_Projector(paddle.nn.Layer):
def __init__(self, input_length):
super().__init__()
self.input_length = input_length
self.conv_first = spectral_norm(
self.conv_first = paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=self.input_length,
out_channels=self.input_length * 2,
Expand All @@ -486,7 +487,7 @@ def forward(self, x):
class ProjBlock(paddle.nn.Layer):
def __init__(self, in_channel, out_channel):
super().__init__()
self.one_conv = spectral_norm(
self.one_conv = paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel - in_channel,
Expand All @@ -495,7 +496,7 @@ def __init__(self, in_channel, out_channel):
)
)
self.double_conv = paddle.nn.Sequential(
spectral_norm(
paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
Expand All @@ -504,7 +505,7 @@ def __init__(self, in_channel, out_channel):
)
),
paddle.nn.ReLU(),
spectral_norm(
paddle.nn.utils.spectral_norm(
paddle.nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
Expand Down