Skip to content

Commit

Permalink
Implement RollingThreshold edsl function
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelCourtney committed Sep 5, 2023
1 parent 2a35aa4 commit 2db49eb
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,26 @@ private static JsonParser<Expression<Spans>> spansExpressionF(JsonParser<Express
untuple((kind, expression) -> new ViolationsOfWindows(expression)),
$ -> tuple(Unit.UNIT, $.expression));

public static final JsonParser<RollingThreshold.RollingThresholdAlgorithm> rollingThresholdAlgorithmP =
enumP(RollingThreshold.RollingThresholdAlgorithm.class, Enum::name);

static final JsonParser<RollingThreshold> rollingThresholdP =
productP
.field("kind", literalP("RollingThreshold"))
.field("spans", spansExpressionP)
.field("width", durationExprP)
.field("threshold", durationExprP)
.field("algorithm", rollingThresholdAlgorithmP)
.map(
untuple((kind, spans, width, threshold, alg) -> new RollingThreshold(spans, width, threshold, alg)),
$ -> tuple(Unit.UNIT, $.spans(), $.width(), $.threshold(), $.algorithm())
);


public static final JsonParser<Expression<ConstraintResult>> constraintP =
recursiveP(selfP -> chooseP(
forEachActivityViolationsF(selfP),
windowsExpressionP.map(ViolationsOfWindows::new, $ -> $.expression),
violationsOfP));
violationsOfP,
rollingThresholdP));
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import java.util.List;
import java.util.Set;

