Skip to content

Commit

Permalink
Fix violation message passing
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelCourtney committed Sep 13, 2024
1 parent bfdb0f5 commit f02a8dd
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,5 @@ interface Constraint {
*/
fun run(plan: Plan, simResults: SimulationResults): Violations

/**
* Default violation message to be displayed to user.
*
* Can be overridden on a violation-by-violation basis by manually specifying
* it in the [Violation] object.
*/
fun message(): String? = null

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Windows
import gov.nasa.ammos.aerie.procedural.timeline.collections.profiles.Real
import gov.nasa.ammos.aerie.procedural.timeline.ops.GeneralOps
Expand All @@ -19,25 +20,30 @@ import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults
abstract class GeneratorConstraint: Constraint {
private var violations = mutableListOf<Violation>()

/** Finalizes one or more intervals as violations. */
@JvmOverloads protected fun violate(vararg i: Interval, message: String? = null) {
violate(i.map { Violation(it) }, message)
}

/** Finalizes one or more violations. */
@JvmOverloads protected fun violate(vararg v: Violation, message: String? = null) {
violate(v.toList())
violate(v.toList(), message)
}

/** Finalizes a list of violations. */
@JvmOverloads protected fun violate(l: List<Violation>, message: String? = null) {
violations.addAll(l.map {
if (it.message == null) Violation(
it.interval,
message,
message ?: defaultMessage(),
it.ids
) else it
})
}

/** Collects a [Violations] timeline and finalizes the result. */
@JvmOverloads protected fun violate(tl: Violations, message: String? = null) {
violate(tl.collect())
violate(tl.collect(), message)
}

/** Creates a [Violations] object that violates when this profile equals a given value. */
Expand Down Expand Up @@ -74,9 +80,19 @@ abstract class GeneratorConstraint: Constraint {
*/
abstract fun generate(plan: Plan, simResults: SimulationResults)

/**
* Default violation message to be displayed to user.
*
* Can be overridden on a violation-by-violation basis by manually specifying
* it in the [Violation] object.
*/
open fun defaultMessage(): String? = null

final override fun run(plan: Plan, simResults: SimulationResults): Violations {
violations = mutableListOf()
generate(plan, simResults)
return Violations(violations)
val message = defaultMessage()
return if (message == null) Violations(violations)
else Violations(violations).withDefaultMessage(message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ data class Violations(private val timeline: Timeline<Violation, Violations>):
*/
fun mapIds(f: (Violation) -> List<ActivityId>) = unsafeMap(BoundsTransformer.IDENTITY, false) { it.withNewIds(f(it)) }

/**
* Sets a default violation message for violations that don't already have one.
*
* @param message the default message to give to the user
*/
fun withDefaultMessage(message: String) = unsafeMap(BoundsTransformer.IDENTITY, false) {
if (it.message == null) Violation(it.interval, message, it.ids)
else it
}

/***/ companion object {
/** Creates a [Violations] object that violates when the profile equals a given value. */
@JvmStatic fun <V: Any> on(tl: SerialConstantOps<V, *>, v: V) = onAll(tl.isolateEqualTo(v))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,18 @@ import kotlin.test.Test

class GeneratorTest: GeneratorConstraint() {
override fun generate(plan: Plan, simResults: SimulationResults) {
violate(Violation(Interval.at(seconds(0))))
violate(Interval.at(seconds(0)), message = "other message")
simResults.resource("/plant", Numbers.deserializer())
.greaterThan(0)
.violateOn(false)
}

override fun message() = "Plant must be greater than 0"
override fun defaultMessage() = "Plant must be greater than 0"

@Test
fun testGenerator() {
val plan = NotImplementedPlan()
val simResults = object : SimulationResults {
override fun isStale() = TODO()

override fun simBounds() = TODO()

val simResults = object : NotImplementedSimulationResults() {
override fun <V : Any, TL : CoalesceSegmentsOp<V, TL>> resource(
name: String,
deserializer: (List<Segment<SerializedValue>>) -> TL
Expand All @@ -45,17 +41,16 @@ class GeneratorTest: GeneratorConstraint() {
TODO("Not yet implemented")
}
}

override fun <A : Any> instances(type: String?, deserializer: (SerializedValue) -> A) = TODO()
}

val result = run(plan, simResults).collect()

val defaultMessage = "Plant must be greater than 0";
assertIterableEquals(
listOf(
Violation(seconds(-4) .. seconds(-2)),
Violation(Interval.at(seconds(0))),
Violation(seconds(1) .. seconds(2))
Violation(seconds(-4) .. seconds(-2), defaultMessage),
Violation(Interval.at(seconds(0)), "other message"),
Violation(seconds(1) .. seconds(2), defaultMessage)
),
result
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,9 @@ import gov.nasa.jpl.aerie.merlin.protocol.types.Duration
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue
import java.time.Instant

class NotImplementedPlan: Plan {
override fun totalBounds(): Interval {
TODO("Not yet implemented")
}

override fun toRelative(abs: Instant): Duration {
TODO("Not yet implemented")
}

override fun toAbsolute(rel: Duration): Instant {
TODO("Not yet implemented")
}

override fun <A : Any> directives(type: String?, deserializer: (SerializedValue) -> A): Directives<A> {
TODO("Not yet implemented")
}
open class NotImplementedPlan: Plan {
override fun totalBounds(): Interval = TODO()
override fun toRelative(abs: Instant): Duration = TODO()
override fun toAbsolute(rel: Duration): Instant = TODO()
override fun <A : Any> directives(type: String?, deserializer: (SerializedValue) -> A): Directives<A> = TODO()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package gov.nasa.ammos.aerie.procedural.constraints

import gov.nasa.ammos.aerie.procedural.timeline.Interval
import gov.nasa.ammos.aerie.procedural.timeline.collections.Instances
import gov.nasa.ammos.aerie.procedural.timeline.ops.coalesce.CoalesceSegmentsOp
import gov.nasa.ammos.aerie.procedural.timeline.payloads.Segment
import gov.nasa.ammos.aerie.procedural.timeline.plan.SimulationResults
import gov.nasa.jpl.aerie.merlin.protocol.types.SerializedValue

open class NotImplementedSimulationResults: SimulationResults {
override fun isStale(): Boolean = TODO()
override fun simBounds(): Interval = TODO()
override fun <V : Any, TL : CoalesceSegmentsOp<V, TL>> resource(
name: String,
deserializer: (List<Segment<SerializedValue>>) -> TL
): TL = TODO()
override fun <A : Any> instances(type: String?, deserializer: (SerializedValue) -> A): Instances<A> = TODO()
}

0 comments on commit f02a8dd

Please sign in to comment.