Skip to content

Commit

Permalink
ENH: Extend to weighted Wasserstein-1 distance.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwboth committed Aug 17, 2024
1 parent 9638a72 commit 930f215
Showing 1 changed file with 58 additions and 11 deletions.
69 changes: 58 additions & 11 deletions src/darsia/measure/wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class VariationalWassersteinDistance(darsia.EMD):
def __init__(
self,
grid: darsia.Grid,
weight: Optional[darsia.Image] = None,
options: dict = {},
) -> None:
"""Initialization of the variational Wasserstein distance.
Expand Down Expand Up @@ -131,8 +132,12 @@ def __init__(
self.mobility_mode = self.options.get("mobility_mode", "cell_based")
"""str: mode for computing the mobility"""

self.weight = weight
"""Weight defined on cells"""

# Setup of method
self._setup_dof_management()
self._setup_face_weights()
self._setup_discretization()
self._setup_linear_solver()
self._setup_acceleration()
Expand Down Expand Up @@ -193,6 +198,20 @@ def _setup_dof_management(self) -> None:
)
"""sps.csc_matrix: embedding operator for fluxes"""

def _setup_face_weights(self) -> None:
"""Convert cell weights to face weights by harmonic averaging."""

if self.weight is None:
self.cell_weights = np.ones(self.grid.shape, dtype=float)
"""np.ndarray: cell weights"""
self.face_weights = np.ones(self.grid.num_faces, dtype=float)
"""np.ndarray: face weights"""
else:
self.cell_weights = self.weight.img
self.face_weights = darsia.cell_to_face_average(
self.grid, self.cell_weights, mode="harmonic"
)

def _setup_discretization(self) -> None:
"""Setup of fixed discretization operators."""

Expand Down Expand Up @@ -224,7 +243,9 @@ def _setup_discretization(self) -> None:
"""sps.csc_matrix: mass matrix on cells: flat pressures -> flat pressures"""

lumping = self.options.get("lumping", True)
self.mass_matrix_faces = darsia.FVMass(self.grid, "faces", lumping).mat
self.mass_matrix_faces = sps.diags(self.face_weights, format="csc") @ (
darsia.FVMass(self.grid, "faces", lumping).mat
)
"""sps.csc_matrix: mass matrix on faces: flat fluxes -> flat fluxes"""

L_init = self.options.get("L_init", 1.0)
Expand Down Expand Up @@ -555,12 +576,13 @@ def _setup_acceleration(self) -> None:
# ! ---- Effective quantities ----

def transport_density(
self, flat_flux: np.ndarray, flatten: bool = True
self, flat_flux: np.ndarray, weighted: bool = True, flatten: bool = True
) -> np.ndarray:
"""Compute the transport density from the solution.
Args:
flat_flux (np.ndarray): face fluxes
weighted (bool): apply weighting. Defaults to True.
flatten (bool): flatten the result. Defaults to True.
Returns:
Expand Down Expand Up @@ -608,7 +630,11 @@ def transport_density(
transport_density = np.zeros(self.grid.shape, dtype=float)
for quad_pt, quad_weight in zip(quad_pts, quad_weights):
cell_flux = darsia.face_to_cell(self.grid, flat_flux, pt=quad_pt)
cell_flux_norm = np.linalg.norm(cell_flux, 2, axis=-1)
if weighted:
weighted_cell_flux = cell_flux * self.cell_weights[..., np.newaxis]
cell_flux_norm = np.linalg.norm(weighted_cell_flux, 2, axis=-1)
else:
cell_flux_norm = np.linalg.norm(cell_flux, 2, axis=-1)
transport_density += quad_weight * cell_flux_norm

if flatten:
Expand Down Expand Up @@ -661,8 +687,10 @@ def vector_face_flux_norm(self, flat_flux: np.ndarray, mode: str) -> np.ndarray:
else:
average_mode = mode.split("_")[2]

# The flux norm is identical to the transport density
cell_flux_norm = self.transport_density(flat_flux, flatten=False)
# The flux norm is identical to the transport density without weights
cell_flux_norm = self.transport_density(
flat_flux, weighted=False, flatten=False
)

# Map to faces via averaging of neighboring cells
flat_flux_norm = darsia.cell_to_face_average(
Expand Down Expand Up @@ -1054,6 +1082,9 @@ def __call__(
# Determine transport density
transport_density = self.transport_density(flat_flux, flatten=False)

# Cell-weighted flux
weighted_flux = flux * self.cell_weights[..., np.newaxis]

# Return solution
return_info = self.options.get("return_info", False)
if return_info:
Expand All @@ -1062,6 +1093,9 @@ def __call__(
"grid": self.grid,
"mass_diff": mass_diff,
"flux": flux,
"weight": self.cell_weights,
"weight_inv": 1.0 / self.cell_weights,
"weighted_flux": weighted_flux,
"pressure": pressure,
"transport_density": transport_density,
"src": img_1,
Expand Down Expand Up @@ -1111,8 +1145,8 @@ class WassersteinDistanceNewton(VariationalWassersteinDistance):
"""

def __init__(self, grid, options) -> None:
super().__init__(grid, options)
def __init__(self, grid, weight, options) -> None:
super().__init__(grid, weight, options)

self.L = self.options.get("L", np.finfo(float).max)
"""float: relaxation/cut-off parameter for mobility, deactivated by default"""
Expand Down Expand Up @@ -1379,6 +1413,7 @@ class WassersteinDistanceBregman(VariationalWassersteinDistance):
def __init__(
self,
grid: darsia.Grid,
weight: Optional[darsia.Image] = None,
options: dict = {},
) -> None:
"""Initialize the Bregman method.
Expand All @@ -1388,7 +1423,7 @@ def __init__(
options (dict, optional): options. Defaults to {}.
"""
super().__init__(grid, options)
super().__init__(grid, weight, options)
self.L = self.options.get("L", 1.0)
"""Penality parameter for the Bregman iteration, associated to face mobility."""

Expand Down Expand Up @@ -1745,6 +1780,7 @@ def wasserstein_distance(
mass_1: darsia.Image,
mass_2: darsia.Image,
method: str,
weight: Optional[darsia.Image] = None,
**kwargs,
):
"""Unified access to Wasserstein distance computation between images with same mass.
Expand All @@ -1770,12 +1806,13 @@ def wasserstein_distance(

# Define method
if method.lower() == "newton":
w1 = WassersteinDistanceNewton(grid, options)
w1 = WassersteinDistanceNewton(grid, weight, options)
elif method.lower() == "bregman":
w1 = WassersteinDistanceBregman(grid, options)
w1 = WassersteinDistanceBregman(grid, weight, options)

elif method.lower() == "cv2.emd":
# Use Earth Mover's Distance from CV2
assert weight is None, "Weighted EMD not supported by cv2."
preprocess = kwargs.get("preprocess")
w1 = darsia.EMD(preprocess)

Expand All @@ -1801,6 +1838,16 @@ def wasserstein_distance_to_vtk(
"""
data = [
(key, info[key])
for key in ["src", "dst", "mass_diff", "flux", "pressure", "transport_density"]
for key in [
"src",
"dst",
"mass_diff",
"flux",
"weighted_flux",
"pressure",
"transport_density",
"weight",
"weight_inv",
]
]
darsia.plotting.to_vtk(path, data)

0 comments on commit 930f215

Please sign in to comment.