diff --git a/src/darsia/measure/wasserstein.py b/src/darsia/measure/wasserstein.py index 66b8e045..cb6f78b5 100644 --- a/src/darsia/measure/wasserstein.py +++ b/src/darsia/measure/wasserstein.py @@ -42,6 +42,7 @@ class VariationalWassersteinDistance(darsia.EMD): def __init__( self, grid: darsia.Grid, + weight: Optional[darsia.Image] = None, options: dict = {}, ) -> None: """Initialization of the variational Wasserstein distance. @@ -131,8 +132,12 @@ def __init__( self.mobility_mode = self.options.get("mobility_mode", "cell_based") """str: mode for computing the mobility""" + self.weight = weight + """Weight defined on cells""" + # Setup of method self._setup_dof_management() + self._setup_face_weights() self._setup_discretization() self._setup_linear_solver() self._setup_acceleration() @@ -193,6 +198,20 @@ def _setup_dof_management(self) -> None: ) """sps.csc_matrix: embedding operator for fluxes""" + def _setup_face_weights(self) -> None: + """Convert cell weights to face weights by harmonic averaging.""" + + if self.weight is None: + self.cell_weights = np.ones(self.grid.shape, dtype=float) + """np.ndarray: cell weights""" + self.face_weights = np.ones(self.grid.num_faces, dtype=float) + """np.ndarray: face weights""" + else: + self.cell_weights = self.weight.img + self.face_weights = darsia.cell_to_face_average( + self.grid, self.cell_weights, mode="harmonic" + ) + def _setup_discretization(self) -> None: """Setup of fixed discretization operators.""" @@ -224,7 +243,9 @@ def _setup_discretization(self) -> None: """sps.csc_matrix: mass matrix on cells: flat pressures -> flat pressures""" lumping = self.options.get("lumping", True) - self.mass_matrix_faces = darsia.FVMass(self.grid, "faces", lumping).mat + self.mass_matrix_faces = sps.diags(self.face_weights, format="csc") @ ( + darsia.FVMass(self.grid, "faces", lumping).mat + ) """sps.csc_matrix: mass matrix on faces: flat fluxes -> flat fluxes""" L_init = self.options.get("L_init", 1.0) @@ -555,12 +576,13 @@ def _setup_acceleration(self) -> None: # ! ---- Effective quantities ---- def transport_density( - self, flat_flux: np.ndarray, flatten: bool = True + self, flat_flux: np.ndarray, weighted: bool = True, flatten: bool = True ) -> np.ndarray: """Compute the transport density from the solution. Args: flat_flux (np.ndarray): face fluxes + weighted (bool): apply weighting. Defaults to True. flatten (bool): flatten the result. Defaults to True. Returns: @@ -608,7 +630,11 @@ def transport_density( transport_density = np.zeros(self.grid.shape, dtype=float) for quad_pt, quad_weight in zip(quad_pts, quad_weights): cell_flux = darsia.face_to_cell(self.grid, flat_flux, pt=quad_pt) - cell_flux_norm = np.linalg.norm(cell_flux, 2, axis=-1) + if weighted: + weighted_cell_flux = cell_flux * self.cell_weights[..., np.newaxis] + cell_flux_norm = np.linalg.norm(weighted_cell_flux, 2, axis=-1) + else: + cell_flux_norm = np.linalg.norm(cell_flux, 2, axis=-1) transport_density += quad_weight * cell_flux_norm if flatten: @@ -661,8 +687,10 @@ def vector_face_flux_norm(self, flat_flux: np.ndarray, mode: str) -> np.ndarray: else: average_mode = mode.split("_")[2] - # The flux norm is identical to the transport density - cell_flux_norm = self.transport_density(flat_flux, flatten=False) + # The flux norm is identical to the transport density without weights + cell_flux_norm = self.transport_density( + flat_flux, weighted=False, flatten=False + ) # Map to faces via averaging of neighboring cells flat_flux_norm = darsia.cell_to_face_average( @@ -1054,6 +1082,9 @@ def __call__( # Determine transport density transport_density = self.transport_density(flat_flux, flatten=False) + # Cell-weighted flux + weighted_flux = flux * self.cell_weights[..., np.newaxis] + # Return solution return_info = self.options.get("return_info", False) if return_info: @@ -1062,6 +1093,9 @@ def __call__( "grid": self.grid, "mass_diff": mass_diff, "flux": flux, + "weight": self.cell_weights, + "weight_inv": 1.0 / self.cell_weights, + "weighted_flux": weighted_flux, "pressure": pressure, "transport_density": transport_density, "src": img_1, @@ -1111,8 +1145,8 @@ class WassersteinDistanceNewton(VariationalWassersteinDistance): """ - def __init__(self, grid, options) -> None: - super().__init__(grid, options) + def __init__(self, grid, weight, options) -> None: + super().__init__(grid, weight, options) self.L = self.options.get("L", np.finfo(float).max) """float: relaxation/cut-off parameter for mobility, deactivated by default""" @@ -1379,6 +1413,7 @@ class WassersteinDistanceBregman(VariationalWassersteinDistance): def __init__( self, grid: darsia.Grid, + weight: Optional[darsia.Image] = None, options: dict = {}, ) -> None: """Initialize the Bregman method. @@ -1388,7 +1423,7 @@ def __init__( options (dict, optional): options. Defaults to {}. """ - super().__init__(grid, options) + super().__init__(grid, weight, options) self.L = self.options.get("L", 1.0) """Penality parameter for the Bregman iteration, associated to face mobility.""" @@ -1745,6 +1780,7 @@ def wasserstein_distance( mass_1: darsia.Image, mass_2: darsia.Image, method: str, + weight: Optional[darsia.Image] = None, **kwargs, ): """Unified access to Wasserstein distance computation between images with same mass. @@ -1770,12 +1806,13 @@ def wasserstein_distance( # Define method if method.lower() == "newton": - w1 = WassersteinDistanceNewton(grid, options) + w1 = WassersteinDistanceNewton(grid, weight, options) elif method.lower() == "bregman": - w1 = WassersteinDistanceBregman(grid, options) + w1 = WassersteinDistanceBregman(grid, weight, options) elif method.lower() == "cv2.emd": # Use Earth Mover's Distance from CV2 + assert weight is None, "Weighted EMD not supported by cv2." preprocess = kwargs.get("preprocess") w1 = darsia.EMD(preprocess) @@ -1801,6 +1838,16 @@ def wasserstein_distance_to_vtk( """ data = [ (key, info[key]) - for key in ["src", "dst", "mass_diff", "flux", "pressure", "transport_density"] + for key in [ + "src", + "dst", + "mass_diff", + "flux", + "weighted_flux", + "pressure", + "transport_density", + "weight", + "weight_inv", + ] ] darsia.plotting.to_vtk(path, data)