public record RollingThreshold(Expression<Spans> expression, Expression<Duration> width, Expression<Duration> threshold, RollingThresholdAlgorithm algorithm) implements Expression<ConstraintResult> {
public record RollingThreshold(Expression<Spans> spans, Expression<Duration> width, Expression<Duration> threshold, RollingThresholdAlgorithm algorithm) implements Expression<ConstraintResult> {

public enum RollingThresholdAlgorithm {
InputSpans,
Expand All @@ -25,7 +25,7 @@ public enum RollingThresholdAlgorithm {
@Override
public ConstraintResult evaluate(SimulationResults results, final Interval bounds, EvaluationEnvironment environment) {
final var width = this.width.evaluate(results, bounds, environment);
var spans = this.expression.evaluate(results, bounds, environment);
var spans = this.spans.evaluate(results, bounds, environment);
final var threshold = this.threshold.evaluate(results, bounds, environment);

final var accDuration = spans.accumulatedDuration(threshold);
Expand Down Expand Up @@ -65,7 +65,7 @@ public ConstraintResult evaluate(SimulationResults results, final Interval bound

@Override
public void extractResources(final Set<String> names) {
this.expression.extractResources(names);
this.spans.extractResources(names);
this.width.extractResources(names);
this.threshold.extractResources(names);
}
Expand All @@ -75,7 +75,7 @@ public String prettyPrint(final String prefix) {
return String.format(
"\n%s(rolling-threshold on %s, width %s, threshold %s, algorithm %s)",
prefix,
this.expression.prettyPrint(prefix + " "),
this.spans.prettyPrint(prefix + " "),
this.width.prettyPrint(prefix + " "),
this.threshold.prettyPrint(prefix + " "),
this.algorithm
Expand Down
13 changes: 11 additions & 2 deletions merlin-server/constraints-dsl-compiler/src/libs/constraints-ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ export enum NodeKind {
ViolationsOf = 'ViolationsOf',
AbsoluteInterval = 'AbsoluteInterval',
IntervalAlias = 'IntervalAlias',
IntervalDuration = 'IntervalDuration'
IntervalDuration = 'IntervalDuration',
RollingThreshold = 'RollingThreshold'
}

export type Constraint = ViolationsOf | WindowsExpression | SpansExpression | ForEachActivityConstraints;
export type Constraint = ViolationsOf | WindowsExpression | SpansExpression | ForEachActivityConstraints | RollingThreshold;

export interface ViolationsOf {
kind: NodeKind.ViolationsOf;
Expand All @@ -71,6 +72,14 @@ export interface ForEachActivitySpans {
expression: SpansExpression;
}

export interface RollingThreshold {
kind: NodeKind.RollingThreshold;
spans: SpansExpression,
width: Duration,
threshold: Duration,
algorithm: API.RollingThresholdAlgorithm
}

export interface AssignGapsExpression<P extends ProfileExpression> {
kind: NodeKind.AssignGapsExpression,
originalProfile: P,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,40 @@ export class Constraint {
expression: expression(new ActivityInstance(activityType, alias)).__astNode,
});
}

/**
* Detect when a spans object's cumulative duration exceeds a threshold within any interval of a given width.
*
* Violations can be reported in two different ways by setting the `algorithm` argument:
* - `RollingThresholdAlgorithm.Spans` highlights the individual spans that contributed to the threshold violation.
* - `RollingThresholdAlgorithm.Hull` highlights the single interval that contains all the violating spans.
*
* @param spans spans object to detect threshold events on
* @param width width of the rolling interval
* @param threshold maximum allowable duration within any `width` interval
* @param algorithm algorithm for reporting violations
* @constructor
*/
public static RollingThreshold(
spans: Spans,
width: AST.Duration,
threshold: AST.Duration,
algorithm: RollingThresholdAlgorithm
): Constraint {
return new Constraint({
kind: AST.NodeKind.RollingThreshold,
spans: spans.__astNode,
width,
threshold,
algorithm
});
}
}

/** Algorithm to use when reporting violations from rolling threshold */
export enum RollingThresholdAlgorithm {
Spans = 'Spans',
Hull = 'Hull'
}

/** A boolean profile; a function from time to truth values. */
Expand Down Expand Up @@ -105,7 +139,7 @@ export class Windows {
public static During(...activityTypes: Gen.ActivityType[]) : Windows {
return Windows.Or(
...activityTypes.map<Windows>((activityType) =>
Spans.ForEachActivity(activityType, (activity) => activity.span()).windows())
Spans.ForEachActivity(activityType).windows())
);
}

Expand Down Expand Up @@ -464,13 +498,14 @@ export class Spans {
* Check a constraint for each instance of an activity type.
*
* @param activityType activity type to check
* @param expression function of an activity instance that returns a Constraint
* @param expression function of an activity instance that returns a Constraint; default returns the instance's span.
* @constructor
*/
public static ForEachActivity<A extends Gen.ActivityType>(
activityType: A,
expression: (instance: ActivityInstance<A>) => Spans,
expression?: (instance: ActivityInstance<A>) => Spans,
): Spans {
if (expression === undefined) expression = instance => instance.span();
let alias = 'span activity alias ' + Spans.__numGeneratedAliases;
Spans.__numGeneratedAliases += 1;
return new Spans({
Expand Down Expand Up @@ -1039,8 +1074,8 @@ declare global {
* @constructor
*/
public static ForbiddenActivityOverlap(
activityType1: Gen.ActivityType,
activityType2: Gen.ActivityType,
activityType1: Gen.ActivityType,
activityType2: Gen.ActivityType,
): Constraint;

/**
Expand All @@ -1054,6 +1089,32 @@ declare global {
activityType: A,
expression: (instance: ActivityInstance<A>) => Constraint,
): Constraint;

/**
* Detect when a spans object's cumulative duration exceeds a threshold within any interval of a given width.
*
* Violations can be reported in two different ways by setting the `algorithm` argument:
* - `RollingThresholdAlgorithm.Spans` highlights the individual spans that contributed to the threshold violation.
* - `RollingThresholdAlgorithm.Hull` highlights the single interval that contains all the violating spans.
*
* @param spans spans object to detect threshold events on
* @param width width of the rolling interval
* @param threshold maximum allowable duration within any `width` interval
* @param algorithm algorithm for reporting violations
* @constructor
*/
public static RollingThreshold(
spans: Spans,
width: AST.Duration,
threshold: AST.Duration,
algorithm: RollingThresholdAlgorithm
): Constraint;
}

/** Algorithm to use when reporting violations from rolling threshold */
export enum RollingThresholdAlgorithm {
Spans = 'Spans',
Hull = 'Hull'
}

/** A boolean profile; a function from time to truth values. */
Expand Down Expand Up @@ -1261,12 +1322,12 @@ declare global {
* Applies an expression producing spans for each instance of an activity type and returns the aggregated set of spans.
*
* @param activityType activity type to check
* @param expression function of an activity instance that returns a Spans
* @param expression function of an activity instance that returns a Spans; default returns the instance's span.
* @constructor
*/
public static ForEachActivity<A extends Gen.ActivityType>(
activityType: A,
expression: (instance: ActivityInstance<A>) => Spans,
expression?: (instance: ActivityInstance<A>) => Spans,
): Spans;

/**
Expand Down Expand Up @@ -1523,4 +1584,4 @@ declare global {
}

// Make Constraint available on the global object
Object.assign(globalThis, { Constraint, Windows, Spans, Real, Discrete, Inclusivity, Interval });
Object.assign(globalThis, { Constraint, Windows, Spans, Real, Discrete, Inclusivity, Interval, RollingThresholdAlgorithm });
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import gov.nasa.jpl.aerie.constraints.tree.RealParameter;
import gov.nasa.jpl.aerie.constraints.tree.RealResource;
import gov.nasa.jpl.aerie.constraints.tree.RealValue;
import gov.nasa.jpl.aerie.constraints.tree.RollingThreshold;
import gov.nasa.jpl.aerie.constraints.tree.ShiftBy;
import gov.nasa.jpl.aerie.constraints.tree.ShiftWindowsEdges;
import gov.nasa.jpl.aerie.constraints.tree.ShorterThan;
Expand Down Expand Up @@ -1275,4 +1276,30 @@ export default() => {
);
}

@Test
void testRollingThreshold() {
checkSuccessfulCompilation(
"""
export default () => {
return Constraint.RollingThreshold(
Spans.ForEachActivity(ActivityType.activity),
Temporal.Duration.from({hours: 1}),
Temporal.Duration.from({minutes: 5}),
RollingThresholdAlgorithm.Hull
);
}
""",
new RollingThreshold(
new ForEachActivitySpans(
"activity",
"span activity alias 0",
new ActivitySpan("span activity alias 0")
),
new DurationLiteral(Duration.of(1, Duration.HOUR)),
new DurationLiteral(Duration.of(5, Duration.MINUTE)),
RollingThreshold.RollingThresholdAlgorithm.Hull
)
);
}

}

0 comments on commit 2db49eb

Please sign in to comment.