Skip to content

Commit

Permalink
view update calls counters directly
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Nov 23, 2024
1 parent a7e3490 commit 7a41f99
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 82 deletions.
3 changes: 2 additions & 1 deletion axon/act-net.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 13 additions & 7 deletions axon/looper.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,25 @@ var ViewTimeCycles = []int{1, 10, 25, 50, 100, 150, 200}
// Use one of these for each mode you want to control separately.
type NetViewUpdate struct {

// toggles update of display on
// On toggles update of display on
On bool

// Time scale to update the network view.
// Time scale to update the network view (Cycle to Trial timescales).
Time ViewTimes

// CounterFunc returns the counter string showing current counters etc.
CounterFunc func(mode, level enums.Enum) string `display:"-"`

// View is the network view.
View *netview.NetView `display:"-"`
}

// Config configures for given NetView and time.
func (vu *NetViewUpdate) Config(nv *netview.NetView, tm ViewTimes) {
// Config configures for given NetView, time and counter func.
func (vu *NetViewUpdate) Config(nv *netview.NetView, tm ViewTimes, fun func(mode, level enums.Enum) string) {
vu.View = nv
vu.On = true
vu.Time = tm
vu.CounterFunc = fun
}

// ShouldUpdate returns true if the view is On,
Expand Down Expand Up @@ -196,10 +200,11 @@ func (vu *NetViewUpdate) GoUpdate(counters string) {
// including recording new data and driving update of display.
// This version is only for calling from the main event loop
// (see also GoUpdate).
func (vu *NetViewUpdate) Update(counters string) {
func (vu *NetViewUpdate) Update(mode, level enums.Enum) {
if !vu.ShouldUpdate() {
return
}
counters := vu.CounterFunc(mode, level)
vu.View.Record(counters, -1) // -1 = default incrementing raster
vu.View.UpdateView()
}
Expand All @@ -209,12 +214,13 @@ func (vu *NetViewUpdate) Update(counters string) {
// This has different logic for the raster view vs. regular.
// This is only for calling from a separate goroutine,
// not the main event loop.
func (vu *NetViewUpdate) UpdateWhenStopped() {
func (vu *NetViewUpdate) UpdateWhenStopped(mode, level enums.Enum) {
if !vu.ShouldUpdate() {
return
}
if !vu.View.Options.Raster.On { // always record when not in raster mode
vu.View.Record("", -1) // -1 = use a dummy counter
counters := vu.CounterFunc(mode, level)
vu.View.Record(counters, -1) // -1 = use a dummy counter
}
vu.View.GoUpdateView()
}
Expand Down
4 changes: 2 additions & 2 deletions examples/deep_fsa/fsa_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (ev *FSAEnv) Init(run int) {
ev.Trial.Init()
ev.Trial.Cur = -1 // init state -- key so that first Step() = 0
ev.AState.Cur = 0
ev.AState.Prv = -1
ev.AState.Prev = -1
}

// NextState sets NextStates including randomly chosen one at start
Expand Down Expand Up @@ -167,7 +167,7 @@ func (ev *FSAEnv) Step() bool {
ev.NextState()
ev.Trial.Incr()
ev.Tick.Incr()
if ev.AState.Prv == 0 {
if ev.AState.Prev == 0 {
ev.Tick.Init()
ev.Seq.Incr()
}
Expand Down
134 changes: 63 additions & 71 deletions examples/ra25/ra25.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (

func main() {
opts := cli.DefaultOptions("ra25", "Random associator.")
opts.DefaultFiles = append(opts.DefaultFiles, "config.toml")
cfg := &Config{}
cli.Run(opts, cfg, RunSim)
}
Expand Down Expand Up @@ -190,48 +191,48 @@ type Sim struct {
// simulation configuration parameters -- set by .toml config file and / or args
Config *Config `new-window:"+"`

// the network -- click to view / edit parameters for layers, paths, etc
// Net is the network: click to view / edit parameters for layers, paths, etc.
Net *axon.Network `new-window:"+" display:"no-inline"`

// network parameter management
// Params manages network parameter setting.
Params axon.Params

// contains looper control loops for running sim
// Loops are the the control loops for running the sim, in different Modes
// across stacks of Levels.
Loops *looper.Stacks `new-window:"+" display:"no-inline"`

// the training patterns to use
Pats *tensorfs.Node `display:"-"`

// Environments
// Envs provides mode-string based storage of environments.
Envs env.Envs `new-window:"+" display:"no-inline"`

// train mode netview update parameters
// TrainUpdate has Train mode netview update parameters.
TrainUpdate axon.NetViewUpdate `display:"inline"`

// test mode netview update parameters
// TestUpdate has Test mode netview update parameters.
TestUpdate axon.NetViewUpdate `display:"inline"`

// Root is the root data dir.
// Root is the root tensorfs directory, where all stats and other misc sim data goes.
Root *tensorfs.Node `display:"-"`

// StatFuncs are statistics functions, per stat, handles everything.
StatFuncs []func(mode Modes, level Levels, phase StatsPhase) `display:"-"`

// Stats has the stats dir.
// Stats has the stats directory within Root.
Stats *tensorfs.Node `display:"-"`

// Current has the current stats values.
// Current has the current stats values within Stats.
Current *tensorfs.Node `display:"-"`

// manages all the gui elements
// StatFuncs are statistics functions called at given mode and level,
// to perform all stats computations. phase = Start does init at start of given level,
// and all intialization / configuration (called during Init too).
StatFuncs []func(mode Modes, level Levels, phase StatsPhase) `display:"-"`

// GUI manages all the GUI elements
GUI egui.GUI `display:"-"`

// a list of random seeds to use for each run
// RandSeeds is a list of random seeds to use for each run.
RandSeeds randx.Seeds `display:"-"`
}

// RunSim runs the simulation.
func RunSim(cfg *Config) error { //cli:cmd -root
// RunSim runs the simulation with given configuration.
func RunSim(cfg *Config) error {
sim := &Sim{}
sim.Config = cfg
sim.Run()
Expand Down Expand Up @@ -277,7 +278,16 @@ func (ss *Sim) ConfigEnv() {
tst = ss.Envs.ByMode(Test).(*env.FixedTable)
}

pats := tensorfs.DirTable(ss.Pats, nil)
pats := tensorfs.DirTable(ss.Root.RecycleDir("Pats"), nil)

// this logic can be used to create train-test splits of a set of patterns:
// n := pats.NumRows()
// order := rand.Perm(n)
// ntrn := int(0.85 * float64(n))
// trnEnv := table.NewView(pats)
// tstEnv := table.NewView(pats)
// trnEnv.Indexes = order[:ntrn]
// tstEnv.Indexes = order[ntrn:]

// note: names must be standard here!
trn.Name = Train.String()
Expand All @@ -289,12 +299,6 @@ func (ss *Sim) ConfigEnv() {
tst.Sequential = true
tst.Validate()

// note: to create a train / test split of pats, do this:
// all := table.NewIndexView(ss.Pats)
// splits, _ := split.Permuted(all, []float64{.8, .2}, []string{"Train", "Test"})
// trn.Table = splits.Splits[0]
// tst.Table = splits.Splits[1]

trn.Init(0)
tst.Init(0)

Expand Down Expand Up @@ -325,7 +329,7 @@ func (ss *Sim) ConfigNet(net *axon.Network) {
// net.LateralConnectLayerPath(hid1, full, &axon.HebbPath{}).SetType(InhibPath)

// note: if you wanted to change a layer type from e.g., Target to Compare, do this:
// out.SetType(emer.Compare)
// out.Type = axon.CompareLayer
// that would mean that the output layer doesn't reflect target values in plus phase
// and thus removes error-driven learning -- but stats are still computed.

Expand All @@ -338,9 +342,6 @@ func (ss *Sim) ConfigNet(net *axon.Network) {

func (ss *Sim) ApplyParams() {
ss.Params.ApplyAll(ss.Net)
// if ss.Config.Params.Network != nil {
// ss.Params.SetNetworkMap(ss.Net, ss.Config.Params.Network)
// }
}

//////// Init, utils
Expand All @@ -359,7 +360,8 @@ func (ss *Sim) Init() {
ss.InitStats()
ss.NewRun()
ss.TrainUpdate.RecordSyns()
ss.TrainUpdate.Update(ss.StatCounters(Train, Trial))
// todo: need to pass the counters function here, instead of calling each time.
ss.TrainUpdate.Update(Train, Trial)
}

// InitRandSeed initializes the random seed based on current training run number
Expand Down Expand Up @@ -388,20 +390,22 @@ func (ss *Sim) NetViewUpdater(mode enums.Enum) *axon.NetViewUpdate {
func (ss *Sim) ConfigLoops() {
ls := looper.NewStacks()

trls := int(math32.IntMultipleGE(float32(ss.Config.Run.NTrials), float32(ss.Config.Run.NData)))
trials := int(math32.IntMultipleGE(float32(ss.Config.Run.NTrials), float32(ss.Config.Run.NData)))
cycles := 200
plusPhase := 50

ls.AddStack(Train, Trial).
AddLevel(Run, ss.Config.Run.NRuns).
AddLevel(Epoch, ss.Config.Run.NEpochs).
AddLevelIncr(Trial, trls, ss.Config.Run.NData).
AddLevel(Cycle, 200)
AddLevelIncr(Trial, trials, ss.Config.Run.NData).
AddLevel(Cycle, cycles)

ls.AddStack(Test, Trial).
AddLevel(Epoch, 1).
AddLevelIncr(Trial, trls, ss.Config.Run.NData).
AddLevel(Cycle, 200)
AddLevelIncr(Trial, trials, ss.Config.Run.NData).
AddLevel(Cycle, cycles)

axon.LooperStandard(ls, ss.Net, ss.NetViewUpdater, 50, 150, 199, Cycle, Trial, Train)
axon.LooperStandard(ls, ss.Net, ss.NetViewUpdater, 50, cycles-plusPhase, cycles-1, Cycle, Trial, Train)

ls.Stacks[Train].OnInit.Add("Init", func() { ss.Init() })

Expand Down Expand Up @@ -430,19 +434,14 @@ func (ss *Sim) ConfigLoops() {
}
})

//////// Stats

ls.AddOnStartToAll("StatsStart", ss.StatsStart)
ls.AddOnEndToAll("StatsStep", ss.StatsStep)

// Save weights to file, to look at later
ls.Loop(Train, Run).OnEnd.Add("SaveWeights", func() {
ctrString := fmt.Sprintf("%03d_%05d", ls.Loop(Train, Run).Counter.Cur, ls.Loop(Train, Epoch).Counter.Cur)
axon.SaveWeightsIfConfigSet(ss.Net, ss.Config.Log.SaveWeights, ctrString, ss.RunName())
})

//////// GUI

if ss.Config.GUI {
axon.LooperUpdateNetView(ls, Cycle, Trial, ss.NetViewUpdater, ss.StatCounters)

Expand All @@ -456,41 +455,36 @@ func (ss *Sim) ConfigLoops() {
ss.Loops = ls
}

// ApplyInputs applies input patterns from given environment.
// It is good practice to have this be a separate method with appropriate
// args so that it can be used for various different contexts
// (training, testing, etc).
// ApplyInputs applies input patterns from given environment for given mode.
// Any other start-of-trial logic can also be put here.
func (ss *Sim) ApplyInputs(mode Modes) {
net := ss.Net
ctx := net.Context()
ndata := int(ctx.NData)
ndata := int(net.Context().NData)
curModeDir := ss.Current.RecycleDir(mode.String())
ev := ss.Envs.ByMode(mode).(*env.FixedTable)
ev := ss.Envs.ByMode(mode)
lays := net.LayersByType(axon.InputLayer, axon.TargetLayer)
net.InitExt()
for di := range ndata {
ev.Step()
tensorfs.Value[string](curModeDir, "TrialName", ndata).SetString1D(ev.TrialName.Cur, di)
tensorfs.Value[string](curModeDir, "TrialName", ndata).SetString1D(ev.String(), di)
for _, lnm := range lays {
ly := ss.Net.LayerByName(lnm)
pats := ev.State(ly.Name)
if pats != nil {
ly.ApplyExt(uint32(di), pats)
st := ev.State(ly.Name)
if st != nil {
ly.ApplyExt(uint32(di), st)
}
}
}
net.ApplyExts() // now required for GPU mode
net.ApplyExts()
}

// NewRun intializes a new run of the model, using the TrainEnv.Run counter
// for the new run value
// NewRun intializes a new Run level of the model.
func (ss *Sim) NewRun() {
ctx := ss.Net.Context()
ss.InitRandSeed(ss.Loops.Loop(Train, Run).Counter.Cur)
ss.Envs.ByMode(Train).Init(0)
ss.Envs.ByMode(Test).Init(0)
ctx.Reset()
ctx.Mode = int32(Train)
ss.Net.InitWeights()
if ss.Config.Run.StartWts != "" { // this is just for testing -- not usually needed
ss.Net.OpenWeightsJSON(core.Filename(ss.Config.Run.StartWts))
Expand All @@ -502,10 +496,10 @@ func (ss *Sim) NewRun() {
func (ss *Sim) TestAll() {
ss.Envs.ByMode(Test).Init(0)
ss.Loops.ResetAndRun(Test)
ss.Loops.Mode = Train // Important to reset Mode back to Train because this is called from within the Train Run.
ss.Loops.Mode = Train // important because this is called from Train Run: go back.
}

//////// Pats
//////// Patterns

func (ss *Sim) ConfigPats() {
dt := table.New()
Expand All @@ -520,17 +514,15 @@ func (ss *Sim) ConfigPats() {
patgen.PermutedBinaryMinDiff(dt.ColumnByIndex(2).Tensor.(*tensor.Float32), 6, 1, 0, 3)
dt.SaveCSV("random_5x5_25_gen.tsv", tensor.Tab, table.Headers)

ss.Pats = ss.Root.RecycleDir("Pats")
tensorfs.DirFromTable(ss.Pats, dt)
tensorfs.DirFromTable(ss.Root.RecycleDir("Pats"), dt)
}

func (ss *Sim) OpenPats() {
dt := table.New()
metadata.SetName(dt, "TrainPats")
metadata.SetDoc(dt, "Training patterns")
errors.Log(dt.OpenCSV("random_5x5_25.tsv", tensor.Tab))
ss.Pats = ss.Root.RecycleDir("Pats")
tensorfs.DirFromTable(ss.Pats, dt)
tensorfs.DirFromTable(ss.Root.RecycleDir("Pats"), dt)
}

//////// Stats
Expand Down Expand Up @@ -794,12 +786,11 @@ func (ss *Sim) StatCounters(mode, level enums.Enum) string {
di := vu.View.Di
counters += fmt.Sprintf(" Di: %d", di)
curModeDir := ss.Current.RecycleDir(mode.String())
trialName := tensorfs.Value[string](curModeDir, "TrialName")
if trialName.Len() == 0 {
if curModeDir.Node("TrialName") == nil {
return counters
}
counters += fmt.Sprintf(" TrialName: %s", trialName.String1D(di))
if curModeDir.Node("UnitErr") == nil {
counters += fmt.Sprintf(" TrialName: %s", tensorfs.Value[string](curModeDir, "TrialName").String1D(di))
if level == Cycle || curModeDir.Node("UnitErr") == nil {
return counters
}
counters += fmt.Sprintf(" UnitErr: %g", tensorfs.Value[float64](curModeDir, "UnitErr").Float1D(di))
Expand All @@ -819,10 +810,11 @@ func (ss *Sim) ConfigGUI() {
nv := ss.GUI.AddNetView("Network")
nv.Options.MaxRecs = 300
nv.SetNet(ss.Net)
ss.TrainUpdate.Config(nv, axon.Phase)
ss.TestUpdate.Config(nv, axon.Phase)
ss.GUI.OnStop = func() {
ss.TrainUpdate.UpdateWhenStopped()
ss.TrainUpdate.Config(nv, axon.Phase, ss.StatCounters)
ss.TestUpdate.Config(nv, axon.Phase, ss.StatCounters)
ss.GUI.OnStop = func(mode, level enums.Enum) {
vu := ss.NetViewUpdater(mode)
vu.UpdateWhenStopped(mode, level) // todo: carry this all the way through
}

nv.SceneXYZ().Camera.Pose.Pos.Set(0, 1, 2.75) // more "head on" than default which is more "top down"
Expand Down
Loading

0 comments on commit 7a41f99

Please sign in to comment.