Skip to content

Commit

Permalink
Add several tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelCourtney committed Feb 16, 2024
1 parent de759e1 commit de7e9fe
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ data class Interval(
fun compareEnds(other: Interval): Int {
val timeComparison: Int = end.compareTo(other.end)
return if (timeComparison != 0) timeComparison
else if (startInclusivity == other.startInclusivity) 0
else if (startInclusivity == Inclusivity.Inclusive) 1
else if (endInclusivity == other.endInclusivity) 0
else if (endInclusivity == Inclusivity.Inclusive) 1
else -1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ interface GeneralOps<V: IntervalLike<V>, THIS: GeneralOps<V, THIS>>: Timeline<V,
unsafeOperate(ctor) { opts ->
val mapped = collect(opts.transformBounds(boundsTransformer)).flatMap {
val nested = f(it)
nested.value.collect(CollectOptions(nested.interval))
nested.value.collect(nested.interval)
}
if (truncate) truncateList(mapped, opts)
else mapped
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ interface SerialConstantOps<V: Any, THIS: SerialConstantOps<V, THIS>>: SerialOps

/** [(DOC)][notEqualTo] Returns a [Windows] that is `true` when this and another profile are not equal. */
fun <OTHER: SerialConstantOps<V, OTHER>> notEqualTo(other: OTHER) =
map2Values(::Windows, other, BinaryOperation.combineOrNull { l, r, _ -> l == r })
map2Values(::Windows, other, BinaryOperation.combineOrNull { l, r, _ -> l != r })

override fun changes() = detectEdges(BinaryOperation.combineOrNull { l, r, _-> l != r })

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package gov.nasa.jpl.aerie.timeline.ops

import gov.nasa.jpl.aerie.timeline.*
import gov.nasa.jpl.aerie.timeline.Interval.Companion.at
import gov.nasa.jpl.aerie.timeline.collections.profiles.Windows
import gov.nasa.jpl.aerie.timeline.util.coalesceList
import gov.nasa.jpl.aerie.timeline.util.map2Serial
import gov.nasa.jpl.aerie.timeline.util.truncateList

Expand Down Expand Up @@ -83,7 +85,10 @@ interface SerialOps<V : Any, THIS: SerialOps<V, THIS>>: SegmentOps<V, THIS> {
* @return a coalesced flattened profile; an instance of the return type of [ctor]
*/
fun <W: Any, OTHER: SerialOps<W, OTHER>, R: Any, NESTED: SerialOps<R, NESTED>, RESULT: GeneralOps<Segment<R>, RESULT>> flatMap2Values(ctor: (Timeline<Segment<R>, RESULT>) -> RESULT, other: SerialOps<W, OTHER>, op: BinaryOperation<V, W, NESTED?>) =
unsafeOperate(ctor) { opts -> map2Serial(collect(opts), other.collect(opts), op).flatMap { it.value.collect(CollectOptions(it.interval, true)) } }
unsafeOperate(ctor) { opts ->
map2Serial(collect(opts), other.collect(opts), op)
.flatMap { it.value.collect(CollectOptions(it.interval, true)) }
}

/**
* [(DOC)][detectEdges] Uses a [BinaryOperation] as a predicate to highlight edges between segments.
Expand All @@ -105,34 +110,19 @@ interface SerialOps<V : Any, THIS: SerialOps<V, THIS>>: SegmentOps<V, THIS> {
var buffer: Segment<V>? = null
val result = collect(CollectOptions(bounds, false))
.flatMap { currentSegment ->
val leftEdge: Boolean?
val rightEdge: Boolean?

val previous = buffer
buffer = currentSegment
val currentInterval = currentSegment.interval

val leftEdgeInterval = Interval.at(currentInterval.start)
val rightEdgeInterval = Interval.at(currentInterval.end)
val leftEdgeInterval = at(currentInterval.start)
val rightEdgeInterval = at(currentInterval.end)

rightEdge = if (currentInterval.end.isEqualTo(bounds.end) && currentInterval.endInclusivity == bounds.endInclusivity) {
if (bounds.includesEnd()) false else null
} else {
edgePredicate.invoke(currentSegment.value, null, rightEdgeInterval)
}
val rightEdge = edgePredicate(currentSegment.value, null, rightEdgeInterval)

leftEdge = if (previous != null) {
if (previous.interval.compareEndToStart(currentInterval) == 0) {
edgePredicate.invoke(previous.value, currentSegment.value, leftEdgeInterval)
} else {
edgePredicate.invoke(null, currentSegment.value, leftEdgeInterval)
}
val leftEdge = if (previous == null || previous.interval.compareEndToStart(currentInterval) == -1) {
edgePredicate(null, currentSegment.value, leftEdgeInterval)
} else {
if (currentInterval.start.isEqualTo(bounds.start) && currentInterval.startInclusivity == bounds.startInclusivity) {
if (bounds.includesStart()) false else null
} else {
edgePredicate.invoke(null, currentSegment.value, leftEdgeInterval)
}
edgePredicate(previous.value, currentSegment.value, leftEdgeInterval)
}

listOfNotNull(
Expand All @@ -144,7 +134,7 @@ interface SerialOps<V : Any, THIS: SerialOps<V, THIS>>: SegmentOps<V, THIS> {
Segment(rightEdgeInterval, rightEdge).transpose()
)
}
truncateList(result, opts)
truncateList(coalesceList(result, Segment<Boolean>::valueEquals), opts)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ data class LinearEquation(
}

/***/
override fun equals(other: Any?): Boolean {
return if (other !is LinearEquation) false
else initialValue == other.valueAt(initialTime) && rate == other.rate
}
override fun equals(other: Any?) =
if (other !is LinearEquation) false
else initialValue == other.valueAt(initialTime) && rate == other.rate

/***/
override fun hashCode(): Int {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package gov.nasa.jpl.aerie.timeline

import gov.nasa.jpl.aerie.merlin.protocol.types.Duration.seconds
import gov.nasa.jpl.aerie.timeline.Interval.Companion.between
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test

class BoundsTransformerTest {
@Test
fun shift() {
val transformer = BoundsTransformer.shift(seconds(1))

assertEquals(
between(seconds(-1), seconds(1)),
transformer(between(seconds(0), seconds(2)))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class IntervalTest {
assertEquals(1, between(seconds(0), seconds(1)).compareEnds(at(seconds(0))))
assertEquals(-1, between(seconds(0), seconds(1), Exclusive).compareEnds(at(seconds(1))))
assertEquals(0, between(seconds(0), seconds(1)).compareEnds(at(seconds(1))))
assertEquals(-1, betweenClosedOpen(seconds(0), seconds(1)).compareEnds(at(seconds(1))))
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package gov.nasa.jpl.aerie.timeline.collections.profiles

import gov.nasa.jpl.aerie.merlin.protocol.types.Duration
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration.seconds
import gov.nasa.jpl.aerie.timeline.Interval
import gov.nasa.jpl.aerie.timeline.Interval.Companion.between
import gov.nasa.jpl.aerie.timeline.Segment
import org.junit.jupiter.api.Test

import org.junit.jupiter.api.Assertions.*

class NumbersTest {

@Test
fun plus() {
val four = Numbers(4)
val five = Numbers(Segment(between(Duration.ZERO, seconds(1)), 5)).assignGaps(Numbers(0))

assertIterableEquals(
listOf(
Segment(between(Duration.ZERO, seconds(1)), 9),
Segment(between(seconds(1), seconds(2), Interval.Inclusivity.Exclusive, Interval.Inclusivity.Inclusive), 4)
),
four.plus(five).collect(between(Duration.ZERO, seconds(2)))
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package gov.nasa.jpl.aerie.timeline.ops

import gov.nasa.jpl.aerie.merlin.protocol.types.Duration
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration.milliseconds
import gov.nasa.jpl.aerie.merlin.protocol.types.Duration.seconds
import gov.nasa.jpl.aerie.timeline.*
import gov.nasa.jpl.aerie.timeline.Interval.Companion.at
import gov.nasa.jpl.aerie.timeline.Interval.Companion.between
import gov.nasa.jpl.aerie.timeline.Interval.Companion.betweenClosedOpen
import gov.nasa.jpl.aerie.timeline.collections.Intervals
import gov.nasa.jpl.aerie.timeline.collections.profiles.Discrete
import gov.nasa.jpl.aerie.timeline.collections.profiles.Numbers
import gov.nasa.jpl.aerie.timeline.collections.profiles.Windows
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Test

class GeneralOpsTest {

@Test
fun operate() {
val result = Discrete(Segment(between(seconds(0), seconds(1)), "hello")).unsafeOperate {
collect(it).map { s -> Segment(s.interval, s.value + " world")}
}.collect()

val expected = listOf(Segment(between(seconds(0), seconds(1)), "hello world"))

assertIterableEquals(expected, result)
}

@Test
fun operateAutoCoalesce() {
val result = Discrete(
Segment(between(seconds(0), seconds(1)), "hello world"),
Segment(between(seconds(1), seconds(2)), "hello there")
).unsafeOperate {
collect(it).map { s -> Segment(s.interval, s.value.substring(0..4))}
}.collect()

val expected = listOf(Segment(between(seconds(0), seconds(2)), "hello"))

assertIterableEquals(expected, result)
}

@Test
fun operateType() {
// this is just a test to make sure the return type of unsafeOperate is correct.
// we would get a compile error if it failed.
@Suppress("UNUSED_VARIABLE")
val result: Windows = Windows(true).unsafeOperate { collect(it) }
}

@Test
fun inspect() {
var count: Int? = null
val tl = Intervals(
at(seconds(1)),
at(seconds(2))
).inspect {
count = it.size
}

assertNull(count)

tl.collect()

assertEquals(2, count)
}

@Test
fun unset() {
val result = Intervals(
between(seconds(0), seconds(2)),
between(seconds(2), seconds(3)),
between(seconds(3), seconds(5)),
between(seconds(10), seconds(11))
).unset(between(seconds(1), seconds(4))).collect()

assertIterableEquals(
listOf(
betweenClosedOpen(seconds(0), seconds(1)),
between(seconds(4), seconds(5), Interval.Inclusivity.Exclusive, Interval.Inclusivity.Inclusive),
between(seconds(10), seconds(11))
),
result
)
}

@Test
fun filter() {
val result = Numbers(
Segment(at(seconds(1)), 4),
Segment(at(seconds(2)), 5)
).filter { it.value.toInt() % 2 == 0 }.collect()

assertIterableEquals(
listOf(Segment(at(seconds(1)), 4)),
result
)
}

@Test
fun filterPreserveMargin() {
val intervals = Intervals(
between(seconds(-1), seconds(1)),
between(seconds(1), seconds(4)),
between(seconds(4), seconds(8)),
)

// without preserve margin
assertIterableEquals(
listOf(between(seconds(1), seconds(4))),
intervals.filter(false) { it.duration().noShorterThan(Duration.of(2, Duration.SECOND)) }
.collect(between(seconds(0), seconds(5)))
)

// with preserve margin and truncate margin
// notice that the marginal intervals are retained and then later truncated to within the bounds
assertIterableEquals(
listOf(
between(seconds(0), seconds(1)),
between(seconds(1), seconds(4)),
between(seconds(4), seconds(5)),
),
intervals.filter(true) { it.duration().noShorterThan(Duration.of(2, Duration.SECOND)) }
.collect(between(seconds(0), seconds(5)))
)

// with preserve margin, without truncate margin
// notice that the marginal intervals are retained and NOT truncated later
assertIterableEquals(
intervals.collect(),
intervals.filter(true) { it.duration().noShorterThan(Duration.of(2, Duration.SECOND)) }
.collect(CollectOptions(between(seconds(0), seconds(5)), false))
)
}

@Test
fun map() {
val result = Intervals(
at(seconds(1)),
between(seconds(2), seconds(3))
).unsafeMap(::Windows, BoundsTransformer.IDENTITY, false) { Segment(it.interval, it.interval.isPoint()) }
.collect()

assertIterableEquals(
listOf(
Segment(at(seconds(1)), true),
Segment(between(seconds(2), seconds(3)), false),
),
result
)
}

@Test
fun shiftBoundsTransform() {
val intervals = Intervals(
between(seconds(-1), seconds(0)),
between(seconds(2), seconds(4)),
).shift(Duration.SECOND)

val expected = listOf(
between(seconds(0), seconds(1)),
between(seconds(3), seconds(5))
)

assertIterableEquals(
expected,
intervals.collect()
)

assertIterableEquals(
expected,
intervals.collect(between(seconds(0), seconds(5)))
)
}

@Test
fun shiftOutOfBounds() {
val intervals = Intervals(at(seconds(3))).shift(seconds(3))

assertIterableEquals(
listOf<Interval>(),
intervals.collect(between(seconds(0), seconds(5)))
)
}

@Test
fun flatMapTest() {
val result = Intervals(
between(seconds(2), seconds(8)),
between(seconds(0), seconds(3))
)
// converts each interval to a windows object.
// false for the first half of the interval, true for the second half
.unsafeFlatMap(::Windows, BoundsTransformer.IDENTITY, false) {
val midpoint = it.interval.start.plus(it.interval.end).dividedBy(2)
Segment(
it.interval.interval,
Windows(false).set(Windows(Segment(between(midpoint, Duration.MAX_VALUE), true)))
)
}
.collect()

val expected = listOf(
Segment(betweenClosedOpen(seconds(0), milliseconds(1500)), false),
Segment(betweenClosedOpen(milliseconds(1500), seconds(2)), true),
Segment(betweenClosedOpen(seconds(2), seconds(5)), false),
Segment(between(seconds(5), seconds(8)), true)
)

assertIterableEquals(expected, result)
}
}
Loading

0 comments on commit de7e9fe

Please sign in to comment.