Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT Error encountered when optimizing GammaC #1288

Open
dpanici opened this issue Oct 2, 2024 · 14 comments · May be fixed by #1229
Open

JIT Error encountered when optimizing GammaC #1288

dpanici opened this issue Oct 2, 2024 · 14 comments · May be fixed by #1229
Assignees
Labels
bug Something isn't working optimization Adding or improving optimization methods P3 Highest Priority, someone is/should be actively working on this

Comments

@dpanici
Copy link
Collaborator

dpanici commented Oct 2, 2024

Error seems to occur when optimizing GammaC objective on gh/Gamma_c branch, happens on the second optimization step and seems related to the JIT cache? The error also only occurs if attempting an optimization at a resolution that you have previously optimized at, changing the eq resolution between steps seems to avoid this issue, so I assume it is related to the caching

MWE:

from desc import set_device

set_device("gpu")

import jax
import numpy as np

import desc.examples
from desc.continuation import solve_continuation_automatic
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.grid import ConcentricGrid, LinearGrid
from desc.io import load
from desc.objectives import (  # FixIota,
    AspectRatio,
    Elongation,
    FixBoundaryR,
    FixBoundaryZ,
    FixCurrent,
    FixPressure,
    FixPsi,
    ForceBalance,
    GammaC,
    GenericObjective,
    ObjectiveFunction,
    QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer
from desc.plotting import plot_boozer_surface
import pdb
from desc.backend import jnp
from desc.examples import get
def run_opt_step(k, eq):
    """Run a step of the optimization example."""
    # this step will only optimize boundary modes with |m|,|n| <= k
    # we create an ObjectiveFunction, in this case made up of multiple objectives
    # which will be combined in a least squares sense

    shape_grid = LinearGrid(
        M=int(eq.M), N=int(eq.N), rho=np.array([1.0]), NFP=eq.NFP, sym=True, axis=False
    )

    ntransits = 8

    zeta_field_line = np.linspace(0, 2 * np.pi * ntransits, 64 * ntransits)
    alpha = jnp.array([0.0])
    rho = jnp.linspace(0.85, 1.0, 2)
    # rho = np.linspace(0.85, 1.0, 2)
    flux_surface_grid = LinearGrid(
        rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP
    )

    objective = ObjectiveFunction(
        (
            GammaC(
                eq=eq,
                rho=rho,
                alpha=alpha,
                deriv_mode="fwd",
                batch=False,
                weight=1e3,
                Nemov = False,
            ),
            Elongation(eq=eq, grid=shape_grid,target=1),#0 bounds=(0.5, 2.0), weight=1e3),
            GenericObjective(
                f="curvature_k2_rho",
                thing=eq,
                grid=shape_grid,
                bounds=(-75, 15),
                weight=2e3,
            ),
        ),
    )
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
    ]
    constraints = (
        ForceBalance(
            eq,
            grid=ConcentricGrid(
                L=round(2 * eq.L),
                M=round(1.5 * eq.M),
                N=round(1.5 * eq.N),
                NFP=eq.NFP,
                sym=eq.sym,
            ),
        ),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixCurrent(eq=eq),
        FixPsi(eq=eq),
    )
    # this is the default optimizer, which re-solves the equilibrium at each step
    optimizer = Optimizer("proximal-lsq-exact")          
    eq_new, result = optimizer.optimize(
        things = eq,
        objective=objective,
        constraints=constraints,
        maxiter=3,  # we don't need to solve to optimality at each multigrid step
        verbose=3,
        copy=True,  # don't modify original, return a new optimized copy
        options={
            # Sometimes the default initial trust radius is too big, allowing the
            # optimizer to take too large a step in a bad direction. If this happens,
            # we can manually specify a smaller starting radius. Each optimizer has a
            # number of different options that can be used to tune the performance.
            # See the documentation for more info.
            "initial_trust_ratio": 1e-2,
            "maxiter": 125,
            "ftol": 1e-3,
            "xtol": 1e-8,
        },
    )
    eq_new = eq_new[0]
   
    return eq_new 

eq = get("ESTELL")
for k in np.arange(1, eq.M + 1, 1):
    if not eq.is_nested():
        print("NOT NESTED")
        assert eq.is_nested()
        break
    jax.clear_caches()
    eq = run_opt_step(k, eq)

Error:

ValueError                                Traceback (most recent call last)
Cell In[1], line 137
    135     break
    136 jax.clear_caches()
--> 137 eq = run_opt_step(k, eq)

Cell In[1], line 107, in run_opt_step(k, eq)
    103 optimizer = Optimizer("proximal-lsq-exact")
    105 print("spot 1:", type(eq))
