Skip to content

Commit

Permalink
Merge pull request #1559 from NASA-AMMOS/feature/generator-constraints
Browse files Browse the repository at this point in the history
Generator Constraints and Violation Messages
  • Loading branch information
JoelCourtney authored Sep 13, 2024
2 parents dbc8ff2 + f02a8dd commit 9af4b3c
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.microsoft.playwright.Playwright;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults;
import gov.nasa.jpl.aerie.e2e.utils.GatewayRequests;
import gov.nasa.jpl.aerie.e2e.utils.HasuraRequests;
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration;
Expand All @@ -12,7 +14,6 @@
import gov.nasa.ammos.aerie.procedural.timeline.payloads.Segment;
import gov.nasa.ammos.aerie.procedural.remote.AeriePostgresPlan;
import gov.nasa.ammos.aerie.procedural.remote.AeriePostgresSimulationResults;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulatedPlan;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -40,7 +41,8 @@ public class TimelineRemoteTests {
private int planId;
private int activityId;

private SimulatedPlan simulatedPlan;
private Plan plan;
private SimulationResults simResults;
private Connection connection;
private HikariDataSource dataSource;
@BeforeAll
Expand Down Expand Up @@ -101,9 +103,8 @@ void beforeEach() throws IOException, InterruptedException {

// Connect to the database

final var plan = new AeriePostgresPlan(connection, planId);
final var simResults = new AeriePostgresSimulationResults(connection, simDatasetId, plan, false);
simulatedPlan = new SimulatedPlan(plan, simResults);
plan = new AeriePostgresPlan(connection, planId);
simResults = new AeriePostgresSimulationResults(connection, simDatasetId, plan, false);
}

@AfterEach
Expand All @@ -114,7 +115,7 @@ void afterEach() throws IOException {

@Test
void queryActivityInstances() {
final var instances = simulatedPlan.instances().collect();
final var instances = simResults.instances().collect();
assertEquals(1, instances.size());
final var instance = instances.get(0);
assertEquals("BiteBanana", instance.getType());
Expand All @@ -126,7 +127,7 @@ void queryActivityInstances() {

@Test
void queryActivityDirectives() {
final var directives = simulatedPlan.directives().collect();
final var directives = plan.directives().collect();
assertEquals(1, directives.size());
final var directive = directives.get(0);
assertEquals("BiteBanana", directive.getType());
Expand All @@ -136,7 +137,7 @@ void queryActivityDirectives() {

@Test
void queryResources() {
final var fruit = simulatedPlan.resource("/fruit", Real.deserializer()).collect();
final var fruit = simResults.resource("/fruit", Real.deserializer()).collect();
assertIterableEquals(
List.of(
Segment.of(Interval.betweenClosedOpen(Duration.ZERO, Duration.HOUR), new LinearEquation(4.0)),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.CollectOptions
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulatedPlan
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults

/** The interface that all constraints must satisfy. */
interface Constraint {
Expand All @@ -12,7 +12,9 @@ interface Constraint {
* the constraint is run. The constraint does not need to use the options unless it collects a timeline prematurely.
*
* @param plan the plan to check the constraint on
* @param options the [CollectOptions] that the result will be collected with
* @param simResults the [SimulationResults] that the result will be collected with
*/
fun run(plan: SimulatedPlan, options: CollectOptions): Violations
fun run(plan: Plan, simResults: SimulationResults): Violations


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Windows
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Real
import gov.nasa.ammos.aerie.procedural.timeline.ops.GeneralOps
import gov.nasa.ammos.aerie.procedural.timeline.ops.ParallelOps
import gov.nasa.ammos.aerie.procedural.timeline.ops.SerialConstantOps
import gov.nasa.ammos.aerie.procedural.timeline.payloads.IntervalLike
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults

/**
* A generator-style implementation of [Constraint].
*
* The subclass must implement [generate], and within it call [violate] to produce violations.
* Or if you are using Kotlin, you can use the timeline extension functions such as [windows.violateInside()][violateInside]
* to more easily submit violations.
*/
abstract class GeneratorConstraint: Constraint {
private var violations = mutableListOf<Violation>()

/** Finalizes one or more intervals as violations. */
@JvmOverloads protected fun violate(vararg i: Interval, message: String? = null) {
violate(i.map { Violation(it) }, message)
}

/** Finalizes one or more violations. */
@JvmOverloads protected fun violate(vararg v: Violation, message: String? = null) {
violate(v.toList(), message)
}

/** Finalizes a list of violations. */
@JvmOverloads protected fun violate(l: List<Violation>, message: String? = null) {
violations.addAll(l.map {
if (it.message == null) Violation(
it.interval,
message ?: defaultMessage(),
it.ids
) else it
})
}

/** Collects a [Violations] timeline and finalizes the result. */
@JvmOverloads protected fun violate(tl: Violations, message: String? = null) {
violate(tl.collect(), message)
}

/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmOverloads protected fun <V: Any> SerialConstantOps<V, *>.violateOn(v: V, message: String? = null) = violate(Violations.on(this, v), message)

/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmOverloads protected fun Real.violateOn(n: Number, message: String? = null) = violate(Violations.on(this, n), message)

/**
* Creates a [Violations] object that violates on every object in the timeline.
*
* If the object is an activity, it will record the directive or instance id.
*/
@JvmOverloads protected fun <I: IntervalLike<I>> ParallelOps<I, *>.violateOnAll(message: String? = null) {
violate(Violations.onAll(this), message)
}

/** Creates a [Violations] object that violates inside each interval. */
@JvmOverloads protected fun Windows.violateInside(message: String? = null) = violate(Violations.inside(this), message)
/** Creates a [Violations] object that violates outside each interval. */
@JvmOverloads protected fun Windows.violateOutside(message: String? = null) = violate(Violations.outside(this), message)

/**
* Creates a [Violations] object from two timelines, that violates whenever they have overlap.
*
* If either object is an activity, it will record the directive or instance id.
*/
@JvmOverloads protected fun <V: IntervalLike<V>, W: IntervalLike<W>> GeneralOps<V, *>.violateWhenSimultaneous(other: GeneralOps<W, *>, message: String? = null) {
violate(Violations.whenSimultaneous(this, other), message)
}

/**
* A generator function that calls [violate] to produce violations.
*/
abstract fun generate(plan: Plan, simResults: SimulationResults)

/**
* Default violation message to be displayed to user.
*
* Can be overridden on a violation-by-violation basis by manually specifying
* it in the [Violation] object.
*/
open fun defaultMessage(): String? = null

final override fun run(plan: Plan, simResults: SimulationResults): Violations {
violations = mutableListOf()
generate(plan, simResults)
val message = defaultMessage()
return if (message == null) Violations(violations)
else Violations(violations).withDefaultMessage(message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ import gov.nasa.ammos.aerie.procedural.timeline.payloads.IntervalLike
import gov.nasa.jpl.aerie.types.ActivityId

/** A single violation of a constraint. */
data class Violation(
data class Violation @JvmOverloads constructor(
/** Interval on which the violation occurs. */
override val interval: Interval,

/** Violation message to be displayed to user. */
val message: String? = null,

/** List of associated activities (directives or instances) that are related to the violation. */
val ids: List<ActivityId> = listOf()
) : IntervalLike<Violation> {

override fun withNewInterval(i: Interval) = Violation(i, ids)
override fun withNewInterval(i: Interval) = Violation(i, message, ids)

/** Constructs a violation on the same interval with a different list of ids. */
fun withNewIds(vararg id: ActivityId) = Violation(interval, id.asList())
fun withNewIds(vararg id: ActivityId) = Violation(interval, message, id.asList())

/** Constructs a violation on the same interval with a different list of ids. */
fun withNewIds(ids: List<ActivityId>) = Violation(interval, ids)
fun withNewIds(ids: List<ActivityId>) = Violation(interval, message, ids)
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,51 @@ data class Violations(private val timeline: Timeline<Violation, Violations>):
*/
fun mapIds(f: (Violation) -> List<ActivityId>) = unsafeMap(BoundsTransformer.IDENTITY, false) { it.withNewIds(f(it)) }

/**
* Sets a default violation message for violations that don't already have one.
*
* @param message the default message to give to the user
*/
fun withDefaultMessage(message: String) = unsafeMap(BoundsTransformer.IDENTITY, false) {
if (it.message == null) Violation(it.interval, message, it.ids)
else it
}

/***/ companion object {
/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmStatic fun <V: Any> SerialConstantOps<V, *>.violateOn(v: V) = isolateEqualTo(v).violations()
/** Creates a [Violations] object that violates when the profile equals a given value. */
@JvmStatic fun <V: Any> on(tl: SerialConstantOps<V, *>, v: V) = onAll(tl.isolateEqualTo(v))

/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmStatic fun Real.violateOn(n: Number) = equalTo(n).violateOn(true)
@JvmStatic fun on(tl: Real, n: Number) = on(tl.equalTo(n), true)

/**
* Creates a [Violations] object that violates on every object in the timeline.
*
* If the object is an activity, it will record the directive or instance id.
*/
@JvmStatic fun <I: IntervalLike<I>> ParallelOps<I, *>.violations() =
unsafeMap(::Violations, BoundsTransformer.IDENTITY, false) {
@JvmStatic fun <I: IntervalLike<I>> onAll(tl: ParallelOps<I, *>) =
tl.unsafeMap(::Violations, BoundsTransformer.IDENTITY, false) {
Violation(
it.interval,
null,
listOfNotNull(it.getActivityId())
)
}

/** Creates a [Violations] object that violates inside each interval. */
@JvmStatic fun Windows.violateInside() = unsafeCast(::Universal).violations()
@JvmStatic fun inside(tl: Windows) = onAll(tl.unsafeCast(::Universal))
/** Creates a [Violations] object that violates outside each interval. */
@JvmStatic fun Windows.violateOutside() = complement().violateInside()
@JvmStatic fun outside(tl: Windows) = inside(tl.complement())

/**
* Creates a [Violations] object from two timelines, that violates whenever they have overlap.
*
* If either object is an activity, it will record the directive or instance id.
*/
@JvmStatic infix fun <V: IntervalLike<V>, W: IntervalLike<W>> GeneralOps<V, *>.mutex(other: GeneralOps<W, *>) =
unsafeMap2(::Violations, other) { l, r, i -> Violation(
@JvmStatic fun <V: IntervalLike<V>, W: IntervalLike<W>> whenSimultaneous(left: GeneralOps<V, *>, right: GeneralOps<W, *>) =
left.unsafeMap2(::Violations, right) { l, r, i -> Violation(
i,
null,
listOfNotNull(
l.getActivityId(),
r.getActivityId()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Numbers
import gov.nasa.ammos.aerie.procedural.timeline.ops.coalesce.CoalesceSegmentsOp
import gov.nasa.ammos.aerie.procedural.timeline.payloads.Segment
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults
import gov.nasa.ammos.aerie.procedural.timeline.util.duration.rangeTo
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration.seconds
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue
import org.junit.jupiter.api.Assertions.assertIterableEquals
import kotlin.test.Test

class GeneratorTest: GeneratorConstraint() {
override fun generate(plan: Plan, simResults: SimulationResults) {
violate(Interval.at(seconds(0)), message = "other message")
simResults.resource("/plant", Numbers.deserializer())
.greaterThan(0)
.violateOn(false)
}

override fun defaultMessage() = "Plant must be greater than 0"

@Test
fun testGenerator() {
val plan = NotImplementedPlan()
val simResults = object : NotImplementedSimulationResults() {
override fun <V : Any, TL : CoalesceSegmentsOp<V, TL>> resource(
name: String,
deserializer: (List<Segment<SerializedValue>>) -> TL
): TL {
if (name == "/plant") {
val list = listOf(
Segment(seconds(-4) .. seconds(-2), SerializedValue.of(-3)),
Segment(seconds(0) .. seconds(1), SerializedValue.of(3)),
Segment(seconds(1) .. seconds(2), SerializedValue.of(-1)),
)
return deserializer(list)
} else {
TODO("Not yet implemented")
}
}
}

val result = run(plan, simResults).collect()

val defaultMessage = "Plant must be greater than 0";
assertIterableEquals(
listOf(
Violation(seconds(-4) .. seconds(-2), defaultMessage),
Violation(Interval.at(seconds(0)), "other message"),
Violation(seconds(1) .. seconds(2), defaultMessage)
),
result
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Directives
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue
import java.time.Instant

open class NotImplementedPlan: Plan {
override fun totalBounds(): Interval = TODO()
override fun toRelative(abs: Instant): Duration = TODO()
override fun toAbsolute(rel: Duration): Instant = TODO()
override fun <A : Any> directives(type: String?, deserializer: (SerializedValue) -> A): Directives<A> = TODO()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Instances
import gov.nasa.ammos.aerie.procedural.timeline.ops.coalesce.CoalesceSegmentsOp
import gov.nasa.ammos.aerie.procedural.timeline.payloads.Segment
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue

open class NotImplementedSimulationResults: SimulationResults {
override fun isStale(): Boolean = TODO()
override fun simBounds(): Interval = TODO()
override fun <V : Any, TL : CoalesceSegmentsOp<V, TL>> resource(
name: String,
deserializer: (List<Segment<SerializedValue>>) -> TL
): TL = TODO()
override fun <A : Any> instances(type: String?, deserializer: (SerializedValue) -> A): Instances<A> = TODO()
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
package gov.nasa.ammos.aerie.procedural.examples.fooprocedures.constraints;

import gov.nasa.ammos.aerie.procedural.constraints.GeneratorConstraint;
import gov.nasa.ammos.aerie.procedural.constraints.Violations;
import gov.nasa.ammos.aerie.procedural.constraints.Constraint;
import gov.nasa.ammos.aerie.procedural.timeline.CollectOptions;
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Real;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulatedPlan;
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults;
import org.jetbrains.annotations.NotNull;

public class ConstFruit implements Constraint {
@NotNull
public class ConstFruit extends GeneratorConstraint {
@Override
public Violations run(SimulatedPlan plan, @NotNull CollectOptions options) {
final var fruit = plan.resource("/fruit", Real.deserializer());
public void generate(@NotNull Plan plan, @NotNull SimulationResults simResults) {
final var fruit = simResults.resource("/fruit", Real.deserializer());


return Violations.violateOn(
violate(Violations.on(
fruit.equalTo(4),
false
);
));
}
}

0 comments on commit 9af4b3c

Please sign in to comment.