diff --git a/python/adjoint/filters.py b/python/adjoint/filters.py index c62e3df3a..197e0b02f 100644 --- a/python/adjoint/filters.py +++ b/python/adjoint/filters.py @@ -706,156 +706,6 @@ def tanh_projection(x: np.ndarray, beta: float, eta: float) -> np.ndarray: ) -def smoothed_projection( - x_smoothed: ArrayLikeType, - beta: float, - eta: float, - resolution: float, -): - """Project using subpixel smoothing, which allows for β→∞. - - This technique integrates out the discontinuity within the projection - function, allowing the user to smoothly increase β from 0 to ∞ without - losing the gradient. Effectively, a level set is created, and from this - level set, first-order subpixel smoothing is applied to the interfaces (if - any are present). - - In order for this to work, the input array must already be smooth (e.g. by - filtering). - - While the original approach involves numerical quadrature, this approach - performs a "trick" by assuming that the user is always infinitely projecting - (β=∞). In this case, the expensive quadrature simplifies to an analytic - fill-factor expression. When to use this fill factor requires some careful - logic. - - For one, we want to make sure that the user can indeed project at any level - (not just infinity). So in these cases, we simply check if in interface is - within the pixel. If not, we revert to the standard filter plus project - technique. - - If there is an interface, we want to make sure the derivative remains - continuous both as the interface leaves the cell, *and* as it crosses the - center. To ensure this, we need to account for the different possibilities. - - Args: - x: The (2D) input design parameters. - beta: The thresholding parameter in the range [0, inf]. Determines the - degree of binarization of the output. - eta: The threshold point in the range [0, 1]. - resolution: resolution of the design grid (not the Meep grid - resolution). - Returns: - The projected and smoothed output. - - Example: - >>> Lx = 2; Ly = 2 - >>> resolution = 50 - >>> eta_i = 0.5; eta_e = 0.75 - >>> lengthscale = 0.1 - >>> filter_radius = get_conic_radius_from_eta_e(lengthscale, eta_e) - >>> Nx = onp.round(Lx * resolution) + 1 - >>> Ny = onp.round(Ly * resolution) + 1 - >>> A = onp.random.rand(Nx, Ny) - >>> beta = npa.inf - >>> A_smoothed = conic_filter(A, filter_radius, Lx, Ly, resolution) - >>> A_projected = smoothed_projection(A_smoothed, beta, eta_i, resolution) - """ - # Note that currently, the underlying assumption is that the smoothing - # kernel is a circle, which means dx = dy. - dx = dy = 1 / resolution - pixel_radius = dx / 2 - - x_projected = tanh_projection(x_smoothed, beta=beta, eta=eta) - - # Compute the spatial gradient (using finite differences) of the *filtered* - # field, which will always be smooth and is the key to our approach. This - # gradient essentially represents the normal direction pointing the the - # nearest inteface. - x_grad = npa.gradient(x_smoothed) - x_grad_helper = (x_grad[0] / dx) ** 2 + (x_grad[1] / dy) ** 2 - - # Note that a uniform field (norm=0) is problematic, because it creates - # divide by zero issues and makes backpropagation difficult, so we sanitize - # and determine where smoothing is actually needed. The value where we don't - # need smoothings doesn't actually matter, since all our computations our - # purely element-wise (no spatial locality) and those pixels will instead - # rely on the standard projection. - nonzero_norm = npa.abs(x_grad_helper) > sys.float_info.epsilon - - x_grad_norm = npa.sqrt(npa.where(nonzero_norm, x_grad_helper, 1)) - x_grad_norm_eff = npa.where(nonzero_norm, x_grad_norm, 1) - - # The distance for the center of the pixel to the nearest interface - d = (eta - x_smoothed) / x_grad_norm_eff - - # Only need smoothing if an interface lies within the voxel. Since d is - # actually an "effective" d by this point, we need to ignore values that may - # have been sanitized earlier on. - needs_smoothing = nonzero_norm & (npa.abs(d) <= pixel_radius) - d = npa.where(needs_smoothing, d, 1) - - # The fill factor is used to perform simple, first-order subpixel smoothing. - # We use the (2D) analytic expression that comes when assuming the smoothing - # kernel is a circle. Note that because the kernel contains some - # expressions that are sensitive to NaNs, we have to use the "double where" - # trick to avoid the Nans in the backward trace. This is a common problem - # with array-based AD tracers, apparently. See here: - # https://github.com/google/jax/issues/1052#issuecomment-5140833520 - - arccos_term = pixel_radius**2 * npa.arccos( - npa.where( - needs_smoothing, - d / pixel_radius, - 0.0, - ) - ) - - sqrt_term = d * npa.sqrt( - npa.where( - needs_smoothing, - pixel_radius**2 - d**2, - 1, - ) - ) - fill_factor = (1 / npa.pi) * (arccos_term - sqrt_term) - fill_factor_eff = npa.where( - needs_smoothing, - fill_factor, - 1, - ) - - # Determine the upper and lower bounds of materials in the current pixel. - x_minus = x_smoothed - x_grad_norm_eff * pixel_radius - x_plus = x_smoothed + x_grad_norm_eff * pixel_radius - - # Create an "effective" set of materials that will ensure everything is - # piecewise differentiable, even if an interface moves out of a pixel, or - # through the pixel center. - x_minus_eff_pert = (x_smoothed * d + x_minus * (pixel_radius - d)) / pixel_radius - x_minus_eff = npa.where( - (d > 0), - x_minus_eff_pert, - x_minus, - ) - x_plus_eff_pert = (-x_smoothed * d + x_plus * (pixel_radius + d)) / pixel_radius - x_plus_eff = npa.where( - (d > 0), - x_plus, - x_plus_eff_pert, - ) - - # Only apply smoothing to interfaces - x_projected_smoothed = (1 - fill_factor_eff) * x_minus_eff + ( - fill_factor_eff - ) * x_plus_eff - - return npa.where( - needs_smoothing, - x_projected_smoothed, - x_projected, - ) - def smoothed_projection( x_smoothed: ArrayLikeType, @@ -975,7 +825,7 @@ def smoothed_projection( 1, ) - # Determine the upper and lower bounds of materials in the current pixel. + # Determine the upper and lower bounds of materials in the current pixel (before projection). x_minus = x_smoothed - x_grad_norm * pixel_radius x_plus = x_smoothed + x_grad_norm * pixel_radius @@ -995,8 +845,12 @@ def smoothed_projection( x_plus_eff_pert, ) + # Finally, we project the extents of our range. + x_plus_eff_projected = tanh_projection(x_plus_eff, beta=beta, eta=eta) + x_minus_eff_projected = tanh_projection(x_minus_eff, beta=beta, eta=eta) + # Only apply smoothing to interfaces - x_projected_smoothed = (1 - fill_factor) * x_minus_eff + (fill_factor) * x_plus_eff + x_projected_smoothed = (1 - fill_factor) * x_minus_eff_projected + (fill_factor) * x_plus_eff_projected return npa.where( needs_smoothing, x_projected_smoothed,