Skip to content

Commit

Permalink
move torsten variadic funcs definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
Yi Zhang committed Dec 24, 2023
1 parent dbc8b10 commit 37ad076
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
14 changes: 12 additions & 2 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2684,6 +2684,16 @@ let variadic_ode_nonadjoint_fns =
[ "ode_bdf_tol"; "ode_rk45_tol"; "ode_adams_tol"; "ode_bdf"; "ode_rk45"
; "ode_adams"; "ode_ckrk"; "ode_ckrk_tol" ]

(* torsten ode *)
let pmx_variadic_ode_fns =
String.Set.of_list
[ "pmx_ode_bdf_ctrl"; "pmx_ode_rk45_ctrl"; "pmx_ode_adams_ctrl"
; "pmx_ode_bdf"; "pmx_ode_rk45"; "pmx_ode_adams"; "pmx_ode_ckrk"
; "pmx_ode_ckrk_ctrl" ]

let pmx_ode_control_suffix = "_ctrl"
(* end of torsten ode *)

let ode_tolerances_suffix = "_tol"
let is_reduce_sum_fn f = Set.mem reduce_sum_functions f
let variadic_dae_fun_return_type = UnsizedType.UVector
Expand Down Expand Up @@ -2720,12 +2730,12 @@ let () =
let add_ode name =
add_variadic_fn name ~return_type:variadic_ode_return_type
~control_args:
( if String.is_suffix name ~suffix:Torsten.pmx_ode_control_suffix then
( if String.is_suffix name ~suffix:pmx_ode_control_suffix then
variadic_ode_mandatory_arg_types @ variadic_ode_tol_arg_types
else variadic_ode_mandatory_arg_types )
~required_fn_rt:variadic_ode_fun_return_type
~required_fn_args:variadic_ode_mandatory_fun_args () in
Set.iter ~f:add_ode Torsten.pmx_variadic_ode_fns ;
Set.iter ~f:add_ode pmx_variadic_ode_fns ;
(* Adjoint ODE function *)
add_variadic_fn variadic_ode_adjoint_fn ~return_type:variadic_ode_return_type
~control_args:
Expand Down
8 changes: 0 additions & 8 deletions src/middle/Torsten.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ let pmx_coupled_ode_func =
, FnPlain
, Mem_pattern.AoS ) ) ]

let pmx_variadic_ode_fns =
String.Set.of_list
[ "pmx_ode_bdf_ctrl"; "pmx_ode_rk45_ctrl"; "pmx_ode_adams_ctrl"
; "pmx_ode_bdf"; "pmx_ode_rk45"; "pmx_ode_adams"; "pmx_ode_ckrk"
; "pmx_ode_ckrk_ctrl" ]

let pmx_ode_control_suffix = "_ctrl"

let pmx_integrate_ode_arg =
[ (UnsizedType.AutoDiffable, UnsizedType.UArray UReal) (* y0 *)
; (UnsizedType.AutoDiffable, UReal) (* t0 *)
Expand Down

0 comments on commit 37ad076

Please sign in to comment.