Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: revamped adjoint filters #1625

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

smartalecH
Copy link
Collaborator

@smartalecH smartalecH commented Jun 21, 2021

Updates all the adjoint filters/projections to:

  • use jax rather than autograd.
  • filter/project along any arbitrary dimension (i.e. support for 1D, 2D, and 3D designs)
  • simplifies the interface for various filters/projections
  • various bug fixes

TODO

  • Update adjoint test with new filter function API
  • Update tutorials
  • Allow for different filter radii in each dimension (for nonuniform lengthscale constraints)
  • Resolve decision to filtfilt (shouldn't technically be needed with our even filters)

@mochen4 to pull into an existing dev branch, make sure you rebase on top of master, as described here.

Note: it looks that while rebasing, not all of the formatting changes were correctly committed (hence some of the extraneous diffs). I'll clean those up too.

@mochen4
Copy link
Collaborator

mochen4 commented Jun 22, 2021

I have been trying to use the new code. I don't know if I missed something, but currently conic_filter(x,radius) gives me nan.

Specifically,

xv, yv = np.meshgrid(np.linspace(-Lx / 2, Lx / 2, Nx),
np.linspace(-Ly / 2, Ly / 2, Ny),
sparse=True,
indexing='ij')
now becomes

    xl = npj.linspace(-x.shape[0] / 2, x.shape[0] / 2, x.shape[0])
    yl = npj.linspace(-x.shape[1] / 2, x.shape[1] / 2, x.shape[1])
    zl = npj.linspace(-x.shape[2] / 2, x.shape[2] / 2, x.shape[2])
    X, Y, Z = npj.meshgrid(xl, yl, zl, sparse=True, indexing='ij')

Thus, kernel = np.where(np.abs(X**2 + Y**2 + Z**2) <= radius**2, (1 - np.sqrt(npj.abs(X**2 + Y**2 + Z**2)) / radius), 0)
is now a zero array. Shouldn't we have x.shape[0] / (2*resolution)or Lx/2 or instead of x.shape[0] / 2 etc. in the linspace?

@smartalecH
Copy link
Collaborator Author

Shouldn't we have x.shape[0] / (2*resolution)or Lx/2 or instead of x.shape[0] / 2 etc. in the linspace?

No, the radius is in units of pixels, so there is no need to factor in the resolution.

The input array x needs to be three-dimensional (in your case) and radius needs to be >= 1.

@mochen4
Copy link
Collaborator

mochen4 commented Jun 22, 2021

Thanks! So for get_conic_radius_from_eta_e(b,eta_e), the minimum length b will be in units of pixels instead of meep unit as before? @smartalecH

@smartalecH
Copy link
Collaborator Author

the minimum length b will be in units of pixels instead of meep unit as before

No, that function remains unchanged. So the user must manually multiply the radius in meep units by the corresponding design region resolution and pass the result of that value into the new filtering routines.

The idea behind the new filtering API is to keep things as simple as possible.

@mochen4
Copy link
Collaborator

mochen4 commented Jun 22, 2021

Thanks!

@@ -3,14 +3,15 @@
"""

import numpy as np
from autograd import numpy as npa
import jax
from jax import numpy as npj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI- convention is import jax.numpy as jnp rather than npj: https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

@stevengj
Copy link
Collaborator

stevengj commented Jun 22, 2021

So the user must manually multiply the radius in meep units by the corresponding design region resolution and pass the result of that value into the new filtering routines.

Can we avoid that? As much as possible, I'd like distances to be specified in "real" units.

I guess the problem with filters is that the same design variables (material grid) can be applied to multiple objects, and so effectively have different resolutions in the same simulation.

@smartalecH
Copy link
Collaborator Author

I'd like distances to be specified in "real" units.

I agree, but it comes at the cost of a more complicated API with lots more parameters to pass (and generally less flexibility).

I guess the problem with filters is that the same design variables (material grid) can be applied to multiple objects, and so effectively have different resolutions in the same simulation.

Exactly. I think it's better for the user to handle things on their side so that any required bookkeeping doesn't restrict any applications.

axis=1)

return out
def atleast_3d(ary):
Copy link
Contributor

@ianwilliamson ianwilliamson Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are already two other versions of a similar function for this (one that I added in utils and one that was already in another module). It might be worth consolidating these into a single function?

In some places it will be used for standard numpy arrays and here you're using it for jax arrays. I assume you'll want to track gradients through this, so I think to support both numpy and jax arrays, the var.reshape(...) approach could be used. This could work with a slight modification to:

def _make_at_least_nd(x: onp.ndarray, dims: int = 3) -> onp.ndarray:
"""Makes an array have at least the specified number of dimensions."""
return onp.reshape(x, x.shape + onp.maximum(dims - x.ndim, 0) * (1, ))

@mochen4
Copy link
Collaborator

mochen4 commented Aug 10, 2021

It seems that the Near2FarFields now fails. Specifically, the adjoint source amplitudes (scale) from

all_nearsrcdata = self._monitor.swigobj.near_sourcedata(
far_pt_vec, farpt_list, self._nfar_pts, dJ)
for near_data in all_nearsrcdata:
cur_comp = near_data.near_fd_comp
amp_arr = np.array(near_data.amp_arr).reshape(-1, self.num_freq)
scale = amp_arr * self._adj_src_scale(include_resolution=False)
becomes negligible (10^-313) in my test case.
However, in optimization_problem.py, if I import jacobian from autograd instead of jax, it would work as usual. The values of the jacobian dJ are the same, and I tried to convert the jax jacobian to a numpy array with jnp.asarray, but it doesn't make a difference.

@smartalecH
Copy link
Collaborator Author

smartalecH commented Aug 10, 2021

However, in optimization_problem.py, if I import jacobian from autograd instead of jax, it would work as usual.

Looks like this is a result of earlier PRs that transitioned the solver to jax. (Note the only actual additions offered by this PR are in filters.py).

What I typically do is use autograd for the main adjoint solver, but use jax for the filter functions used here. It's not ideal, but easier than fiddling with jax.

Although more recently I had to abandon jax altogether to get it to work on multi-node systems.

In an ideal world, we would have a package consisting of jax's autodiff function library, without losing autograd's simplicity.

@mochen4
Copy link
Collaborator

mochen4 commented Aug 10, 2021

Thanks Alec!

@stevengj
Copy link
Collaborator

Any update on this?

@smartalecH
Copy link
Collaborator Author

Resolve decision to filtfilt (shouldn't technically be needed with our even filters)

I bet this issue is due to the off-by-one bug we've been fixing in various PRs (#1769, #1760).

Also, all our filters are separable, which means we can actually just perform a simple 1D filter with the corresponding 1D kernel in each dimension. This not only speeds things up significantly, but should clean up the code quite a bit too (we can keep the full fftfilt routine on the off chance we discover a non-separable filter that's useful).

@smartalecH smartalecH marked this pull request as draft October 1, 2021 01:24
This was referenced Mar 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants