Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MeepJaxWrapper AxisError when computing the gradient of a objective dependent of a single frequency defined MeepJaxWrapper object #2246

Open
rafael-fuente opened this issue Sep 24, 2022 · 5 comments

Comments

@rafael-fuente
Copy link

rafael-fuente commented Sep 24, 2022

Jax meep simulation wrapper object MeepJaxWrapper works fine when it's initialized with more than one frequency set in its frequency argument. But it's not when the list uses a single frequency, e.g. frequencies = [fcen]

An example to reproduce the error, that I adapted from one of the meep adjoint tutorials for simplicity, can be found here:

import meep as mp
import meep.adjoint as mpa
import numpy as np
import jax.numpy as jnp
from jax import grad

seed = 240
np.random.seed(seed)
Si = mp.Medium(index=3.4)
SiO2 = mp.Medium(index=1.44)
resolution = 50
Sx = 6
Sy = 5
cell_size = mp.Vector3(Sx,Sy)
pml_layers = [mp.PML(1.0)]


fcen = 1/1.55
width = 0.2
fwidth = width * fcen
source_center  = [-1,0,0]
source_size    = mp.Vector3(0,2,0)
kpoint = mp.Vector3(1,0,0)
src = mp.GaussianSource(frequency=fcen,fwidth=fwidth)
source = [mp.EigenModeSource(src,
                    eig_band = 1,
                    direction=mp.NO_DIRECTION,
                    eig_kpoint=kpoint,
                    size = source_size,
                    center=source_center)]


design_region_resolution = 10
Nx = design_region_resolution
Ny = design_region_resolution

design_variables = mp.MaterialGrid(mp.Vector3(Nx,Ny),SiO2,Si,grid_type='U_MEAN')
design_region = mpa.DesignRegion(design_variables,volume=mp.Volume(center=mp.Vector3(), size=mp.Vector3(1, 1, 0)))


geometry = [
    mp.Block(center=mp.Vector3(x=-Sx/4), material=Si, size=mp.Vector3(Sx/2, 0.5, 0)), # horizontal waveguide
    mp.Block(center=mp.Vector3(y=Sy/4), material=Si, size=mp.Vector3(0.5, Sy/2, 0)),  # vertical waveguide
    mp.Block(center=design_region.center, size=design_region.size, material=design_variables), # design region
    mp.Block(center=design_region.center, size=design_region.size, material=design_variables,
             e1=mp.Vector3(x=-1).rotate(mp.Vector3(z=1), np.pi/2), e2=mp.Vector3(y=1).rotate(mp.Vector3(z=1), np.pi/2))
]


x0 = np.random.rand(Nx*Ny)
x = jnp.array(x0.reshape([Nx,Ny]))

sim = mp.Simulation(cell_size=cell_size,
                    boundary_layers=pml_layers,
                    geometry=geometry,
                    sources=source,
                    eps_averaging=False,
                    resolution=resolution)

TE0 = mpa.EigenmodeCoefficient(sim,mp.Volume(center=mp.Vector3(0,1,0),size=mp.Vector3(x=2)),mode=1)
monitor_list = [TE0]


wrapped_meep = mpa.MeepJaxWrapper(
    simulation = sim,
    sources = source,
    monitors = monitor_list,
    design_regions =[design_region] ,
    frequencies = [fcen],
    dft_threshold = 1e-6,
    minimum_run_time = 0,
    maximum_run_time = np.inf,
    until_after_sources = True
)


def loss(x):
    monitor_values = wrapped_meep([x])
    return (jnp.abs(monitor_values[0,0])**2)

grad_loss = grad(loss)(x)

The script returns an AxisError: axis 1 is out of bounds for array of dimension 1 when grad_loss = grad(loss)(x) is called.
And it doesn't if for example the frequencies list contains more than one frequency, e.g: frequencies = [fcen, 0.5*fcen]

@smartalecH
Copy link
Collaborator

cc @ianwilliamson

@ianwilliamson
Copy link
Contributor

Can you please provide the full stack trace?

@rafael-fuente
Copy link
Author

Can you please provide the full stack trace?

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/anaconda3/envs/mp124/lib/python3.10/runpy.py:191, in _run_module_as_main(***failed resolving arguments***)
    190 except _Error as exc:
--> 191     msg = "%s: %s" % (sys.executable, exc)
    192     sys.exit(msg)

File ~/anaconda3/envs/mp124/lib/python3.10/runpy.py:75, in _run_code(***failed resolving arguments***)
     74 loader = mod_spec.loader
---> 75 fname = mod_spec.origin
     76 cached = mod_spec.cached

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel_launcher.py:12, in <module>
      9 if __name__ == "__main__":
     10     # Remove the CWD from sys.path while we load stuff.
     11     # This is added back by InteractiveShellApp.init_path()
---> 12     if sys.path[0] == "":
     13         del sys.path[0]

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/traitlets/config/application.py:974, in Application.launch_instance(***failed resolving arguments***)
    970 """Launch a global instance of this Application
    971 
    972 If a global instance already exists, this reinitializes and starts it
    973 """
--> 974 app = cls.instance(**kwargs)
    975 app.initialize(argv)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelapp.py:702, in IPKernelApp.start(***failed resolving arguments***)
    701 if self.trio_loop:
--> 702     from ipykernel.trio_runner import TrioRunner
    704     tr = TrioRunner()

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/tornado/platform/asyncio.py:212, in BaseAsyncIOLoop.start(***failed resolving arguments***)
    211 except (RuntimeError, AssertionError):
--> 212     old_loop = None  # type: ignore
    213 try:

File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/base_events.py:594, in BaseEventLoop.run_forever(***failed resolving arguments***)
    592 self._thread_id = threading.get_ident()
--> 594 old_agen_hooks = sys.get_asyncgen_hooks()
    595 sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
    596                        finalizer=self._asyncgen_finalizer_hook)

File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/base_events.py:1860, in BaseEventLoop._run_once(***failed resolving arguments***)
   1858     timeout = min(max(0, when - self.time()), MAXIMUM_SELECT_TIMEOUT)
-> 1860 event_list = self._selector.select(timeout)
   1861 self._process_events(event_list)

File ~/anaconda3/envs/mp124/lib/python3.10/asyncio/events.py:80, in Handle._run(***failed resolving arguments***)
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(***failed resolving arguments***)
    509 try:
--> 510     await self.process_one()
    511 except Exception:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:496, in Kernel.process_one(***failed resolving arguments***)
    495 try:
--> 496     t, dispatch, args = self.msg_queue.get_nowait()
    497 except (asyncio.QueueEmpty, QueueEmpty):

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:383, in Kernel.dispatch_shell(***failed resolving arguments***)
    382     self.shell_stream.flush(zmq.POLLOUT)
--> 383     return
    385 # Print some info about this message and leave a '--->' marker, so it's
    386 # easier to trace visually the message chain when debugging.  Each
    387 # handler prints its message at the end.

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/kernelbase.py:701, in Kernel.execute_request(***failed resolving arguments***)
    699 stop_on_error = content.get("stop_on_error", True)
