Skip to content

Commit

Permalink
【PPSCI Export&Infer No.15-16】 (#875)
Browse files Browse the repository at this point in the history
* fix doc bugs

* fix codestyle bugs

* 【PPSCI Export&Infer No.15-16】

* fix codestyle bug for PPSCI Export&Infer No.15-16】

* fix codestyle bugs for 【PPSCI Export&Infer No.15-16】

* fix codestyle bugs for 【PPSCI Export&Infer No.15-16】

* fix codestyle bugs for 【PPSCI Export&Infer No.15-16】

* fix bugs for 【PPSCI Export&Infer No.15-16】

* fix codestyle bugs
  • Loading branch information
wufei2 authored May 7, 2024
1 parent 41914da commit dbd9234
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 2 deletions.
12 changes: 12 additions & 0 deletions docs/zh/examples/ldc2d_steady.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
python ldc2d_steady_Re10.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/ldc2d_steady_Re10/ldc2d_steady_Re10_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python ldc2d_steady_Re10.py mode=export
```

=== "模型推理命令"

``` sh
python ldc2d_steady_Re10.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [ldc2d_steady_Re10_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/ldc2d_steady_Re10/ldc2d_steady_Re10_pretrained.pdparams) | loss(Residual): 365.36164<br>MSE.momentum_x(Residual): 0.01435<br>MSE.continuity(Residual): 0.04072<br>MSE.momentum_y(Residual): 0.02471 |
Expand Down
12 changes: 12 additions & 0 deletions docs/zh/examples/ldc2d_unsteady.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
python ldc2d_unsteady_Re10.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/ldc2d_unsteady_Re10/ldc2d_unsteady_Re10_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python ldc2d_unsteady_Re10.py mode=export
```

=== "模型推理命令"

``` sh
python ldc2d_unsteady_Re10.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [ldc2d_unsteady_Re10_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/ldc2d_unsteady_Re10/ldc2d_unsteady_Re10_pretrained.pdparams) | loss(Residual): 155652.67530<br>MSE.momentum_x(Residual): 6.78030<br>MSE.continuity(Residual): 0.16590<br>MSE.momentum_y(Residual): 12.05981 |
Expand Down
19 changes: 19 additions & 0 deletions examples/ldc/conf/ldc2d_steady_Re10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ hydra:
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
log_freq: 20

# set working condition
NU: 0.01
Expand Down Expand Up @@ -55,3 +56,21 @@ EVAL:
pretrained_model_path: null
batch_size:
residual_validator: 8192

# inference settings
INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/ldc2d_steady_Re10/ldc2d_steady_Re10_pretrained.pdparams
export_path: ./inference/ldc2d_steady_Re10
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 8192
num_cpu_threads: 10
batch_size: 8192
19 changes: 19 additions & 0 deletions examples/ldc/conf/ldc2d_unsteady_Re10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ hydra:
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
log_freq: 20

# set working condition
NU: 0.01
Expand Down Expand Up @@ -56,3 +57,21 @@ EVAL:
pretrained_model_path: null
batch_size:
residual_validator: 8192

# inference settings
INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/ldc2d_unsteady_Re10/ldc2d_unsteady_Re10_pretrained.pdparams
export_path: ./inference/ldc2d_unsteady_Re10
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 8192
num_cpu_threads: 10
batch_size: 8192
59 changes: 58 additions & 1 deletion examples/ldc/ldc2d_steady_Re10.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,57 @@ def evaluate(cfg: DictConfig):
solver.visualize()


def export(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)

# set geometry
geom = {"rect": ppsci.geometry.Rectangle((-0.05, -0.05), (0.05, 0.05))}
# manually collate input data for inference
NPOINT_PDE = 99**2
NPOINT_TOP = 101
NPOINT_BOTTOM = 101
NPOINT_LEFT = 99
NPOINT_RIGHT = 99
NPOINT_BC = NPOINT_TOP + NPOINT_BOTTOM + NPOINT_LEFT + NPOINT_RIGHT
input_dict = geom["rect"].sample_interior(NPOINT_PDE + NPOINT_BC, evenly=True)
output_dict = predictor.predict(
{key: input_dict[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
)

# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}

ppsci.visualize.save_vtu_from_dict(
"./ldc2d_steady_Re10.vtu",
{**input_dict, **output_dict},
input_dict.keys(),
cfg.MODEL.output_keys,
)


@hydra.main(
version_base=None, config_path="./conf", config_name="ldc2d_steady_Re10.yaml"
)
Expand All @@ -242,8 +293,14 @@ def main(cfg: DictConfig):
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
87 changes: 86 additions & 1 deletion examples/ldc/ldc2d_unsteady_Re10.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,85 @@ def evaluate(cfg: DictConfig):
solver.visualize()


def export(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)

# set timestamps(including initial t0)
timestamps = np.linspace(0.0, 1.5, cfg.NTIME_ALL, endpoint=True)
# set time-geometry
geom = {
"time_rect": ppsci.geometry.TimeXGeometry(
ppsci.geometry.TimeDomain(0.0, 1.5, timestamps=timestamps),
ppsci.geometry.Rectangle((-0.05, -0.05), (0.05, 0.05)),
)
}
# manually collate input data for inference
NPOINT_PDE = 99**2
NPOINT_TOP = 101
NPOINT_DOWN = 101
NPOINT_LEFT = 99
NPOINT_RIGHT = 99
NPOINT_IC = 99**2
NTIME_PDE = cfg.NTIME_ALL - 1
NPOINT_BC = NPOINT_TOP + NPOINT_DOWN + NPOINT_LEFT + NPOINT_RIGHT
input_dict = geom["time_rect"].sample_initial_interior(
(NPOINT_IC + NPOINT_BC), evenly=True
)
input_pde_dict = geom["time_rect"].sample_interior(
(NPOINT_PDE + NPOINT_BC) * NTIME_PDE, evenly=True
)
# (interior+boundary) x all timestamps
for t in range(NTIME_PDE):
for key in geom["time_rect"].dim_keys:
input_dict[key] = np.concatenate(
(
input_dict[key],
input_pde_dict[key][
t
* (NPOINT_PDE + NPOINT_BC) : (t + 1)
* (NPOINT_PDE + NPOINT_BC)
],
)
)
output_dict = predictor.predict(
{key: input_dict[key] for key in cfg.MODEL.input_keys}, cfg.INFER.batch_size
)

# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}

ppsci.visualize.save_vtu_from_dict(
"./ldc2d_unsteady_Re10_pred.vtu",
{**input_dict, **output_dict},
input_dict.keys(),
cfg.MODEL.output_keys,
cfg.NTIME_ALL,
)


@hydra.main(
version_base=None, config_path="./conf", config_name="ldc2d_unsteady_Re10.yaml"
)
Expand All @@ -316,8 +395,14 @@ def main(cfg: DictConfig):
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down

0 comments on commit dbd9234

Please sign in to comment.