From 1227ae81f762984bbec3675dfd873da849613c70 Mon Sep 17 00:00:00 2001 From: chaous Date: Tue, 27 Jun 2023 15:01:02 +0300 Subject: [PATCH 1/2] Added NS incompressible and Euler equation --- modulus/sym/eq/pdes/navier_stokes.py | 165 +++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/modulus/sym/eq/pdes/navier_stokes.py b/modulus/sym/eq/pdes/navier_stokes.py index 2bb1f44e..e6f21b9a 100644 --- a/modulus/sym/eq/pdes/navier_stokes.py +++ b/modulus/sym/eq/pdes/navier_stokes.py @@ -339,6 +339,90 @@ def __init__(self, T, dim=3, time=True): normal_x * T.diff(x) + normal_y * T.diff(y) + normal_z * T.diff(z) ) +class NavierStokesIncompressible(PDE): + '''Cystom implementation of incompressible NS eqyastyans''' + + name = "navier stokes incompressible" + + def __init__(self, nu, rho=1, dim=3, time=False): + + '''Parameters + ========== + nu : float, Sympy Symbol/Expr, str + The kinematic viscosity. If `nu` is a str then it is + converted to Sympy Function of form `nu(x,y,z,t)`. + If `nu` is a Sympy Symbol or Expression then this + is substituted into the equation. This allows for + variable viscosity. + rho : float, Sympy Symbol/Expr, str + The density of the fluid. If `rho` is a str then it is + converted to Sympy Function of form 'rho(x,y,z,t)'. + If 'rho' is a Sympy Symbol or Expression then this + is substituted into the equation to allow for + compressible Navier Stokes. Default is 1. + dim : int + Dimension of the Navier Stokes (2 or 3). Default is 3. + time : bool + If time-dependent equations or not. Default is False.''' + + self.dim = dim + self.time = time + + # coordinates + x, y, z = Symbol("x"), Symbol("y"), Symbol("z") + + # time + t = Symbol("t") + + input_variables = {"x": x, "y": y, "z": z, "t": t} + if self.dim == 2: + input_variables.pop("z") + if not self.time: + input_variables.pop("t") + + + # kinematic viscosity + if isinstance(nu, str): + nu = Function(nu)(*input_variables) + elif isinstance(nu, (float, int)): + nu = Number(nu) + + # density + if isinstance(rho, str): + rho = Function(rho)(*input_variables) + elif isinstance(rho, (float, int)): + rho = Number(rho) + + + u = Function("u")(*input_variables) + v = Function("v")(*input_variables) + w = Function("w")(*input_variables) + + + + #pressure + p = Function("p")(*input_variables) + + # set equations + self.equations = {} + self.equations["continuity"] = ( + u.diff(x) + v.diff(y) + w.diff(z) + ) + + self.equations["momentum_x"] = ( + u.diff(t) + u * u.diff(x) + v * u.diff(y) + w * u.diff(z) + 1 / rho * p.diff(x) + - nu * (u.diff(x).diff(x) + u.diff(y).diff(y) + u.diff(z).diff(z)) + ) + + self.equations["momentum_y"] = ( + v.diff(t) +u * v.diff(x) + v * v.diff(y) + w * v.diff(z) + 1 / rho * p.diff(y) + - nu * (v.diff(x).diff(x) + v.diff(y).diff(y) + v.diff(z).diff(z)) + ) + self.equations["momentum_z"] = ( + w.diff(t) +u * w.diff(x) + v * w.diff(y) + w * w.diff(z) + 1 / rho * p.diff(z) + - nu * (w.diff(x).diff(x) + w.diff(y).diff(y) + w.diff(z).diff(z)) + ) + class Curl(PDE): """ @@ -507,3 +591,84 @@ def __init__(self, T="T", D="D", rho=1, vec=["u", "v", "w"]): self.equations[str(T) + "_flux"] += ( Symbol(v) * n * rho * T - rho * D * n * g ) + + +class Euler(PDE): + '''Cystom implementation of compressible Euler eqyastyans''' + def __init__(self, rho=1, dim=3, ratio_heats=1.4, time=False): + ''' + Parameters + ========== + rho : float, Sympy Symbol/Expr, str + The density of the fluid. If `rho` is a str then it is + converted to Sympy Function of form 'rho(x,y,z,t)'. + If 'rho' is a Sympy Symbol or Expression then this + is substituted into the equation to allow for + compressible Navier Stokes. Default is 1. + dim : int + Dimension of the Navier Stokes (2 or 3). Default is 3. + time : bool + If time-dependent equations or not. Default is False. + ratio_heats : float + Ratio of specific heats. Default is 1.4 (for air). + ''' + self.dim = dim + self.time = time + + # coordinates + x, y, z = Symbol("x"), Symbol("y"), Symbol("z") + + # time + t = Symbol("t") + + # make input variables + input_variables = {"x": x, "y": y, "z": z, "t": t} + if self.dim == 2: + input_variables.pop("z") + if not self.time: + input_variables.pop("t") + + u = Function("u")(*input_variables) + v = Function("v")(*input_variables) + if self.dim == 3: + w = Function("w")(*input_variables) + else: + w = Number(0) + + + # pressure + p = Function("p")(*input_variables) + + # density + if isinstance(rho, str): + rho = Function(rho)(*input_variables) + elif isinstance(rho, (float, int)): + rho = Number(rho) + + + self.equations = {} + + e = p /(rho * (ratio_heats - 1)) + # Energy + E = rho * (e + (u **2 + v **2 + w **2) / 2) + + self.equations["continuity"] = ( + rho.diff(t) + (rho * u).diff(x) + (rho * v).diff(y) + (rho * w).diff(z) + ) + + self.equations["momentum_x"] = ( + (rho * u).diff(t) + (rho * u **2 + p).diff(x) + (rho * u * v).diff(y) + (rho * u * w).diff(z) + + ) + + self.equations["momentum_y"] = ( + (rho * v).diff(t) + (rho * u * v).diff(x) + (rho * v **2 + p).diff(y) + (rho * v * w).diff(z) + ) + + self.equations["momentum_z"] = ( + (rho * w).diff(t) + (rho * u * w).diff(x) + (rho * v * w).diff(y) + (rho * w **2 + p).diff(z) + ) + + self.equations["energy"] = ( + E.diff(t) + (u * (E + p)).diff(x) + (v * (E + p)).diff(y) + (w * (E + p)).diff(z) + ) From 3e69c672950b720364b02e4c5d3a78c66f5d4e04 Mon Sep 17 00:00:00 2001 From: Ilya Date: Thu, 28 Mar 2024 17:44:41 +0300 Subject: [PATCH 2/2] added only equations and examples with navier-stikes incompressible and Euler eqation --- examples/cylinder_dynamic/conf/config.yaml | 46 +++ examples/cylinder_dynamic/cylinder_2d_time.py | 289 ++++++++++++++++++ examples/sphere_high_speed/conf/conf.yaml | 29 ++ .../sphere_high_speed/sphere_high_speed.py | 263 ++++++++++++++++ .../stl_files/Inlet_large.stl | 3 + .../stl_files/Outlet_large.stl | 3 + .../stl_files/closed_large.stl | 3 + .../sphere_high_speed/stl_files/refinment.stl | 3 + .../sphere_high_speed/stl_files/sphere.stl | 3 + modulus/sym/eq/pdes/navier_stokes.py | 69 +++-- 10 files changed, 680 insertions(+), 31 deletions(-) create mode 100644 examples/cylinder_dynamic/conf/config.yaml create mode 100644 examples/cylinder_dynamic/cylinder_2d_time.py create mode 100644 examples/sphere_high_speed/conf/conf.yaml create mode 100644 examples/sphere_high_speed/sphere_high_speed.py create mode 100644 examples/sphere_high_speed/stl_files/Inlet_large.stl create mode 100644 examples/sphere_high_speed/stl_files/Outlet_large.stl create mode 100644 examples/sphere_high_speed/stl_files/closed_large.stl create mode 100644 examples/sphere_high_speed/stl_files/refinment.stl create mode 100644 examples/sphere_high_speed/stl_files/sphere.stl diff --git a/examples/cylinder_dynamic/conf/config.yaml b/examples/cylinder_dynamic/conf/config.yaml new file mode 100644 index 00000000..a2721862 --- /dev/null +++ b/examples/cylinder_dynamic/conf/config.yaml @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults : + - modulus_default + - arch: + - fully_connected + - scheduler: tf_exponential_lr + - optimizer: adam + - loss: sum + - _self_ + + +jit: false + +scheduler: + decay_rate: 0.95 + decay_steps: 2000 + +training: + rec_results_freq : 1000 + rec_constraint_freq: 10000 + max_steps : 200000 + +batch_size: + inlet: 840 + outlet: 840 + walls: 840 + no_slip: 940 + interior: 7000 + initial_condition: 4000 + interior_cylinder: 5000 + interior_small: 6000 \ No newline at end of file diff --git a/examples/cylinder_dynamic/cylinder_2d_time.py b/examples/cylinder_dynamic/cylinder_2d_time.py new file mode 100644 index 00000000..0da38832 --- /dev/null +++ b/examples/cylinder_dynamic/cylinder_2d_time.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings + +import numpy as np +from sympy import Symbol, Eq + +import modulus.sym +from modulus.sym.hydra import to_absolute_path, instantiate_arch, ModulusConfig +from modulus.sym.domain import Domain +from modulus.sym.geometry import Bounds +from modulus.sym.geometry.primitives_2d import Line, Circle, Channel2D +from modulus.sym.eq.pdes.navier_stokes import NavierStokesIncompressible +from modulus.sym.eq.pdes.basic import NormalDotVec +from modulus.sym.domain.constraint import ( + PointwiseBoundaryConstraint, + PointwiseInteriorConstraint, +) + + +from modulus.sym.key import Key +from modulus.sym import quantity +from modulus.sym.eq.non_dim import NonDimensionalizer, Scaler +from modulus.sym.models.moving_time_window import MovingTimeWindowArch +from modulus.sym.domain.inferencer import PointVTKInferencer +from modulus.sym.utils.io import ( + VTKUniformGrid, +) +from modulus.sym.solver import SequentialSolver +from sympy import Symbol, Function, Number + + +from modulus.sym.eq.pde import PDE + + +@modulus.sym.main(config_path="conf", config_name="config") +def run(cfg: ModulusConfig) -> None: + # physical quantities + nu = quantity(1.48e-5, "kg/(m*s)") + # nu = quantity(8.9e-3, "m^2/s") + rho = quantity(1.225, "kg/m^3") + inlet_u = quantity(1.5, "m/s") + inlet_v = quantity(0.0, "m/s") + noslip_u = quantity(0.0, "m/s") + noslip_v = quantity(0.0, "m/s") + outlet_p = quantity(0.0, "pa") + time_window_size = quantity(2, "s") + t_symbol = Symbol("t") + + nr_time_windows = 20 + velocity_scale = inlet_u + density_scale = rho + length_scale = quantity(20, "m") + nd = NonDimensionalizer( + length_scale=length_scale, + time_scale=length_scale / velocity_scale, + mass_scale=density_scale * (length_scale**3), + ) + time_range = {t_symbol: (nd.ndim(quantity(0.0, "s")), + nd.ndim(time_window_size))} + + # geometry + channel_length = (quantity(-10, "m"), quantity(30, "m")) + channel_width = (quantity(-10, "m"), quantity(10, "m")) + cylinder_center = (quantity(0, "m"), quantity(0, "m")) + cylinder_radius = quantity(0.5, "m") + channel_length_nd = tuple(map(lambda x: nd.ndim(x), channel_length)) + channel_width_nd = tuple(map(lambda x: nd.ndim(x), channel_width)) + cylinder_center_nd = tuple(map(lambda x: nd.ndim(x), cylinder_center)) + cylinder_radius_nd = nd.ndim(cylinder_radius) + + channel = Channel2D( + (channel_length_nd[0], channel_width_nd[0]), + (channel_length_nd[1], channel_width_nd[1]), + ) + inlet = Line( + (channel_length_nd[0], channel_width_nd[0]), + (channel_length_nd[0], channel_width_nd[1]), + normal=1, + ) + outlet = Line( + (channel_length_nd[1], channel_width_nd[0]), + (channel_length_nd[1], channel_width_nd[1]), + normal=1, + ) + wall_top = Line( + (channel_length_nd[1], channel_width_nd[0]), + (channel_length_nd[1], channel_width_nd[1]), + normal=1, + ) + cylinder = Circle(cylinder_center_nd, cylinder_radius_nd) + volume_geo = channel - cylinder + + volume_geo_small = channel.scale(0.2) - cylinder + + cylinder_interior = cylinder.scale(2) - cylinder + + # make list of nodes to unroll graph on + ns = NavierStokesIncompressible(nu=nd.ndim( + nu), rho=nd.ndim(rho), dim=2, time=True) + normal_dot_vel = NormalDotVec(["u", "v"]) + + flow_net = instantiate_arch( + # Include time as an input key + input_keys=[Key("x"), Key("y"), Key("t")], + output_keys=[Key("u"), Key("v"), Key("p")], + cfg=cfg.arch.fully_connected, + ) + time_window_net = MovingTimeWindowArch(flow_net, nd.ndim(time_window_size)) + + nodes = ( + ns.make_nodes() + + normal_dot_vel.make_nodes() + + [time_window_net.make_node(name="time_window_network")] + + Scaler( + ["u", "v", "p", "t"], + ["u_scaled", "v_scaled", "p_scaled", "t_scaled"], + ["m/s", "m/s", "m^2/s^2"], + nd, + ).make_node() + ) + + # make domain + # make initial condition domain + ic_domain = Domain("initial_conditions") + + # make moving window domain + window_domain = Domain("window") + x, y = Symbol("x"), Symbol("y") + + ic = PointwiseInteriorConstraint( + nodes=nodes, + geometry=volume_geo, + outvar={ + "u": nd.ndim(inlet_u), + "v": nd.ndim(inlet_v), + "p": np.ndim(outlet_p) + }, + batch_size=cfg.batch_size.initial_condition, + bounds=Bounds({x: channel_length_nd, y: channel_width_nd}), + parameterization={t_symbol: 0} + ) + ic_domain.add_constraint(ic, name="ic") + + # make constraint for matching previous windows initial condition + ic = PointwiseInteriorConstraint( + nodes=nodes, + geometry=volume_geo, + outvar={"u_prev_step_diff": 0, + "v_prev_step_diff": 0}, + batch_size=cfg.batch_size.interior, + bounds=Bounds({x: channel_length_nd, y: channel_width_nd}), + parameterization={t_symbol: 0}, + ) + window_domain.add_constraint(ic, name="ic") + + # inlet + inlet = PointwiseBoundaryConstraint( + nodes=nodes, + geometry=inlet, + outvar={"u": nd.ndim(inlet_u), "v": nd.ndim(inlet_v)}, + batch_size=cfg.batch_size.inlet, + parameterization=time_range, + ) + window_domain.add_constraint(inlet, "inlet") + ic_domain.add_constraint(inlet, "inlet") + + # outlet + outlet = PointwiseBoundaryConstraint( + nodes=nodes, + geometry=outlet, + outvar={"p": nd.ndim(outlet_p)}, + batch_size=cfg.batch_size.outlet, + parameterization=time_range, + ) + window_domain.add_constraint(outlet, "outlet") + ic_domain.add_constraint(outlet, "outlet") + + # full slip (channel walls) + walls = PointwiseBoundaryConstraint( + nodes=nodes, + geometry=channel, + outvar={"u": nd.ndim(inlet_u), "v": nd.ndim(inlet_v)}, + batch_size=cfg.batch_size.walls, + parameterization=time_range, + ) + window_domain.add_constraint(walls, "walls") + ic_domain.add_constraint(walls, "walls") + + # no slip + no_slip = PointwiseBoundaryConstraint( + nodes=nodes, + geometry=cylinder, + outvar={"u": nd.ndim(noslip_u), "v": nd.ndim(noslip_v)}, + batch_size=cfg.batch_size.no_slip, + parameterization=time_range, + ) + window_domain.add_constraint(no_slip, "no_slip") + ic_domain.add_constraint(no_slip, "no_slip") + + # interior contraints + interior = PointwiseInteriorConstraint( + nodes=nodes, + geometry=volume_geo, + outvar={"continuity": 0, "momentum_x": 0, "momentum_y": 0}, + batch_size=cfg.batch_size.interior, + bounds=Bounds({x: channel_length_nd, y: channel_width_nd}), + parameterization=time_range, + ) + window_domain.add_constraint(interior, "interior") + ic_domain.add_constraint(interior, "interior") + + interior_cylinder = PointwiseInteriorConstraint( + nodes=nodes, + geometry=cylinder_interior, + outvar={"continuity": 0, "momentum_x": 0, "momentum_y": 0}, + batch_size=cfg.batch_size.interior_cylinder, + parameterization=time_range, + ) + window_domain.add_constraint(interior_cylinder, "interior_cylinder") + ic_domain.add_constraint(interior_cylinder, "interior_cylinder") + + interior_small = PointwiseInteriorConstraint( + nodes=nodes, + geometry=volume_geo_small, + outvar={"continuity": 0, "momentum_x": 0, "momentum_y": 0}, + batch_size=cfg.batch_size.interior_small, + parameterization=time_range, + ) + + window_domain.add_constraint(interior_small, "interior_small") + ic_domain.add_constraint(interior_small, "interior_small") + + bounds_nd = [ + # Normalized bounds for x-direction + [channel_length_nd[0], channel_length_nd[1]], + # Normalized bounds for y-direction + [channel_width_nd[0], channel_width_nd[1]] + ] + + for i, specific_time in enumerate(np.linspace(0, nd.ndim(time_window_size), 100)): + vtk_obj = VTKUniformGrid( + bounds=bounds_nd, + npoints=[256, 256], + export_map={"u": ["u", "v"], "p": ["p"]}, + ) + grid_inference = PointVTKInferencer( + vtk_obj=vtk_obj, + nodes=nodes, + input_vtk_map={"x": "x", "y": "y"}, + output_names=["u", "v", "p"], + requires_grad=False, + invar={"t": np.full([256**2, 1], specific_time)}, + batch_size=100000, + ) + ic_domain.add_inferencer( + grid_inference, name="time_slice_" + str(i).zfill(4)) + window_domain.add_inferencer( + grid_inference, name="time_slice_" + str(i).zfill(4) + ) + + # make solver + slv = SequentialSolver( + cfg, + [(1, ic_domain), (nr_time_windows, window_domain)], + custom_update_operation=time_window_net.move_window, + ) + + # start solver + slv.solve() + + +if __name__ == "__main__": + run() diff --git a/examples/sphere_high_speed/conf/conf.yaml b/examples/sphere_high_speed/conf/conf.yaml new file mode 100644 index 00000000..7f90e0e8 --- /dev/null +++ b/examples/sphere_high_speed/conf/conf.yaml @@ -0,0 +1,29 @@ +defaults : + - modulus_default + - arch: + - fully_connected + - scheduler: tf_exponential_lr + - optimizer: adam + - loss: sum + - _self_ + +optimizer: + lr: 0.0001 + +scheduler: + decay_rate: 0.9 + decay_steps: 4000 + +training: + rec_results_freq : 10000 + rec_constraint_freq: 10000 + max_steps : 12000000 + +batch_size: + inlet: 900 + outlet: 900 + no_slip: 1800 + interior: 4500 + refinment: 5300 + interior_init: 2000 + refinment_init: 2000 diff --git a/examples/sphere_high_speed/sphere_high_speed.py b/examples/sphere_high_speed/sphere_high_speed.py new file mode 100644 index 00000000..1a78bb45 --- /dev/null +++ b/examples/sphere_high_speed/sphere_high_speed.py @@ -0,0 +1,263 @@ +import torch +from torch.utils.data import DataLoader, Dataset + +import numpy as np +from sympy import Symbol, sqrt, Max + +import modulus +import modulus.sym +from modulus.sym.hydra import to_absolute_path, instantiate_arch, ModulusConfig +from modulus.sym.solver import Solver +from modulus.sym.domain import Domain +from modulus.sym.domain.constraint import ( + PointwiseBoundaryConstraint, + PointwiseInteriorConstraint +) + +from modulus.sym.domain.inferencer import VoxelInferencer +from modulus.sym.domain.monitor import PointwiseMonitor +from modulus.sym.key import Key +from modulus.sym.eq.pdes.navier_stokes import GradNormal, Euler +from modulus.sym.eq.pdes.basic import NormalDotVec +from modulus.sym.geometry.tessellation import Tessellation +from modulus.sym import quantity +from modulus.sym.eq.non_dim import NonDimensionalizer, Scaler + +import stl +from stl import mesh + + +from sympy import Symbol + +from modulus.sym.eq.pde import PDE + + +class Magnityde(PDE): + def __init__(self): + u_scaled = Symbol("u_scaled") + v_scaled = Symbol("v_scaled") + w_scaled = Symbol("w_scaled") + + self.equations = {} + self.equations["magnityte_scaled"] = ( + u_scaled ** 2 + v_scaled ** 2 + w_scaled**2)**(0.5) + + +@modulus.sym.main(config_path="conf", config_name="conf") +def run(cfg: ModulusConfig) -> None: + # print(to_yaml(cfg)) + + # path definitions + point_path = to_absolute_path("./stl_files") + path_inlet = point_path + "/Inlet_large.stl" + dict_path_outlet = {'path_outlet': point_path + "/Outlet_large.stl", + } + path_noslip = point_path + "/sphere.stl" + path_interior = point_path + "/closed_large.stl" + path_outlet_combined = point_path + '/Outlet_large.stl' + refinment_path = point_path + '/refinment.stl' + + # create and save combined outlet stl + def combined_stl(meshes, save_path="./combined.stl"): + combined = mesh.Mesh(np.concatenate([m.data for m in meshes])) + combined.save(save_path, mode=stl.Mode.ASCII) + + meshes = [mesh.Mesh.from_file(file_) + for file_ in dict_path_outlet.values()] + combined_stl(meshes, path_outlet_combined) + + # read stl files to make geometry + inlet_mesh = Tessellation.from_stl(path_inlet, airtight=False) + dict_outlet = {} + for idx_, key_ in enumerate(dict_path_outlet): + dict_outlet['outlet'+str(idx_)+'_mesh'] = Tessellation.from_stl( + dict_path_outlet[key_], airtight=False) + noslip_mesh = Tessellation.from_stl(path_noslip, airtight=True) + interior_mesh = Tessellation.from_stl(path_interior, airtight=True) + refinment_mesh = Tessellation.from_stl(refinment_path, airtight=True) + + rho = quantity(1.225, "kg/m^3") + D = quantity(25, "m") + inlet_u = quantity(396.0, "m/s") + inlet_v = quantity(0.0, "m/s") + inlet_w = quantity(0.0, "m/s") + inlet_p = quantity(10 ** 5, "pa") + + nd = NonDimensionalizer( + length_scale=D, + time_scale=D / inlet_u, + mass_scale=rho * ((D / 2)**3) * 4 / 3 * 3.1415, + ) + + # normalize meshes + def normalize_mesh(mesh, center, scale): + mesh.translate([-c for c in center]) + mesh.scale(scale) + + # normalize invars + def normalize_invar(invar, center, scale, dims=2): + invar["x"] -= center[0] + invar["y"] -= center[1] + invar["z"] -= center[2] + invar["x"] *= scale + invar["y"] *= scale + invar["z"] *= scale + if "area" in invar.keys(): + invar["area"] *= scale ** dims + return invar + + # geometry scaling + + print(nd.ndim(D)) + scale = 1 # / nd.ndim(D) # turn off scaling + + # center of overall geometry + center = (0, 0, 0) + print('Overall geometry center: ', center) + + # center of inlet in original coordinate system + inlet_center_abs = (0, 0, 0) + print("inlet_center_abs:", inlet_center_abs) + + # scale end center the inlet center + inlet_center = list( + (np.array(inlet_center_abs) - np.array(center)) * scale) + print("inlet_center:", inlet_center) + + # inlet normal vector; should point into the cylinder, not outwards + inlet_normal = (1, 0, 0) + print("inlet_normal:", inlet_normal) + + # inlet velocity profile + + def circular_parabola(x, y, z, center, normal, radius, max_vel): + centered_x = x - center[0] + centered_y = y - center[1] + centered_z = z - center[2] + + distance = sqrt(centered_x ** 2 + centered_y ** 2 + centered_z ** 2) + parabola = max_vel * Max((1 - (distance / radius) ** 2), 0) + return normal[0] * parabola, normal[1] * parabola, normal[2] * parabola + + # make aneurysm domain + domain = Domain() + + # make list of nodes to unroll graph on + es = Euler(rho="rho", dim=3, time=False) + navier_stokes_nodes = es.make_nodes() # ns.make_nodes() + ze.make_nodes() + gn_p = GradNormal("p", dim=3, time=False) + gn_u = GradNormal("u", dim=3, time=False) + gn_v = GradNormal("v", dim=3, time=False) + gn_w = GradNormal("w", dim=3, time=False) + gn_rho = GradNormal("rho", dim=3, time=False) + normal_dot_vel = NormalDotVec(["u", "v", "w"]) + flow_net = instantiate_arch( + input_keys=[Key("x"), Key("y"), Key("z")], + output_keys=[Key("u"), Key("v"), Key("w"), Key("p"), Key("rho")], + cfg=cfg.arch.fully_connected, + ) + + mg = Magnityde() + + nodes = ( + navier_stokes_nodes + + gn_p.make_nodes() + + gn_u.make_nodes() + + gn_v.make_nodes() + + gn_w.make_nodes() + + gn_rho.make_nodes() + + normal_dot_vel.make_nodes() + + [flow_net.make_node(name="flow_network")] + + Scaler( + ["u", "v", "w", "p"], + ["u_scaled", "v_scaled", "w_scaled", "p_scaled"], + ["m/s", "m/s", "m/s", "m^2/s^2"], + nd, + ).make_node() + + mg.make_nodes() + ) + + inlet = PointwiseBoundaryConstraint( + nodes=nodes, + geometry=inlet_mesh, + outvar={"u": nd.ndim(inlet_u), "v": nd.ndim(inlet_v), "w": nd.ndim( + inlet_w), "p": nd.ndim(inlet_p), "rho": nd.ndim(rho)}, + batch_size=cfg.batch_size.inlet, + ) + domain.add_constraint(inlet, "inlet") + + # outlet + for idx_, key_ in enumerate(dict_outlet): + outlet = PointwiseBoundaryConstraint( + nodes=nodes, + geometry=dict_outlet[key_], + # outvar={"normal_gradient_p": 0, "normal_gradient_u": 0, "normal_gradient_v": 0, "normal_gradient_w": 0,"rho" : nd.ndim(rho)}, + outvar={"u": nd.ndim(inlet_u), "v": nd.ndim(inlet_v), "w": nd.ndim( + inlet_w), "p": nd.ndim(inlet_p), "rho": nd.ndim(rho)}, + batch_size=cfg.batch_size.outlet, + ) + domain.add_constraint(outlet, "outlet"+str(idx_)) + + # no slip + no_slip = PointwiseBoundaryConstraint( + nodes=nodes, + geometry=noslip_mesh, + outvar={"normal_gradient_p": 0, "normal_gradient_u": 0, + "normal_gradient_v": 0, "normal_gradient_w": 0}, + batch_size=cfg.batch_size.no_slip, + fixed_dataset=True + ) + domain.add_constraint(no_slip, "no_slip") + # interior + interior = PointwiseInteriorConstraint( + nodes=nodes, + geometry=interior_mesh, + outvar={"continuity": 0, "momentum_x": 0, + "momentum_y": 0, "momentum_z": 0, "energy": 0}, + batch_size=cfg.batch_size.interior, + fixed_dataset=True + ) + domain.add_constraint(interior, "interior") + + refinment = PointwiseInteriorConstraint( + nodes=nodes, + geometry=refinment_mesh, + outvar={"continuity": 0, "momentum_x": 0, + "momentum_y": 0, "momentum_z": 0, "energy": 0}, + batch_size=cfg.batch_size.refinment, + fixed_dataset=True + ) + domain.add_constraint(refinment, "refinment") + + voxel_grid = VoxelInferencer( + bounds=[[-7 * scale, 7 * scale], + [-7 * scale, 7 * scale], [-7 * scale, 7 * scale]], + npoints=[128, 128, 128], + nodes=nodes, + output_names=["u", "v", "w", "p", + "u_scaled", "v_scaled", "w_scaled", "p_scaled", "magnityte_scaled", "rho"] + ) + + domain.add_inferencer(voxel_grid, 'voxel_inf') + + force = PointwiseMonitor( + noslip_mesh.sample_boundary(900), + output_names=["p_scaled"], + metrics={ + "force_x": lambda var: torch.sum(var["normal_x"] * var["area"] * var["p_scaled"]), + "force_y": lambda var: torch.sum(var["normal_y"] * var["area"] * var["p_scaled"]), + "force_z": lambda var: torch.sum(var["normal_z"] * var["area"] * var["p_scaled"]), + }, + nodes=nodes, + ) + domain.add_monitor(force) + + # make solver + slv = Solver(cfg, domain) + + # start solver + slv.solve() + + +if __name__ == "__main__": + run() diff --git a/examples/sphere_high_speed/stl_files/Inlet_large.stl b/examples/sphere_high_speed/stl_files/Inlet_large.stl new file mode 100644 index 00000000..dfcdb773 --- /dev/null +++ b/examples/sphere_high_speed/stl_files/Inlet_large.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:102ab2323aa9fb44917500a39549558cdb74af5d6cd321b1974f7f2ee4d9d073 +size 24084 diff --git a/examples/sphere_high_speed/stl_files/Outlet_large.stl b/examples/sphere_high_speed/stl_files/Outlet_large.stl new file mode 100644 index 00000000..fcdb0296 --- /dev/null +++ b/examples/sphere_high_speed/stl_files/Outlet_large.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:196a538512335ba938f59373c974fe1214508dede7f33a6629c6f5bd794c2cff +size 92329 diff --git a/examples/sphere_high_speed/stl_files/closed_large.stl b/examples/sphere_high_speed/stl_files/closed_large.stl new file mode 100644 index 00000000..a00dc712 --- /dev/null +++ b/examples/sphere_high_speed/stl_files/closed_large.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a53db64455afaf1dbaea6f2b102e7ac41c57bc035245e6e885ef27ad8f5e2265 +size 96084 diff --git a/examples/sphere_high_speed/stl_files/refinment.stl b/examples/sphere_high_speed/stl_files/refinment.stl new file mode 100644 index 00000000..f2ea9d00 --- /dev/null +++ b/examples/sphere_high_speed/stl_files/refinment.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b321cab80c075736133e096e3513663dc9cb07df30ed8abd5e1b3aea5818a853 +size 96084 diff --git a/examples/sphere_high_speed/stl_files/sphere.stl b/examples/sphere_high_speed/stl_files/sphere.stl new file mode 100644 index 00000000..0cc5fa89 --- /dev/null +++ b/examples/sphere_high_speed/stl_files/sphere.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3326ea5fc29852e8c233faa1c403f7d0a21645251c0e7e68acde70f92d17dd0 +size 48084 diff --git a/modulus/sym/eq/pdes/navier_stokes.py b/modulus/sym/eq/pdes/navier_stokes.py index e6f21b9a..789f5309 100644 --- a/modulus/sym/eq/pdes/navier_stokes.py +++ b/modulus/sym/eq/pdes/navier_stokes.py @@ -26,7 +26,7 @@ class NavierStokes(PDE): Compressible Navier Stokes equations Reference: https://turbmodels.larc.nasa.gov/implementrans.html - + Parameters ========== nu : float, Sympy Symbol/Expr, str @@ -112,11 +112,13 @@ def __init__(self, nu, rho=1, dim=3, time=True, mixed_form=False): # set equations self.equations = {} self.equations["continuity"] = ( - rho.diff(t) + (rho * u).diff(x) + (rho * v).diff(y) + (rho * w).diff(z) + rho.diff(t) + (rho * u).diff(x) + + (rho * v).diff(y) + (rho * w).diff(z) ) if not self.mixed_form: - curl = Number(0) if rho.diff(x) == 0 else u.diff(x) + v.diff(y) + w.diff(z) + curl = Number(0) if rho.diff(x) == 0 else u.diff( + x) + v.diff(y) + w.diff(z) self.equations["momentum_x"] = ( (rho * u).diff(t) + ( @@ -339,13 +341,13 @@ def __init__(self, T, dim=3, time=True): normal_x * T.diff(x) + normal_y * T.diff(y) + normal_z * T.diff(z) ) + class NavierStokesIncompressible(PDE): - '''Cystom implementation of incompressible NS eqyastyans''' + '''Custom implementation of incompressible NS equations''' name = "navier stokes incompressible" def __init__(self, nu, rho=1, dim=3, time=False): - '''Parameters ========== nu : float, Sympy Symbol/Expr, str @@ -354,12 +356,8 @@ def __init__(self, nu, rho=1, dim=3, time=False): If `nu` is a Sympy Symbol or Expression then this is substituted into the equation. This allows for variable viscosity. - rho : float, Sympy Symbol/Expr, str - The density of the fluid. If `rho` is a str then it is - converted to Sympy Function of form 'rho(x,y,z,t)'. - If 'rho' is a Sympy Symbol or Expression then this - is substituted into the equation to allow for - compressible Navier Stokes. Default is 1. + rho : float, int + The density of the fluid. Default is 1. dim : int Dimension of the Navier Stokes (2 or 3). Default is 3. time : bool @@ -380,7 +378,6 @@ def __init__(self, nu, rho=1, dim=3, time=False): if not self.time: input_variables.pop("t") - # kinematic viscosity if isinstance(nu, str): nu = Function(nu)(*input_variables) @@ -389,18 +386,18 @@ def __init__(self, nu, rho=1, dim=3, time=False): # density if isinstance(rho, str): - rho = Function(rho)(*input_variables) + raise Exception("rho must be number") elif isinstance(rho, (float, int)): rho = Number(rho) - u = Function("u")(*input_variables) v = Function("v")(*input_variables) - w = Function("w")(*input_variables) - - + if self.dim == 3: + w = Function("w")(*input_variables) + else: + w = Number(0) - #pressure + # pressure p = Function("p")(*input_variables) # set equations @@ -410,19 +407,25 @@ def __init__(self, nu, rho=1, dim=3, time=False): ) self.equations["momentum_x"] = ( - u.diff(t) + u * u.diff(x) + v * u.diff(y) + w * u.diff(z) + 1 / rho * p.diff(x) + u.diff(t) + u * u.diff(x) + v * u.diff(y) + + w * u.diff(z) + 1 / rho * p.diff(x) - nu * (u.diff(x).diff(x) + u.diff(y).diff(y) + u.diff(z).diff(z)) ) self.equations["momentum_y"] = ( - v.diff(t) +u * v.diff(x) + v * v.diff(y) + w * v.diff(z) + 1 / rho * p.diff(y) + v.diff(t) + u * v.diff(x) + v * v.diff(y) + + w * v.diff(z) + 1 / rho * p.diff(y) - nu * (v.diff(x).diff(x) + v.diff(y).diff(y) + v.diff(z).diff(z)) ) self.equations["momentum_z"] = ( - w.diff(t) +u * w.diff(x) + v * w.diff(y) + w * w.diff(z) + 1 / rho * p.diff(z) + w.diff(t) + u * w.diff(x) + v * w.diff(y) + + w * w.diff(z) + 1 / rho * p.diff(z) - nu * (w.diff(x).diff(x) + w.diff(y).diff(y) + w.diff(z).diff(z)) ) + if self.dim == 2: + self.equations.pop("momentum_z") + class Curl(PDE): """ @@ -594,7 +597,8 @@ def __init__(self, T="T", D="D", rho=1, vec=["u", "v", "w"]): class Euler(PDE): - '''Cystom implementation of compressible Euler eqyastyans''' + '''Custom implementation of compressible Euler equations''' + def __init__(self, rho=1, dim=3, ratio_heats=1.4, time=False): ''' Parameters @@ -634,7 +638,6 @@ def __init__(self, rho=1, dim=3, ratio_heats=1.4, time=False): w = Function("w")(*input_variables) else: w = Number(0) - # pressure p = Function("p")(*input_variables) @@ -644,31 +647,35 @@ def __init__(self, rho=1, dim=3, ratio_heats=1.4, time=False): rho = Function(rho)(*input_variables) elif isinstance(rho, (float, int)): rho = Number(rho) - self.equations = {} - e = p /(rho * (ratio_heats - 1)) + e = p / (rho * (ratio_heats - 1)) # Energy - E = rho * (e + (u **2 + v **2 + w **2) / 2) + E = rho * (e + (u ** 2 + v ** 2 + w ** 2) / 2) self.equations["continuity"] = ( - rho.diff(t) + (rho * u).diff(x) + (rho * v).diff(y) + (rho * w).diff(z) + rho.diff(t) + (rho * u).diff(x) + + (rho * v).diff(y) + (rho * w).diff(z) ) self.equations["momentum_x"] = ( - (rho * u).diff(t) + (rho * u **2 + p).diff(x) + (rho * u * v).diff(y) + (rho * u * w).diff(z) + (rho * u).diff(t) + (rho * u ** 2 + p).diff(x) + + (rho * u * v).diff(y) + (rho * u * w).diff(z) ) self.equations["momentum_y"] = ( - (rho * v).diff(t) + (rho * u * v).diff(x) + (rho * v **2 + p).diff(y) + (rho * v * w).diff(z) + (rho * v).diff(t) + (rho * u * v).diff(x) + + (rho * v ** 2 + p).diff(y) + (rho * v * w).diff(z) ) self.equations["momentum_z"] = ( - (rho * w).diff(t) + (rho * u * w).diff(x) + (rho * v * w).diff(y) + (rho * w **2 + p).diff(z) + (rho * w).diff(t) + (rho * u * w).diff(x) + + (rho * v * w).diff(y) + (rho * w ** 2 + p).diff(z) ) self.equations["energy"] = ( - E.diff(t) + (u * (E + p)).diff(x) + (v * (E + p)).diff(y) + (w * (E + p)).diff(z) + E.diff(t) + (u * (E + p)).diff(x) + + (v * (E + p)).diff(y) + (w * (E + p)).diff(z) )