--> 701 metadata = self.init_metadata(parent)
    703 # Re-broadcast our input for the benefit of listening clients, and
    704 # start computing output

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/ipkernel.py:352, in IPythonKernel.do_execute(***failed resolving arguments***)
    350 if with_cell_id:
    351     coro = run_cell(
--> 352         code,
    353         store_history=store_history,
    354         silent=silent,
    355         transformed_cell=transformed_cell,
    356         preprocessing_exc_tuple=preprocessing_exc_tuple,
    357         cell_id=cell_id,
    358     )
    359 else:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2882, in InteractiveShell.run_cell(***failed resolving arguments***)
   2880 try:
   2881     result = self._run_cell(
-> 2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2911, in InteractiveShell._run_cell(***failed resolving arguments***)
   2909 assert transformed_cell is not None
   2910 coro = self.run_cell_async(
-> 2911     raw_cell,
   2912     store_history=store_history,
   2913     silent=silent,
   2914     shell_futures=shell_futures,
   2915     transformed_cell=transformed_cell,
   2916     preprocessing_exc_tuple=preprocessing_exc_tuple,
   2917     cell_id=cell_id,
   2918 )
   2920 # run_cell_async is async, but may not actually need an eventloop.
   2921 # when this is the case, we want to run it using the pseudo_sync_runner
   2922 # so that code can invoke eventloops (for example via the %run , and
   2923 # `%paste` magic.

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(***failed resolving arguments***)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3109, in InteractiveShell.run_cell_async(***failed resolving arguments***)
   3108     code_ast = compiler.ast_parse(cell, filename=cell_name)
-> 3109 except self.custom_exceptions as e:
   3110     etype, value, tb = sys.exc_info()

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3306, in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
   3305     to_run_exec, to_run_interactive = nodelist[:-1], nodelist[-1:]
-> 3306 elif interactivity == 'all':
   3307     to_run_exec, to_run_interactive = [], nodelist

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3395, in InteractiveShell.run_code(***failed resolving arguments***)
   3394 try:
-> 3395     if async_:
   3396         await eval(code_obj, self.user_global_ns, self.user_ns)

Input In [1], in <cell line: 81>()
     79     return (jnp.abs(monitor_values[0,0])**2)
---> 81 grad_loss = grad(loss)(x)

Input In [1], in loss(***failed resolving arguments***)
     77 def loss(x):
---> 78     monitor_values = wrapped_meep([x])
     79     return (jnp.abs(monitor_values[0,0])**2)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:138, in MeepJaxWrapper.__call__(***failed resolving arguments***)
    122 """Performs a Meep simulation, taking a list of designs and returning mode overlaps.
    123 
    124 Args:
   (...)
    136   a shape of (num monitors, num frequencies).
    137 """
--> 138 return self._simulate_fn(designs)

JaxStackTraceBeforeTransformation: numpy.AxisError: axis 1 is out of bounds for array of dimension 1

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

AxisError                                 Traceback (most recent call last)
Input In [1], in <cell line: 81>()
     78     monitor_values = wrapped_meep([x])
     79     return (jnp.abs(monitor_values[0,0])**2)
---> 81 grad_loss = grad(loss)(x)

    [... skipping hidden 13 frame]

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:227, in MeepJaxWrapper._initialize_callable.<locals>._simulate_rev(res, monitor_values_grad)
    225 design_variable_shapes = res
    226 self.adj_design_region_monitors = self._run_adjoint_simulation(monitor_values_grad)
--> 227 vjps = self._calculate_vjps(self.fwd_design_region_monitors, self.adj_design_region_monitors,
    228                             design_variable_shapes)
    229 return ([jnp.asarray(vjp) for vjp in vjps], )

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/wrapper.py:197, in MeepJaxWrapper._calculate_vjps(self, fwd_fields, adj_fields, design_variable_shapes, sum_freq_partials)
    189 def _calculate_vjps(
    190     self,
    191     fwd_fields,
   (...)
    194     sum_freq_partials=True,
    195 ):
    196     """Calculates the VJP for a given set of forward and adjoint fields."""
--> 197     return utils.calculate_vjps(
    198         self.simulation,
    199         self.design_regions,
    200         self.frequencies,
    201         fwd_fields,
    202         adj_fields,
    203         design_variable_shapes,
    204         sum_freq_partials=sum_freq_partials,
    205         finite_difference_step=self.finite_difference_step
    206     )

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/utils.py:93, in calculate_vjps(simulation, design_regions, frequencies, fwd_fields, adj_fields, design_variable_shapes, sum_freq_partials, finite_difference_step)
     83 vjps = [
     84     design_region.get_gradient(
     85         simulation,
   (...)
     90     ) for i, design_region in enumerate(design_regions)
     91 ]
     92 if sum_freq_partials:
---> 93     vjps = [
     94         onp.sum(vjp, axis=_GRADIENT_FREQ_AXIS).reshape(shape)
     95         for vjp, shape in zip(vjps, design_variable_shapes)
     96     ]
     97 else:
     98     vjps = [
     99         vjp.reshape(shape + (-1, ))
    100         for vjp, shape in zip(vjps, design_variable_shapes)
    101     ]

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/meep/adjoint/utils.py:94, in <listcomp>(.0)
     83 vjps = [
     84     design_region.get_gradient(
     85         simulation,
   (...)
     90     ) for i, design_region in enumerate(design_regions)
     91 ]
     92 if sum_freq_partials:
     93     vjps = [
---> 94         onp.sum(vjp, axis=_GRADIENT_FREQ_AXIS).reshape(shape)
     95         for vjp, shape in zip(vjps, design_variable_shapes)
     96     ]
     97 else:
     98     vjps = [
     99         vjp.reshape(shape + (-1, ))
    100         for vjp, shape in zip(vjps, design_variable_shapes)
    101     ]

File <__array_function__ internals>:180, in sum(*args, **kwargs)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2296, in sum(a, axis, dtype, out, keepdims, initial, where)
   2293         return out
   2294     return res
-> 2296 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims,
   2297                       initial=initial, where=where)

File ~/anaconda3/envs/mp124/lib/python3.10/site-packages/numpy/core/fromnumeric.py:86, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
     83         else:
     84             return reduction(axis=axis, out=out, **passkwargs)
---> 86 return ufunc.reduce(obj, axis, dtype, out, **passkwargs)

AxisError: axis 1 is out of bounds for array of dimension 1

@ianwilliamson
Copy link
Contributor

Thanks. This looks like a bug in the get_gradient() method on the design region. It is not maintaining a singleton frequency axis in the returned ndarray when there is just one frequency.

@smartalecH smartalecH added the bug label Sep 29, 2022
@smartalecH
Copy link
Collaborator

Hmm this may be an artifact from #1855.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants