diff --git a/docs/zh/examples/bracket.md b/docs/zh/examples/bracket.md
index 14c6bba14..4db059f73 100644
--- a/docs/zh/examples/bracket.md
+++ b/docs/zh/examples/bracket.md
@@ -26,6 +26,24 @@
python bracket.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/bracket/bracket_pretrained.pdparams
```
+=== "模型导出命令"
+
+ ``` sh
+ python bracket.py mode=export
+ ```
+
+=== "模型推理命令"
+
+ ``` sh
+ # linux
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/bracket/bracket_dataset.tar
+ # windows
+ # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/bracket/bracket_dataset.tar --output bracket_dataset.tar
+ # unzip it
+ tar -xvf bracket_dataset.tar
+ python bracket.py mode=infer
+ ```
+
| 预训练模型 | 指标 |
|:--| :--|
| [bracket_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/bracket/bracket_pretrained.pdparams) | loss(commercial_ref_u_v_w_sigmas): 32.28704
MSE.u(commercial_ref_u_v_w_sigmas): 0.00005
MSE.v(commercial_ref_u_v_w_sigmas): 0.00000
MSE.w(commercial_ref_u_v_w_sigmas): 0.00734
MSE.sigma_xx(commercial_ref_u_v_w_sigmas): 27.64751
MSE.sigma_yy(commercial_ref_u_v_w_sigmas): 1.23101
MSE.sigma_zz(commercial_ref_u_v_w_sigmas): 0.89106
MSE.sigma_xy(commercial_ref_u_v_w_sigmas): 0.84370
MSE.sigma_xz(commercial_ref_u_v_w_sigmas): 1.42126
MSE.sigma_yz(commercial_ref_u_v_w_sigmas): 0.24510 |
diff --git a/examples/bracket/bracket.py b/examples/bracket/bracket.py
index 381e63ce8..5d85b535a 100644
--- a/examples/bracket/bracket.py
+++ b/examples/bracket/bracket.py
@@ -514,12 +514,75 @@ def evaluate(cfg: DictConfig):
solver.visualize()
+def export(cfg: DictConfig):
+ # set model
+ disp_net = ppsci.arch.MLP(**cfg.MODEL.disp_net)
+ stress_net = ppsci.arch.MLP(**cfg.MODEL.stress_net)
+ # wrap to a model_list
+ model = ppsci.arch.ModelList((disp_net, stress_net))
+
+ # initialize solver
+ solver = ppsci.solver.Solver(
+ model,
+ pretrained_model_path=cfg.INFER.pretrained_model_path,
+ )
+
+ # export model
+ from paddle.static import InputSpec
+
+ input_spec = [
+ {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
+ ]
+ solver.export(input_spec, cfg.INFER.export_path)
+
+
+def inference(cfg: DictConfig):
+ from deploy.python_infer import pinn_predictor
+
+ predictor = pinn_predictor.PINNPredictor(cfg)
+ ref_xyzu = ppsci.utils.reader.load_csv_file(
+ cfg.DEFORMATION_X_PATH,
+ ("x", "y", "z", "u"),
+ {
+ "x": "X Location (m)",
+ "y": "Y Location (m)",
+ "z": "Z Location (m)",
+ "u": "Directional Deformation (m)",
+ },
+ "\t",
+ )
+ input_dict = {
+ "x": ref_xyzu["x"],
+ "y": ref_xyzu["y"],
+ "z": ref_xyzu["z"],
+ }
+ output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
+
+ # mapping data to cfg.INFER.output_keys
+ output_keys = cfg.MODEL.disp_net.output_keys + cfg.MODEL.stress_net.output_keys
+ output_dict = {
+ store_key: output_dict[infer_key]
+ for store_key, infer_key in zip(output_keys, output_dict.keys())
+ }
+
+ ppsci.visualize.save_vtu_from_dict(
+ "./bracket_pred",
+ {**input_dict, **output_dict},
+ input_dict.keys(),
+ output_keys,
+ )
+
+
@hydra.main(version_base=None, config_path="./conf", config_name="bracket.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
+ elif cfg.mode == "export":
+ export(cfg)
+ elif cfg.mode == "infer":
+ inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
diff --git a/examples/bracket/conf/bracket.yaml b/examples/bracket/conf/bracket.yaml
index 3531ec03d..5e4682b4f 100644
--- a/examples/bracket/conf/bracket.yaml
+++ b/examples/bracket/conf/bracket.yaml
@@ -102,3 +102,21 @@ EVAL:
eval_with_no_grad: true
batch_size:
sup_validator: 128
+
+# inference settings
+INFER:
+ pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/bracket/bracket_pretrained.pdparams"
+ export_path: ./inference/bracket
+ pdmodel_path: ${INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.export_path}.pdiparams
+ device: gpu
+ engine: native
+ precision: fp32
+ onnx_path: ${INFER.export_path}.onnx
+ ir_optim: true
+ min_subgraph_size: 10
+ gpu_mem: 4000
+ gpu_id: 0
+ max_batch_size: 128
+ num_cpu_threads: 4
+ batch_size: 128