Skip to content

Commit

Permalink
Clean-up interface of hsa ode_factory
Browse files Browse the repository at this point in the history
- Allow `u` to be also a normal argument (instead of only keyword arguments)
- Pass `phi` as `args` keyword when calling `diffeqsolve`
  • Loading branch information
mstoelzle committed Oct 18, 2023
1 parent 0a1ea65 commit 652ac79
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 2 additions & 1 deletion examples/simulate_planar_hsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def chi2u(chi: Array) -> Array:
x0 = x0.at[: q0.shape[0]].set(q0) # set initial configuration

ode_fn = planar_hsa.ode_factory(dynamical_matrices_fn, params)
ode_term = ODETerm(partial(ode_fn, u=phi))
ode_term = ODETerm(ode_fn)

sol = diffeqsolve(
ode_term,
Expand All @@ -295,6 +295,7 @@ def chi2u(chi: Array) -> Array:
t1=ts[-1],
dt0=dt,
y0=x0,
args=phi,
max_steps=None,
saveat=SaveAt(ts=video_ts),
)
Expand Down
3 changes: 1 addition & 2 deletions src/jsrm/systems/planar_hsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,13 +634,12 @@ def ode_factory(
num_rods = params["rout"].shape[0] * params["rout"].shape[1]

@jit
def ode_fn(t: float, x: Array, *args, u: Array) -> Array:
def ode_fn(t: float, x: Array, u: Array) -> Array:
"""
ODE of the dynamical Lagrangian system.
Args:
t: time
x: state vector of shape (2 * n_q, )
args: additional arguments
u: input to the system.
- if consider_underactuation_model is True, then this is an array of shape (n_phi) with
motor positions / twist angles of the proximal end of the rods
Expand Down

0 comments on commit 652ac79

Please sign in to comment.