Skip to content

Commit

Permalink
Merge pull request #1322 from NASA-AMMOS/bugfix--polynomial-rootfindi…
Browse files Browse the repository at this point in the history
…ng-convergence

Bugfix for polynomial root-finding
  • Loading branch information
mattdailis authored Mar 1, 2024
2 parents c137793 + fac6582 commit e8bfb35
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
/**
* The time at which a value expires.
*/
public record Expiry(Optional<Duration> value) {
public record Expiry(Optional<Duration> value) implements Comparable<Expiry> {
public static Expiry NEVER = expiry(Optional.empty());

public static Expiry at(Duration t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ public static <D extends Dynamics<?, D>> Condition dynamicsChange(Resource<D> re
final Duration startTime = currentTime();
Condition result = (positive, atEarliest, atLatest) -> {
var currentDynamics = resource.getDynamics();
var elapsedTime = currentTime().minus(startTime);
boolean haveChanged = startingDynamics.match(
start -> currentDynamics.match(
current -> !current.data().equals(start.data().step(currentTime().minus(startTime))),
current -> !current.data().equals(start.data().step(elapsedTime)) ||
!current.expiry().equals(start.expiry().minus(elapsedTime)),
ignored -> true),
startException -> currentDynamics.match(
ignored -> true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,26 @@ public static <D extends Dynamics<?, D>> Function<Expiring<D>, Duration> byBound
exp.data(),
t,
maximumError));
var effectiveMinSamplePeriod = Duration.min(e, minimumSamplePeriod);
var effectiveMaxSamplePeriod = Duration.min(e, maximumSamplePeriod);

try {
double intervalSize = solver.solve(
100,
errorFn,
Duration.min(e, minimumSamplePeriod).ratioOver(SECOND),
Duration.min(e, maximumSamplePeriod).ratioOver(SECOND));
effectiveMinSamplePeriod.ratioOver(SECOND),
effectiveMaxSamplePeriod.ratioOver(SECOND));
return Duration.roundNearest(intervalSize, SECOND);
} catch (NoBracketingException x) {
if (errorFn.value(minimumSamplePeriod.ratioOver(SECOND)) > 0) {
// maximum error > estimated error, best case
return maximumSamplePeriod;
return effectiveMaxSamplePeriod;
} else {
// maximum error < estimated error, worst case
return minimumSamplePeriod;
return effectiveMinSamplePeriod;
}
} catch (TooManyEvaluationsException | NumberIsTooLargeException x) {
return minimumSamplePeriod;
return effectiveMinSamplePeriod;
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import static gov.nasa.jpl.aerie.contrib.streamline.core.Expiring.neverExpiring;
import static gov.nasa.jpl.aerie.contrib.streamline.core.Reactions.whenever;
import static gov.nasa.jpl.aerie.contrib.streamline.core.monads.ExpiringMonad.bind;
import static gov.nasa.jpl.aerie.contrib.streamline.core.Resources.eraseExpiry;
import static gov.nasa.jpl.aerie.contrib.streamline.debugging.Context.contextualized;
import static gov.nasa.jpl.aerie.contrib.streamline.debugging.Dependencies.addDependency;
import static gov.nasa.jpl.aerie.contrib.streamline.debugging.Naming.getName;
Expand Down Expand Up @@ -314,7 +313,7 @@ private Stream<DirectionalConstraint> directionalConstraints(Variable constraine
// Expiry for driven terms is captured by re-solving rather than expiring the solution.
// If solver has a feedback loop from last iteration (which is common)
// feeding that expiry in here can loop the solver forever.
var result = eraseExpiry(drivenTerm);
var result = drivenTerm;
for (var drivingVariable : drivingVariables) {
var scale = controlledTerm.get(drivingVariable);
var domain = domains.get(drivingVariable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ private Expiry findExpiryNearRoot(Predicate<Duration> expires) {

// Do a binary search to find the exact transition time
while (end.longerThan(start.plus(EPSILON))) {
Duration midpoint = start.plus(end).dividedBy(2);
Duration midpoint = start.plus(end.minus(start).dividedBy(2));
if (expires.test(midpoint)) {
end = midpoint;
} else {
Expand Down Expand Up @@ -256,19 +256,38 @@ public Expiring<Polynomial> max(Polynomial other) {
* Finds all roots of this function in the future
*/
private Stream<Duration> findFutureRoots() {
// TODO: In some sense, isn't having an infinite coefficient the same as a vertical line,
// hence the same as having a root at x = 0?
// Unless the value itself is non-finite, that is...
// If this polynomial can never have a root, fail immediately
if (this.isNonFinite() || this.isConstant()) {
return Stream.empty();
}

// Defining epsilon keeps the Laguerre solver fast and stable for poorly-behaved polynomials.
final double epsilon = 2 * Arrays.stream(coefficients).map(Math::ulp).max().orElseThrow();
if (coefficients[0] == 0.0) {
return Stream.of(ZERO);
}

// If the polynomial is linear, solve it analytically for performance
if (this.degree() <= 1) {
double t = -getCoefficient(0) / getCoefficient(1);
if (t >= -ABSOLUTE_ACCURACY_FOR_DURATIONS / 2 && t <= MAX_SECONDS_FOR_DURATION) {
return Stream.of(Duration.roundNearest(t, SECOND));
} else {
return Stream.empty();
}
}

// Condition the problem by dividing through by the first coefficient:
double[] conditionedCoefficients = Arrays.stream(coefficients).map(c -> c / coefficients[0]).toArray();
// Defining epsilon keeps the Laguerre solver faster and more stable for poorly-behaved polynomials.
final double epsilon = 2 * Arrays.stream(conditionedCoefficients).map(Math::ulp).max().orElseThrow();
final Complex[] solutions = new LaguerreSolver(0, ABSOLUTE_ACCURACY_FOR_DURATIONS, epsilon)
.solveAllComplex(coefficients, 0);
.solveAllComplex(conditionedCoefficients, 0);
return Arrays.stream(solutions)
.filter(solution -> Math.abs(solution.getImaginary()) < epsilon)
.map(Complex::getReal)
.filter(t -> t >= 0 && t <= MAX_SECONDS_FOR_DURATION)
.filter(t -> t >= -ABSOLUTE_ACCURACY_FOR_DURATIONS / 2 && t <= MAX_SECONDS_FOR_DURATION)
.sorted()
.map(t -> Duration.roundNearest(t, SECOND));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ public static ClampedIntegrateResult clampedIntegrate(
Resource<Polynomial> integrand, Resource<Polynomial> lowerBound, Resource<Polynomial> upperBound, double startingValue) {
LinearBoundaryConsistencySolver rateSolver = new LinearBoundaryConsistencySolver("clampedIntegrate rate solver");
var integral = resource(polynomial(startingValue));
var neverExpiringIntegral = eraseExpiry(integral);

// Solve for the rate as a function of value
var overflowRate = rateSolver.variable("overflowRate", Domain::lowerBound);
Expand All @@ -377,11 +378,11 @@ public static ClampedIntegrateResult clampedIntegrate(

// Set up rate clamping conditions
var integrandUB = choose(
greaterThanOrEquals(integral, upperBound),
greaterThanOrEquals(neverExpiringIntegral, upperBound),
differentiate(upperBound),
constant(Double.POSITIVE_INFINITY));
var integrandLB = choose(
lessThanOrEquals(integral, lowerBound),
lessThanOrEquals(neverExpiringIntegral, lowerBound),
differentiate(lowerBound),
constant(Double.NEGATIVE_INFINITY));

Expand All @@ -390,10 +391,10 @@ public static ClampedIntegrateResult clampedIntegrate(

// Use a simple feedback loop on volumes to do the integration and clamping.
// Clamping here takes care of discrete under-/overflows and overshooting bounds due to discrete time steps.
var clampedCell = clamp(integral, lowerBound, upperBound);
var clampedCell = clamp(neverExpiringIntegral, lowerBound, upperBound);
var correctedCell = map(clampedCell, rate.resource(), (v, r) -> r.integral(v.extract()));
// Use the corrected integral values to set volumes, but erase expiry information in the process to avoid loops
forward(eraseExpiry(correctedCell), integral);
forward(correctedCell, integral);

name(integral, "Clamped Integral (%s)", integrand);
name(overflowRate.resource(), "Overflow of %s", integral);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,80 @@ void comparing_converging_nonlinear_terms_with_fine_precision() {
check_extrema(false, true);
}

// Unrepresentable convergence:
// These tests reflect polynomials that in theory converge, but do so in timespans
// that are too large to represent. Thus, they should be treated as non-converging.

@Test
void comparing_linear_terms_with_convergence_unrepresentable_by_double() {
setup(() -> {
set(p, polynomial(Double.MAX_VALUE));
set(q, polynomial(0, 0.1));
});

check_comparison(p_lt_q, false, false);
check_comparison(p_lte_q, false, false);
check_comparison(p_gt_q, true, false);
check_comparison(p_gte_q, true, false);
check_extrema(true, false);
}

@Test
void comparing_linear_terms_with_convergence_unrepresentable_by_duration() {
setup(() -> {
set(p, polynomial(Duration.MAX_VALUE.ratioOver(SECOND)));
set(q, polynomial(0, 0.1));
});

check_comparison(p_lt_q, false, false);
check_comparison(p_lte_q, false, false);
check_comparison(p_gt_q, true, false);
check_comparison(p_gte_q, true, false);
check_extrema(true, false);
}

@Test
void comparing_nonlinear_terms_with_convergence_unrepresentable_by_double() {
setup(() -> {
set(p, polynomial(Double.MAX_VALUE));
set(q, polynomial(0, 0, 0.1));
});

check_comparison(p_lt_q, false, false);
check_comparison(p_lte_q, false, false);
check_comparison(p_gt_q, true, false);
check_comparison(p_gte_q, true, false);
check_extrema(true, false);
}

@Test
void comparing_nonlinear_terms_with_convergence_unrepresentable_by_duration() {
setup(() -> {
set(p, polynomial(Duration.MAX_VALUE.ratioOver(SECOND) * Duration.MAX_VALUE.ratioOver(SECOND)));
set(q, polynomial(0, 0, 0.1));
});

check_comparison(p_lt_q, false, false);
check_comparison(p_lte_q, false, false);
check_comparison(p_gt_q, true, false);
check_comparison(p_gte_q, true, false);
check_extrema(true, false);
}

@Test
void comparing_pathological_nonlinear_terms_with_convergence_unrepresentable_by_duration() {
setup(() -> {
set(p, polynomial(Duration.MAX_VALUE.ratioOver(SECOND) * Duration.MAX_VALUE.ratioOver(SECOND)));
set(q, polynomial(0, Duration.MIN_VALUE.ratioOver(SECOND), 1.0 + Math.ulp(1.0)));
});

check_comparison(p_lt_q, false, false);
check_comparison(p_lte_q, false, false);
check_comparison(p_gt_q, true, false);
check_comparison(p_gte_q, true, false);
check_extrema(true, false);
}

private void check_comparison(Resource<Discrete<Boolean>> result, boolean expectedValue, boolean expectCrossover) {
reset();
var resultDynamics = result.getDynamics().getOrThrow();
Expand Down

0 comments on commit e8bfb35

Please sign in to comment.