-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train.test.scala
79 lines (62 loc) · 2.05 KB
/
Train.test.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package genovese
import munit.FunSuite
class TrainTest extends FunSuite:
test("Training"):
case class Breaks(one: Boolean, two: Boolean, width: Float)
given Featureful[Breaks] =
Featureful.derive(
FieldConfig(Map("width" -> Feature.FloatRange(50, 150)))
)
case class FormattingConfig(
format: Boolean,
scala3: Boolean,
lineLength: Float,
breaks: Breaks
)
given Featureful[FormattingConfig] =
Featureful.derive(
FieldConfig(Map("lineLength" -> Feature.FloatRange(50, 140)))
)
given RuntimeChecks = RuntimeChecks.Full
val trainingConfig = TrainingConfig(
populationSize = 100,
mutationRate = NormalisedFloat(0.5f),
steps = 100,
random = scala.util.Random(80085L),
selection = Selection.Top(0.8)
)
val default = FormattingConfig(true, true, 0, Breaks(false, false, 0.0f))
val fitness = Fitness[FormattingConfig]: c =>
var score = 0.0f
val step = 0.1f
if c.format then score += step
if c.scala3 then score += step
if c.lineLength > 110 && c.lineLength < 112 then score += step
if c.breaks.width > 81 && c.breaks.width < 83f then score += step
if c.breaks.one then score += step
if !c.breaks.two then score += step
NormalisedFloat(score)
object Handler extends EventHandler:
import TrainingEvent.*, TrainingInstruction.*
def handle[T](t: TrainingEvent[T], data: T | Null): TrainingInstruction =
t match
case r @ ReportFitness =>
if ReportFitness.cast(data).max >= 0.6f then Halt
else Continue
case _ => Continue
end handle
override val allowed: Set[TrainingEvent[?]] =
all -- Set(
TopSpecimen,
EpochFinished,
EpochStarted
)
end Handler
val top = Train(
summon[Featureful[FormattingConfig]],
config = trainingConfig,
fitness = fitness,
events = Handler,
).train().maxBy(_._2)._1
assertEquals(fitness(top), 0.6f)
end TrainTest