From 1493e8fb4c92236fb13a06d0c12bed931b100ba8 Mon Sep 17 00:00:00 2001 From: "Randall C. O'Reilly" Date: Fri, 20 Dec 2024 17:00:35 -0800 Subject: [PATCH] bgdorsal: rest of stats, stopping crit etc --- sims/bgdorsal/bg-dorsal.go | 38 ++++++++++++++++++++++++++++++++++---- sims/bgdorsal/config.go | 2 +- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/sims/bgdorsal/bg-dorsal.go b/sims/bgdorsal/bg-dorsal.go index 724575b3..71e7e9f3 100644 --- a/sims/bgdorsal/bg-dorsal.go +++ b/sims/bgdorsal/bg-dorsal.go @@ -423,8 +423,8 @@ func (ss *Sim) ConfigLoops() { ls.Loop(Train, Run).OnStart.Add("NewRun", ss.NewRun) ls.Loop(Train, Epoch).IsDone.AddBool("StopCrit", func() bool { - curModeDir := ss.Current.Dir(Train.String()) - rew := curModeDir.Value("RewEpc").Float1D(-1) + epcDir := ss.Stats.Dir(Train.String()).Dir(Epoch.String()) + rew := epcDir.Value("Rew").Float1D(-1) stop := rew >= 0.98 return stop }) @@ -785,7 +785,7 @@ func (ss *Sim) ConfigStats() { } }) - seqStats := []string{"NCorrect", "Rew", "RewPred", "RPE", "RewEpc"} + seqStats := []string{"NCorrect", "Rew", "RewPred", "RPE"} ss.AddStat(func(mode Modes, level Levels, phase StatsPhase) { if level <= Trial { return @@ -824,12 +824,42 @@ func (ss *Sim) ConfigStats() { curModeDir.Float32(name, ndata).SetFloat1D(float64(stat), di) tsr.AppendRowFloat(float64(stat)) } - default: + case Epoch: stat = stats.StatMean.Call(subDir.Value(name)).Float1D(0) tsr.AppendRowFloat(stat) + default: // Run, Expt + stat = stats.StatFinal.Call(subDir.Value(name)).Float1D(0) + tsr.AppendRowFloat(stat) } } }) + ss.AddStat(func(mode Modes, level Levels, phase StatsPhase) { + if level <= Epoch { + return + } + name := "EpochsToCrit" + modeDir := ss.Stats.Dir(mode.String()) + levelDir := modeDir.Dir(level.String()) + subDir := modeDir.Dir((level - 1).String()) // note: will fail for Cycle + tsr := levelDir.Float64(name) + if phase == Start { + tsr.SetNumRows(0) + plot.SetFirstStylerTo(tsr, func(s *plot.Style) { + s.Range.SetMin(0) + s.On = true + }) + return + } + var stat float64 + switch level { + case Run: + stat = float64(ss.Loops.Loop(mode, (level - 1)).Counter.Cur) + tsr.AppendRowFloat(stat) + default: // in case higher + stat = stats.StatFinal.Call(subDir.Value(name)).Float1D(0) + tsr.AppendRowFloat(stat) + } + }) } // StatCounters returns counters string to show at bottom of netview. diff --git a/sims/bgdorsal/config.go b/sims/bgdorsal/config.go index f401dec4..32341acd 100644 --- a/sims/bgdorsal/config.go +++ b/sims/bgdorsal/config.go @@ -85,7 +85,7 @@ type RunConfig struct { Runs int `default:"25" min:"1"` // Epochs is the total number of epochs per run. - Epochs int `default:"50"` + Epochs int `default:"100"` // Sequences is the total number of sequences per epoch. // Should be an even multiple of NData.