Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generator Constraints and Violation Messages #1559

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.microsoft.playwright.Playwright;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults;
import gov.nasa.jpl.aerie.e2e.utils.GatewayRequests;
import gov.nasa.jpl.aerie.e2e.utils.HasuraRequests;
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration;
Expand All @@ -12,7 +14,6 @@
import gov.nasa.ammos.aerie.procedural.timeline.payloads.Segment;
import gov.nasa.ammos.aerie.procedural.remote.AeriePostgresPlan;
import gov.nasa.ammos.aerie.procedural.remote.AeriePostgresSimulationResults;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulatedPlan;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -40,7 +41,8 @@ public class TimelineRemoteTests {
private int planId;
private int activityId;

private SimulatedPlan simulatedPlan;
private Plan plan;
private SimulationResults simResults;
private Connection connection;
private HikariDataSource dataSource;
@BeforeAll
Expand Down Expand Up @@ -101,9 +103,8 @@ void beforeEach() throws IOException, InterruptedException {

// Connect to the database

final var plan = new AeriePostgresPlan(connection, planId);
final var simResults = new AeriePostgresSimulationResults(connection, simDatasetId, plan, false);
simulatedPlan = new SimulatedPlan(plan, simResults);
plan = new AeriePostgresPlan(connection, planId);
simResults = new AeriePostgresSimulationResults(connection, simDatasetId, plan, false);
}

@AfterEach
Expand All @@ -114,7 +115,7 @@ void afterEach() throws IOException {

@Test
void queryActivityInstances() {
final var instances = simulatedPlan.instances().collect();
final var instances = simResults.instances().collect();
assertEquals(1, instances.size());
final var instance = instances.get(0);
assertEquals("BiteBanana", instance.getType());
Expand All @@ -126,7 +127,7 @@ void queryActivityInstances() {

@Test
void queryActivityDirectives() {
final var directives = simulatedPlan.directives().collect();
final var directives = plan.directives().collect();
assertEquals(1, directives.size());
final var directive = directives.get(0);
assertEquals("BiteBanana", directive.getType());
Expand All @@ -136,7 +137,7 @@ void queryActivityDirectives() {

@Test
void queryResources() {
final var fruit = simulatedPlan.resource("/fruit", Real.deserializer()).collect();
final var fruit = simResults.resource("/fruit", Real.deserializer()).collect();
assertIterableEquals(
List.of(
Segment.of(Interval.betweenClosedOpen(Duration.ZERO, Duration.HOUR), new LinearEquation(4.0)),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.CollectOptions
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulatedPlan
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults

/** The interface that all constraints must satisfy. */
interface Constraint {
Expand All @@ -12,7 +12,9 @@ interface Constraint {
* the constraint is run. The constraint does not need to use the options unless it collects a timeline prematurely.
*
* @param plan the plan to check the constraint on
* @param options the [CollectOptions] that the result will be collected with
* @param simResults the [SimulationResults] that the result will be collected with
*/
fun run(plan: SimulatedPlan, options: CollectOptions): Violations
fun run(plan: Plan, simResults: SimulationResults): Violations


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Windows
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Real
import gov.nasa.ammos.aerie.procedural.timeline.ops.GeneralOps
import gov.nasa.ammos.aerie.procedural.timeline.ops.ParallelOps
import gov.nasa.ammos.aerie.procedural.timeline.ops.SerialConstantOps
import gov.nasa.ammos.aerie.procedural.timeline.payloads.IntervalLike
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults

/**
* A generator-style implementation of [Constraint].
*
* The subclass must implement [generate], and within it call [violate] to produce violations.
* Or if you are using Kotlin, you can use the timeline extension functions such as [windows.violateInside()][violateInside]
* to more easily submit violations.
*/
abstract class GeneratorConstraint: Constraint {
private var violations = mutableListOf<Violation>()

/** Finalizes one or more intervals as violations. */
@JvmOverloads protected fun violate(vararg i: Interval, message: String? = null) {
violate(i.map { Violation(it) }, message)
}

/** Finalizes one or more violations. */
@JvmOverloads protected fun violate(vararg v: Violation, message: String? = null) {
violate(v.toList(), message)
}

/** Finalizes a list of violations. */
@JvmOverloads protected fun violate(l: List<Violation>, message: String? = null) {
violations.addAll(l.map {
if (it.message == null) Violation(
it.interval,
message ?: defaultMessage(),
it.ids
) else it
})
}

/** Collects a [Violations] timeline and finalizes the result. */
@JvmOverloads protected fun violate(tl: Violations, message: String? = null) {
violate(tl.collect(), message)
}

/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmOverloads protected fun <V: Any> SerialConstantOps<V, *>.violateOn(v: V, message: String? = null) = violate(Violations.on(this, v), message)

/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmOverloads protected fun Real.violateOn(n: Number, message: String? = null) = violate(Violations.on(this, n), message)

/**
* Creates a [Violations] object that violates on every object in the timeline.
*
* If the object is an activity, it will record the directive or instance id.
*/
@JvmOverloads protected fun <I: IntervalLike<I>> ParallelOps<I, *>.violateOnAll(message: String? = null) {
violate(Violations.onAll(this), message)
}

/** Creates a [Violations] object that violates inside each interval. */
@JvmOverloads protected fun Windows.violateInside(message: String? = null) = violate(Violations.inside(this), message)
/** Creates a [Violations] object that violates outside each interval. */
@JvmOverloads protected fun Windows.violateOutside(message: String? = null) = violate(Violations.outside(this), message)

/**
* Creates a [Violations] object from two timelines, that violates whenever they have overlap.
*
* If either object is an activity, it will record the directive or instance id.
*/
@JvmOverloads protected fun <V: IntervalLike<V>, W: IntervalLike<W>> GeneralOps<V, *>.violateWhenSimultaneous(other: GeneralOps<W, *>, message: String? = null) {
violate(Violations.whenSimultaneous(this, other), message)
}

/**
* A generator function that calls [violate] to produce violations.
*/
abstract fun generate(plan: Plan, simResults: SimulationResults)

/**
* Default violation message to be displayed to user.
*
* Can be overridden on a violation-by-violation basis by manually specifying
* it in the [Violation] object.
*/
open fun defaultMessage(): String? = null

final override fun run(plan: Plan, simResults: SimulationResults): Violations {
violations = mutableListOf()
generate(plan, simResults)
val message = defaultMessage()
return if (message == null) Violations(violations)
else Violations(violations).withDefaultMessage(message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ import gov.nasa.ammos.aerie.procedural.timeline.payloads.IntervalLike
import gov.nasa.jpl.aerie.types.ActivityId

/** A single violation of a constraint. */
data class Violation(
data class Violation @JvmOverloads constructor(
/** Interval on which the violation occurs. */
override val interval: Interval,

/** Violation message to be displayed to user. */
val message: String? = null,

/** List of associated activities (directives or instances) that are related to the violation. */
val ids: List<ActivityId> = listOf()
) : IntervalLike<Violation> {

override fun withNewInterval(i: Interval) = Violation(i, ids)
override fun withNewInterval(i: Interval) = Violation(i, message, ids)

/** Constructs a violation on the same interval with a different list of ids. */
fun withNewIds(vararg id: ActivityId) = Violation(interval, id.asList())
fun withNewIds(vararg id: ActivityId) = Violation(interval, message, id.asList())

/** Constructs a violation on the same interval with a different list of ids. */
fun withNewIds(ids: List<ActivityId>) = Violation(interval, ids)
fun withNewIds(ids: List<ActivityId>) = Violation(interval, message, ids)
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,51 @@ data class Violations(private val timeline: Timeline<Violation, Violations>):
*/
fun mapIds(f: (Violation) -> List<ActivityId>) = unsafeMap(BoundsTransformer.IDENTITY, false) { it.withNewIds(f(it)) }

/**
* Sets a default violation message for violations that don't already have one.
*
* @param message the default message to give to the user
*/
fun withDefaultMessage(message: String) = unsafeMap(BoundsTransformer.IDENTITY, false) {
if (it.message == null) Violation(it.interval, message, it.ids)
else it
}

/***/ companion object {
/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmStatic fun <V: Any> SerialConstantOps<V, *>.violateOn(v: V) = isolateEqualTo(v).violations()
/** Creates a [Violations] object that violates when the profile equals a given value. */
@JvmStatic fun <V: Any> on(tl: SerialConstantOps<V, *>, v: V) = onAll(tl.isolateEqualTo(v))

/** Creates a [Violations] object that violates when this profile equals a given value. */
@JvmStatic fun Real.violateOn(n: Number) = equalTo(n).violateOn(true)
@JvmStatic fun on(tl: Real, n: Number) = on(tl.equalTo(n), true)

/**
* Creates a [Violations] object that violates on every object in the timeline.
*
* If the object is an activity, it will record the directive or instance id.
*/
@JvmStatic fun <I: IntervalLike<I>> ParallelOps<I, *>.violations() =
unsafeMap(::Violations, BoundsTransformer.IDENTITY, false) {
@JvmStatic fun <I: IntervalLike<I>> onAll(tl: ParallelOps<I, *>) =
tl.unsafeMap(::Violations, BoundsTransformer.IDENTITY, false) {
Violation(
it.interval,
null,
listOfNotNull(it.getActivityId())
)
}

/** Creates a [Violations] object that violates inside each interval. */
@JvmStatic fun Windows.violateInside() = unsafeCast(::Universal).violations()
@JvmStatic fun inside(tl: Windows) = onAll(tl.unsafeCast(::Universal))
/** Creates a [Violations] object that violates outside each interval. */
@JvmStatic fun Windows.violateOutside() = complement().violateInside()
@JvmStatic fun outside(tl: Windows) = inside(tl.complement())

/**
* Creates a [Violations] object from two timelines, that violates whenever they have overlap.
*
* If either object is an activity, it will record the directive or instance id.
*/
@JvmStatic infix fun <V: IntervalLike<V>, W: IntervalLike<W>> GeneralOps<V, *>.mutex(other: GeneralOps<W, *>) =
unsafeMap2(::Violations, other) { l, r, i -> Violation(
@JvmStatic fun <V: IntervalLike<V>, W: IntervalLike<W>> whenSimultaneous(left: GeneralOps<V, *>, right: GeneralOps<W, *>) =
left.unsafeMap2(::Violations, right) { l, r, i -> Violation(
i,
null,
listOfNotNull(
l.getActivityId(),
r.getActivityId()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Numbers
import gov.nasa.ammos.aerie.procedural.timeline.ops.coalesce.CoalesceSegmentsOp
import gov.nasa.ammos.aerie.procedural.timeline.payloads.Segment
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults
import gov.nasa.ammos.aerie.procedural.timeline.util.duration.rangeTo
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration.seconds
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue
import org.junit.jupiter.api.Assertions.assertIterableEquals
import kotlin.test.Test

class GeneratorTest: GeneratorConstraint() {
override fun generate(plan: Plan, simResults: SimulationResults) {
violate(Interval.at(seconds(0)), message = "other message")
simResults.resource("/plant", Numbers.deserializer())
.greaterThan(0)
.violateOn(false)
}

override fun defaultMessage() = "Plant must be greater than 0"

@Test
fun testGenerator() {
val plan = NotImplementedPlan()
val simResults = object : NotImplementedSimulationResults() {
override fun <V : Any, TL : CoalesceSegmentsOp<V, TL>> resource(
name: String,
deserializer: (List<Segment<SerializedValue>>) -> TL
): TL {
if (name == "/plant") {
val list = listOf(
Segment(seconds(-4) .. seconds(-2), SerializedValue.of(-3)),
Segment(seconds(0) .. seconds(1), SerializedValue.of(3)),
Segment(seconds(1) .. seconds(2), SerializedValue.of(-1)),
)
return deserializer(list)
} else {
TODO("Not yet implemented")
}
}
}

val result = run(plan, simResults).collect()

val defaultMessage = "Plant must be greater than 0";
assertIterableEquals(
listOf(
Violation(seconds(-4) .. seconds(-2), defaultMessage),
Violation(Interval.at(seconds(0)), "other message"),
Violation(seconds(1) .. seconds(2), defaultMessage)
),
result
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Directives
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue
import java.time.Instant

open class NotImplementedPlan: Plan {
override fun totalBounds(): Interval = TODO()
override fun toRelative(abs: Instant): Duration = TODO()
override fun toAbsolute(rel: Duration): Instant = TODO()
override fun <A : Any> directives(type: String?, deserializer: (SerializedValue) -> A): Directives<A> = TODO()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Instances
import gov.nasa.ammos.aerie.procedural.timeline.ops.coalesce.CoalesceSegmentsOp
import gov.nasa.ammos.aerie.procedural.timeline.payloads.Segment
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue

open class NotImplementedSimulationResults: SimulationResults {
override fun isStale(): Boolean = TODO()
override fun simBounds(): Interval = TODO()
override fun <V : Any, TL : CoalesceSegmentsOp<V, TL>> resource(
name: String,
deserializer: (List<Segment<SerializedValue>>) -> TL
): TL = TODO()
override fun <A : Any> instances(type: String?, deserializer: (SerializedValue) -> A): Instances<A> = TODO()
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
package gov.nasa.ammos.aerie.procedural.examples.fooprocedures.constraints;

import gov.nasa.ammos.aerie.procedural.constraints.GeneratorConstraint;
import gov.nasa.ammos.aerie.procedural.constraints.Violations;
import gov.nasa.ammos.aerie.procedural.constraints.Constraint;
import gov.nasa.ammos.aerie.procedural.timeline.CollectOptions;
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Real;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulatedPlan;
import gov.nasa.ammos.aerie.procedural.timeline.plan.Plan;
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults;
import org.jetbrains.annotations.NotNull;

public class ConstFruit implements Constraint {
@NotNull
public class ConstFruit extends GeneratorConstraint {
@Override
public Violations run(SimulatedPlan plan, @NotNull CollectOptions options) {
final var fruit = plan.resource("/fruit", Real.deserializer());
public void generate(@NotNull Plan plan, @NotNull SimulationResults simResults) {
final var fruit = simResults.resource("/fruit", Real.deserializer());


return Violations.violateOn(
violate(Violations.on(
fruit.equalTo(4),
false
);
));
}
}
Loading