Skip to content

Commit

Permalink
Spk optimizations - 15-20% speedup of orbit propagation (#73)
Browse files Browse the repository at this point in the history
* Optimizations in interpolation
This leads to about 15-25% speedup in querying state vectors.
This will lead to commensurate speedups in integration.
  • Loading branch information
dahlend authored Jul 9, 2024
1 parent 5ed67a8 commit 5013967
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 46 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Optimized SPICE kernel loading for Type 2 records, which is what DE440 is saved as.
This means effectively all n-body propagation is now 15-20% faster than before.
- Removed `SpiceKernel` as a class, lowering all its methods to the submodule level,
see #68 for more discussion.
- Removed `Time` object which was a wrapper over astropy.Time, instead making a
Expand Down
5 changes: 1 addition & 4 deletions src/neospy/ztf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import lru_cache
import os
from collections import defaultdict
import numpy as np

from .cache import cached_file_download, cache_path
from .fov import ZtfCcdQuad, ZtfField, FOVList
Expand Down Expand Up @@ -98,9 +97,7 @@ def fetch_ZTF_fovs(year: int):
)

# Exposures are 30 seconds
jds_str = [x.split("+")[0] for x in irsa_query["obsdate"]]
jds = np.array(Time(jds_str, "iso", "utc").jd)

jds = [Time.from_iso(x + ":00").jd for x in irsa_query["obsdate"]]
obs_info = find_obs_code("ZTF")

# ZTF fields are made up of up to 64 individual CCD quads, here we first construct
Expand Down
42 changes: 28 additions & 14 deletions src/neospy_core/src/spice/interpolation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,24 @@ use nalgebra::DVector;
/// This is useful for reading values out of JPL SPK format, be aware though that time
/// scaling is also important for that particular use case.
///
/// This evaluates the coefficients at a single point of time, but for 3 sets of
/// coefficients at once. This is specifically done for performance reasons.
///
/// # Arguments
///
/// * `t` - Time at which to evaluate the chebyshev polynomials.
/// * `coef` - List of coefficients of the chebyshev polynomials.
/// * `coefx` - Slice of coefficients of the chebyshev polynomials.
/// * `coefy` - Slice of coefficients of the chebyshev polynomials.
/// * `coefz` - Slice of coefficients of the chebyshev polynomials.
///
#[inline(always)]
pub fn chebyshev_evaluate_both(x: f64, coef: &[f64]) -> Result<(f64, f64), NEOSpyError> {
let n_coef = coef.len();
pub fn chebyshev3_evaluate_both(
x: f64,
coefx: &[f64],
coefy: &[f64],
coefz: &[f64],
) -> Result<([f64; 3], [f64; 3]), NEOSpyError> {
let n_coef = coefx.len();

if n_coef < 2 {
return Err(NEOSpyError::IOError(
Expand All @@ -28,38 +38,42 @@ pub fn chebyshev_evaluate_both(x: f64, coef: &[f64]) -> Result<(f64, f64), NEOSp
}
let x2 = 2.0 * x;

let mut val = 0.0;
let mut val = [
coefx[0] + coefx[1] * x,
coefy[0] + coefy[1] * x,
coefz[0] + coefz[1] * x,
];
let mut second_t = 1.0;
let mut last_t = x;
let mut next_t;

val += coef[0] * second_t;
val += coef[1] * last_t;

// The derivative of the first kind is defined by the recurrence relation:
// d T_i / dx = i * U_{i-1}
let mut der_vel;
let mut second_u = 1.0;
let mut last_u = x2;
let mut next_u;

der_vel = coef[1] * second_u;
let mut der_val = [coefx[1], coefy[1], coefz[1]];

for (idx, &c) in coef.iter().enumerate().skip(2) {
for (idx, ((x, y), z)) in coefx.iter().zip(coefy).zip(coefz).enumerate().skip(2) {
next_t = x2 * last_t - second_t;
val += c * next_t;
val[0] += x * next_t;
val[1] += y * next_t;
val[2] += z * next_t;

second_t = last_t;
last_t = next_t;

next_u = x2 * last_u - second_u;
der_vel += c * last_u * (idx as f64);
der_val[0] += x * last_u * (idx as f64);
der_val[1] += y * last_u * (idx as f64);
der_val[2] += z * last_u * (idx as f64);

second_u = last_u;
last_u = next_u;
}

Ok((val, der_vel))
Ok((val, der_val))
}

/// Interpolate using Hermite interpolation.
Expand All @@ -70,7 +84,7 @@ pub fn chebyshev_evaluate_both(x: f64, coef: &[f64]) -> Result<(f64, f64), NEOSp
/// * `x` - The values of the function `f` evaluated at the specified times.
/// * `dx` - The values of the derivative of the function `f`.
/// * `eval_time` - Time at which to evaluate the interpolation function.
pub fn hermite_interpolation(times: &[&f64], x: &[f64], dx: &[f64], eval_time: f64) -> (f64, f64) {
pub fn hermite_interpolation(times: &[f64], x: &[f64], dx: &[f64], eval_time: f64) -> (f64, f64) {
assert_eq!(times.len(), x.len());
assert_eq!(times.len(), dx.len());

Expand Down
5 changes: 2 additions & 3 deletions src/neospy_core/src/spice/pck_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,8 @@ impl PckSegmentType2 {
let dec_coef = &record[(self.n_coef + 2)..(2 * self.n_coef + 2)];
let w_coef = &record[(2 * self.n_coef + 2)..(3 * self.n_coef + 2)];

let (ra, ra_der) = chebyshev_evaluate_both(t, ra_coef)?;
let (dec, dec_der) = chebyshev_evaluate_both(t, dec_coef)?;
let (w, w_der) = chebyshev_evaluate_both(t, w_coef)?;
let ([ra, dec, w], [ra_der, dec_der, w_der]) =
chebyshev3_evaluate_both(t, ra_coef, dec_coef, w_coef)?;

// rem_euclid is equivalent to the modulo operator, so this maps w to [0, 2pi]
let w = w.rem_euclid(std::f64::consts::TAU);
Expand Down
78 changes: 53 additions & 25 deletions src/neospy_core/src/spice/spk_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,33 @@ pub struct SpkSegmentType2 {
record_len: usize,
}

/// Type 2 Record View
/// A view into a record of type 2, provided mainly for clarity to the underlying
/// data structure.
struct Type2RecordView<'a> {
t_mid: &'a f64,
t_step: &'a f64,

x_coef: &'a [f64],
y_coef: &'a [f64],
z_coef: &'a [f64],
}

impl SpkSegmentType2 {
fn get_record(&self, idx: usize) -> &[f64] {
fn get_record(&self, idx: usize) -> Type2RecordView {
unsafe {
self.array
let vals = self
.array
.data
.get_unchecked(idx * self.record_len..(idx + 1) * self.record_len)
.get_unchecked(idx * self.record_len..(idx + 1) * self.record_len);

Type2RecordView {
t_mid: vals.get_unchecked(0),
t_step: vals.get_unchecked(1),
x_coef: vals.get_unchecked(2..(self.n_coef + 2)),
y_coef: vals.get_unchecked((self.n_coef + 2)..(2 * self.n_coef + 2)),
z_coef: vals.get_unchecked((2 * self.n_coef + 2)..(3 * self.n_coef + 2)),
}
}
}

Expand All @@ -333,29 +354,26 @@ impl SpkSegmentType2 {
let record_index = ((jd - jd_start) / self.jd_step).floor() as usize;
let record = self.get_record(record_index);

let x_coef = unsafe { record.get_unchecked(2..(self.n_coef + 2)) };
let y_coef = unsafe { record.get_unchecked((self.n_coef + 2)..(2 * self.n_coef + 2)) };
let z_coef = unsafe { record.get_unchecked((2 * self.n_coef + 2)..(3 * self.n_coef + 2)) };
let t_step = record.t_step;

let t_mid = unsafe { record.get_unchecked(0) };
let t_step = unsafe { record.get_unchecked(1) };
let t = (jd - t_mid) / t_step;
let t = (jd - record.t_mid) / t_step;

let t_step_scaled = 86400.0 / t_step / AU_KM;

let (x, vx) = chebyshev_evaluate_both(t, x_coef)?;
let (y, vy) = chebyshev_evaluate_both(t, y_coef)?;
let (z, vz) = chebyshev_evaluate_both(t, z_coef)?;
let (p, v) = chebyshev3_evaluate_both(t, record.x_coef, record.y_coef, record.z_coef)?;
Ok((
[x / AU_KM, y / AU_KM, z / AU_KM],
[vx * t_step_scaled, vy * t_step_scaled, vz * t_step_scaled],
[p[0] / AU_KM, p[1] / AU_KM, p[2] / AU_KM],
[
v[0] * t_step_scaled,
v[1] * t_step_scaled,
v[2] * t_step_scaled,
],
))
}
}

impl From<DafArray> for SpkSegmentType2 {
fn from(array: DafArray) -> Self {
// let n_records = array[array.len() - 1] as usize;
let record_len = array[array.len() - 2] as usize;
let jd_step = array[array.len() - 3];

Expand Down Expand Up @@ -527,9 +545,23 @@ impl From<DafArray> for SpkSegmentType13 {
}
}

/// Type 13 Record View
/// A view into a record of type 13, provided mainly for clarity to the underlying
/// data structure.
struct Type13RecordView<'a> {
pos: &'a [f64; 3],
vel: &'a [f64; 3],
}

impl SpkSegmentType13 {
fn get_record(&self, idx: usize) -> &[f64] {
unsafe { self.array.data.get_unchecked(idx * 6..(idx + 1) * 6) }
fn get_record(&self, idx: usize) -> Type13RecordView {
unsafe {
let rec = self.array.data.get_unchecked(idx * 6..(idx + 1) * 6);
Type13RecordView {
pos: rec[0..3].try_into().unwrap(),
vel: rec[3..6].try_into().unwrap(),
}
}
}

fn get_times(&self) -> &[f64] {
Expand Down Expand Up @@ -562,19 +594,15 @@ impl SpkSegmentType13 {

let mut pos = [0.0; 3];
let mut vel = [0.0; 3];
let times: Box<[&f64]> = times
.iter()
.skip(start_idx)
.take(self.window_size)
.collect();
for idx in 0..3 {
let p: Box<[f64]> = (0..self.window_size)
.map(|i| self.get_record(i + start_idx)[idx])
.map(|i| self.get_record(i + start_idx).pos[idx])
.collect();
let dp: Box<[f64]> = (0..self.window_size)
.map(|i| self.get_record(i + start_idx)[idx + 3])
.map(|i| self.get_record(i + start_idx).vel[idx])
.collect();
let (p, v) = hermite_interpolation(&times, &p, &dp, jd);
let (p, v) =
hermite_interpolation(&times[start_idx..start_idx + self.window_size], &p, &dp, jd);
pos[idx] = p / AU_KM;
vel[idx] = v / AU_KM * 86400.;
}
Expand Down

0 comments on commit 5013967

Please sign in to comment.