Skip to content

Commit

Permalink
choice added GroupMinMax option for organizing min / max distances in…
Browse files Browse the repository at this point in the history
… groups -- provides clearest test case for effort-based choice -- not learning enough about diff CSs b/c BLA is US!
  • Loading branch information
rcoreilly committed Apr 21, 2024
1 parent f7615a9 commit 7a7ad13
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 15 deletions.
8 changes: 6 additions & 2 deletions examples/choose/armaze/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ type Config struct {
// number of different arms
NArms int

// maximum arm length (distance)
MaxArmLength int
// minimum arm length (distance) range (inclusive)
ArmLengths minmax.Int

// If true, group arms by minimum vs. maximum lengths -- provides a better test when using the always left strategy
GroupMinMax bool

// number of different CSs -- typically at least a unique CS per US -- relationship is determined in the US params
NCSs int
Expand All @@ -75,6 +78,7 @@ func (cfg *Config) Defaults() {
if cfg.NDrives == 0 {
cfg.NDrives = 4
}
cfg.ArmLengths.Set(4, 4)
cfg.Update()
if cfg.NCSs == 0 {
cfg.NCSs = cfg.NUSs
Expand Down
41 changes: 32 additions & 9 deletions examples/choose/armaze/maze.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,28 +172,29 @@ func (ev *Env) ConfigEnv(di int) {
cfg.Update()

ev.Drives = make([]float32, cfg.NDrives)
ev.Config.USs = make([]*USParams, cfg.NUSs)
ev.Config.Arms = make([]*Arm, cfg.NArms)
cfg.USs = make([]*USParams, cfg.NUSs)
cfg.Arms = make([]*Arm, cfg.NArms)

// log.Printf("drives: %d, USs: %d, CSs: %d", cfg.NDrives, cfg.NUSs, cfg.NCSs)
// log.Printf("max arm length: %d", cfg.MaxArmLength)

// defaults
for i := range ev.Config.Arms {
for i := range cfg.Arms {
// TODO: if we permute CSs do we also want to keep the USs aligned?
length := 4
if ev.Config.MaxArmLength > 0 {
length = length + ev.Rand.Intn(ev.Config.MaxArmLength, -1)
length := cfg.ArmLengths.Min
lrng := cfg.ArmLengths.Range()
if lrng > 0 {
length += ev.Rand.Intn(lrng, -1)
}
arm := &Arm{Length: length, CS: i % cfg.NCSs, US: i % cfg.NUSs}
ev.Config.Arms[i] = arm
cfg.Arms[i] = arm
arm.Effort.Set(1, 1)
}

// defaults
for i := range ev.Config.USs {
for i := range cfg.USs {
us := &USParams{Prob: 1}
ev.Config.USs[i] = us
cfg.USs[i] = us
if i < cfg.NDrives {
us.Negative = false
} else {
Expand All @@ -202,9 +203,31 @@ func (ev *Env) ConfigEnv(di int) {
us.Mag.Set(1, 1)
}

if cfg.GroupMinMax {
ev.ConfigGroupMinMax()
}

ev.UpdateMaxLength()
}

func (ev *Env) ConfigGroupMinMax() {
cfg := &ev.Config
// defaults
// narms := cfg.NArms
// nalts := narms / cfg.NUSs
for i, arm := range cfg.Arms {
ci := i / cfg.NUSs
ui := i % cfg.NUSs
arm.CS = i
arm.US = ui
if ci%2 == 0 {
arm.Length = cfg.ArmLengths.Max
} else {
arm.Length = cfg.ArmLengths.Min
}
}
}

func (ev *Env) Validate() error {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion examples/choose/choose.go
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ func (ss *Sim) ConfigLogs() {
// todo: PCA items should apply to CT layers too -- pass a type here.
axon.LogAddPCAItems(&ss.Logs, ss.Net, etime.Train, etime.Run, etime.Epoch, etime.Trial)

ss.Logs.PlotItems("ActMatch", "GateCS", "Deciding", "GateUS", "WrongCSGate", "Rew_R", "RewPred_R", "DA_R", "RewPred_NR", "DA_NR", "MaintEarly")
ss.Logs.PlotItems("GateCS", "GateUS", "WrongCSGate", "Rew_R", "RewPred_R", "DA_R", "MaintEarly")

ss.Logs.CreateTables()
ss.Logs.SetContext(&ss.Stats, ss.Net)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@
NArms = 8
NDrives = 4
NCSs = 8
MaxArmLength = 8
ArmLengths.Min = 4
ArmLengths.Max = 8
GroupMinMax = true

2 changes: 1 addition & 1 deletion examples/choose/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ var ParamSets = netparams.Sets{
Params: params.Params{
"Prjn.PrjnScale.Abs": "3", // 4 > 3 > 2 -- key for rapid learning
"Prjn.Learn.Trace.LearnThr": "0",
"Prjn.Learn.LRate.Base": "0.02", // 0.02 needed in test
"Prjn.Learn.LRate.Base": "0.05", // 0.02 needed in test
}},
{Sel: "#CSToBLAposAcqD1", Desc: "",
Params: params.Params{
Expand Down
4 changes: 3 additions & 1 deletion examples/pvlv/cond/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ func (ev *CondEnv) RenderSequence(trli, tick int) {
ev.CurTick.USOn = false
if trl.USOn && (tick >= trl.USStart) && (tick <= trl.USEnd) {
ev.CurTick.USOn = true
ev.CurTick.Type = US
if trl.Valence == Pos {
SetUS(ev.CurStates["USpos"], ev.NYReps, trl.US, trl.USMag)
ev.SequenceName += fmt.Sprintf("_Pos%d", trl.US)
Expand All @@ -261,6 +260,9 @@ func (ev *CondEnv) RenderSequence(trli, tick int) {
ev.SequenceName += fmt.Sprintf("_Neg%d", trl.US)
}
}
if (tick >= trl.USStart) && (tick <= trl.USEnd) {
ev.CurTick.Type = US // even if not on, this is the type
}
if tick > trl.USEnd {
ev.CurTick.Type = Post
}
Expand Down

0 comments on commit 7a7ad13

Please sign in to comment.