Skip to content

Commit

Permalink
Move MOID computation to rust (#99)
Browse files Browse the repository at this point in the history
* move MOID computation to rust
  • Loading branch information
dahlend authored Aug 21, 2024
1 parent fec6f94 commit 651dbee
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 62 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added an Omni-Directional Field of View.
- Added a "getting started" example.

### Changed

- Optimized the `moid` computation, improving performance by over 30x.

### Fixed

- Field of View checks for states was optimized for multi-core processing on millions
Expand Down
20 changes: 16 additions & 4 deletions src/neospy/irsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def annotate_plot(
c="red",
text_color="White",
style="+",
text_dx=0,
text_dy=0,
text_fs=None,
):
"""
Add an annotation for a point in a FITS plot, this requires a world coordinate
Expand All @@ -315,7 +318,13 @@ def annotate_plot(
text_color :
If text is provided, this defines the text color.
style :
Style of marker, this can be either "o" or "+".
Style of marker, this can be either "o", "+", or "L".
text_dx :
Offset of the text x location in pixels.
text_dy :
Offset of the text y location in pixels.
text_fs :
Text font size.
"""
ra, dec = _ra_dec(ra, dec)
x, y = wcs.world_to_pixel(SkyCoord(ra, dec, unit="deg"))
Expand All @@ -325,10 +334,13 @@ def annotate_plot(
plt.plot([x + px_gap, x + total], [y, y], c=c, lw=lw)
plt.plot([x, x], [y - px_gap, y - total], c=c, lw=lw)
plt.plot([x, x], [y + px_gap, y + total], c=c, lw=lw)
if style == "L":
plt.plot([x + px_gap, x + total], [y, y], c=c, lw=lw)
plt.plot([x, x], [y + px_gap, y + total], c=c, lw=lw)
elif style == "o":
plt.scatter(x, y, fc="None", ec=c, s=5 * px_gap)
plt.scatter(x, y, fc="None", ec=c, s=5 * px_gap, lw=lw)
else:
raise ValueError("Style is not recognized, must be one of: ['o', '+']")
raise ValueError("Style is not recognized, must be one of: o, +, L")

if text:
plt.text(x, y + px_gap / 2, " " + text, c=text_color)
plt.text(x + text_dx, y + text_dy, text, c=text_color, fontsize=text_fs)
2 changes: 1 addition & 1 deletion src/neospy/neos.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
[400.0, 1.05708, 0.95284],
[500.0, 1.01897, 0.96149],
]
"""Expected color correction required for black body sources"""
"""Expected color correction required for black body sources at 300k"""
54 changes: 1 addition & 53 deletions src/neospy/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
"""

from __future__ import annotations
from typing import Optional
from scipy import optimize
import numpy as np

from .vector import State
from . import spice
from ._core import (
NonGravModel,
propagate_n_body,
propagate_n_body_long,
propagate_two_body,
moid,
)


Expand All @@ -25,51 +21,3 @@
"NonGravModel",
"moid",
]


def _moid_single(obj0: State, other: State):
"""
Given the state of 2 objects, compute the MOID between them. This is used by the
moid function below and is not intended to be used directly.
"""
obj0_elem = obj0.elements
obj1_elem = other.elements
self_center = obj0_elem.peri_time
self_period = obj0_elem.orbital_period
other_center = obj1_elem.peri_time
other_period = obj1_elem.orbital_period

def _err(x):
jd0, jd1 = x
jd0 = jd0 * self_period / 4 + self_center
jd1 = jd1 * other_period / 4 + other_center
pos0 = propagate_two_body([obj0], jd0)[0].pos
pos1 = propagate_two_body([other], jd1)[0].pos
return np.linalg.norm(pos0 - pos1)

soln = []
soln.append(optimize.minimize(_err, [1, 1]).fun)
soln.append(optimize.minimize(_err, [-1, -1]).fun)
soln.append(optimize.minimize(_err, [-1, 1]).fun)
soln.append(optimize.minimize(_err, [1, -1]).fun)
return min(soln)


def moid(state: State, other: Optional[State] = None):
"""
Compute the MOID between two objects assuming 2 body mechanics.
If other is not provided, it is assumed to be Earth.
Parameters
----------
state:
The state describing an object.
other:
The state of the object to calculate the MOID for, if this is not provided,
then Earth is fetched from :mod:`~neospy.spice` and is used in the
calculation.
"""
if other is None:
other = spice.get_state("Earth", state.jd)
return _moid_single(state, other)
1 change: 1 addition & 0 deletions src/neospy/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ fn _core(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {

m.add_function(wrap_pyfunction!(propagation::propagation_n_body_spk_py, m)?)?;
m.add_function(wrap_pyfunction!(propagation::propagation_n_body_py, m)?)?;
m.add_function(wrap_pyfunction!(propagation::moid_py, m)?)?;

m.add_function(wrap_pyfunction!(fovs::fov_checks_py, m)?)?;
m.add_function(wrap_pyfunction!(fovs::fov_spk_checks_py, m)?)?;
Expand Down
30 changes: 28 additions & 2 deletions src/neospy/rust/propagation.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use itertools::Itertools;
use neospy_core::{
errors::Error,
propagation::{self, NonGravModel},
spice::{self},
propagation::{self, moid, NonGravModel},
spice::{self, get_spk_singleton},
state::State,
time::{scales::TDB, Time},
};
Expand All @@ -12,6 +12,32 @@ use rayon::prelude::*;
use crate::state::PyState;
use crate::{nongrav::PyNonGravModel, time::PyTime};

/// Compute the MOID between the input state and an optional second state.
/// If the second state is not provided, default to Earth.
///
/// Returns the MOID in units of au.
///
/// Parameters
/// ----------
/// state_a:
/// State of the first object.
/// state_b:
/// Optional state of the second object, defaults to Earth.
#[pyfunction]
#[pyo3(name = "moid", signature = (state_a, state_b=None))]
pub fn moid_py(state_a: PyState, state_b: Option<PyState>) -> PyResult<f64> {
let state_b =
state_b
.map(|x| x.0)
.unwrap_or(get_spk_singleton().read().unwrap().try_get_state(
399,
state_a.0.jd,
10,
state_a.0.frame,
)?);
Ok(moid(state_a.0, state_b)?)
}

/// Propagate the provided :class:`~neospy.State` using N body mechanics to the
/// specified times, no approximations are made, this can be very CPU intensive.
///
Expand Down
2 changes: 2 additions & 0 deletions src/neospy_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ edition = "2021"
name = "neospy_core"

[dependencies]
argmin = "*"
argmin-math = "*"
itertools = "^0.13.0"
kdtree = "^0.7.0"
lazy_static = "^1.5.0"
Expand Down
85 changes: 85 additions & 0 deletions src/neospy_core/src/propagation/kepler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
use crate::errors::Error;
use crate::fitting::newton_raphson;
use crate::prelude::CometElements;
use crate::state::State;
use crate::{constants::*, prelude::NeosResult};
use argmin::solver::neldermead::NelderMead;
use core::f64;
use nalgebra::{ComplexField, Vector3};
use std::f64::consts::TAU;

Expand Down Expand Up @@ -313,6 +316,88 @@ pub fn propagate_two_body(state: &State, jd_final: f64) -> NeosResult<State> {
))
}

use argmin::core::{CostFunction, Error as ArgminErr, Executor};
struct MoidCost {
state_a: State,

state_b: State,
}

impl CostFunction for MoidCost {
type Param = Vec<f64>;
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, ArgminErr> {
let dt_a = param.first().unwrap();
let dt_b = param.last().unwrap();

let s0 = propagate_two_body(&self.state_a, self.state_a.jd + dt_a)?;
let s1 = propagate_two_body(&self.state_b, self.state_b.jd + dt_b)?;

Ok((Vector3::from(s0.pos) - Vector3::from(s1.pos)).norm())
}
}

/// Compute the MOID between two states in au.
/// MOID = Minimum Orbital Intersection Distance
pub fn moid(mut state_a: State, mut state_b: State) -> NeosResult<f64> {
state_a.try_change_frame_mut(crate::frames::Frame::Ecliptic)?;
state_b.try_change_frame_mut(crate::frames::Frame::Ecliptic)?;

let elements_a = CometElements::from_state(&state_a);
state_a = propagate_two_body(&state_a, elements_a.peri_time)?;
let elements_b = CometElements::from_state(&state_b);
state_b = propagate_two_body(&state_b, elements_b.peri_time)?;

const N_STEPS: i32 = 50;

let state_a_step_size = match elements_a.orbital_period() {
p if p.is_finite() => p / N_STEPS as f64,
_ => 300.0 / N_STEPS as f64,
};
let state_b_step_size = match elements_b.orbital_period() {
p if p.is_finite() => p / N_STEPS as f64,
_ => 300.0 / N_STEPS as f64,
};

let mut states_b: Vec<State> = Vec::with_capacity(N_STEPS as usize);
let mut states_a: Vec<State> = Vec::with_capacity(N_STEPS as usize);

for idx in (-N_STEPS)..N_STEPS {
states_a.push(propagate_two_body(
&state_a,
state_a.jd + idx as f64 * state_a_step_size,
)?);
states_b.push(propagate_two_body(
&state_b,
state_b.jd + idx as f64 * state_b_step_size,
)?);
}
let mut best = (f64::INFINITY, state_a.clone(), state_b.clone());
for s0 in &states_a {
for s1 in &states_b {
let d = (Vector3::from(s0.pos) - Vector3::from(s1.pos)).norm();
if d < best.0 {
best = (d, s0.clone(), s1.clone());
}
}
}

let cost = MoidCost {
state_a: best.1,
state_b: best.2,
};

let solver = NelderMead::new(vec![vec![-15.0, -15.0], vec![15.0, -15.0], vec![0.0, 15.0]]);

let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(1000))
.run()
.unwrap();

Ok(res.state().get_best_cost())
}

#[cfg(test)]
mod tests {
use std::f64::consts::TAU;
Expand Down
8 changes: 6 additions & 2 deletions src/tests/test_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def test_propagation_light_delay(self):
assert np.allclose(calculated.pos, should_be.pos)


@pytest.mark.parametrize("planet", [None, "Earth", "Mercury"])
def test_moid(planet):
@pytest.mark.parametrize("planet", [(None, 1.58), ("Earth", 1.58), ("Mercury", 2.18)])
def test_moid(planet, ceres_traj):
planet, ceres_moid = planet
if planet is None:
state = None
vs = spice.get_state("Earth", 2461161.5)
Expand All @@ -115,3 +116,6 @@ def test_moid(planet):
vs = state

assert np.isclose(moid(vs, state), 0)

ceres = ceres_traj[0]
assert np.isclose(moid(ceres, state), ceres_moid, atol=1e-2)

0 comments on commit 651dbee

Please sign in to comment.