diff --git a/src/neospy/rust/propagation.rs b/src/neospy/rust/propagation.rs index abaff03..ed7c2f9 100644 --- a/src/neospy/rust/propagation.rs +++ b/src/neospy/rust/propagation.rs @@ -158,18 +158,27 @@ pub fn propagation_n_body_spk_py( /// Iterable /// A :class:`~neospy.State` at the new time. #[pyfunction] -#[pyo3(name = "propagate_n_body_vec", signature = (states, jd_final, non_gravs=None))] +#[pyo3(name = "propagate_n_body_vec", signature = (states, jd_final, planet_states=None, non_gravs=None))] pub fn propagation_n_body_py( states: Vec, jd_final: PyTime, + planet_states: Option>, non_gravs: Option>, -) -> PyResult> { +) -> PyResult<(Vec, Vec)> { let states: Vec = states.into_iter().map(|x| x.0).collect(); + let planet_states: Option> = + planet_states.map(|s| s.into_iter().map(|x| x.0).collect()); let non_gravs: Option> = non_gravs.map(|x| x.into_iter().map(|x| x.0).collect()); let jd = jd_final.jd(); - let res = propagation::propagate_n_body_vec(states, jd, non_gravs) - .map(|x| x.into_iter().map(PyState::from).collect::>())?; - Ok(res) + let res = propagation::propagate_n_body_vec(states, jd, planet_states, non_gravs).map( + |(planets, states)| { + ( + planets.into_iter().map(PyState::from).collect::>(), + states.into_iter().map(PyState::from).collect::>(), + ) + }, + ); + Ok(res?) } diff --git a/src/neospy_core/src/propagation/mod.rs b/src/neospy_core/src/propagation/mod.rs index d84a063..618af6d 100644 --- a/src/neospy_core/src/propagation/mod.rs +++ b/src/neospy_core/src/propagation/mod.rs @@ -142,10 +142,9 @@ pub fn propagation_central(state: &State, jd_final: f64) -> NeosResult<[[f64; 3] pub fn propagate_n_body_vec( states: Vec, jd_final: f64, + planet_states: Option>, non_gravs: Option>, -) -> NeosResult> { - let spk = get_spk_singleton().try_read().unwrap(); - +) -> NeosResult<(Vec, Vec)> { if states.is_empty() { Err(Error::ValueError( "State vector is empty, propagation cannot continue".into(), @@ -161,34 +160,59 @@ pub fn propagate_n_body_vec( } let jd_init = states.first().unwrap().jd; - let mass_list = SIMPLE_PLANETS; let mut pos: Vec = Vec::new(); let mut vel: Vec = Vec::new(); let mut desigs: Vec = Vec::new(); - for obj in mass_list.iter() { - let planet = spk.try_get_state(obj.naif_id, jd_init, 0, Frame::Ecliptic)?; - pos.append(&mut planet.pos.into()); - vel.append(&mut planet.vel.into()); - desigs.push(planet.desig.to_owned()); + let planet_states = planet_states.unwrap_or_else(|| { + let spk = get_spk_singleton().try_read().unwrap(); + let mut planet_states = Vec::with_capacity(SIMPLE_PLANETS.len()); + for obj in SIMPLE_PLANETS.iter() { + let planet = spk + .try_get_state(obj.naif_id, jd_init, 10, Frame::Ecliptic) + .expect("Failed to find state for the provided initial jd"); + planet_states.push(planet); + } + planet_states + }); + + if planet_states.len() != SIMPLE_PLANETS.len() { + Err(Error::ValueError( + "Input planet states must contain the correct number of states.".into(), + ))?; + } + if planet_states.first().unwrap().jd != jd_init { + Err(Error::ValueError( + "Planet states JD must match JD of input state.".into(), + ))?; + } + for planet_state in planet_states.into_iter() { + pos.append(&mut planet_state.pos.into()); + vel.append(&mut planet_state.vel.into()); + desigs.push(planet_state.desig); } for mut state in states.into_iter() { - spk.try_change_center(&mut state, 0)?; state.try_change_frame_mut(Frame::Ecliptic)?; if jd_init != state.jd { Err(Error::ValueError( "All input states must have the same JD".into(), ))?; } + if state.center_id != 10 { + Err(Error::ValueError( + "Center of all states must be 10 (the Sun).".into(), + ))? + } pos.append(&mut state.pos.into()); vel.append(&mut state.vel.into()); - desigs.push(state.desig.to_owned()); + desigs.push(state.desig); } + let meta = AccelVecMeta { non_gravs, - massive_obj: mass_list, + massive_obj: SIMPLE_PLANETS, }; let (pos, vel, _) = { @@ -201,21 +225,15 @@ pub fn propagate_n_body_vec( meta, )? }; - let sun_pos = pos.rows(0, 3); - let sun_vel = vel.rows(0, 3); + let sun_pos = pos.fixed_rows::<3>(0); + let sun_vel = vel.fixed_rows::<3>(0); let mut final_states: Vec = Vec::new(); for (idx, desig) in desigs.into_iter().enumerate() { - let pos_chunk = pos.rows(idx * 3, 3) - sun_pos; - let vel_chunk = vel.rows(idx * 3, 3) - sun_vel; - let state = State::new( - desig, - jd_final, - nalgebra::convert(pos_chunk), - nalgebra::convert(vel_chunk), - Frame::Ecliptic, - 10, - ); + let pos = pos.fixed_rows::<3>(idx * 3) - sun_pos; + let vel = vel.fixed_rows::<3>(idx * 3) - sun_vel; + let state = State::new(desig, jd_final, pos, vel, Frame::Ecliptic, 10); final_states.push(state); } - Ok(final_states) + let final_planets = final_states.split_off(SIMPLE_PLANETS.len()); + Ok((final_planets, final_states)) }