--> 107 eq_new, result = optimizer.optimize(
    108     things = eq,
    109     objective=objective,
    110     constraints=constraints,
    111     maxiter=3,  # we don't need to solve to optimality at each multigrid step
    112     verbose=3,
    113     copy=True,  # don't modify original, return a new optimized copy
    114     options={
    115         # Sometimes the default initial trust radius is too big, allowing the
    116         # optimizer to take too large a step in a bad direction. If this happens,
    117         # we can manually specify a smaller starting radius. Each optimizer has a
    118         # number of different options that can be used to tune the performance.
    119         # See the documentation for more info.
    120         "initial_trust_ratio": 1e-2,
    121         "maxiter": 125,
    122         "ftol": 1e-3,
    123         "xtol": 1e-8,
    124     },
    125 )
    126 eq_new = eq_new[0]
    128 return eq_new

File ~/DESC/desc/optimize/optimizer.py:311, in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
    307     print("Using method: " + str(self.method))
    309 timer.start("Solution time")
--> 311 result = optimizers[method]["fun"](
    312     objective,
    313     nonlinear_constraint,
    314     x0,
    315     method,
    316     x_scale,
    317     verbose,
    318     stoptol,
    319     options,
    320 )
    322 if isinstance(objective, LinearConstraintProjection):
    323     # remove wrapper to get at underlying objective
    324     result["allx"] = [objective.recover(x) for x in result["allx"]]

File ~/DESC/desc/optimize/_desc_wrappers.py:270, in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
    267     options.setdefault("initial_trust_ratio", 0.1)
    268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
    271     objective.compute_scaled_error,
    272     x0=x0,
    273     jac=objective.jac_scaled_error,
    274     args=(objective.constants,),
    275     x_scale=x_scale,
    276     ftol=stoptol["ftol"],
    277     xtol=stoptol["xtol"],
    278     gtol=stoptol["gtol"],
    279     maxiter=stoptol["maxiter"],
    280     verbose=verbose,
    281     callback=None,
    282     options=options,
    283 )
    284 return result

File ~/DESC/desc/optimize/least_squares.py:176, in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
    173 assert in_bounds(x, lb, ub), "x0 is infeasible"
    174 x = make_strictly_feasible(x, lb, ub)
--> 176 f = fun(x, *args)
    177 nfev += 1
    178 cost = 0.5 * jnp.dot(f, f)

File ~/DESC/desc/optimize/_constraint_wrappers.py:224, in LinearConstraintProjection.compute_scaled_error(self, x_reduced, constants)
    208 """Compute the objective function and apply weighting / bounds.
    209 
    210 Parameters
   (...)
    221 
    222 """
    223 x = self.recover(x_reduced)
--> 224 f = self._objective.compute_scaled_error(x, constants)
    225 return f

File ~/DESC/desc/optimize/_constraint_wrappers.py:843, in ProximalProjection.compute_scaled_error(self, x, constants)
    841 constants = setdefault(constants, self.constants)
    842 xopt, _ = self._update_equilibrium(x, store=False)
--> 843 return self._objective.compute_scaled_error(xopt, constants[0])

    [... skipping hidden 6 frame]

File ~/.conda/envs/desc-env-latest/lib/python3.11/site-packages/jax/_src/pjit.py:1339, in seen_attrs_get(fun, in_type)
   1337 cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
   1338 assert fun.in_type is None or fun.in_type == in_type
-> 1339 return cache[(fun.transforms, fun.params, in_type)]

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
@dpanici dpanici added the bug Something isn't working label Oct 2, 2024
@unalmis
Copy link
Collaborator

unalmis commented Oct 2, 2024

Here are the steps that should be taken to debug

  1. Has this error ever occurred on the branch Gamma_c where this quantity is developed on? JAX version?
  2. If yes to 1, then check which recent changes to master, perhaps the Jacobian changes, have caused this

Unjitting the compute function tends to help for debugging

@f0uriest
Copy link
Member

f0uriest commented Oct 3, 2024

It seems to be unique to the bounce integral objectives, if I comment out GammaC it works fine, and if i change to EffectiveRipple it still happens.

Other things:

  • the bounce integral objectives take a loooong time to compile, even at very low resolution. Like a few minutes, compared to a few seconds for the other objectives
  • when either bounce integral objective is included it seems to make the optimizer stall out (rejects a lot of steps and exits)

These probably aren't related to the error above, but might be another source of concern

@f0uriest
Copy link
Member

f0uriest commented Oct 3, 2024

with use_jit=False (and commenting out the jit in constraint wrappers) I'm unable to reproduce

@unalmis
Copy link
Collaborator

unalmis commented Oct 3, 2024

Ok I ran optimizations leading up to ISHW, so commit 5cd7ebd should not have this issue. State of branch at that commit https://github.com/PlasmaControl/DESC/tree/5cd7ebde563258f754a0401d9da6aa143bc3376f

@unalmis
Copy link
Collaborator

unalmis commented Oct 3, 2024

