From d8c3963fa90b58eef59a9cb5646458e320189ca1 Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Sat, 6 Apr 2024 03:38:52 +0000 Subject: [PATCH 1/5] eadd export and inference for viv --- examples/fsi/conf/viv.yaml | 22 ++++++++++ examples/fsi/viv.py | 87 +++++++++++++++++++++++++++++++++++++- ppsci/utils/symbolic.py | 2 +- 3 files changed, 109 insertions(+), 2 deletions(-) diff --git a/examples/fsi/conf/viv.yaml b/examples/fsi/conf/viv.yaml index 989fd7af4..e01322091 100644 --- a/examples/fsi/conf/viv.yaml +++ b/examples/fsi/conf/viv.yaml @@ -11,6 +11,8 @@ hydra: - TRAIN.checkpoint_path - TRAIN.pretrained_model_path - EVAL.pretrained_model_path + - INFER.pretrained_model_path + - INFER.export_path - mode - output_dir - log_freq @@ -60,3 +62,23 @@ TRAIN: EVAL: pretrained_model_path: null batch_size: 32 + +# inference settings +INFER: + pretrained_model_path: "./viv_pretrained" + export_path: ./inference/viv + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + pdmodel_equ_path: ${INFER.export_path}_equ.pdmodel + pdpiparams_equ_path: ${INFER.export_path}_equ.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 64 + num_cpu_threads: 4 + batch_size: 16 diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py index b360daeeb..640050ac8 100644 --- a/examples/fsi/viv.py +++ b/examples/fsi/viv.py @@ -200,14 +200,99 @@ def evaluate(cfg: DictConfig): solver.visualize() +def export(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + # initialize equation + equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)} + + # initialize solver + solver = ppsci.solver.Solver( + model, + equation=equation, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # Convert equation to callable function + func = ppsci.lambdify( + solver.equation["VIV"].equations["f"], + solver.model, + list(solver.equation["VIV"].learnable_parameters), + ) + # export model and equation + from paddle.static import InputSpec + + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + + solver.export(input_spec, cfg.INFER.export_path) + + from paddle import jit + + jit.enable_to_static(True) + + static_model = jit.to_static( + func, + input_spec=input_spec, + full_graph=True, + ) + + jit.save(static_model, cfg.INFER.export_path + "_equ", skip_prune_program=True) + + jit.enable_to_static(False) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + # set model predictor + predictor = pinn_predictor.PINNPredictor(cfg) + + # set equation predictor + cfg.INFER.pdmodel_path = cfg.INFER.pdmodel_equ_path + cfg.INFER.pdpiparams_path = cfg.INFER.pdpiparams_equ_path + equ_predictor = pinn_predictor.PINNPredictor(cfg) + + infer_mat = ppsci.utils.reader.load_mat_file( + cfg.VIV_DATA_PATH, + ("t_f", "eta_gt", "f_gt"), + alias_dict={"eta_gt": "eta", "f_gt": "f"}, + ) + + input_dict = {key: infer_mat[key] for key in cfg.MODEL.input_keys} + + output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) + equ_output_dict = equ_predictor.predict(input_dict, cfg.INFER.batch_size) + + # mapping data to cfg.INFER.output_keys + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + } + for value in equ_output_dict.values(): + output_dict["f"] = value + infer_mat.update(output_dict) + + ppsci.visualize.plot.save_plot_from_1d_dict( + "./viv_pred", infer_mat, ("t_f",), ("eta", "eta_gt", "f", "f_gt") + ) + + @hydra.main(version_base=None, config_path="./conf", config_name="viv.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__": diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index 5a2f1e99c..2537b4ff4 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -490,7 +490,7 @@ class ComposedNode(nn.Layer): def __init__(self, callable_nodes: List[Node]): super().__init__() assert len(callable_nodes) - self.callable_nodes = callable_nodes + self.callable_nodes = paddle.nn.LayerList(callable_nodes) def forward(self, data_dict: DATA_DICT) -> paddle.Tensor: # call all callable_nodes in order From 76b69b28d2fe21fe7eb1f385602864d0a9a37de1 Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Sat, 6 Apr 2024 04:34:39 +0000 Subject: [PATCH 2/5] add doc --- docs/zh/examples/viv.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/zh/examples/viv.md b/docs/zh/examples/viv.md index 4a5bf5712..c9ee6b0ca 100644 --- a/docs/zh/examples/viv.md +++ b/docs/zh/examples/viv.md @@ -16,6 +16,20 @@ python viv.py mode=eval EVAL.pretrained_model_path=./viv_pretrained ``` +=== "模型导出命令" + + ``` sh + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams + python viv.py mode=export + ``` + +=== "模型推理命令" + + ``` sh + python viv.py mode=infer + ``` + | 预训练模型 | 指标 | |:--| :--| | [viv_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdparams)
[viv_pretrained.pdeqn](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdeqn) | 'eta': 1.1416150300647132e-06
'f': 4.635014192899689e-06 | From e3d25831237720b81e2833a40e14ee16663fc2b2 Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Sat, 6 Apr 2024 12:16:42 +0000 Subject: [PATCH 3/5] fix viv export&infer --- examples/fsi/conf/viv.yaml | 4 ++-- examples/fsi/viv.py | 48 ++++++++++++++++++++++++-------------- ppsci/utils/symbolic.py | 2 +- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/examples/fsi/conf/viv.yaml b/examples/fsi/conf/viv.yaml index e01322091..91d31bdeb 100644 --- a/examples/fsi/conf/viv.yaml +++ b/examples/fsi/conf/viv.yaml @@ -69,8 +69,8 @@ INFER: export_path: ./inference/viv pdmodel_path: ${INFER.export_path}.pdmodel pdpiparams_path: ${INFER.export_path}.pdiparams - pdmodel_equ_path: ${INFER.export_path}_equ.pdmodel - pdpiparams_equ_path: ${INFER.export_path}_equ.pdiparams + input_keys: ["t_f"] + output_keys: ["eta", 'f'] device: gpu engine: native precision: fp32 diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py index 640050ac8..3ae3c29d2 100644 --- a/examples/fsi/viv.py +++ b/examples/fsi/viv.py @@ -12,10 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import hydra from omegaconf import DictConfig +from paddle import nn import ppsci +from ppsci.arch import base + + +class EqnTranArch(base.Arch): + def __init__(self, funcs, input_keys: List, output_keys: List): + super().__init__() + if not isinstance(funcs, list): + funcs = [funcs] + self.modellist = nn.LayerList(funcs) + self.input_keys = input_keys + self.output_keys = output_keys + + def forward(self, x): + output_dict = {} + for i, model in enumerate(self.modellist): + output_dict[self.output_keys[i]] = model(x) + return output_dict def train(cfg: DictConfig): @@ -213,32 +233,34 @@ def export(cfg: DictConfig): equation=equation, pretrained_model_path=cfg.INFER.pretrained_model_path, ) - # Convert equation to callable function - func = ppsci.lambdify( + # Convert equation to Arch + funcs = ppsci.lambdify( solver.equation["VIV"].equations["f"], solver.model, list(solver.equation["VIV"].learnable_parameters), ) - # export model and equation + eqn = EqnTranArch(funcs, cfg.INFER.input_keys, ["f"]) + + # Combine the two instances + models = ppsci.arch.ModelList((solver.model, eqn)) + # export models from paddle.static import InputSpec input_spec = [ {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, ] - solver.export(input_spec, cfg.INFER.export_path) - from paddle import jit jit.enable_to_static(True) static_model = jit.to_static( - func, + models, input_spec=input_spec, full_graph=True, ) - jit.save(static_model, cfg.INFER.export_path + "_equ", skip_prune_program=True) + jit.save(static_model, cfg.INFER.export_path, skip_prune_program=True) jit.enable_to_static(False) @@ -249,29 +271,21 @@ def inference(cfg: DictConfig): # set model predictor predictor = pinn_predictor.PINNPredictor(cfg) - # set equation predictor - cfg.INFER.pdmodel_path = cfg.INFER.pdmodel_equ_path - cfg.INFER.pdpiparams_path = cfg.INFER.pdpiparams_equ_path - equ_predictor = pinn_predictor.PINNPredictor(cfg) - infer_mat = ppsci.utils.reader.load_mat_file( cfg.VIV_DATA_PATH, ("t_f", "eta_gt", "f_gt"), alias_dict={"eta_gt": "eta", "f_gt": "f"}, ) - input_dict = {key: infer_mat[key] for key in cfg.MODEL.input_keys} + input_dict = {key: infer_mat[key] for key in cfg.INFER.input_keys} output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) - equ_output_dict = equ_predictor.predict(input_dict, cfg.INFER.batch_size) # mapping data to cfg.INFER.output_keys output_dict = { store_key: output_dict[infer_key] - for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys()) + for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys()) } - for value in equ_output_dict.values(): - output_dict["f"] = value infer_mat.update(output_dict) ppsci.visualize.plot.save_plot_from_1d_dict( diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index 2537b4ff4..a48a7399f 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -490,7 +490,7 @@ class ComposedNode(nn.Layer): def __init__(self, callable_nodes: List[Node]): super().__init__() assert len(callable_nodes) - self.callable_nodes = paddle.nn.LayerList(callable_nodes) + self.callable_nodes = nn.LayerList(callable_nodes) def forward(self, data_dict: DATA_DICT) -> paddle.Tensor: # call all callable_nodes in order From 0818e24cf83868511a9c069f24eb2261cfb5cbb0 Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Sun, 7 Apr 2024 01:34:06 +0000 Subject: [PATCH 4/5] Rewriting function --- examples/fsi/viv.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py index 3ae3c29d2..d7f0c1774 100644 --- a/examples/fsi/viv.py +++ b/examples/fsi/viv.py @@ -12,30 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - import hydra from omegaconf import DictConfig -from paddle import nn import ppsci -from ppsci.arch import base - - -class EqnTranArch(base.Arch): - def __init__(self, funcs, input_keys: List, output_keys: List): - super().__init__() - if not isinstance(funcs, list): - funcs = [funcs] - self.modellist = nn.LayerList(funcs) - self.input_keys = input_keys - self.output_keys = output_keys - - def forward(self, x): - output_dict = {} - for i, model in enumerate(self.modellist): - output_dict[self.output_keys[i]] = model(x) - return output_dict def train(cfg: DictConfig): @@ -233,14 +213,32 @@ def export(cfg: DictConfig): equation=equation, pretrained_model_path=cfg.INFER.pretrained_model_path, ) - # Convert equation to Arch + # Convert equation to func funcs = ppsci.lambdify( solver.equation["VIV"].equations["f"], solver.model, list(solver.equation["VIV"].learnable_parameters), ) - eqn = EqnTranArch(funcs, cfg.INFER.input_keys, ["f"]) + def wrap_prediction_to_dict(instance, func): + def wrapper(instance, *args, **kwargs): + result = func(*args, **kwargs) + return {"f": result} + + if hasattr(func, "__func__"): + wrapper.__func__ = func.__func__ + return wrapper + + def wrap_forward_methods(instance): + instance.input_keys = cfg.MODEL.input_keys + instance.output_keys = ["f"] + for attr_name in dir(instance): + if attr_name == "forward": + attr = getattr(instance, attr_name) + setattr(instance, attr_name, wrap_prediction_to_dict(instance, attr)) + return instance + + eqn = wrap_forward_methods(funcs) # Combine the two instances models = ppsci.arch.ModelList((solver.model, eqn)) # export models From 3ec9bb715b764a32baa36631a17d1924ffd4190a Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Sun, 7 Apr 2024 10:44:32 +0000 Subject: [PATCH 5/5] fix viv export&infer --- docs/zh/examples/viv.md | 2 -- examples/fsi/conf/viv.yaml | 6 ++-- examples/fsi/viv.py | 57 ++++++++++++-------------------------- 3 files changed, 20 insertions(+), 45 deletions(-) diff --git a/docs/zh/examples/viv.md b/docs/zh/examples/viv.md index a47bf4dc7..e0b51fa5b 100644 --- a/docs/zh/examples/viv.md +++ b/docs/zh/examples/viv.md @@ -17,8 +17,6 @@ === "模型导出命令" ``` sh - wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn - wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams python viv.py mode=export ``` diff --git a/examples/fsi/conf/viv.yaml b/examples/fsi/conf/viv.yaml index 91d31bdeb..8eb3a0c38 100644 --- a/examples/fsi/conf/viv.yaml +++ b/examples/fsi/conf/viv.yaml @@ -65,12 +65,12 @@ EVAL: # inference settings INFER: - pretrained_model_path: "./viv_pretrained" + pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams" export_path: ./inference/viv pdmodel_path: ${INFER.export_path}.pdmodel pdpiparams_path: ${INFER.export_path}.pdiparams - input_keys: ["t_f"] - output_keys: ["eta", 'f'] + input_keys: ${MODEL.input_keys} + output_keys: ["eta", "f"] device: gpu engine: native precision: fp32 diff --git a/examples/fsi/viv.py b/examples/fsi/viv.py index d7f0c1774..1b27bf52c 100644 --- a/examples/fsi/viv.py +++ b/examples/fsi/viv.py @@ -201,12 +201,13 @@ def evaluate(cfg: DictConfig): def export(cfg: DictConfig): + from paddle import nn + from paddle.static import InputSpec + # set model model = ppsci.arch.MLP(**cfg.MODEL) - # initialize equation equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)} - # initialize solver solver = ppsci.solver.Solver( model, @@ -214,53 +215,29 @@ def export(cfg: DictConfig): pretrained_model_path=cfg.INFER.pretrained_model_path, ) # Convert equation to func - funcs = ppsci.lambdify( + f_func = ppsci.lambdify( solver.equation["VIV"].equations["f"], solver.model, list(solver.equation["VIV"].learnable_parameters), ) - def wrap_prediction_to_dict(instance, func): - def wrapper(instance, *args, **kwargs): - result = func(*args, **kwargs) - return {"f": result} - - if hasattr(func, "__func__"): - wrapper.__func__ = func.__func__ - return wrapper - - def wrap_forward_methods(instance): - instance.input_keys = cfg.MODEL.input_keys - instance.output_keys = ["f"] - for attr_name in dir(instance): - if attr_name == "forward": - attr = getattr(instance, attr_name) - setattr(instance, attr_name, wrap_prediction_to_dict(instance, attr)) - return instance - - eqn = wrap_forward_methods(funcs) - # Combine the two instances - models = ppsci.arch.ModelList((solver.model, eqn)) - # export models - from paddle.static import InputSpec + class Wrapped_Model(nn.Layer): + def __init__(self, model, func): + super().__init__() + self.model = model + self.func = func + + def forward(self, x): + model_out = self.model(x) + func_out = self.func(x) + return {**model_out, "f": func_out} + solver.model = Wrapped_Model(model, f_func) + # export models input_spec = [ {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, ] - - from paddle import jit - - jit.enable_to_static(True) - - static_model = jit.to_static( - models, - input_spec=input_spec, - full_graph=True, - ) - - jit.save(static_model, cfg.INFER.export_path, skip_prune_program=True) - - jit.enable_to_static(False) + solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True) def inference(cfg: DictConfig):