Skip to content

Commit

Permalink
fix backprop and norm
Browse files Browse the repository at this point in the history
  • Loading branch information
Alec Hammond committed Jan 12, 2024
1 parent 16702c5 commit 7fb6c6b
Showing 1 changed file with 147 additions and 146 deletions.
293 changes: 147 additions & 146 deletions python/adjoint/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,153 @@ def tanh_projection(x: np.ndarray, beta: float, eta: float) -> np.ndarray:
)


def smoothed_projection(
x_smoothed: ArrayLikeType,
beta: float,
eta: float,
resolution: float,
):
"""Project using subpixel smoothing, which allows for β→∞.
This technique integrates out the discontinuity within the projection
function, allowing the user to smoothly increase β from 0 to ∞ without
losing the gradient. Effectively, a level set is created, and from this
level set, first-order subpixel smoothing is applied to the interfaces (if
any are present).
In order for this to work, the input array must already be smooth (e.g. by
filtering).
While the original approach involves numerical quadrature, this approach
performs a "trick" by assuming that the user is always infinitely projecting
(β=∞). In this case, the expensive quadrature simplifies to an analytic
fill-factor expression. When to use this fill factor requires some careful
logic.
For one, we want to make sure that the user can indeed project at any level
(not just infinity). So in these cases, we simply check if in interface is
within the pixel. If not, we revert to the standard filter plus project
technique.
If there is an interface, we want to make sure the derivative remains
continuous both as the interface leaves the cell, *and* as it crosses the
center. To ensure this, we need to account for the different possibilities.
Args:
x: The (2D) input design parameters.
beta: The thresholding parameter in the range [0, inf]. Determines the
degree of binarization of the output.
eta: The threshold point in the range [0, 1].
resolution: resolution of the design grid (not the Meep grid
resolution).
Returns:
The projected and smoothed output.
Example:
>>> Lx = 2; Ly = 2
>>> resolution = 50
>>> eta_i = 0.5; eta_e = 0.75
>>> lengthscale = 0.1
>>> filter_radius = get_conic_radius_from_eta_e(lengthscale, eta_e)
>>> Nx = onp.round(Lx * resolution) + 1
>>> Ny = onp.round(Ly * resolution) + 1
>>> A = onp.random.rand(Nx, Ny)
>>> beta = npa.inf
>>> A_smoothed = conic_filter(A, filter_radius, Lx, Ly, resolution)
>>> A_projected = smoothed_projection(A_smoothed, beta, eta_i, resolution)
"""
# Note that currently, the underlying assumption is that the smoothing
# kernel is a circle, which means dx = dy.
dx = dy = 1 / resolution
pixel_radius = dx / 2

x_projected = tanh_projection(x_smoothed, beta=beta, eta=eta)

# Compute the spatial gradient (using finite differences) of the *filtered*
# field, which will always be smooth and is the key to our approach. This
# gradient essentially represents the normal direction pointing the the
# nearest inteface.
x_grad = npa.gradient(x_smoothed)
x_grad_helper = (x_grad[0] / dx) ** 2 + (x_grad[1] / dy) ** 2

# Note that a uniform field (norm=0) is problematic, because it creates
# divide by zero issues and makes backpropagation difficult, so we sanitize
# and determine where smoothing is actually needed. The value where we don't
# need smoothings doesn't actually matter, since all our computations our
# purely element-wise (no spatial locality) and those pixels will instead
# rely on the standard projection. So just use 1, since it's well behaved.
nonzero_norm = npa.abs(x_grad_helper) > 0

x_grad_norm = npa.sqrt(npa.where(nonzero_norm, x_grad_helper, 1))
x_grad_norm_eff = npa.where(nonzero_norm, x_grad_norm, 1)

# The distance for the center of the pixel to the nearest interface
d = (eta - x_smoothed) / x_grad_norm_eff

# Only need smoothing if an interface lies within the voxel. Since d is
# actually an "effective" d by this point, we need to ignore values that may
# have been sanitized earlier on.
needs_smoothing = nonzero_norm & (npa.abs(d) <= pixel_radius)

# The fill factor is used to perform simple, first-order subpixel smoothing.
# We use the (2D) analytic expression that comes when assuming the smoothing
# kernel is a circle. Note that because the kernel contains some
# expressions that are sensitive to NaNs, we have to use the "double where"
# trick to avoid the Nans in the backward trace. This is a common problem
# with array-based AD tracers, apparently. See here:
# https://github.com/google/jax/issues/1052#issuecomment-5140833520

arccos_term = pixel_radius**2 * npa.arccos(
npa.where(
needs_smoothing,
d / pixel_radius,
0.0,
)
)

sqrt_term = d * npa.sqrt(
npa.where(
needs_smoothing,
pixel_radius**2 - d**2,
1,
)
)

fill_factor = npa.where(
needs_smoothing,
(1 / (npa.pi * pixel_radius**2)) * (arccos_term - sqrt_term),
1,
)

# Determine the upper and lower bounds of materials in the current pixel.
x_minus = x_smoothed - x_grad_norm * pixel_radius
x_plus = x_smoothed + x_grad_norm * pixel_radius

# Create an "effective" set of materials that will ensure everything is
# piecewise differentiable, even if an interface moves out of a pixel, or
# through the pixel center.
x_minus_eff_pert = (x_smoothed * d + x_minus * (pixel_radius - d)) / pixel_radius
x_minus_eff = npa.where(
(d > 0),
x_minus_eff_pert,
x_minus,
)
x_plus_eff_pert = (-x_smoothed * d + x_plus * (pixel_radius + d)) / pixel_radius
x_plus_eff = npa.where(
(d > 0),
x_plus,
x_plus_eff_pert,
)

