From d8c3963fa90b58eef59a9cb5646458e320189ca1 Mon Sep 17 00:00:00 2001
From: smallpoxscattered <1989838596@qq.com>
Date: Sat, 6 Apr 2024 03:38:52 +0000
Subject: [PATCH 1/5] eadd export and inference for viv
---
examples/fsi/conf/viv.yaml | 22 ++++++++++
examples/fsi/viv.py | 87 +++++++++++++++++++++++++++++++++++++-
ppsci/utils/symbolic.py | 2 +-
3 files changed, 109 insertions(+), 2 deletions(-)
diff --git a/examples/fsi/conf/viv.yaml b/examples/fsi/conf/viv.yaml
index 989fd7af4..e01322091 100644
--- a/examples/fsi/conf/viv.yaml
+++ b/examples/fsi/conf/viv.yaml
@@ -11,6 +11,8 @@ hydra:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
+ - INFER.pretrained_model_path
+ - INFER.export_path
- mode
- output_dir
- log_freq
@@ -60,3 +62,23 @@ TRAIN:
EVAL:
pretrained_model_path: null
batch_size: 32
+
+# inference settings
+INFER:
+ pretrained_model_path: "./viv_pretrained"
+ export_path: ./inference/viv
+ pdmodel_path: ${INFER.export_path}.pdmodel
+ pdpiparams_path: ${INFER.export_path}.pdiparams
+ pdmodel_equ_path: ${INFER.export_path}_equ.pdmodel
+ pdpiparams_equ_path: ${INFER.export_path}_equ.pdiparams
+ device: gpu
+ engine: native
+ precision: fp32
+ onnx_path: ${INFER.export_path}.onnx
+ ir_optim: true
+ min_subgraph_size: 10
+ gpu_mem: 4000
+ gpu_id: 0
+ max_batch_size: 64
+ num_cpu_threads: 4
+ batch_size: 16
diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py
index b360daeeb..640050ac8 100644
--- a/examples/fsi/viv.py
+++ b/examples/fsi/viv.py
@@ -200,14 +200,99 @@ def evaluate(cfg: DictConfig):
solver.visualize()
+def export(cfg: DictConfig):
+ # set model
+ model = ppsci.arch.MLP(**cfg.MODEL)
+
+ # initialize equation
+ equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)}
+
+ # initialize solver
+ solver = ppsci.solver.Solver(
+ model,
+ equation=equation,
+ pretrained_model_path=cfg.INFER.pretrained_model_path,
+ )
+ # Convert equation to callable function
+ func = ppsci.lambdify(
+ solver.equation["VIV"].equations["f"],
+ solver.model,
+ list(solver.equation["VIV"].learnable_parameters),
+ )
+ # export model and equation
+ from paddle.static import InputSpec
+
+ input_spec = [
+ {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
+ ]
+
+ solver.export(input_spec, cfg.INFER.export_path)
+
+ from paddle import jit
+
+ jit.enable_to_static(True)
+
+ static_model = jit.to_static(
+ func,
+ input_spec=input_spec,
+ full_graph=True,
+ )
+
+ jit.save(static_model, cfg.INFER.export_path + "_equ", skip_prune_program=True)
+
+ jit.enable_to_static(False)
+
+
+def inference(cfg: DictConfig):
+ from deploy.python_infer import pinn_predictor
+
+ # set model predictor
+ predictor = pinn_predictor.PINNPredictor(cfg)
+
+ # set equation predictor
+ cfg.INFER.pdmodel_path = cfg.INFER.pdmodel_equ_path
+ cfg.INFER.pdpiparams_path = cfg.INFER.pdpiparams_equ_path
+ equ_predictor = pinn_predictor.PINNPredictor(cfg)
+
+ infer_mat = ppsci.utils.reader.load_mat_file(
+ cfg.VIV_DATA_PATH,
+ ("t_f", "eta_gt", "f_gt"),
+ alias_dict={"eta_gt": "eta", "f_gt": "f"},
+ )
+
+ input_dict = {key: infer_mat[key] for key in cfg.MODEL.input_keys}
+
+ output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
+ equ_output_dict = equ_predictor.predict(input_dict, cfg.INFER.batch_size)
+
+ # mapping data to cfg.INFER.output_keys
+ output_dict = {
+ store_key: output_dict[infer_key]
+ for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
+ }
+ for value in equ_output_dict.values():
+ output_dict["f"] = value
+ infer_mat.update(output_dict)
+
+ ppsci.visualize.plot.save_plot_from_1d_dict(
+ "./viv_pred", infer_mat, ("t_f",), ("eta", "eta_gt", "f", "f_gt")
+ )
+
+
@hydra.main(version_base=None, config_path="./conf", config_name="viv.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
+ elif cfg.mode == "export":
+ export(cfg)
+ elif cfg.mode == "infer":
+ inference(cfg)
else:
- raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
+ raise ValueError(
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
+ )
if __name__ == "__main__":
diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py
index 5a2f1e99c..2537b4ff4 100644
--- a/ppsci/utils/symbolic.py
+++ b/ppsci/utils/symbolic.py
@@ -490,7 +490,7 @@ class ComposedNode(nn.Layer):
def __init__(self, callable_nodes: List[Node]):
super().__init__()
assert len(callable_nodes)
- self.callable_nodes = callable_nodes
+ self.callable_nodes = paddle.nn.LayerList(callable_nodes)
def forward(self, data_dict: DATA_DICT) -> paddle.Tensor:
# call all callable_nodes in order
From 76b69b28d2fe21fe7eb1f385602864d0a9a37de1 Mon Sep 17 00:00:00 2001
From: smallpoxscattered <1989838596@qq.com>
Date: Sat, 6 Apr 2024 04:34:39 +0000
Subject: [PATCH 2/5] add doc
---
docs/zh/examples/viv.md | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/docs/zh/examples/viv.md b/docs/zh/examples/viv.md
index 4a5bf5712..c9ee6b0ca 100644
--- a/docs/zh/examples/viv.md
+++ b/docs/zh/examples/viv.md
@@ -16,6 +16,20 @@
python viv.py mode=eval EVAL.pretrained_model_path=./viv_pretrained
```
+=== "模型导出命令"
+
+ ``` sh
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
+ python viv.py mode=export
+ ```
+
+=== "模型推理命令"
+
+ ``` sh
+ python viv.py mode=infer
+ ```
+
| 预训练模型 | 指标 |
|:--| :--|
| [viv_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdparams)
[viv_pretrained.pdeqn](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdeqn) | 'eta': 1.1416150300647132e-06
'f': 4.635014192899689e-06 |
From e3d25831237720b81e2833a40e14ee16663fc2b2 Mon Sep 17 00:00:00 2001
From: smallpoxscattered <1989838596@qq.com>
Date: Sat, 6 Apr 2024 12:16:42 +0000
Subject: [PATCH 3/5] fix viv export&infer
---
examples/fsi/conf/viv.yaml | 4 ++--
examples/fsi/viv.py | 48 ++++++++++++++++++++++++--------------
ppsci/utils/symbolic.py | 2 +-
3 files changed, 34 insertions(+), 20 deletions(-)
diff --git a/examples/fsi/conf/viv.yaml b/examples/fsi/conf/viv.yaml
index e01322091..91d31bdeb 100644
--- a/examples/fsi/conf/viv.yaml
+++ b/examples/fsi/conf/viv.yaml
@@ -69,8 +69,8 @@ INFER:
export_path: ./inference/viv
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
- pdmodel_equ_path: ${INFER.export_path}_equ.pdmodel
- pdpiparams_equ_path: ${INFER.export_path}_equ.pdiparams
+ input_keys: ["t_f"]
+ output_keys: ["eta", 'f']
device: gpu
engine: native
precision: fp32
diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py
index 640050ac8..3ae3c29d2 100644
--- a/examples/fsi/viv.py
+++ b/examples/fsi/viv.py
@@ -12,10 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
+
import hydra
from omegaconf import DictConfig
+from paddle import nn
import ppsci
+from ppsci.arch import base
+
+
+class EqnTranArch(base.Arch):
+ def __init__(self, funcs, input_keys: List, output_keys: List):
+ super().__init__()
+ if not isinstance(funcs, list):
+ funcs = [funcs]
+ self.modellist = nn.LayerList(funcs)
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+
+ def forward(self, x):
+ output_dict = {}
+ for i, model in enumerate(self.modellist):
+ output_dict[self.output_keys[i]] = model(x)
+ return output_dict
def train(cfg: DictConfig):
@@ -213,32 +233,34 @@ def export(cfg: DictConfig):
equation=equation,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
- # Convert equation to callable function
- func = ppsci.lambdify(
+ # Convert equation to Arch
+ funcs = ppsci.lambdify(
solver.equation["VIV"].equations["f"],
solver.model,
list(solver.equation["VIV"].learnable_parameters),
)
- # export model and equation
+ eqn = EqnTranArch(funcs, cfg.INFER.input_keys, ["f"])
+
+ # Combine the two instances
+ models = ppsci.arch.ModelList((solver.model, eqn))
+ # export models
from paddle.static import InputSpec
input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
- solver.export(input_spec, cfg.INFER.export_path)
-
from paddle import jit
jit.enable_to_static(True)
static_model = jit.to_static(
- func,
+ models,
input_spec=input_spec,
full_graph=True,
)
- jit.save(static_model, cfg.INFER.export_path + "_equ", skip_prune_program=True)
+ jit.save(static_model, cfg.INFER.export_path, skip_prune_program=True)
jit.enable_to_static(False)
@@ -249,29 +271,21 @@ def inference(cfg: DictConfig):
# set model predictor
predictor = pinn_predictor.PINNPredictor(cfg)
- # set equation predictor
- cfg.INFER.pdmodel_path = cfg.INFER.pdmodel_equ_path
- cfg.INFER.pdpiparams_path = cfg.INFER.pdpiparams_equ_path
- equ_predictor = pinn_predictor.PINNPredictor(cfg)
-
infer_mat = ppsci.utils.reader.load_mat_file(
cfg.VIV_DATA_PATH,
("t_f", "eta_gt", "f_gt"),
alias_dict={"eta_gt": "eta", "f_gt": "f"},
)
- input_dict = {key: infer_mat[key] for key in cfg.MODEL.input_keys}
+ input_dict = {key: infer_mat[key] for key in cfg.INFER.input_keys}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
- equ_output_dict = equ_predictor.predict(input_dict, cfg.INFER.batch_size)
# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
- for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
+ for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys())
}
- for value in equ_output_dict.values():
- output_dict["f"] = value
infer_mat.update(output_dict)
ppsci.visualize.plot.save_plot_from_1d_dict(
diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py
index 2537b4ff4..a48a7399f 100644
--- a/ppsci/utils/symbolic.py
+++ b/ppsci/utils/symbolic.py
@@ -490,7 +490,7 @@ class ComposedNode(nn.Layer):
def __init__(self, callable_nodes: List[Node]):
super().__init__()
assert len(callable_nodes)
- self.callable_nodes = paddle.nn.LayerList(callable_nodes)
+ self.callable_nodes = nn.LayerList(callable_nodes)
def forward(self, data_dict: DATA_DICT) -> paddle.Tensor:
# call all callable_nodes in order
From 0818e24cf83868511a9c069f24eb2261cfb5cbb0 Mon Sep 17 00:00:00 2001
From: smallpoxscattered <1989838596@qq.com>
Date: Sun, 7 Apr 2024 01:34:06 +0000
Subject: [PATCH 4/5] Rewriting function
---
examples/fsi/viv.py | 42 ++++++++++++++++++++----------------------
1 file changed, 20 insertions(+), 22 deletions(-)
diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py
index 3ae3c29d2..d7f0c1774 100644
--- a/examples/fsi/viv.py
+++ b/examples/fsi/viv.py
@@ -12,30 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
-
import hydra
from omegaconf import DictConfig
-from paddle import nn
import ppsci
-from ppsci.arch import base
-
-
-class EqnTranArch(base.Arch):
- def __init__(self, funcs, input_keys: List, output_keys: List):
- super().__init__()
- if not isinstance(funcs, list):
- funcs = [funcs]
- self.modellist = nn.LayerList(funcs)
- self.input_keys = input_keys
- self.output_keys = output_keys
-
- def forward(self, x):
- output_dict = {}
- for i, model in enumerate(self.modellist):
- output_dict[self.output_keys[i]] = model(x)
- return output_dict
def train(cfg: DictConfig):
@@ -233,14 +213,32 @@ def export(cfg: DictConfig):
equation=equation,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
- # Convert equation to Arch
+ # Convert equation to func
funcs = ppsci.lambdify(
solver.equation["VIV"].equations["f"],
solver.model,
list(solver.equation["VIV"].learnable_parameters),
)
- eqn = EqnTranArch(funcs, cfg.INFER.input_keys, ["f"])
+ def wrap_prediction_to_dict(instance, func):
+ def wrapper(instance, *args, **kwargs):
+ result = func(*args, **kwargs)
+ return {"f": result}
+
+ if hasattr(func, "__func__"):
+ wrapper.__func__ = func.__func__
+ return wrapper
+
+ def wrap_forward_methods(instance):
+ instance.input_keys = cfg.MODEL.input_keys
+ instance.output_keys = ["f"]
+ for attr_name in dir(instance):
+ if attr_name == "forward":
+ attr = getattr(instance, attr_name)
+ setattr(instance, attr_name, wrap_prediction_to_dict(instance, attr))
+ return instance
+
+ eqn = wrap_forward_methods(funcs)
# Combine the two instances
models = ppsci.arch.ModelList((solver.model, eqn))
# export models
From 3ec9bb715b764a32baa36631a17d1924ffd4190a Mon Sep 17 00:00:00 2001
From: smallpoxscattered <1989838596@qq.com>
Date: Sun, 7 Apr 2024 10:44:32 +0000
Subject: [PATCH 5/5] fix viv export&infer
---
docs/zh/examples/viv.md | 2 --
examples/fsi/conf/viv.yaml | 6 ++--
examples/fsi/viv.py | 57 ++++++++++++--------------------------
3 files changed, 20 insertions(+), 45 deletions(-)
diff --git a/docs/zh/examples/viv.md b/docs/zh/examples/viv.md
index a47bf4dc7..e0b51fa5b 100644
--- a/docs/zh/examples/viv.md
+++ b/docs/zh/examples/viv.md
@@ -17,8 +17,6 @@
=== "模型导出命令"
``` sh
- wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn
- wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
python viv.py mode=export
```
diff --git a/examples/fsi/conf/viv.yaml b/examples/fsi/conf/viv.yaml
index 91d31bdeb..8eb3a0c38 100644
--- a/examples/fsi/conf/viv.yaml
+++ b/examples/fsi/conf/viv.yaml
@@ -65,12 +65,12 @@ EVAL:
# inference settings
INFER:
- pretrained_model_path: "./viv_pretrained"
+ pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams"
export_path: ./inference/viv
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
- input_keys: ["t_f"]
- output_keys: ["eta", 'f']
+ input_keys: ${MODEL.input_keys}
+ output_keys: ["eta", "f"]
device: gpu
engine: native
precision: fp32
diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py
index d7f0c1774..1b27bf52c 100644
--- a/examples/fsi/viv.py
+++ b/examples/fsi/viv.py
@@ -201,12 +201,13 @@ def evaluate(cfg: DictConfig):
def export(cfg: DictConfig):
+ from paddle import nn
+ from paddle.static import InputSpec
+
# set model
model = ppsci.arch.MLP(**cfg.MODEL)
-
# initialize equation
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)}
-
# initialize solver
solver = ppsci.solver.Solver(
model,
@@ -214,53 +215,29 @@ def export(cfg: DictConfig):
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# Convert equation to func
- funcs = ppsci.lambdify(
+ f_func = ppsci.lambdify(
solver.equation["VIV"].equations["f"],
solver.model,
list(solver.equation["VIV"].learnable_parameters),
)
- def wrap_prediction_to_dict(instance, func):
- def wrapper(instance, *args, **kwargs):
- result = func(*args, **kwargs)
- return {"f": result}
-
- if hasattr(func, "__func__"):
- wrapper.__func__ = func.__func__
- return wrapper
-
- def wrap_forward_methods(instance):
- instance.input_keys = cfg.MODEL.input_keys
- instance.output_keys = ["f"]
- for attr_name in dir(instance):
- if attr_name == "forward":
- attr = getattr(instance, attr_name)
- setattr(instance, attr_name, wrap_prediction_to_dict(instance, attr))
- return instance
-
- eqn = wrap_forward_methods(funcs)
- # Combine the two instances
- models = ppsci.arch.ModelList((solver.model, eqn))
- # export models
- from paddle.static import InputSpec
+ class Wrapped_Model(nn.Layer):
+ def __init__(self, model, func):
+ super().__init__()
+ self.model = model
+ self.func = func
+
+ def forward(self, x):
+ model_out = self.model(x)
+ func_out = self.func(x)
+ return {**model_out, "f": func_out}
+ solver.model = Wrapped_Model(model, f_func)
+ # export models
input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
-
- from paddle import jit
-
- jit.enable_to_static(True)
-
- static_model = jit.to_static(
- models,
- input_spec=input_spec,
- full_graph=True,
- )
-
- jit.save(static_model, cfg.INFER.export_path, skip_prune_program=True)
-
- jit.enable_to_static(False)
+ solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True)
def inference(cfg: DictConfig):