diff --git a/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/Duration.java b/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/Duration.java index 2883953668..e8680f788b 100644 --- a/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/Duration.java +++ b/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/Duration.java @@ -439,6 +439,16 @@ public boolean noShorterThan(final Duration other) { return this.compareTo(other) >= 0; } + /** + * Determine whether this duration lies in [lower bound, upper bound], bounds comprised + * @param lowerBound the lower bound, inclusive + * @param upperBound the upper bound, inclusive + * @return true if lies in the interval, false otherwise + */ + public boolean between(final Duration lowerBound, final Duration upperBound){ + return (this.noShorterThan(lowerBound) && this.noLongerThan(upperBound)); + } + /** * Determine whether this duration is longer than another. * diff --git a/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/EquationSolvingAlgorithms.java b/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/EquationSolvingAlgorithms.java index ea24a04833..6a9eb69f63 100644 --- a/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/EquationSolvingAlgorithms.java +++ b/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/EquationSolvingAlgorithms.java @@ -1,14 +1,22 @@ package gov.nasa.jpl.aerie.scheduler; import gov.nasa.jpl.aerie.merlin.protocol.types.Duration; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; +import java.util.Optional; +import java.util.random.RandomGenerator; +import java.util.random.RandomGeneratorFactory; + public class EquationSolvingAlgorithms { private static final Logger logger = LoggerFactory.getLogger(EquationSolvingAlgorithms.class); - public record RootFindingResult(T x, T fx, History history){} + public record FunctionCoordinate(T x, T fx){} + + public record RootFindingResult(FunctionCoordinate functionCoordinate, History history){} /** * Solves f(x) = y for x in [xLow, xHigh] with confidence interval [yLow, yHigh] around y such that we stop when @@ -16,28 +24,42 @@ public record RootFindingResult(T x, T fx, History history){} * x0 and x1 are initial guesses for x, they must be close to the solution to avoid diverging * It is considered that algorithm is diverging when the iterated value of x goes out of [xLow, xHigh]. */ - public interface SecantAlgorithm{ - RootFindingResult findRoot(Function f, - History history, + public interface SecantAlgorithm{ + RootFindingResult findRoot(Function f, + History history, T x0, - T x1, T y, T toleranceYLow, T toleranceYHigh, T xLow, T xHigh, - int maxNbIterations) throws ZeroDerivativeException, InfiniteDerivativeException, DivergenceException, - ExceededMaxIterationException, NoSolutionException; + int maxNbIterations) throws + ZeroDerivativeException, + InfiniteDerivativeException, + DivergenceException, + ExceededMaxIterationException, + NoSolutionException; + } + + public interface Function { + T valueAt(T x, History history) throws DiscontinuityException; } - public interface Function { - T valueAt(T x, History historyType); + public interface History{ + void add(FunctionCoordinate functionCoordinate, Metadata metadata); + List, Optional>> getHistory(); + Optional, Optional>> getLastEvent(); + boolean alreadyVisited(T x); } public static class ZeroDerivativeException extends Exception{ public ZeroDerivativeException() {} } + public static class DiscontinuityException extends Exception{ + public DiscontinuityException(){} + } + public static class InfiniteDerivativeException extends Exception{ public InfiniteDerivativeException() {} } @@ -47,11 +69,6 @@ public DivergenceException(String errorMessage) { super(errorMessage); } } - public static class WrongBracketingException extends Exception{ - public WrongBracketingException(String errorMessage) { - super(errorMessage); - } - } public static class ExceededMaxIterationException extends Exception{ public ExceededMaxIterationException() { @@ -63,74 +80,184 @@ public static class NoSolutionException extends Exception{ public NoSolutionException() { super(); } - public NoSolutionException(String errorMessage) { - super(errorMessage); - } } - public static class SecantDurationAlgorithm implements SecantAlgorithm{ - - public RootFindingResult findRoot( - Function f, - History history, - Duration x0, - Duration x1, - Duration y, - Duration toleranceYLow, - Duration toleranceYHigh, - Duration xLow, - Duration xHigh, - int maxNbIterations) - throws ZeroDerivativeException, InfiniteDerivativeException, DivergenceException, ExceededMaxIterationException, NoSolutionException + public static class SecantDurationAlgorithm implements SecantAlgorithm{ + + private final RandomGenerator randomGenerator = RandomGeneratorFactory.of("Random").create(956756789); + + /** + * Randomly selects a value in the interval [bound1, bound2] + * @param bound1 the first bound + * @param bound2 the second bound + * @return a value chosen randomly + */ + private Duration chooseRandomX(final Duration bound1, final Duration bound2){ + var low = bound1; + var high = bound2; + if(low.isEqualTo(high)) return low; + if(bound1.longerThan(bound2)) { low = bound2; high = bound1; } + return Duration.of( + randomGenerator.nextLong(low.in(Duration.MICROSECONDS), high.in(Duration.MICROSECONDS)), + Duration.MICROSECONDS); + } + + private record IteratingResult(FunctionCoordinate result, int nbIterationsPerformed){} + + /** + * Querying Function.valueAt may lead to a discontinuity. This procedure starts at an initial x value + * and stops only when the value returned is not a discontinuity or the maximum number of iterations has been reached + * Kind of an infaillible valueAt with a limited number of iterations + * @param function the function we are trying to call + * @param init the initial x value + * @param min the lower bound of the domain of x + * @param max the upper bound of the domain of x + * @param history the querying history of f + * @param maxIteration the maximum number of iteration possible + * @return a coordinate (x, f(x)) s.t. f is continuous at x. + * @throws ExceededMaxIterationException + */ + private IteratingResult nextValueAt( + final Function function, + final Duration init, + final Duration min, + final Duration max, + final History history, + final int maxIteration) + throws ExceededMaxIterationException + { + var cur = init; + int i = 0; + do { + //we should not come back to previously visited values + if (!history.alreadyVisited(cur)) { + i++; + try { + final var value = function.valueAt(cur, history); + return new IteratingResult(new FunctionCoordinate<>(cur, value), i); + } catch (DiscontinuityException e) { + //nothing, keep iterating + } + } + cur = chooseRandomX(min, max); + //if min == max, another call to random will have no effect and thus we should exit + } while(i < maxIteration && !min.isEqualTo(max)); + throw new ExceededMaxIterationException(); + } + + /** + * Solves x s.t. f(x) = y by transforming it to the equivalent rootfinding problem x s.t. f(x) - y = 0 + * @param f the function + * @param history + * @param x0 one of the initial x value + * @param y the objective + * @param toleranceYLow absolute value of the tolerance below 0 + * @param toleranceYHigh absolute value of the tolerance above 0 + * @param xLow the lower bound for x + * @param xHigh the upper bound for x + * @param maxNbIterations the maximum number of iterations possible + * @return the solution to the equation, throws an exception otherwise + * @throws ZeroDerivativeException + * @throws NoSolutionException + * @throws ExceededMaxIterationException + * @throws DivergenceException + * @throws InfiniteDerivativeException + */ + public RootFindingResult findRoot( + final Function f, + final History history, + final Duration x0, + final Duration y, + final Duration toleranceYLow, + final Duration toleranceYHigh, + final Duration xLow, + final Duration xHigh, + final int maxNbIterations) + throws ZeroDerivativeException, NoSolutionException, ExceededMaxIterationException, DivergenceException, + InfiniteDerivativeException { - final var ff = new Function (){ + final var ff = new EquationSolvingAlgorithms.Function(){ @Override - public Duration valueAt(final Duration x, final History history) { + public Duration valueAt(final Duration x, final History history) throws EquationSolvingAlgorithms.DiscontinuityException + { return f.valueAt(x, history).minus(y); } }; - double x_nminus1_double = x0.in(Duration.MICROSECONDS); - double x_n_double = x1.in(Duration.MICROSECONDS); - var x_n = x1; + final var result = new EquationSolvingAlgorithms + .SecantDurationAlgorithm() + .findRoot( + ff, + history, + x0, + toleranceYLow, + toleranceYHigh, + xLow, + xHigh, + maxNbIterations); + return new RootFindingResult<>(new FunctionCoordinate<>(result.functionCoordinate.x(), result.functionCoordinate.fx().plus(y)), result.history); + } + + /** + * Solves x s.t. f(x) = 0 + */ + public RootFindingResult findRoot( + final Function f, + final History history, + final Duration x0, + final Duration toleranceYLow, + final Duration toleranceYHigh, + final Duration xLow, + final Duration xHigh, + final int maxNbIterations) + throws ZeroDerivativeException, InfiniteDerivativeException, ExceededMaxIterationException + { final var xLow_long = xLow.in(Duration.MICROSECONDS); final var xHigh_long = xHigh.in(Duration.MICROSECONDS); + final var resultX0 = nextValueAt(f, x0, xLow, xHigh, history, maxNbIterations); + int nbItPerformed = resultX0.nbIterationsPerformed(); + var ff_x_nminus1 = resultX0.result().fx(); + var x_nminus1 = resultX0.result().x(); + double x_nminus1_double = x_nminus1.in(Duration.MICROSECONDS); - if (x_n_double < xLow_long || x_n_double > xHigh_long) { - throw new DivergenceException("Looking for root out of prescribed domain :[" + xLow + "," + xHigh + "]"); - } //We check whether the initial bounds might satisfy the exit criteria. - var ff_x_nminus1 = ff.valueAt(x0, history); - if (ff_x_nminus1.noShorterThan(Duration.negate(toleranceYLow)) && ff_x_nminus1.noLongerThan(toleranceYHigh)) { - return new RootFindingResult<>(x0, ff_x_nminus1.plus(y), history); - } - var ff_x_n = ff.valueAt(x_n, history); - if (ff_x_n.noShorterThan(Duration.negate(toleranceYLow)) && ff_x_n.noLongerThan(toleranceYHigh)) { - return new RootFindingResult<>(x_n, ff_x_n.plus(y), history); + if (ff_x_nminus1.between(Duration.negate(toleranceYLow), toleranceYHigh)) { + return new RootFindingResult<>(new FunctionCoordinate<>(x_nminus1, ff_x_nminus1), history); } - // After these checks, we can be sure that if the two bounds are the same, the derivative will be 0, and thus throw an exception. - if (x0.isEqualTo(x1)) { - throw new NoSolutionException(); + //optimistic heuristic based on the first evaluation: we assume the duration of the activity is constant + var x_n = x_nminus1.minus(ff_x_nminus1); + final var resultX1 = nextValueAt(f, x_n, xLow, xHigh, history, maxNbIterations - nbItPerformed); + nbItPerformed += resultX0.nbIterationsPerformed(); + var ff_x_n = resultX1.result().fx(); + x_n = resultX1.result().x(); + double x_n_double = x_n.in(Duration.MICROSECONDS); + if (ff_x_n.between(Duration.negate(toleranceYLow), toleranceYHigh)) { + return new RootFindingResult<>(new FunctionCoordinate<>(x_n, ff_x_n), history); } - for (int nbIt = 0; nbIt < maxNbIterations; nbIt++) { + while (nbItPerformed < maxNbIterations) { //(f(xn) - f(xn_m1)) / (xn - xn_m1) final double localDerivative = (float) (ff_x_n.minus(ff_x_nminus1)).in(Duration.MICROSECONDS) / (x_n_double - x_nminus1_double); if (localDerivative == 0) throw new ZeroDerivativeException(); - if (Double.isNaN(localDerivative)) throw new InfiniteDerivativeException(); x_nminus1_double = x_n_double; ff_x_nminus1 = ff_x_n; //Note : xn_m2 is implicit here as it is used only for computing the derivative //localDerivative has been computed with what is now xn_m1 and xn_m2 x_n_double = x_n_double - (ff_x_nminus1.in(Duration.MICROSECONDS) / localDerivative); + x_nminus1 = x_n; x_n = Duration.of((long) x_n_double, Duration.MICROSECONDS); - ff_x_n = ff.valueAt(x_n, history); + if (x_n.isEqualTo(x_nminus1)) throw new InfiniteDerivativeException(); + final var resultXn = nextValueAt(f, x_n, xLow, xHigh, history, maxNbIterations - nbItPerformed); + nbItPerformed += resultXn.nbIterationsPerformed(); + ff_x_n = resultXn.result().fx(); + x_n = resultXn.result().x(); + x_n_double = x_n.in(Duration.MICROSECONDS); + //The final solution needs to be in the given bounds which is why this check is added here. - if (ff_x_n.noShorterThan(Duration.negate(toleranceYLow)) && - ff_x_n.noLongerThan(toleranceYHigh) && + if (ff_x_n.between(Duration.negate(toleranceYLow), toleranceYHigh) && (x_n_double >= xLow_long && x_n_double <= xHigh_long)){ - logger.debug("Root found after " + nbIt + " iterations"); - return new RootFindingResult<>(x_n, ff_x_n.plus(y), history); + logger.debug("Root found after " + nbItPerformed + " iterations"); + return new RootFindingResult<>(new FunctionCoordinate<>(x_n, ff_x_n), history); } } throw new ExceededMaxIterationException(); diff --git a/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/solver/PrioritySolver.java b/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/solver/PrioritySolver.java index a33d75c673..1e6d898612 100644 --- a/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/solver/PrioritySolver.java +++ b/scheduler-driver/src/main/java/gov/nasa/jpl/aerie/scheduler/solver/PrioritySolver.java @@ -88,26 +88,39 @@ public class PrioritySolver implements Solver { private final SimulationFacade simulationFacade; - public record EventWithActivity(Duration start, Duration end, SchedulingActivityDirective activity){} - - public static class HistoryWithActivity{ - List events; + public record ActivityMetadata(SchedulingActivityDirective activityDirective){} + public static class HistoryWithActivity implements EquationSolvingAlgorithms.History { + List, Optional>> events; public HistoryWithActivity(){ events = new ArrayList<>(); } - public void add(EventWithActivity event){ - this.events.add(event); + public void add(EquationSolvingAlgorithms.FunctionCoordinate functionCoordinate, ActivityMetadata activityMetadata){ + this.events.add(Pair.of(functionCoordinate, Optional.ofNullable(activityMetadata))); } - public Optional getLastEvent(){ - if(!events.isEmpty()) return Optional.of(events.get(events.size()-1)); - return Optional.empty(); + + @Override + public List, Optional>> getHistory() { + return events; + } + + public Optional, Optional>> getLastEvent(){ + if(events.isEmpty()) return Optional.empty(); + return Optional.of(events.get(events.size() - 1)); + } + + @Override + public boolean alreadyVisited(final Duration x) { + for(final var event:events){ + if(event.getLeft().x().isEqualTo(x)) return true; + } + return false; } public void logHistory(){ logger.info("Rootfinding history"); for(final var event: events){ - logger.info("Start:" + event.start + " end:" + (event.end==null ? "error" : event.end)); + logger.info("Start:" + event.getLeft().x() + " end:" + (event.getLeft().fx()==null ? "error" : event.getLeft().fx())); } } } @@ -979,9 +992,11 @@ private Optional instantiateActivity( //CASE 1: activity has an uncontrollable duration if(activityExpression.type().getDurationType() instanceof DurationType.Uncontrollable){ final var history = new HistoryWithActivity(); - final var f = new EquationSolvingAlgorithms.Function(){ + final var f = new EquationSolvingAlgorithms.Function(){ @Override - public Duration valueAt(Duration start, HistoryWithActivity history) { + public Duration valueAt(Duration start, final EquationSolvingAlgorithms.History history) + throws EquationSolvingAlgorithms.DiscontinuityException + { final var latestConstraintsSimulationResults = getLatestSimResultsUpTo(start); final var actToSim = SchedulingActivityDirective.of( activityExpression.type(), @@ -999,21 +1014,21 @@ public Duration valueAt(Duration start, HistoryWithActivity history) { final var lastInsertion = history.getLastEvent(); Optional computedDuration = Optional.empty(); final var toRemove = new ArrayList(); - lastInsertion.ifPresent(eventWithActivity -> toRemove.add(eventWithActivity.activity())); + lastInsertion.ifPresent(eventWithActivity -> toRemove.add(eventWithActivity.getValue().get().activityDirective())); try { simulationFacade.removeAndInsertActivitiesFromSimulation(toRemove, List.of(actToSim)); computedDuration = simulationFacade.getActivityDuration(actToSim); if(computedDuration.isPresent()) { - history.add(new EventWithActivity(start, start.plus(computedDuration.get()), actToSim)); + history.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(start, start.plus(computedDuration.get())), new ActivityMetadata(actToSim)); } else{ logger.debug("No simulation error but activity duration could not be found in simulation, likely caused by unfinished activity."); - history.add(new EventWithActivity(start, null, actToSim)); + history.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(start, null), new ActivityMetadata(actToSim)); } } catch (SimulationFacade.SimulationException e) { logger.debug("Simulation error while trying to simulate activities: " + e); - history.add(new EventWithActivity(start, null, actToSim)); + history.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(start, null), new ActivityMetadata(actToSim)); } - return computedDuration.map(start::plus).orElse(Duration.MAX_VALUE); + return computedDuration.map(start::plus).orElseThrow(EquationSolvingAlgorithms.DiscontinuityException::new); } }; @@ -1085,9 +1100,9 @@ public Duration valueAt(Duration start, HistoryWithActivity history) { true)); } else if (activityExpression.type().getDurationType() instanceof DurationType.Parametric dt) { final var history = new HistoryWithActivity(); - final var f = new EquationSolvingAlgorithms.Function() { + final var f = new EquationSolvingAlgorithms.Function() { @Override - public Duration valueAt(final Duration start, final HistoryWithActivity history) { + public Duration valueAt(final Duration start, final EquationSolvingAlgorithms.History history) { final var instantiatedArgs = SchedulingActivityDirective.instantiateArguments( activityExpression.arguments(), start, @@ -1104,7 +1119,7 @@ public Duration valueAt(final Duration start, final HistoryWithActivity history) null, null, true); - history.add(new EventWithActivity(start, start.plus(duration), activity)); + history.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(start, start.plus(duration)), new ActivityMetadata(activity)); return duration.plus(start); } catch (InstantiationException e) { logger.error("Cannot instantiate parameterized duration activity type: " + activityExpression.type().getName()); @@ -1120,7 +1135,7 @@ public Duration valueAt(final Duration start, final HistoryWithActivity history) } private Optional rootFindingHelper( - final EquationSolvingAlgorithms.Function f, + final EquationSolvingAlgorithms.Function f, final HistoryWithActivity history, final TaskNetworkAdapter.TNActData solved) { try { @@ -1130,12 +1145,11 @@ private Optional rootFindingHelper( final var durationHalfEndInterval = endInterval.duration().dividedBy(2); final var result = new EquationSolvingAlgorithms - .SecantDurationAlgorithm() + .SecantDurationAlgorithm() .findRoot( f, history, startInterval.start, - startInterval.end, endInterval.start.plus(durationHalfEndInterval), durationHalfEndInterval, durationHalfEndInterval, @@ -1144,10 +1158,10 @@ private Optional rootFindingHelper( 20); // TODO: When scheduling is allowed to create activities with anchors, this constructor should pull from an expanded creation template - final var lastActivityTested = result.history().getLastEvent(); logger.info("Finished rootfinding: SUCCESS"); history.logHistory(); - return Optional.of(lastActivityTested.get().activity); + final var lastActivityTested = result.history().getHistory().get(history.getHistory().size() - 1); + return Optional.of(lastActivityTested.getRight().get().activityDirective()); } catch (EquationSolvingAlgorithms.ZeroDerivativeException zeroOrInfiniteDerivativeException) { logger.info("Rootfinding encountered a zero-derivative"); } catch (EquationSolvingAlgorithms.InfiniteDerivativeException infiniteDerivativeException) { @@ -1161,7 +1175,7 @@ private Optional rootFindingHelper( } if(!history.events.isEmpty()) { try { - simulationFacade.removeActivitiesFromSimulation(List.of(history.getLastEvent().get().activity())); + simulationFacade.removeActivitiesFromSimulation(List.of(history.getLastEvent().get().getRight().get().activityDirective())); } catch (SimulationFacade.SimulationException e) { throw new RuntimeException("Exception while simulating original plan after activity insertion failure" ,e); } diff --git a/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/RootfindingTest.java b/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/RootfindingTest.java new file mode 100644 index 0000000000..b8418b37c3 --- /dev/null +++ b/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/RootfindingTest.java @@ -0,0 +1,170 @@ +package gov.nasa.jpl.aerie.scheduler; + +import gov.nasa.jpl.aerie.merlin.protocol.types.Duration; +import gov.nasa.jpl.aerie.scheduler.solver.PrioritySolver; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RootfindingTest { + final Duration zeroSecond = Duration.of(0, Duration.SECONDS); + final static Duration oneSecond = Duration.of(1, Duration.SECONDS); + final Duration twoSecond = Duration.of(2, Duration.SECONDS); + final Duration threeSecond = Duration.of(3, Duration.SECONDS); + final Duration thirtySecond = Duration.of(30, Duration.SECONDS); + + //compared with the testSimpleDiscontinuous, the function is discontinuous for all odd x, rootfinding is hitting + //discontinuous values multiple times + @Test + void testHighlyDiscontinuous() + throws EquationSolvingAlgorithms.ZeroDerivativeException, EquationSolvingAlgorithms.NoSolutionException, + EquationSolvingAlgorithms.ExceededMaxIterationException, EquationSolvingAlgorithms.DivergenceException, + EquationSolvingAlgorithms.InfiniteDerivativeException + { + final var durationFunctionDiscontinuousAtEverySecond = + new EquationSolvingAlgorithms.Function() { + @Override + public Duration valueAt( + final Duration x, + final EquationSolvingAlgorithms.History historyType) + throws EquationSolvingAlgorithms.DiscontinuityException + { + if (x.in(Duration.MICROSECONDS) % 2 != 0) { + throw new EquationSolvingAlgorithms.DiscontinuityException(); + } + final var ret = x.times(2); + historyType.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(x, ret), null); + return ret; + } + }; + final var alg = new EquationSolvingAlgorithms.SecantDurationAlgorithm(); + final var history = new PrioritySolver.HistoryWithActivity(); + final var solution = alg.findRoot( + durationFunctionDiscontinuousAtEverySecond, + history, + oneSecond, + Duration.of(39, Duration.SECONDS).plus(121, Duration.MICROSECONDS), + Duration.of(50, Duration.MICROSECONDS), + Duration.of(50, Duration.MICROSECONDS), + zeroSecond, + thirtySecond, + 100); + assertEquals(3, solution.history().getHistory().size()); + assertEquals(new EquationSolvingAlgorithms.FunctionCoordinate<>(Duration.of(19500060, Duration.MICROSECONDS), Duration.of(39000120, Duration.MICROSECONDS)), solution.functionCoordinate()); + } + + + @Test + //this is reproducing issue 1139 : + // f(x0) throws an exception + inf val which leads in the end to a ZeroDerivative + public void testSimpleDiscontinuous() + throws EquationSolvingAlgorithms.ZeroDerivativeException, EquationSolvingAlgorithms.NoSolutionException, + EquationSolvingAlgorithms.ExceededMaxIterationException, EquationSolvingAlgorithms.DivergenceException, + EquationSolvingAlgorithms.InfiniteDerivativeException + { + final var alg = new EquationSolvingAlgorithms.SecantDurationAlgorithm(); + + //function only discontinuous at x = 1 + final var durationFunctionDiscontinuousAtOne = + new EquationSolvingAlgorithms.Function() { + @Override + public Duration valueAt( + final Duration x, + final EquationSolvingAlgorithms.History historyType) + throws EquationSolvingAlgorithms.DiscontinuityException + { + if (x.isEqualTo(oneSecond)) { + throw new EquationSolvingAlgorithms.DiscontinuityException(); + } + final var ret = x.times(2); + historyType.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(x, ret), null); + return ret; + } + }; + + final var history = new PrioritySolver.HistoryWithActivity(); + final var solution = alg.findRoot( + durationFunctionDiscontinuousAtOne, + history, + oneSecond, + Duration.of(3, Duration.SECONDS).plus(500, Duration.MILLISECONDS), + Duration.of(50, Duration.MICROSECONDS), + Duration.of(50, Duration.MICROSECONDS), + zeroSecond, + threeSecond, + 10); + assertEquals(3, solution.history().getHistory().size()); + assertEquals(new EquationSolvingAlgorithms.FunctionCoordinate<>(Duration.of(1750000, Duration.MICROSECONDS), Duration.of(3500000, Duration.MICROSECONDS)), solution.functionCoordinate()); + } + + @Test + public void squareZeros() + throws EquationSolvingAlgorithms.ZeroDerivativeException, + EquationSolvingAlgorithms.ExceededMaxIterationException, + EquationSolvingAlgorithms.InfiniteDerivativeException + { + final var alg = new EquationSolvingAlgorithms.SecantDurationAlgorithm(); + //f(x) = x^2 + final var squareFunc = + new EquationSolvingAlgorithms.Function() { + @Override + public Duration valueAt( + final Duration x, + final EquationSolvingAlgorithms.History historyType) { + final var ret = Duration.of((long) Math.pow(x.in(Duration.MICROSECONDS),2), Duration.MICROSECONDS); + historyType.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(x, ret), null); + return ret; + } + }; + + final var history = new PrioritySolver.HistoryWithActivity(); + final var solution = alg.findRoot( + squareFunc, + history, + Duration.of(-2, Duration.SECONDS), + Duration.of(0, Duration.MICROSECONDS), + Duration.of(0, Duration.MICROSECONDS), + Duration.of(-2, Duration.SECONDS), + twoSecond, + 100); + assertEquals(29, solution.history().getHistory().size()); + assertEquals(new EquationSolvingAlgorithms.FunctionCoordinate<>(Duration.of(0, Duration.MICROSECONDS), Duration.of(0, Duration.MICROSECONDS)), solution.functionCoordinate()); + } + + + @Test + public void floorZeros() + throws EquationSolvingAlgorithms.ZeroDerivativeException, + EquationSolvingAlgorithms.ExceededMaxIterationException, + EquationSolvingAlgorithms.InfiniteDerivativeException + { + final var alg = new EquationSolvingAlgorithms.SecantDurationAlgorithm(); + final var floorFunc = + new EquationSolvingAlgorithms.Function() { + @Override + public Duration valueAt( + final Duration x, + final EquationSolvingAlgorithms.History historyType) + throws EquationSolvingAlgorithms.DiscontinuityException + { + if(x.in(Duration.SECONDS) == 1) throw new EquationSolvingAlgorithms.DiscontinuityException(); + final var ret = Duration.of(x.dividedBy(Duration.SECONDS), Duration.SECONDS); + historyType.add(new EquationSolvingAlgorithms.FunctionCoordinate<>(x, ret), null); + return ret; + } + }; + + final var history = new PrioritySolver.HistoryWithActivity(); + final var solution = alg.findRoot( + floorFunc, + history, + Duration.of(-5, Duration.SECONDS), + Duration.of(0, Duration.MICROSECONDS), + Duration.of(0, Duration.MICROSECONDS), + Duration.of(-2, Duration.SECONDS), + twoSecond, + 100); + assertEquals(2, solution.history().getHistory().size()); + assertEquals(new EquationSolvingAlgorithms.FunctionCoordinate<>(Duration.of(0, Duration.MICROSECONDS), Duration.of(0, Duration.MICROSECONDS)), solution.functionCoordinate()); + } +} diff --git a/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/TestApplyWhen.java b/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/TestApplyWhen.java index bfb5207c65..9759a0a46a 100644 --- a/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/TestApplyWhen.java +++ b/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/TestApplyWhen.java @@ -470,7 +470,7 @@ public void testRecurrenceCutoffUncontrollable() { assertTrue(TestUtility.activityStartingAtTime(plan,Duration.of(6, Duration.SECONDS), activityType)); assertFalse(TestUtility.activityStartingAtTime(plan,Duration.of(11, Duration.SECONDS), activityType)); assertFalse(TestUtility.activityStartingAtTime(plan,Duration.of(16, Duration.SECONDS), activityType)); - assertEquals(8, problem.getSimulationFacade().countSimulationRestarts()); + assertEquals(5, problem.getSimulationFacade().countSimulationRestarts()); } diff --git a/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/UncontrollableDurationTest.java b/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/UncontrollableDurationTest.java index 4c1202bbc5..d79ab600a9 100644 --- a/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/UncontrollableDurationTest.java +++ b/scheduler-driver/src/test/java/gov/nasa/jpl/aerie/scheduler/UncontrollableDurationTest.java @@ -1,5 +1,6 @@ package gov.nasa.jpl.aerie.scheduler; +import gov.nasa.jpl.aerie.constraints.time.Interval; import gov.nasa.jpl.aerie.constraints.time.Windows; import gov.nasa.jpl.aerie.constraints.tree.SpansFromWindows; import gov.nasa.jpl.aerie.constraints.tree.WindowsWrapperExpression; @@ -95,7 +96,7 @@ public void testNonLinear(){ assertTrue(TestUtility.containsActivity(plan, planningHorizon.fromStart("PT0S"), planningHorizon.fromStart("PT1M29S"), problem.getActivityType("SolarPanelNonLinear"))); assertTrue(TestUtility.containsActivity(plan, planningHorizon.fromStart("PT16M40S"), planningHorizon.fromStart("PT18M9S"), problem.getActivityType("SolarPanelNonLinear"))); assertTrue(TestUtility.containsActivity(plan, planningHorizon.fromStart("PT33M20S"), planningHorizon.fromStart("PT34M49S"), problem.getActivityType("SolarPanelNonLinear"))); - assertEquals(13, problem.getSimulationFacade().countSimulationRestarts()); + assertEquals(11, problem.getSimulationFacade().countSimulationRestarts()); } @Test @@ -219,10 +220,10 @@ public void testScheduleExceptionThrowingTask(){ final var plan = solver.getNextSolution().get(); //Activity can be started in [0, 2m] but this activity will throw an exception if ran in [0, 1m] so it is scheduled at 2m (as being the second bounds the rootfinding tries before search). assertTrue(TestUtility.containsActivity(plan, - planningHorizon.fromStart("PT120S"), - planningHorizon.fromStart("PT120S"), + planningHorizon.fromStart("PT1M38.886061S"), + planningHorizon.fromStart("PT1M38.886061S"), problem.getActivityType("LateRiser"))); - assertEquals(3, problem.getSimulationFacade().countSimulationRestarts()); + assertEquals(4, problem.getSimulationFacade().countSimulationRestarts()); } }