Skip to content

Commit

Permalink
Make rootfinding robust to discontinuities (sim exceptions)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienmaillard committed Oct 13, 2023
1 parent 2473d01 commit 75190b0
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,16 @@ public boolean noShorterThan(final Duration other) {
return this.compareTo(other) >= 0;
}

/**
* Determine whether this duration lies in [lower bound, upper bound], bounds comprised
* @param lowerBound the lower bound, inclusive
* @param upperBound the upper bound, inclusive
* @return true if lies in the interval, false otherwise
*/
public boolean between(final Duration lowerBound, final Duration upperBound){
return (this.noShorterThan(lowerBound) && this.noLongerThan(upperBound));
}

/**
* Determine whether this duration is longer than another.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,65 @@
package gov.nasa.jpl.aerie.scheduler;

import gov.nasa.jpl.aerie.merlin.protocol.types.Duration;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Optional;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;

public class EquationSolvingAlgorithms {

private static final Logger logger = LoggerFactory.getLogger(EquationSolvingAlgorithms.class);

public record RootFindingResult<T, History>(T x, T fx, History history){}
public record FunctionCoordinate<T>(T x, T fx){}

public record RootFindingResult<T, Metadata>(FunctionCoordinate<T> functionCoordinate, History<T, Metadata> history){}

/**
* Solves f(x) = y for x in [xLow, xHigh] with confidence interval [yLow, yHigh] around y such that we stop when
* value y is in [y-yLow, y+yHigh].
* x0 and x1 are initial guesses for x, they must be close to the solution to avoid diverging
* It is considered that algorithm is diverging when the iterated value of x goes out of [xLow, xHigh].
*/
public interface SecantAlgorithm<T, History>{
RootFindingResult<T, History> findRoot(Function<T, History> f,
History history,
public interface SecantAlgorithm<T, Metadata>{
RootFindingResult<T, Metadata> findRoot(Function<T, Metadata> f,
History<T, Metadata> history,
T x0,
T x1,
T y,
T toleranceYLow,
T toleranceYHigh,
T xLow,
T xHigh,
int maxNbIterations) throws ZeroDerivativeException, InfiniteDerivativeException, DivergenceException,
ExceededMaxIterationException, NoSolutionException;
int maxNbIterations) throws
ZeroDerivativeException,
InfiniteDerivativeException,
DivergenceException,
ExceededMaxIterationException,
NoSolutionException;
}

public interface Function<T, Metadata> {
T valueAt(T x, History<T, Metadata> history) throws DiscontinuityException;
}

public interface Function<T, History> {
T valueAt(T x, History historyType);
public interface History<T, Metadata>{
void add(FunctionCoordinate<T> functionCoordinate, Metadata metadata);
List<Pair<FunctionCoordinate<T>, Optional<Metadata>>> getHistory();
Optional<Pair<FunctionCoordinate<T>, Optional<Metadata>>> getLastEvent();
boolean alreadyVisited(T x);
}

public static class ZeroDerivativeException extends Exception{
public ZeroDerivativeException() {}
}

public static class DiscontinuityException extends Exception{
public DiscontinuityException(){}
}

public static class InfiniteDerivativeException extends Exception{
public InfiniteDerivativeException() {}
}
Expand All @@ -47,11 +69,6 @@ public DivergenceException(String errorMessage) {
super(errorMessage);
}
}
public static class WrongBracketingException extends Exception{
public WrongBracketingException(String errorMessage) {
super(errorMessage);
}
}

public static class ExceededMaxIterationException extends Exception{
public ExceededMaxIterationException() {
Expand All @@ -63,74 +80,184 @@ public static class NoSolutionException extends Exception{
public NoSolutionException() {
super();
}
public NoSolutionException(String errorMessage) {
super(errorMessage);
}
}

public static class SecantDurationAlgorithm<History> implements SecantAlgorithm<Duration, History>{

public RootFindingResult<Duration, History> findRoot(
Function<Duration, History> f,
History history,
Duration x0,
Duration x1,
Duration y,
Duration toleranceYLow,
Duration toleranceYHigh,
Duration xLow,
Duration xHigh,
int maxNbIterations)
throws ZeroDerivativeException, InfiniteDerivativeException, DivergenceException, ExceededMaxIterationException, NoSolutionException
public static class SecantDurationAlgorithm<Metadata> implements SecantAlgorithm<Duration, Metadata>{

private final RandomGenerator randomGenerator = RandomGeneratorFactory.of("Random").create(956756789);

/**
* Randomly selects a value in the interval [bound1, bound2]
* @param bound1 the first bound
* @param bound2 the second bound
* @return a value chosen randomly
*/
private Duration chooseRandomX(final Duration bound1, final Duration bound2){
var low = bound1;
var high = bound2;
if(low.isEqualTo(high)) return low;
if(bound1.longerThan(bound2)) { low = bound2; high = bound1; }
return Duration.of(
randomGenerator.nextLong(low.in(Duration.MICROSECONDS), high.in(Duration.MICROSECONDS)),
Duration.MICROSECONDS);
}

private record IteratingResult(FunctionCoordinate<Duration> result, int nbIterationsPerformed){}

/**
* Querying Function.valueAt may lead to a discontinuity. This procedure starts at an initial x value
* and stops only when the value returned is not a discontinuity or the maximum number of iterations has been reached
* Kind of an infaillible valueAt with a limited number of iterations
* @param function the function we are trying to call
* @param init the initial x value
* @param min the lower bound of the domain of x
* @param max the upper bound of the domain of x
* @param history the querying history of f
* @param maxIteration the maximum number of iteration possible
* @return a coordinate (x, f(x)) s.t. f is continuous at x.
* @throws ExceededMaxIterationException
*/
private IteratingResult nextValueAt(
final Function<Duration, Metadata> function,
final Duration init,
final Duration min,
final Duration max,
final History<Duration, Metadata> history,
final int maxIteration)
throws ExceededMaxIterationException
{
var cur = init;
int i = 0;
do {
//we should not come back to previously visited values
if (!history.alreadyVisited(cur)) {
i++;
try {
final var value = function.valueAt(cur, history);
return new IteratingResult(new FunctionCoordinate<>(cur, value), i);
} catch (DiscontinuityException e) {
//nothing, keep iterating
}
}
cur = chooseRandomX(min, max);
//if min == max, another call to random will have no effect and thus we should exit
} while(i < maxIteration && !min.isEqualTo(max));
throw new ExceededMaxIterationException();
}

/**
* Solves x s.t. f(x) = y by transforming it to the equivalent rootfinding problem x s.t. f(x) - y = 0
* @param f the function
* @param history
* @param x0 one of the initial x value
* @param y the objective
* @param toleranceYLow absolute value of the tolerance below 0
* @param toleranceYHigh absolute value of the tolerance above 0
* @param xLow the lower bound for x
* @param xHigh the upper bound for x
* @param maxNbIterations the maximum number of iterations possible
* @return the solution to the equation, throws an exception otherwise
* @throws ZeroDerivativeException
* @throws NoSolutionException
* @throws ExceededMaxIterationException
* @throws DivergenceException
* @throws InfiniteDerivativeException
*/
public RootFindingResult<Duration, Metadata> findRoot(
final Function<Duration, Metadata> f,
final History<Duration, Metadata> history,
final Duration x0,
final Duration y,
final Duration toleranceYLow,
final Duration toleranceYHigh,
final Duration xLow,
final Duration xHigh,
final int maxNbIterations)
throws ZeroDerivativeException, NoSolutionException, ExceededMaxIterationException, DivergenceException,
InfiniteDerivativeException
{
final var ff = new Function<Duration, History> (){
final var ff = new EquationSolvingAlgorithms.Function<Duration, Metadata>(){
@Override
public Duration valueAt(final Duration x, final History history) {
public Duration valueAt(final Duration x, final History<Duration, Metadata> history) throws EquationSolvingAlgorithms.DiscontinuityException
{
return f.valueAt(x, history).minus(y);
}
};

double x_nminus1_double = x0.in(Duration.MICROSECONDS);
double x_n_double = x1.in(Duration.MICROSECONDS);
var x_n = x1;
final var result = new EquationSolvingAlgorithms
.SecantDurationAlgorithm<Metadata>()
.findRoot(
ff,
history,
x0,
toleranceYLow,
toleranceYHigh,
xLow,
xHigh,
maxNbIterations);
return new RootFindingResult<>(new FunctionCoordinate<>(result.functionCoordinate.x(), result.functionCoordinate.fx().plus(y)), result.history);
}

/**
* Solves x s.t. f(x) = 0
*/
public RootFindingResult<Duration, Metadata> findRoot(
final Function<Duration, Metadata> f,
final History<Duration, Metadata> history,
final Duration x0,
final Duration toleranceYLow,
final Duration toleranceYHigh,
final Duration xLow,
final Duration xHigh,
final int maxNbIterations)
throws ZeroDerivativeException, InfiniteDerivativeException, ExceededMaxIterationException
{
final var xLow_long = xLow.in(Duration.MICROSECONDS);
final var xHigh_long = xHigh.in(Duration.MICROSECONDS);
final var resultX0 = nextValueAt(f, x0, xLow, xHigh, history, maxNbIterations);
int nbItPerformed = resultX0.nbIterationsPerformed();
var ff_x_nminus1 = resultX0.result().fx();
var x_nminus1 = resultX0.result().x();
double x_nminus1_double = x_nminus1.in(Duration.MICROSECONDS);

if (x_n_double < xLow_long || x_n_double > xHigh_long) {
throw new DivergenceException("Looking for root out of prescribed domain :[" + xLow + "," + xHigh + "]");
}
//We check whether the initial bounds might satisfy the exit criteria.
var ff_x_nminus1 = ff.valueAt(x0, history);
if (ff_x_nminus1.noShorterThan(Duration.negate(toleranceYLow)) && ff_x_nminus1.noLongerThan(toleranceYHigh)) {
return new RootFindingResult<>(x0, ff_x_nminus1.plus(y), history);
}
var ff_x_n = ff.valueAt(x_n, history);
if (ff_x_n.noShorterThan(Duration.negate(toleranceYLow)) && ff_x_n.noLongerThan(toleranceYHigh)) {
return new RootFindingResult<>(x_n, ff_x_n.plus(y), history);
if (ff_x_nminus1.between(Duration.negate(toleranceYLow), toleranceYHigh)) {
return new RootFindingResult<>(new FunctionCoordinate<>(x_nminus1, ff_x_nminus1), history);
}
// After these checks, we can be sure that if the two bounds are the same, the derivative will be 0, and thus throw an exception.
if (x0.isEqualTo(x1)) {
throw new NoSolutionException();
//optimistic heuristic based on the first evaluation: we assume the duration of the activity is constant
var x_n = x_nminus1.minus(ff_x_nminus1);
final var resultX1 = nextValueAt(f, x_n, xLow, xHigh, history, maxNbIterations - nbItPerformed);
nbItPerformed += resultX0.nbIterationsPerformed();
var ff_x_n = resultX1.result().fx();
x_n = resultX1.result().x();
double x_n_double = x_n.in(Duration.MICROSECONDS);
if (ff_x_n.between(Duration.negate(toleranceYLow), toleranceYHigh)) {
return new RootFindingResult<>(new FunctionCoordinate<>(x_n, ff_x_n), history);
}
for (int nbIt = 0; nbIt < maxNbIterations; nbIt++) {
while (nbItPerformed < maxNbIterations) {
//(f(xn) - f(xn_m1)) / (xn - xn_m1)
final double localDerivative =
(float) (ff_x_n.minus(ff_x_nminus1)).in(Duration.MICROSECONDS) / (x_n_double - x_nminus1_double);
if (localDerivative == 0) throw new ZeroDerivativeException();
if (Double.isNaN(localDerivative)) throw new InfiniteDerivativeException();
x_nminus1_double = x_n_double;
ff_x_nminus1 = ff_x_n;
//Note : xn_m2 is implicit here as it is used only for computing the derivative
//localDerivative has been computed with what is now xn_m1 and xn_m2
x_n_double = x_n_double - (ff_x_nminus1.in(Duration.MICROSECONDS) / localDerivative);
x_nminus1 = x_n;
x_n = Duration.of((long) x_n_double, Duration.MICROSECONDS);
ff_x_n = ff.valueAt(x_n, history);
if (x_n.isEqualTo(x_nminus1)) throw new InfiniteDerivativeException();
final var resultXn = nextValueAt(f, x_n, xLow, xHigh, history, maxNbIterations - nbItPerformed);
nbItPerformed += resultXn.nbIterationsPerformed();
ff_x_n = resultXn.result().fx();
x_n = resultXn.result().x();
x_n_double = x_n.in(Duration.MICROSECONDS);

//The final solution needs to be in the given bounds which is why this check is added here.
if (ff_x_n.noShorterThan(Duration.negate(toleranceYLow)) &&
ff_x_n.noLongerThan(toleranceYHigh) &&
if (ff_x_n.between(Duration.negate(toleranceYLow), toleranceYHigh) &&
(x_n_double >= xLow_long && x_n_double <= xHigh_long)){
logger.debug("Root found after " + nbIt + " iterations");
return new RootFindingResult<>(x_n, ff_x_n.plus(y), history);
logger.debug("Root found after " + nbItPerformed + " iterations");
return new RootFindingResult<>(new FunctionCoordinate<>(x_n, ff_x_n), history);
}
}
throw new ExceededMaxIterationException();
Expand Down
Loading

0 comments on commit 75190b0

Please sign in to comment.