with use_jit=False (and commenting out the jit in constraint wrappers) I'm unable to reproduce

There is also a jit call wrapping the compute function in _compute. When you could no longer reproduce, was this JIT call still online?

the bounce integral objectives take a loooong time to compile, even at very low resolution. Like a few minutes, compared to a few seconds for the other objectives

Aren't these compiled once? The BallooningStability objective requires less resolution than bounce integrals along a field line, but it still does a coordinate mapping inside the objective and builds transforms on the resulting grid. How does compilation time / optimization stalling compare when "low resolution" is typical resolution for BallooningStability?

the optimizer stall out (rejects a lot of steps and exits)

Can memory usage effect this? Is this forward or reverse mode? I ran forward optimizations before ISHW and did not see the optimizer exit

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 3, 2024

5cd7ebd...Gamma_c
the diff page btwn the commit Kaya mentioned and the current Gamma_c branch

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 3, 2024

I won't have time to debug tonight/tmrw, but will look more this weekend. thanks for starting to look into this so quickly though. on Gamma_c I see the same bug for both GammaC objective and EffectiveRipple

@unalmis
Copy link
Collaborator

unalmis commented Oct 27, 2024

I think this is some jax issue; and the caching suggest this is problem dependent. In any case, I suggest trying on #1290, and if the issue disappears then can mark this resolved.

The objectives there use an optimization step independent transforms grid, so that might solve that caching issue you came across.

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 28, 2024

Same error occurs in #1290 , once I find the specific cause I can commit a fix

@dpanici dpanici added the P3 Highest Priority, someone is/should be actively working on this label Nov 11, 2024
dpanici added a commit that referenced this issue Nov 23, 2024
@unalmis
Copy link
Collaborator

unalmis commented Nov 23, 2024

I accidentally ran the the tutorial's optimization cell block 6 another time after the optimization completes successfully, and I get the same error. JIT caching is not done there, and the block is self-contained so the second run is a completely new optimization, not a second step, so I am uncertain if it is related to this issue.

The error message suggests jax is getting an array with different dimension than it expects, so flattening all inputs from tuples and higher dim arrays to 1D arrays before they reach the objective function, in particular those in constants, avoided the issue for some reason.

The omnigenity objective also passes in a 2D array in constants, so it might have the same issue, and could be worth looking into how 2D arrays are interpreted in the compute scaled error functions.

@unalmis unalmis linked a pull request Nov 24, 2024 that will close this issue
@dpanici
Copy link
Collaborator Author

dpanici commented Nov 24, 2024

Yep the fix in #1229 is actually pretty simple, Greta basically changed the way rho is passed from being an array which is in constants to instead being through the nodes attribute of a LinearGrid, I asked her to make the same change in ripple branch and the branch where you implemented the objectives using the 2D interpolated version of the bounce functions as well once she narrows down that this specific change in the code was the one which fixes the bug.

@dpanici
Copy link
Collaborator Author

dpanici commented Nov 24, 2024

Hm what is the error message actually? that seems different than what we get (ours is an np logical array, not quite related to shape mismatches which is what yours sounds like? this could be a separate issue?)

@dpanici
Copy link
Collaborator Author

dpanici commented Nov 24, 2024

I accidentally ran the the tutorial's optimization cell block 6 another time after the optimization completes successfully, and I get the same error. JIT caching is not done there, and the block is self-contained so the second run is a completely new optimization, not a second step, so I am uncertain if it is related to this issue. The error message suggests jax is getting an array with different dimension than it expects, so flattening all inputs from tuples and higher dim arrays to 1D arrays before they reach the objective function, in particular those in constants, avoided the issue for some reason.

The omnigenity objective also passes in a 2D array in constants, so it might have the same issue, and could be worth looking into how 2D arrays are interpreted in the compute scaled error functions.

I think even if the block is self-contained, running it again will still find that there is a cached jitted version of the ObjectiveFunction.compute method and attempt to see if it can re-use it, and in the check of the cache is where we found we would get the error. So not that it is a second step, just anytime the same resolution eq (with same grids etc) is used to build and then compile an ObjectiveFunction. The test I have here shows the kind of thing we would find fails before fixing the issue (though I guess the test is not re-instantiating the objective, but I know that even re-instantiating it would cause the bug as if you were to previously run pytest on the two tests in that file when they used the same res eq, the second would fail with this bug, because the first test's cached jitted obective compute was attempted to be used by JAX, but in checking if it is compatible, it would throw the numpy logical error above)

@unalmis unalmis added the optimization Adding or improving optimization methods label Nov 24, 2024
@dpanici
Copy link
Collaborator Author

dpanici commented Nov 25, 2024

Check if @rahulgaur104 's ballooning objective is also affected by this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working optimization Adding or improving optimization methods P3 Highest Priority, someone is/should be actively working on this
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants