From 0b1015d9cf1d126b75fe23f507c4b2678034a790 Mon Sep 17 00:00:00 2001 From: "Randall C. O'Reilly" Date: Tue, 5 Nov 2024 00:39:54 -0800 Subject: [PATCH] leabra updated to new looper and looper gui. --- examples/deep_fsa/deep_fsa.go | 78 +++++++++++++++++------------------ examples/hip/hip.go | 76 +++++++++++++++------------------- examples/ra25/ra25.go | 78 +++++++++++++++++------------------ examples/sir2/sir2.go | 76 +++++++++++++++------------------- go.mod | 2 +- go.sum | 4 +- leabra/hip.go | 16 ++----- leabra/looper.go | 56 +++++++++++-------------- 8 files changed, 174 insertions(+), 212 deletions(-) diff --git a/examples/deep_fsa/deep_fsa.go b/examples/deep_fsa/deep_fsa.go index 08d86e8..a0c1377 100644 --- a/examples/deep_fsa/deep_fsa.go +++ b/examples/deep_fsa/deep_fsa.go @@ -15,6 +15,7 @@ import ( "cogentcore.org/core/base/mpi" "cogentcore.org/core/base/randx" "cogentcore.org/core/core" + "cogentcore.org/core/enums" "cogentcore.org/core/icons" "cogentcore.org/core/math32/vecint" "cogentcore.org/core/tensor/table" @@ -242,7 +243,7 @@ type Sim struct { Params emer.NetParams `display:"add-fields"` // contains looper control loops for running sim - Loops *looper.Manager `new-window:"+" display:"no-inline"` + Loops *looper.Stacks `new-window:"+" display:"no-inline"` // contains computed statistic values Stats estats.Stats `new-window:"+"` @@ -405,35 +406,37 @@ func (ss *Sim) InitRandSeed(run int) { // ConfigLoops configures the control loops: Training, Testing func (ss *Sim) ConfigLoops() { - man := looper.NewManager() + ls := looper.NewStacks() trls := ss.Config.Run.NTrials - man.AddStack(etime.Train). + ls.AddStack(etime.Train). AddTime(etime.Run, ss.Config.Run.NRuns). AddTime(etime.Epoch, ss.Config.Run.NEpochs). AddTime(etime.Trial, trls). AddTime(etime.Cycle, 100) - man.AddStack(etime.Test). + ls.AddStack(etime.Test). AddTime(etime.Epoch, 1). AddTime(etime.Trial, trls). AddTime(etime.Cycle, 100) - leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing - leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code + leabra.LooperStdPhases(ls, &ss.Context, ss.Net, 75, 99) // plus phase timing + leabra.LooperSimCycleAndLearn(ls, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code - for m, _ := range man.Stacks { - stack := man.Stacks[m] + ls.Stacks[etime.Train].OnInit.Add("Init", func() { ss.Init() }) + + for m, _ := range ls.Stacks { + stack := ls.Stacks[m] stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() { ss.ApplyInputs() }) } - man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) + ls.Loop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) // Train stop early condition - man.GetLoop(etime.Train, etime.Epoch).IsDone["NZeroStop"] = func() bool { + ls.Loop(etime.Train, etime.Epoch).IsDone.AddBool("NZeroStop", func() bool { // This is calculated in TrialStats stopNz := ss.Config.Run.NZero if stopNz <= 0 { @@ -442,10 +445,10 @@ func (ss *Sim) ConfigLoops() { curNZero := ss.Stats.Int("NZero") stop := curNZero >= stopNz return stop - } + }) // Add Testing - trainEpoch := man.GetLoop(etime.Train, etime.Epoch) + trainEpoch := ls.Loop(etime.Train, etime.Epoch) trainEpoch.OnStart.Add("TestAtInterval", func() { if (ss.Config.Run.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.Run.TestInterval == 0) { // Note the +1 so that it doesn't occur at the 0th timestep. @@ -456,33 +459,35 @@ func (ss *Sim) ConfigLoops() { ///////////////////////////////////////////// // Logging - man.GetLoop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { + ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { leabra.LogTestErrors(&ss.Logs) }) - man.GetLoop(etime.Train, etime.Epoch).OnEnd.Add("PCAStats", func() { - trnEpc := man.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur + ls.Loop(etime.Train, etime.Epoch).OnEnd.Add("PCAStats", func() { + trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur if ss.Config.Run.PCAInterval > 0 && trnEpc%ss.Config.Run.PCAInterval == 0 { leabra.PCAStats(ss.Net, &ss.Logs, &ss.Stats) ss.Logs.ResetLog(etime.Analyze, etime.Trial) } }) - man.AddOnEndToAll("Log", ss.Log) - leabra.LooperResetLogBelow(man, &ss.Logs) + ls.AddOnEndToAll("Log", func(mode, time enums.Enum) { + ss.Log(mode.(etime.Modes), time.(etime.Times)) + }) + leabra.LooperResetLogBelow(ls, &ss.Logs) - man.GetLoop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() { - trnEpc := man.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur + ls.Loop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() { + trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur if (ss.Config.Run.PCAInterval > 0) && (trnEpc%ss.Config.Run.PCAInterval == 0) { ss.Log(etime.Analyze, etime.Trial) } }) - man.GetLoop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() { + ls.Loop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() { ss.Logs.RunStats("PctCor", "FirstZero", "LastZero") }) // Save weights to file, to look at later - man.GetLoop(etime.Train, etime.Run).OnEnd.Add("SaveWeights", func() { + ls.Loop(etime.Train, etime.Run).OnEnd.Add("SaveWeights", func() { ctrString := ss.Stats.PrintValues([]string{"Run", "Epoch"}, []string{"%03d", "%05d"}, "_") leabra.SaveWeightsIfConfigSet(ss.Net, ss.Config.Log.SaveWeights, ctrString, ss.Stats.String("RunName")) }) @@ -492,19 +497,21 @@ func (ss *Sim) ConfigLoops() { if !ss.Config.GUI { if ss.Config.Log.NetData { - man.GetLoop(etime.Test, etime.Trial).Main.Add("NetDataRecord", func() { + ls.Loop(etime.Test, etime.Trial).OnEnd.Add("NetDataRecord", func() { ss.GUI.NetDataRecord(ss.ViewUpdate.Text) }) } } else { - leabra.LooperUpdateNetView(man, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) - leabra.LooperUpdatePlots(man, &ss.GUI) + leabra.LooperUpdateNetView(ls, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) + leabra.LooperUpdatePlots(ls, &ss.GUI) + ls.Stacks[etime.Train].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) + ls.Stacks[etime.Test].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) } if ss.Config.Debug { - mpi.Println(man.DocString()) + mpi.Println(ls.DocString()) } - ss.Loops = man + ss.Loops = ls } // ApplyInputs applies input patterns from given environment. @@ -542,7 +549,7 @@ func (ss *Sim) ApplyInputs() { // for the new run value func (ss *Sim) NewRun() { ctx := &ss.Context - ss.InitRandSeed(ss.Loops.GetLoop(etime.Train, etime.Run).Counter.Cur) + ss.InitRandSeed(ss.Loops.Loop(etime.Train, etime.Run).Counter.Cur) ss.Envs.ByMode(etime.Train).Init(0) ss.Envs.ByMode(etime.Test).Init(0) ctx.Reset() @@ -691,8 +698,7 @@ func (ss *Sim) Log(mode etime.Modes, time etime.Times) { ss.Logs.LogRow(mode, time, row) // also logs to file, etc } -//////////////////////////////////////////////////////////////////////////////////////////// -// Gui +//////// GUI // ConfigGUI configures the Cogent Core GUI interface for this simulation. func (ss *Sim) ConfigGUI() { @@ -715,18 +721,8 @@ func (ss *Sim) ConfigGUI() { } func (ss *Sim) MakeToolbar(p *tree.Plan) { - ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Init", Icon: icons.Update, - Tooltip: "Initialize everything including network weights, and start over. Also applies current params.", - Active: egui.ActiveStopped, - Func: func() { - ss.Init() - ss.GUI.UpdateWindow() - }, - }) + ss.GUI.AddLooperCtrl(p, ss.Loops) - ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test}) - - //////////////////////////////////////////////// tree.Add(p, func(w *core.Separator) {}) ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Reset RunLog", Icon: icons.Reset, @@ -789,7 +785,7 @@ func (ss *Sim) RunNoGUI() { ss.Init() mpi.Printf("Running %d Runs starting at %d\n", ss.Config.Run.NRuns, ss.Config.Run.Run) - ss.Loops.GetLoop(etime.Train, etime.Run).Counter.SetCurMaxPlusN(ss.Config.Run.Run, ss.Config.Run.NRuns) + ss.Loops.Loop(etime.Train, etime.Run).Counter.SetCurMaxPlusN(ss.Config.Run.Run, ss.Config.Run.NRuns) if ss.Config.Run.StartWts != "" { // this is just for testing -- not usually needed ss.Loops.Step(etime.Train, 1, etime.Trial) // get past NewRun diff --git a/examples/hip/hip.go b/examples/hip/hip.go index 9f3a28a..29a51dc 100644 --- a/examples/hip/hip.go +++ b/examples/hip/hip.go @@ -18,6 +18,7 @@ import ( "cogentcore.org/core/base/errors" "cogentcore.org/core/base/randx" "cogentcore.org/core/core" + "cogentcore.org/core/enums" "cogentcore.org/core/icons" "cogentcore.org/core/plot/plotcore" "cogentcore.org/core/tensor/stats/split" @@ -195,7 +196,7 @@ type Sim struct { Params emer.NetParams `display:"add-fields"` // contains looper control loops for running sim - Loops *looper.Manager `new-window:"+" display:"no-inline"` + Loops *looper.Stacks `new-window:"+" display:"no-inline"` // contains computed statistic values Stats estats.Stats `new-window:"+"` @@ -395,7 +396,7 @@ func (ss *Sim) Init() { } func (ss *Sim) TestInit() { - ss.Loops.ResetCountersByMode(etime.Test) + ss.Loops.InitMode(etime.Test) tst := ss.Envs.ByMode(etime.Test).(*env.FixedTable) tst.Init(0) } @@ -410,29 +411,32 @@ func (ss *Sim) InitRandSeed(run int) { // ConfigLoops configures the control loops: Training, Testing func (ss *Sim) ConfigLoops() { - man := looper.NewManager() + ls := looper.NewStacks() trls := ss.TrainAB.Rows ttrls := ss.TestAll.Rows - man.AddStack(etime.Train).AddTime(etime.Run, ss.Config.NRuns).AddTime(etime.Epoch, ss.Config.NEpochs).AddTime(etime.Trial, trls).AddTime(etime.Cycle, 100) + ls.AddStack(etime.Train).AddTime(etime.Run, ss.Config.NRuns).AddTime(etime.Epoch, ss.Config.NEpochs).AddTime(etime.Trial, trls).AddTime(etime.Cycle, 100) - man.AddStack(etime.Test).AddTime(etime.Epoch, 1).AddTime(etime.Trial, ttrls).AddTime(etime.Cycle, 100) + ls.AddStack(etime.Test).AddTime(etime.Epoch, 1).AddTime(etime.Trial, ttrls).AddTime(etime.Cycle, 100) - leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing - leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code - ss.Net.ConfigLoopsHip(&ss.Context, man) + leabra.LooperStdPhases(ls, &ss.Context, ss.Net, 75, 99) // plus phase timing + leabra.LooperSimCycleAndLearn(ls, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code + ss.Net.ConfigLoopsHip(&ss.Context, ls) - for m, _ := range man.Stacks { - stack := man.Stacks[m] + ls.Stacks[etime.Train].OnInit.Add("Init", func() { ss.Init() }) + ls.Stacks[etime.Test].OnInit.Add("Init", func() { ss.TestInit() }) + + for m, _ := range ls.Stacks { + stack := ls.Stacks[m] stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() { ss.ApplyInputs() }) } - man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) + ls.Loop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) - man.GetLoop(etime.Train, etime.Run).OnEnd.Add("RunDone", func() { + ls.Loop(etime.Train, etime.Run).OnEnd.Add("RunDone", func() { if ss.Stats.Int("Run") >= ss.Config.NRuns-1 { ss.RunStats() expt := ss.Stats.Int("Expt") @@ -441,7 +445,7 @@ func (ss *Sim) ConfigLoops() { }) // Add Testing - trainEpoch := man.GetLoop(etime.Train, etime.Epoch) + trainEpoch := ls.Loop(etime.Train, etime.Epoch) trainEpoch.OnEnd.Add("TestAtInterval", func() { if (ss.Config.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.TestInterval == 0) { // Note the +1 so that it doesn't occur at the 0th timestep. @@ -461,28 +465,33 @@ func (ss *Sim) ConfigLoops() { }) // early stop - man.GetLoop(etime.Train, etime.Epoch).IsDone["ACMemStop"] = func() bool { + ls.Loop(etime.Train, etime.Epoch).IsDone.AddBool("ACMemStop", func() bool { // This is calculated in TrialStats tstEpcLog := ss.Logs.Tables[etime.Scope(etime.Test, etime.Epoch)] acMem := float32(tstEpcLog.Table.Float("ACMem", ss.Stats.Int("Epoch"))) stop := acMem >= ss.Config.StopMem return stop - } + }) ///////////////////////////////////////////// // Logging - man.GetLoop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { + ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { leabra.LogTestErrors(&ss.Logs) }) - man.AddOnEndToAll("Log", ss.Log) - leabra.LooperResetLogBelow(man, &ss.Logs) + ls.AddOnEndToAll("Log", func(mode, time enums.Enum) { + ss.Log(mode.(etime.Modes), time.(etime.Times)) + }) + leabra.LooperResetLogBelow(ls, &ss.Logs) - leabra.LooperUpdateNetView(man, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) - leabra.LooperUpdatePlots(man, &ss.GUI) - ss.Loops = man + leabra.LooperUpdateNetView(ls, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) + leabra.LooperUpdatePlots(ls, &ss.GUI) + ls.Stacks[etime.Train].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) + ls.Stacks[etime.Test].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) + + ss.Loops = ls // mpi.Println(man.DocString()) } @@ -519,7 +528,7 @@ func (ss *Sim) ApplyInputs() { // for the new run value func (ss *Sim) NewRun() { ctx := &ss.Context - ss.InitRandSeed(ss.Loops.GetLoop(etime.Train, etime.Run).Counter.Cur) + ss.InitRandSeed(ss.Loops.Loop(etime.Train, etime.Run).Counter.Cur) // ss.ConfigPats() ss.ConfigEnv() ctx.Reset() @@ -681,7 +690,7 @@ func (ss *Sim) NetViewCounters(tm etime.Times) { // TrialStats computes the trial-level statistics. // Aggregation is done directly from log data. func (ss *Sim) TrialStats() { - ss.MemStats(ss.Loops.Mode) + ss.MemStats(ss.Loops.Mode.(etime.Modes)) } // MemStats computes ActM vs. Target on ECout with binary counts @@ -936,27 +945,8 @@ func (ss *Sim) ConfigGUI() { } func (ss *Sim) MakeToolbar(p *tree.Plan) { - ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Init", Icon: icons.Update, - Tooltip: "Initialize everything including network weights, and start over. Also applies current params.", - Active: egui.ActiveStopped, - Func: func() { - ss.Init() - ss.GUI.UpdateWindow() - }, - }) - - ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test}) + ss.GUI.AddLooperCtrl(p, ss.Loops) - ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Test Init", Icon: icons.Update, - Tooltip: "Initialize the testing process.", - Active: egui.ActiveStopped, - Func: func() { - ss.TestInit() - ss.GUI.UpdateWindow() - }, - }) - - //////////////////////////////////////////////// tree.Add(p, func(w *core.Separator) {}) ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Reset RunLog", Icon: icons.Reset, diff --git a/examples/ra25/ra25.go b/examples/ra25/ra25.go index 734d647..f7436b6 100644 --- a/examples/ra25/ra25.go +++ b/examples/ra25/ra25.go @@ -18,6 +18,7 @@ import ( "cogentcore.org/core/base/mpi" "cogentcore.org/core/base/randx" "cogentcore.org/core/core" + "cogentcore.org/core/enums" "cogentcore.org/core/icons" "cogentcore.org/core/math32" "cogentcore.org/core/math32/vecint" @@ -238,7 +239,7 @@ type Sim struct { Params emer.NetParams `display:"add-fields"` // contains looper control loops for running sim - Loops *looper.Manager `new-window:"+" display:"no-inline"` + Loops *looper.Stacks `new-window:"+" display:"no-inline"` // contains computed statistic values Stats estats.Stats `new-window:"+"` @@ -400,35 +401,39 @@ func (ss *Sim) InitRandSeed(run int) { // ConfigLoops configures the control loops: Training, Testing func (ss *Sim) ConfigLoops() { - man := looper.NewManager() + ls := looper.NewStacks() trls := ss.Config.Run.NTrials - man.AddStack(etime.Train). + ls.AddStack(etime.Train). AddTime(etime.Run, ss.Config.Run.NRuns). AddTime(etime.Epoch, ss.Config.Run.NEpochs). AddTime(etime.Trial, trls). AddTime(etime.Cycle, 100) - man.AddStack(etime.Test). + ls.AddStack(etime.Test). AddTime(etime.Epoch, 1). AddTime(etime.Trial, trls). AddTime(etime.Cycle, 100) - leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing - leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code + leabra.LooperStdPhases(ls, &ss.Context, ss.Net, 75, 99) // plus phase timing + leabra.LooperSimCycleAndLearn(ls, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code - for m, _ := range man.Stacks { - stack := man.Stacks[m] - stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() { + ls.Stacks[etime.Train].OnInit.Add("Init", func() { + ss.Init() + }) + + for m, _ := range ls.Stacks { + st := ls.Stacks[m] + st.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() { ss.ApplyInputs() }) } - man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) + ls.Loop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) // Train stop early condition - man.GetLoop(etime.Train, etime.Epoch).IsDone["NZeroStop"] = func() bool { + ls.Loop(etime.Train, etime.Epoch).IsDone.AddBool("NZeroStop", func() bool { // This is calculated in TrialStats stopNz := ss.Config.Run.NZero if stopNz <= 0 { @@ -437,10 +442,10 @@ func (ss *Sim) ConfigLoops() { curNZero := ss.Stats.Int("NZero") stop := curNZero >= stopNz return stop - } + }) // Add Testing - trainEpoch := man.GetLoop(etime.Train, etime.Epoch) + trainEpoch := ls.Loop(etime.Train, etime.Epoch) trainEpoch.OnStart.Add("TestAtInterval", func() { if (ss.Config.Run.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.Run.TestInterval == 0) { // Note the +1 so that it doesn't occur at the 0th timestep. @@ -451,33 +456,35 @@ func (ss *Sim) ConfigLoops() { ///////////////////////////////////////////// // Logging - man.GetLoop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { + ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { leabra.LogTestErrors(&ss.Logs) }) - man.GetLoop(etime.Train, etime.Epoch).OnEnd.Add("PCAStats", func() { - trnEpc := man.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur + ls.Loop(etime.Train, etime.Epoch).OnEnd.Add("PCAStats", func() { + trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur if ss.Config.Run.PCAInterval > 0 && trnEpc%ss.Config.Run.PCAInterval == 0 { leabra.PCAStats(ss.Net, &ss.Logs, &ss.Stats) ss.Logs.ResetLog(etime.Analyze, etime.Trial) } }) - man.AddOnEndToAll("Log", ss.Log) - leabra.LooperResetLogBelow(man, &ss.Logs) + ls.AddOnEndToAll("Log", func(mode, time enums.Enum) { + ss.Log(mode.(etime.Modes), time.(etime.Times)) + }) + leabra.LooperResetLogBelow(ls, &ss.Logs) - man.GetLoop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() { - trnEpc := man.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur + ls.Loop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() { + trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur if (ss.Config.Run.PCAInterval > 0) && (trnEpc%ss.Config.Run.PCAInterval == 0) { ss.Log(etime.Analyze, etime.Trial) } }) - man.GetLoop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() { + ls.Loop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() { ss.Logs.RunStats("PctCor", "FirstZero", "LastZero") }) // Save weights to file, to look at later - man.GetLoop(etime.Train, etime.Run).OnEnd.Add("SaveWeights", func() { + ls.Loop(etime.Train, etime.Run).OnEnd.Add("SaveWeights", func() { ctrString := ss.Stats.PrintValues([]string{"Run", "Epoch"}, []string{"%03d", "%05d"}, "_") leabra.SaveWeightsIfConfigSet(ss.Net, ss.Config.Log.SaveWeights, ctrString, ss.Stats.String("RunName")) }) @@ -487,19 +494,21 @@ func (ss *Sim) ConfigLoops() { if !ss.Config.GUI { if ss.Config.Log.NetData { - man.GetLoop(etime.Test, etime.Trial).Main.Add("NetDataRecord", func() { + ls.Loop(etime.Test, etime.Trial).OnEnd.Add("NetDataRecord", func() { ss.GUI.NetDataRecord(ss.ViewUpdate.Text) }) } } else { - leabra.LooperUpdateNetView(man, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) - leabra.LooperUpdatePlots(man, &ss.GUI) + leabra.LooperUpdateNetView(ls, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) + leabra.LooperUpdatePlots(ls, &ss.GUI) + ls.Stacks[etime.Train].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) + ls.Stacks[etime.Test].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) } if ss.Config.Debug { - mpi.Println(man.DocString()) + mpi.Println(ls.DocString()) } - ss.Loops = man + ss.Loops = ls } // ApplyInputs applies input patterns from given environment. @@ -527,7 +536,7 @@ func (ss *Sim) ApplyInputs() { // for the new run value func (ss *Sim) NewRun() { ctx := &ss.Context - ss.InitRandSeed(ss.Loops.GetLoop(etime.Train, etime.Run).Counter.Cur) + ss.InitRandSeed(ss.Loops.Loop(etime.Train, etime.Run).Counter.Cur) ss.Envs.ByMode(etime.Train).Init(0) ss.Envs.ByMode(etime.Test).Init(0) ctx.Reset() @@ -711,16 +720,7 @@ func (ss *Sim) ConfigGUI() { } func (ss *Sim) MakeToolbar(p *tree.Plan) { - ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Init", Icon: icons.Update, - Tooltip: "Initialize everything including network weights, and start over. Also applies current params.", - Active: egui.ActiveStopped, - Func: func() { - ss.Init() - ss.GUI.UpdateWindow() - }, - }) - - ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test}) + ss.GUI.AddLooperCtrl(p, ss.Loops) //////////////////////////////////////////////// tree.Add(p, func(w *core.Separator) {}) @@ -785,7 +785,7 @@ func (ss *Sim) RunNoGUI() { ss.Init() mpi.Printf("Running %d Runs starting at %d\n", ss.Config.Run.NRuns, ss.Config.Run.Run) - ss.Loops.GetLoop(etime.Train, etime.Run).Counter.SetCurMaxPlusN(ss.Config.Run.Run, ss.Config.Run.NRuns) + ss.Loops.Loop(etime.Train, etime.Run).Counter.SetCurMaxPlusN(ss.Config.Run.Run, ss.Config.Run.NRuns) if ss.Config.Run.StartWts != "" { // this is just for testing -- not usually needed ss.Loops.Step(etime.Train, 1, etime.Trial) // get past NewRun diff --git a/examples/sir2/sir2.go b/examples/sir2/sir2.go index f7c7041..bcf5497 100644 --- a/examples/sir2/sir2.go +++ b/examples/sir2/sir2.go @@ -17,6 +17,7 @@ import ( "cogentcore.org/core/base/randx" "cogentcore.org/core/core" + "cogentcore.org/core/enums" "cogentcore.org/core/icons" "cogentcore.org/core/math32" "cogentcore.org/core/tree" @@ -227,7 +228,7 @@ type Sim struct { Params emer.NetParams `display:"add-fields"` // contains looper control loops for running sim - Loops *looper.Manager `new-window:"+" display:"no-inline"` + Loops *looper.Stacks `new-window:"+" display:"no-inline"` // contains computed statistic values Stats estats.Stats `new-window:"+"` @@ -416,49 +417,52 @@ func (ss *Sim) InitRandSeed(run int) { // ConfigLoops configures the control loops: Training, Testing func (ss *Sim) ConfigLoops() { - man := looper.NewManager() + ls := looper.NewStacks() trls := ss.Config.NTrials - man.AddStack(etime.Train). + ls.AddStack(etime.Train). AddTime(etime.Run, ss.Config.NRuns). AddTime(etime.Epoch, ss.Config.NEpochs). AddTime(etime.Trial, trls). AddTime(etime.Cycle, 100) - man.AddStack(etime.Test). + ls.AddStack(etime.Test). AddTime(etime.Epoch, 1). AddTime(etime.Trial, trls). AddTime(etime.Cycle, 100) - leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing - leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code + leabra.LooperStdPhases(ls, &ss.Context, ss.Net, 75, 99) // plus phase timing + leabra.LooperSimCycleAndLearn(ls, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code - for m, _ := range man.Stacks { - stack := man.Stacks[m] + ls.Stacks[etime.Train].OnInit.Add("Init", func() { ss.Init() }) + + for m, _ := range ls.Stacks { + stack := ls.Stacks[m] stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() { ss.ApplyInputs() }) } - man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) + ls.Loop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) - man.GetLoop(etime.Train, etime.Run).OnEnd.Add("RunDone", func() { + ls.Loop(etime.Train, etime.Run).OnEnd.Add("RunDone", func() { if ss.Stats.Int("Run") >= ss.Config.NRuns-1 { expt := ss.Stats.Int("Expt") ss.Stats.SetInt("Expt", expt+1) } }) - stack := man.Stacks[etime.Train] + stack := ls.Stacks[etime.Train] cyc, _ := stack.Loops[etime.Cycle] - plus := cyc.EventByName("PlusPhase") - plus.OnEvent.InsertBefore("MinusPhase:End", "ApplyReward", func() { + plus := cyc.EventByName("MinusPhase:End") + plus.OnEvent.InsertBefore("MinusPhase:End", "ApplyReward", func() bool { ss.ApplyReward(true) + return true }) // Train stop early condition - man.GetLoop(etime.Train, etime.Epoch).IsDone["NZeroStop"] = func() bool { + ls.Loop(etime.Train, etime.Epoch).IsDone.AddBool("NZeroStop", func() bool { // This is calculated in TrialStats stopNz := ss.Config.NZero if stopNz <= 0 { @@ -467,10 +471,10 @@ func (ss *Sim) ConfigLoops() { curNZero := ss.Stats.Int("NZero") stop := curNZero >= stopNz return stop - } + }) // Add Testing - trainEpoch := man.GetLoop(etime.Train, etime.Epoch) + trainEpoch := ls.Loop(etime.Train, etime.Epoch) trainEpoch.OnStart.Add("TestAtInterval", func() { if (ss.Config.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.TestInterval == 0) { // Note the +1 so that it doesn't occur at the 0th timestep. @@ -481,22 +485,26 @@ func (ss *Sim) ConfigLoops() { ///////////////////////////////////////////// // Logging - man.GetLoop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { + ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() { leabra.LogTestErrors(&ss.Logs) }) - man.AddOnEndToAll("Log", ss.Log) - leabra.LooperResetLogBelow(man, &ss.Logs) - man.GetLoop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() { + ls.AddOnEndToAll("Log", func(mode, time enums.Enum) { + ss.Log(mode.(etime.Modes), time.(etime.Times)) + }) + leabra.LooperResetLogBelow(ls, &ss.Logs) + ls.Loop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() { ss.Logs.RunStats("PctCor", "FirstZero", "LastZero") }) //////////////////////////////////////////// // GUI - leabra.LooperUpdateNetView(man, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) - leabra.LooperUpdatePlots(man, &ss.GUI) + leabra.LooperUpdateNetView(ls, &ss.ViewUpdate, ss.Net, ss.NetViewCounters) + leabra.LooperUpdatePlots(ls, &ss.GUI) + ls.Stacks[etime.Train].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) + ls.Stacks[etime.Test].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() }) - ss.Loops = man + ss.Loops = ls } // ApplyInputs applies input patterns from given environment. @@ -548,7 +556,7 @@ func (ss *Sim) ApplyReward(train bool) { // for the new run value func (ss *Sim) NewRun() { ctx := &ss.Context - ss.InitRandSeed(ss.Loops.GetLoop(etime.Train, etime.Run).Counter.Cur) + ss.InitRandSeed(ss.Loops.Loop(etime.Train, etime.Run).Counter.Cur) ss.Envs.ByMode(etime.Train).Init(0) ss.Envs.ByMode(etime.Test).Init(0) ctx.Reset() @@ -719,26 +727,8 @@ func (ss *Sim) ConfigGUI() { } func (ss *Sim) MakeToolbar(p *tree.Plan) { - ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Init", Icon: icons.Update, - Tooltip: "Initialize everything including network weights, and start over. Also applies current params.", - Active: egui.ActiveStopped, - Func: func() { - ss.Init() - ss.GUI.UpdateWindow() - }, - }) + ss.GUI.AddLooperCtrl(p, ss.Loops) - ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test}) - - ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Test Init", Icon: icons.Update, - Tooltip: "Initialize testing to start over -- if Test Step doesn't work, then do this.", - Active: egui.ActiveStopped, - Func: func() { - ss.Loops.ResetCountersByMode(etime.Test) - }, - }) - - //////////////////////////////////////////////// tree.Add(p, func(w *core.Separator) {}) ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Reset RunLog", Icon: icons.Reset, diff --git a/go.mod b/go.mod index d1e85be..c06c5ce 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( cogentcore.org/core v0.3.5 - github.com/emer/emergent/v2 v2.0.0-dev0.1.3 + github.com/emer/emergent/v2 v2.0.0-dev0.1.4 ) require ( diff --git a/go.sum b/go.sum index 708d361..805a07a 100644 --- a/go.sum +++ b/go.sum @@ -32,8 +32,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/emer/emergent/v2 v2.0.0-dev0.1.3 h1:3ixqy2VIubETYEj3GMi7JU0RF6K/Cb9TXC8ITq+Udmk= -github.com/emer/emergent/v2 v2.0.0-dev0.1.3/go.mod h1:9QhWnj/IHq/TrVzcXsC96GlC1Gg/pK0pwMI5sZr+/yU= +github.com/emer/emergent/v2 v2.0.0-dev0.1.4 h1:HCBBq6s/t3n6xlLPS9TzOL5BO9+U4BYIboJdBrFeliA= +github.com/emer/emergent/v2 v2.0.0-dev0.1.4/go.mod h1:9QhWnj/IHq/TrVzcXsC96GlC1Gg/pK0pwMI5sZr+/yU= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= diff --git a/leabra/hip.go b/leabra/hip.go index 54ee1c3..d7621a6 100644 --- a/leabra/hip.go +++ b/leabra/hip.go @@ -206,7 +206,7 @@ func (pt *Path) DWtEcCa1() { // ConfigLoopsHip configures the hippocampal looper and should be included in ConfigLoops // in model to make sure hip loops is configured correctly. // see hip.go for an instance of implementation of this function. -func (net *Network) ConfigLoopsHip(ctx *Context, man *looper.Manager) { +func (net *Network) ConfigLoopsHip(ctx *Context, ls *looper.Stacks) { var tmpValues []float32 ecout := net.LayerByName("ECout") ecin := net.LayerByName("ECin") @@ -218,21 +218,14 @@ func (net *Network) ConfigLoopsHip(ctx *Context, man *looper.Manager) { dgPjScale := ca3FromDg.WtScale.Rel - // configure events -- note that events are shared between Train, Test - // so only need to do it once on Train - mode := etime.Train - stack := man.Stacks[mode] - cyc, _ := stack.Loops[etime.Cycle] - minusStart := cyc.EventByName("MinusPhase") // cycle 0 - minusStart.OnEvent.Add("HipMinusPhase:Start", func() { + ls.AddEventAllModes(etime.Cycle, "HipMinusPhase:Start", 0, func() { ca1FromECin.WtScale.Abs = 1 ca1FromCa3.WtScale.Abs = 0 ca3FromDg.WtScale.Rel = 0 net.GScaleFromAvgAct() net.InitGInc() }) - quarter1 := cyc.EventByName("Quarter1") - quarter1.OnEvent.Add("Hip:Quarter1", func() { + ls.AddEventAllModes(etime.Cycle, "Hip:Quarter1", 25, func() { ca1FromECin.WtScale.Abs = 0 ca1FromCa3.WtScale.Abs = 1 if ctx.Mode == etime.Test { @@ -243,8 +236,7 @@ func (net *Network) ConfigLoopsHip(ctx *Context, man *looper.Manager) { net.GScaleFromAvgAct() net.InitGInc() }) - plus := cyc.EventByName("PlusPhase") - plus.OnEvent.InsertBefore("MinusPhase:End", "HipPlusPhase:Start", func() { + ls.AddEventAllModes(etime.Cycle, "HipPlusPhase:Start", 75, func() { ca1FromECin.WtScale.Abs = 1 ca1FromCa3.WtScale.Abs = 0 if ctx.Mode == etime.Train { diff --git a/leabra/looper.go b/leabra/looper.go index 8609462..1149cc0 100644 --- a/leabra/looper.go +++ b/leabra/looper.go @@ -18,36 +18,31 @@ import ( // and plusEnd is end of plus phase, typically 99 // resets the state at start of trial. // Can pass a trial-level time scale to use instead of the default etime.Trial -func LooperStdPhases(man *looper.Manager, ctx *Context, net *Network, plusStart, plusEnd int, trial ...etime.Times) { +func LooperStdPhases(ls *looper.Stacks, ctx *Context, net *Network, plusStart, plusEnd int, trial ...etime.Times) { trl := etime.Trial if len(trial) > 0 { trl = trial[0] } - minusPhase := &looper.Event{Name: "MinusPhase", AtCounter: 0} - minusPhase.OnEvent.Add("MinusPhase:Start", func() { + ls.AddEventAllModes(etime.Cycle, "MinusPhase:Start", 0, func() { ctx.PlusPhase = false }) - quarter1 := looper.NewEvent("Quarter1", 25, func() { + ls.AddEventAllModes(etime.Cycle, "Quarter1", 25, func() { net.QuarterFinal(ctx) ctx.QuarterInc() }) - quarter2 := looper.NewEvent("Quarter2", 50, func() { + ls.AddEventAllModes(etime.Cycle, "Quarter2", 50, func() { net.QuarterFinal(ctx) ctx.QuarterInc() }) - plusPhase := &looper.Event{Name: "PlusPhase", AtCounter: plusStart} - plusPhase.OnEvent.Add("MinusPhase:End", func() { + ls.AddEventAllModes(etime.Cycle, "MinusPhase:End", plusStart, func() { net.QuarterFinal(ctx) ctx.QuarterInc() }) - plusPhase.OnEvent.Add("PlusPhase:Start", func() { + ls.AddEventAllModes(etime.Cycle, "PlusPhase:Start", plusStart, func() { ctx.PlusPhase = true }) - man.AddEventAllModes(etime.Cycle, minusPhase, quarter1, quarter2, plusPhase) - - for m, _ := range man.Stacks { - stack := man.Stacks[m] + for m, stack := range ls.Stacks { stack.Loops[trl].OnStart.Add("AlphaCycInit", func() { net.AlphaCycInit(m == etime.Train) ctx.AlphaCycStart() @@ -61,19 +56,18 @@ func LooperStdPhases(man *looper.Manager, ctx *Context, net *Network, plusStart, // LooperSimCycleAndLearn adds Cycle and DWt, WtFromDWt functions to looper // for given network, ctx, and netview update manager // Can pass a trial-level time scale to use instead of the default etime.Trial -func LooperSimCycleAndLearn(man *looper.Manager, net *Network, ctx *Context, viewupdt *netview.ViewUpdate, trial ...etime.Times) { +func LooperSimCycleAndLearn(ls *looper.Stacks, net *Network, ctx *Context, viewupdt *netview.ViewUpdate, trial ...etime.Times) { trl := etime.Trial if len(trial) > 0 { trl = trial[0] } - for m, _ := range man.Stacks { - cycLoop := man.Stacks[m].Loops[etime.Cycle] - cycLoop.Main.Add("Cycle", func() { + for m := range ls.Stacks { + ls.Stacks[m].Loops[etime.Cycle].OnStart.Add("Cycle", func() { net.Cycle(ctx) ctx.CycleInc() }) } - ttrl := man.GetLoop(etime.Train, trl) + ttrl := ls.Loop(etime.Train, trl) if ttrl != nil { ttrl.OnEnd.Add("UpdateWeights", func() { net.DWt() @@ -85,10 +79,10 @@ func LooperSimCycleAndLearn(man *looper.Manager, net *Network, ctx *Context, vie } // Set variables on ss that are referenced elsewhere, such as ApplyInputs. - for m, loops := range man.Stacks { + for m, loops := range ls.Stacks { for _, loop := range loops.Loops { loop.OnStart.Add("SetCtxMode", func() { - ctx.Mode = m + ctx.Mode = m.(etime.Modes) }) } } @@ -98,8 +92,8 @@ func LooperSimCycleAndLearn(man *looper.Manager, net *Network, ctx *Context, vie // to reset the log at the level below each loop -- this is good default behavior. // Exceptions can be passed to exclude specific levels -- e.g., if except is Epoch // then Epoch does not reset the log below it -func LooperResetLogBelow(man *looper.Manager, logs *elog.Logs, except ...etime.Times) { - for m, stack := range man.Stacks { +func LooperResetLogBelow(ls *looper.Stacks, logs *elog.Logs, except ...etime.Times) { + for m, stack := range ls.Stacks { for t, loop := range stack.Loops { curTime := t isExcept := false @@ -111,7 +105,7 @@ func LooperResetLogBelow(man *looper.Manager, logs *elog.Logs, except ...etime.T } if below := stack.TimeBelow(curTime); !isExcept && below != etime.NoTime { loop.OnStart.Add("ResetLog"+below.String(), func() { - logs.ResetLog(m, below) + logs.ResetLog(m.(etime.Modes), below.(etime.Times)) }) } } @@ -119,10 +113,10 @@ func LooperResetLogBelow(man *looper.Manager, logs *elog.Logs, except ...etime.T } // LooperUpdateNetView adds netview update calls at each time level -func LooperUpdateNetView(man *looper.Manager, viewupdt *netview.ViewUpdate, net *Network, ctrUpdateFunc func(tm etime.Times)) { - for m, stack := range man.Stacks { +func LooperUpdateNetView(ls *looper.Stacks, viewupdt *netview.ViewUpdate, net *Network, ctrUpdateFunc func(tm etime.Times)) { + for m, stack := range ls.Stacks { for t, loop := range stack.Loops { - curTime := t + curTime := t.(etime.Times) if curTime != etime.Cycle { loop.OnEnd.Add("GUI:UpdateNetView", func() { ctrUpdateFunc(curTime) @@ -131,7 +125,7 @@ func LooperUpdateNetView(man *looper.Manager, viewupdt *netview.ViewUpdate, net }) } } - cycLoop := man.GetLoop(m, etime.Cycle) + cycLoop := ls.Loop(m, etime.Cycle) cycLoop.OnEnd.Add("GUI:UpdateNetView", func() { cyc := cycLoop.Counter.Cur ctrUpdateFunc(etime.Cycle) @@ -142,19 +136,19 @@ func LooperUpdateNetView(man *looper.Manager, viewupdt *netview.ViewUpdate, net } // LooperUpdatePlots adds plot update calls at each time level -func LooperUpdatePlots(man *looper.Manager, gui *egui.GUI) { - for m, stack := range man.Stacks { +func LooperUpdatePlots(ls *looper.Stacks, gui *egui.GUI) { + for m, stack := range ls.Stacks { for t, loop := range stack.Loops { - curTime := t + curTime := t.(etime.Times) curLoop := loop if curTime == etime.Cycle { curLoop.OnEnd.Add("GUI:UpdatePlot", func() { cyc := curLoop.Counter.Cur - gui.GoUpdateCyclePlot(m, cyc) + gui.GoUpdateCyclePlot(m.(etime.Modes), cyc) }) } else { curLoop.OnEnd.Add("GUI:UpdatePlot", func() { - gui.GoUpdatePlot(m, curTime) + gui.GoUpdatePlot(m.(etime.Modes), curTime) }) } }