diff --git a/python/adjoint/filters.py b/python/adjoint/filters.py index 1b03c0b7f..12d406643 100644 --- a/python/adjoint/filters.py +++ b/python/adjoint/filters.py @@ -700,6 +700,153 @@ def tanh_projection(x: np.ndarray, beta: float, eta: float) -> np.ndarray: ) +def smoothed_projection( + x_smoothed: ArrayLikeType, + beta: float, + eta: float, + resolution: float, +): + """Project using subpixel smoothing, which allows for β→∞. + + This technique integrates out the discontinuity within the projection + function, allowing the user to smoothly increase β from 0 to ∞ without + losing the gradient. Effectively, a level set is created, and from this + level set, first-order subpixel smoothing is applied to the interfaces (if + any are present). + + In order for this to work, the input array must already be smooth (e.g. by + filtering). + + While the original approach involves numerical quadrature, this approach + performs a "trick" by assuming that the user is always infinitely projecting + (β=∞). In this case, the expensive quadrature simplifies to an analytic + fill-factor expression. When to use this fill factor requires some careful + logic. + + For one, we want to make sure that the user can indeed project at any level + (not just infinity). So in these cases, we simply check if in interface is + within the pixel. If not, we revert to the standard filter plus project + technique. + + If there is an interface, we want to make sure the derivative remains + continuous both as the interface leaves the cell, *and* as it crosses the + center. To ensure this, we need to account for the different possibilities. + + Args: + x: The (2D) input design parameters. + beta: The thresholding parameter in the range [0, inf]. Determines the + degree of binarization of the output. + eta: The threshold point in the range [0, 1]. + resolution: resolution of the design grid (not the Meep grid + resolution). + Returns: + The projected and smoothed output. + + Example: + >>> Lx = 2; Ly = 2 + >>> resolution = 50 + >>> eta_i = 0.5; eta_e = 0.75 + >>> lengthscale = 0.1 + >>> filter_radius = get_conic_radius_from_eta_e(lengthscale, eta_e) + >>> Nx = onp.round(Lx * resolution) + 1 + >>> Ny = onp.round(Ly * resolution) + 1 + >>> A = onp.random.rand(Nx, Ny) + >>> beta = npa.inf + >>> A_smoothed = conic_filter(A, filter_radius, Lx, Ly, resolution) + >>> A_projected = smoothed_projection(A_smoothed, beta, eta_i, resolution) + """ + # Note that currently, the underlying assumption is that the smoothing + # kernel is a circle, which means dx = dy. + dx = dy = 1 / resolution + pixel_radius = dx / 2 + + x_projected = tanh_projection(x_smoothed, beta=beta, eta=eta) + + # Compute the spatial gradient (using finite differences) of the *filtered* + # field, which will always be smooth and is the key to our approach. This + # gradient essentially represents the normal direction pointing the the + # nearest inteface. + x_grad = npa.gradient(x_smoothed) + x_grad_helper = (x_grad[0] / dx) ** 2 + (x_grad[1] / dy) ** 2 + + # Note that a uniform field (norm=0) is problematic, because it creates + # divide by zero issues and makes backpropagation difficult, so we sanitize + # and determine where smoothing is actually needed. The value where we don't + # need smoothings doesn't actually matter, since all our computations our + # purely element-wise (no spatial locality) and those pixels will instead + # rely on the standard projection. So just use 1, since it's well behaved. + nonzero_norm = npa.abs(x_grad_helper) > 0 + + x_grad_norm = npa.sqrt(npa.where(nonzero_norm, x_grad_helper, 1)) + x_grad_norm_eff = npa.where(nonzero_norm, x_grad_norm, 1) + + # The distance for the center of the pixel to the nearest interface + d = (eta - x_smoothed) / x_grad_norm_eff + + # Only need smoothing if an interface lies within the voxel. Since d is + # actually an "effective" d by this point, we need to ignore values that may + # have been sanitized earlier on. + needs_smoothing = nonzero_norm & (npa.abs(d) <= pixel_radius) + + # The fill factor is used to perform simple, first-order subpixel smoothing. + # We use the (2D) analytic expression that comes when assuming the smoothing + # kernel is a circle. Note that because the kernel contains some + # expressions that are sensitive to NaNs, we have to use the "double where" + # trick to avoid the Nans in the backward trace. This is a common problem + # with array-based AD tracers, apparently. See here: + # https://github.com/google/jax/issues/1052#issuecomment-5140833520 + + arccos_term = pixel_radius**2 * npa.arccos( + npa.where( + needs_smoothing, + d / pixel_radius, + 0.0, + ) + ) + + sqrt_term = d * npa.sqrt( + npa.where( + needs_smoothing, + pixel_radius**2 - d**2, + 1, + ) + ) + + fill_factor = npa.where( + needs_smoothing, + (1 / (npa.pi * pixel_radius**2)) * (arccos_term - sqrt_term), + 1, + ) + + # Determine the upper and lower bounds of materials in the current pixel. + x_minus = x_smoothed - x_grad_norm * pixel_radius + x_plus = x_smoothed + x_grad_norm * pixel_radius + + # Create an "effective" set of materials that will ensure everything is + # piecewise differentiable, even if an interface moves out of a pixel, or + # through the pixel center. + x_minus_eff_pert = (x_smoothed * d + x_minus * (pixel_radius - d)) / pixel_radius + x_minus_eff = npa.where( + (d > 0), + x_minus_eff_pert, + x_minus, + ) + x_plus_eff_pert = (-x_smoothed * d + x_plus * (pixel_radius + d)) / pixel_radius + x_plus_eff = npa.where( + (d > 0), + x_plus, + x_plus_eff_pert, + ) + + # Only apply smoothing to interfaces + x_projected_smoothed = (1 - fill_factor) * x_minus_eff + (fill_factor) * x_plus_eff + return npa.where( + needs_smoothing, + x_projected_smoothed, + x_projected, + ) + + def heaviside_projection(x, beta, eta): """Projection filter that thresholds the input parameters between 0 and 1.