# Only apply smoothing to interfaces
x_projected_smoothed = (1 - fill_factor) * x_minus_eff + (fill_factor) * x_plus_eff
return npa.where(
needs_smoothing,
x_projected_smoothed,
x_projected,
)


def heaviside_projection(x, beta, eta):
"""Projection filter that thresholds the input parameters between 0 and 1.
Expand Down Expand Up @@ -1099,149 +1246,3 @@ def gray_indicator(x):
density-based topology optimization. Archive of Applied Mechanics, 86(1-2), 189-218.
"""
return npa.mean(4 * x.flatten() * (1 - x.flatten())) * 100


def hybrid_levelset(
x: ArrayLikeType,
beta: float,
eta: float,
Lx: float,
Ly: float,
resolution: float,
filter_radius: float,
periodic_axes: ArrayLikeType = None,
):
"""Filter and project using subpixel smoothing, which allows for β→∞.
This technique integrates out the discontinuity within the projection
function, allowing the user to smoothly increase β from 0 to ∞ without
losing the gradient. Effectively, a level set is created, and from this
level set, first-order subpixel smoothing is applied to the interfaces (if
any are present).
While the original approach involves numerical quadrature, this approach
performs a "trick" by assuming that the user is always infinitely projecting
(β=∞). In this case, the expensive quadrature simplifies to an analytic
fill-factor expression. When to use this fill factor requires some careful
logic.
For one, we want to make sure that the user can indeed project at any level
(not just infinity). So in these cases, we simply check if in interface is
within the pixel. If not, we revert to the standard filter plus project
technique.
If there is an interface, we want to make sure the derivative remains
continuous both as the interface leaves the cell, *and* as it crosses the
center. To ensure this, we need to account for the different possibilities.
Args:
x: The (2D) input design parameters.
beta: The thresholding parameter in the range [0, inf]. Determines the
degree of binarization of the output.
eta: The threshold point in the range [0, 1].
Lx: The length of design region in X direction (in Meep units).
Ly: The length of design region in Y direction (in Meep units).
resolution: The resolution of the design grid (not the Meep grid
resolution).
filter_radius: The filter radius (in Meep units).
periodic_axes: The list of axes (x, y = 0, 1) that are to be treated as
periodic. Default is None (all axes are non-periodic).
Returns:
The projected and smoothed output.
Example:
>>> Lx = 2
>>> Ly = 2
>>> resolution = 50
>>> eta_i = 0.5
>>> eta_e = 0.75
>>> lengthscale = 0.1
>>> filter_radius = get_conic_radius_from_eta_e(lengthscale, eta_e)
>>> Nx = onp.round(Lx * resolution) + 1
>>> Ny = onp.round(Ly * resolution) + 1
>>> A = onp.random.rand(Nx, Ny)
>>> beta = npa.inf
>>> A_smoothed = hybrid_levelset(A, beta, eta_i, Lx, Ly, resolution, filter_radius)
"""
# Note that currently, the underlying assumption is that the smoothing
# kernel is a circle, which means dx = dy.
dx = dy = 1 / resolution
pixel_radius = dx / 2

# Perform the standard filter and projection common in density-based TO
x_filtered = conic_filter(
x,
radius=filter_radius,
Lx=Lx,
Ly=Ly,
resolution=resolution,
periodic_axes=periodic_axes,
)
x_projected = npa.clip(tanh_projection(x_filtered, beta=beta, eta=eta), 0, 1)

# Compute the spatial gradient (using finite differences) of the *filtered*
# field, which will always be smooth and is the key to our approach. This
# gradient essentially represents the normal direction pointing the the
# nearest inteface.
x_grad = npa.gradient(x_filtered)
x_grad_norm = npa.sqrt((x_grad[0] / dx) ** 2 + (x_grad[1] / dy) ** 2)

# The distance for the center of the pixel to the nearest interface
d = (eta - x_filtered) / x_grad_norm

# The fill factor is used to perform simple, first-order subpixel smoothing.
# We use the (2D) analytic expression that comes when assuming the smoothing
# kernel is a cylinder. Note that because the kernel contains some
# expressions that are sensitive to NaNs, we have to use the "double where"
# trick to avoid the Nans in the backward trace. This is a common problem
# with array-based AD tracers, apparently. See here:
# https://github.com/google/jax/issues/1052#issuecomment-5140833520
fill_factor = npa.where(
(d <= pixel_radius)
& (npa.abs(d / pixel_radius) <= 1), # domain of arccos() and sqrt()
(1 / (npa.pi * pixel_radius**2))
* (
pixel_radius**2
* npa.arccos(
# "double where" trick
npa.where(
npa.abs(d / pixel_radius) < 1,
d / pixel_radius,
0.0,
)
)
- d
* npa.sqrt(
# "double where" trick
npa.where(
(pixel_radius**2 - d**2) > 0, pixel_radius**2 - d**2, 0
)
)
),
1,
)

# Determine the upper and lower bounds of materials in the current pixel.
x_minus = x_projected - x_grad_norm * pixel_radius
x_plus = x_projected + x_grad_norm * pixel_radius

# Create an "effective" set of materials that will ensure everything is
# piecewise differentiable, even if an interface moves out of a pixel, or
# through the pixel center.
x_minus_eff = npa.where(
d > 0,
(x_filtered * d + x_minus * (pixel_radius - d)) / pixel_radius,
x_minus,
)
x_plus_eff = npa.where(
d > 0,
x_plus,
(-x_filtered * d + x_plus * (pixel_radius + d)) / pixel_radius,
)

# Only apply smoothing to interfaces
return npa.where(
npa.abs(d) > pixel_radius,
x_projected,
(1 - fill_factor) * x_minus_eff + (fill_factor) * x_plus_eff,
)

0 comments on commit 7fb6c6b

Please sign in to comment.