Skip to content

Commit

Permalink
Make ModepyElementGroup public, rename attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 28, 2024
1 parent 28446e0 commit b6d3338
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 52 deletions.
6 changes: 3 additions & 3 deletions meshmode/discretization/connection/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def make_face_restriction(

# }}}

from meshmode.mesh import _ModepyElementGroup, make_mesh
from meshmode.mesh import ModepyElementGroup, make_mesh
bdry_mesh_groups = []
connection_data = {}

Expand All @@ -235,7 +235,7 @@ def make_face_restriction(

mgrp = grp.mesh_el_group

if not isinstance(mgrp, _ModepyElementGroup):
if not isinstance(mgrp, ModepyElementGroup):
raise NotImplementedError("can only take boundary of "
"meshes based on SimplexElementGroup and "
"TensorProductElementGroup")
Expand Down Expand Up @@ -301,7 +301,7 @@ def make_face_restriction(
bdry_unit_nodes = mp.edge_clustered_nodes_for_space(space, face)

vol_basis = mp.basis_for_space(
mgrp._modepy_space, mgrp._modepy_shape).functions
mgrp.space, mgrp.shape).functions

vertex_indices = np.empty(
(ngroup_bdry_elements, face.nvertices),
Expand Down
6 changes: 3 additions & 3 deletions meshmode/discretization/poly_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@
def from_mesh_interp_matrix(grp: InterpolatoryElementGroupBase) -> np.ndarray:
meg = grp.mesh_el_group

from meshmode.mesh import _ModepyElementGroup
assert isinstance(meg, _ModepyElementGroup)
from meshmode.mesh import ModepyElementGroup
assert isinstance(meg, ModepyElementGroup)

meg_basis = mp.basis_for_space(meg._modepy_space, meg._modepy_shape)
meg_basis = mp.basis_for_space(meg.space, meg.shape)
return mp.resampling_matrix(
meg_basis.functions,
grp.unit_nodes,
Expand Down
8 changes: 4 additions & 4 deletions meshmode/discretization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,12 @@ def tensor_cell_types(self):
def connectivity_for_element_group(self, grp):
import modepy as mp

from meshmode.mesh import _ModepyElementGroup
from meshmode.mesh import ModepyElementGroup

if isinstance(grp.mesh_el_group, _ModepyElementGroup):
shape = grp.mesh_el_group._modepy_shape
if isinstance(grp.mesh_el_group, ModepyElementGroup):
shape = grp.mesh_el_group.shape
space = mp.space_for_shape(shape, grp.order)
assert type(space) == type(grp.mesh_el_group._modepy_space) # noqa: E721
assert type(space) == type(grp.mesh_el_group.space) # noqa: E721

node_tuples = mp.node_tuples_for_space(space)

Expand Down
73 changes: 50 additions & 23 deletions meshmode/mesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Hashable, Iterable, Mapping, Sequence
from dataclasses import InitVar, dataclass, field, replace
from functools import partial
from typing import (
Any,
ClassVar,
Expand All @@ -41,13 +42,14 @@
import numpy.linalg as la

import modepy as mp
from pytools import memoize_method
from pytools import memoize_method, module_getattr_for_deprecations

from meshmode.mesh.tools import AffineMap, optional_array_equal


__doc__ = """
.. autoclass:: MeshElementGroup
.. autoclass:: ModepyElementGroup
.. autoclass:: SimplexElementGroup
.. autoclass:: TensorProductElementGroup
Expand Down Expand Up @@ -321,38 +323,45 @@ def make_group(cls, *args: Any, **kwargs: Any) -> MeshElementGroup:

# {{{ modepy-based element group

# https://stackoverflow.com/a/13624858
class _classproperty(property): # noqa: N801
def __get__(self, owner_self: Any, owner_cls: type | None = None) -> Any:
assert self.fget is not None
return self.fget(owner_cls)


@dataclass(frozen=True, eq=False)
class _ModepyElementGroup(MeshElementGroup):
class ModepyElementGroup(MeshElementGroup):
"""
.. attribute:: _modepy_shape_cls
.. attribute:: modepy_shape_cls
Must be set by subclasses to generate the correct shape and spaces
attributes for the group.
.. attribute:: _modepy_shape
.. attribute:: _modepy_space
.. attribute:: shape
.. attribute:: space
"""

_modepy_shape_cls: ClassVar[type[mp.Shape]]
_modepy_shape: mp.Shape = field(repr=False)
_modepy_space: mp.FunctionSpace = field(repr=False)
shape_cls: ClassVar[type[mp.Shape]]
shape: mp.Shape = field(repr=False)
space: mp.FunctionSpace = field(repr=False)

@property
def nvertices(self) -> int:
return self._modepy_shape.nvertices # pylint: disable=no-member
return self.shape.nvertices # pylint: disable=no-member

@property
@memoize_method
def _modepy_faces(self) -> Sequence[mp.Face]:
return mp.faces_for_shape(self._modepy_shape)
return mp.faces_for_shape(self.shape)

@memoize_method
def face_vertex_indices(self) -> tuple[tuple[int, ...], ...]:
return tuple(face.volume_vertex_indices for face in self._modepy_faces)

@memoize_method
def vertex_unit_coordinates(self) -> np.ndarray:
return mp.unit_vertices_for_shape(self._modepy_shape).T
return mp.unit_vertices_for_shape(self.shape).T

@classmethod
def make_group(cls,
Expand All @@ -361,7 +370,7 @@ def make_group(cls,
nodes: np.ndarray,
*,
unit_nodes: np.ndarray | None = None,
dim: int | None = None) -> _ModepyElementGroup:
dim: int | None = None) -> ModepyElementGroup:

if unit_nodes is None:
if dim is None:
Expand All @@ -374,7 +383,7 @@ def make_group(cls,
raise ValueError("'dim' does not match 'unit_nodes' dimension")

# pylint: disable=abstract-class-instantiated
shape = cls._modepy_shape_cls(dim)
shape = cls.shape_cls(dim)
space = mp.space_for_shape(shape, order)

if unit_nodes is None:
Expand All @@ -388,17 +397,29 @@ def make_group(cls,
vertex_indices=vertex_indices,
nodes=nodes,
unit_nodes=unit_nodes,
_modepy_shape=shape,
_modepy_space=space)
shape=shape,
space=space)

@_classproperty
def _modepy_shape_cls(cls) -> type[mp.Shape]: # noqa: N805
return cls.shape_cls

@property
def _modepy_shape(self) -> mp.Shape:
return self.shape

@property
def _modepy_space(self) -> mp.FunctionSpace:
return self.space

# }}}


@dataclass(frozen=True, eq=False)
class SimplexElementGroup(_ModepyElementGroup):
class SimplexElementGroup(ModepyElementGroup):
r"""Inherits from :class:`MeshElementGroup`."""

_modepy_shape_cls: ClassVar[type[mp.Shape]] = mp.Simplex
shape_cls: ClassVar[type[mp.Shape]] = mp.Simplex

@property
@memoize_method
Expand All @@ -407,10 +428,10 @@ def is_affine(self) -> bool:


@dataclass(frozen=True, eq=False)
class TensorProductElementGroup(_ModepyElementGroup):
class TensorProductElementGroup(ModepyElementGroup):
r"""Inherits from :class:`MeshElementGroup`."""

_modepy_shape_cls: ClassVar[type[mp.Shape]] = mp.Hypercube
shape_cls: ClassVar[type[mp.Shape]] = mp.Hypercube

@property
def is_affine(self) -> bool:
Expand Down Expand Up @@ -1472,8 +1493,8 @@ def __eq__(self, other: object) -> bool:
# {{{ node-vertex consistency test

def _mesh_group_node_vertex_error(mesh: Mesh, mgrp: MeshElementGroup) -> np.ndarray:
if isinstance(mgrp, _ModepyElementGroup):
basis = mp.basis_for_space(mgrp._modepy_space, mgrp._modepy_shape).functions
if isinstance(mgrp, ModepyElementGroup):
basis = mp.basis_for_space(mgrp.space, mgrp.shape).functions
else:
raise TypeError(f"unsupported group type: {type(mgrp).__name__}")

Expand Down Expand Up @@ -1535,7 +1556,7 @@ def _test_node_vertex_consistency(
:raises InconsistentVerticesError: if the vertices are not consistent.
"""
for igrp, mgrp in enumerate(mesh.groups):
if isinstance(mgrp, _ModepyElementGroup):
if isinstance(mgrp, ModepyElementGroup):
_test_group_node_vertex_consistency_resampling(mesh, igrp, tol=tol)
else:
warn("Not implemented: node-vertex consistency check for "
Expand Down Expand Up @@ -2182,7 +2203,7 @@ def is_affine_simplex_group(
return True

# get matrices
basis = mp.basis_for_space(group._modepy_space, group._modepy_shape)
basis = mp.basis_for_space(group.space, group.shape)
vinv = la.inv(mp.vandermonde(basis.functions, group.unit_nodes))
diff = mp.differentiation_matrices(
basis.functions, basis.gradients, group.unit_nodes)
Expand Down Expand Up @@ -2210,4 +2231,10 @@ def is_affine_simplex_group(

# }}}


__getattr__ = partial(module_getattr_for_deprecations, __name__, {
"_ModepyElementGroup": ("ModepyElementGroup", ModepyElementGroup, 2026),
})


# vim: foldmethod=marker
2 changes: 1 addition & 1 deletion meshmode/mesh/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@ def warp_and_refine_until_resolved(
n_tail_orders=1 if warped_mesh.dim > 1 else 2)

basis = mp.orthonormal_basis_for_space(
egrp._modepy_space, egrp._modepy_shape)
egrp.space, egrp.shape)
vdm_inv = la.inv(mp.vandermonde(basis.functions, egrp.unit_nodes))

mapping_coeffs = np.einsum("ij,dej->dei", vdm_inv, egrp.nodes)
Expand Down
6 changes: 3 additions & 3 deletions meshmode/mesh/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ def find_volume_mesh_element_group_orientation(
each negatively oriented element.
"""

from meshmode.mesh import _ModepyElementGroup
from meshmode.mesh import ModepyElementGroup

if not isinstance(grp, _ModepyElementGroup):
if not isinstance(grp, ModepyElementGroup):
raise NotImplementedError(
"finding element orientations "
"only supported on "
Expand Down Expand Up @@ -756,7 +756,7 @@ def _get_tensor_product_element_flip_matrix_and_vertex_permutation(

flipped_unit_nodes = np.einsum("ij,jn->in", unit_flip_matrix, grp.unit_nodes)

basis = mp.basis_for_space(grp._modepy_space, grp._modepy_shape)
basis = mp.basis_for_space(grp.space, grp.shape)
flip_matrix = mp.resampling_matrix(
basis.functions,
flipped_unit_nodes,
Expand Down
28 changes: 14 additions & 14 deletions meshmode/mesh/refinement/tessellate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import modepy as mp

from meshmode.mesh import MeshElementGroup, _ModepyElementGroup
from meshmode.mesh import MeshElementGroup, ModepyElementGroup


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -183,11 +183,11 @@ def _get_ref_midpoints(shape, ref_vertices):

# {{{ modepy.shape tessellation and resampling

@get_group_midpoints.register(_ModepyElementGroup)
@get_group_midpoints.register(ModepyElementGroup)
def _get_group_midpoints_modepy(
meg: _ModepyElementGroup, el_tess_info, elements):
shape = meg._modepy_shape
space = meg._modepy_space
meg: ModepyElementGroup, el_tess_info, elements):
shape = meg.shape
space = meg.space

# get midpoints in reference coordinates
midpoints = -1 + np.array(_get_ref_midpoints(shape, el_tess_info.ref_vertices))
Expand All @@ -204,11 +204,11 @@ def _get_group_midpoints_modepy(
return dict(zip(elements, resampled_midpoints, strict=True))


@get_group_tessellated_nodes.register(_ModepyElementGroup)
@get_group_tessellated_nodes.register(ModepyElementGroup)
def _get_group_tessellated_nodes_modepy(
meg: _ModepyElementGroup, el_tess_info, elements):
shape = meg._modepy_shape
space = meg._modepy_space
meg: ModepyElementGroup, el_tess_info, elements):
shape = meg.shape
space = meg.space

# get child unit node coordinates.
from meshmode.mesh.refinement.utils import map_unit_nodes_to_children
Expand All @@ -234,18 +234,18 @@ def _get_group_tessellated_nodes_modepy(
}


@get_group_tessellation_info.register(_ModepyElementGroup)
def _get_group_tessellation_info_modepy(meg: _ModepyElementGroup):
shape = meg._modepy_shape
@get_group_tessellation_info.register(ModepyElementGroup)
def _get_group_tessellation_info_modepy(meg: ModepyElementGroup):
shape = meg.shape
space = mp.space_for_shape(shape, 2)
assert type(space) == type(meg._modepy_space) # noqa: E721
assert type(space) == type(meg.space) # noqa: E721

ref_vertices = mp.node_tuples_for_space(space)
ref_vertices_to_index = {rv: i for i, rv in enumerate(ref_vertices)}

from pytools import add_tuples
space = mp.space_for_shape(shape, 1)
assert type(space) == type(meg._modepy_space) # noqa: E721
assert type(space) == type(meg.space) # noqa: E721
orig_vertices = tuple(add_tuples(vt, vt) for vt in mp.node_tuples_for_space(space))
orig_vertex_indices = [ref_vertices_to_index[vt] for vt in orig_vertices]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"modepy>=2021.1",
"numpy",
"pymbolic>=2022.2",
"pytools>=2022.1",
"pytools>=2024.1.17",
"recursivenodes",
]

Expand Down

0 comments on commit b6d3338

Please sign in to comment.