Skip to content

Commit

Permalink
pbwm: put GateState in Pool; bg sim is first-pass working.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Oct 17, 2024
1 parent 56b86c5 commit cf74bcf
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 55 deletions.
10 changes: 5 additions & 5 deletions leabra/enumgen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 3 additions & 8 deletions leabra/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,13 @@ func (ly *Layer) InitActs() {
}
for pi := range ly.Pools {
pl := &ly.Pools[pi]
pl.Inhib.Init()
pl.Init()
pl.ActM.Init()
pl.ActP.Init()
}
ly.DA = 0
ly.ACh = 0
ly.SE = 0
for si := range ly.GateStates {
gs := &ly.GateStates[si]
gs.Init()
}
}

// InitWeightsSym initializes the weight symmetry.
Expand Down Expand Up @@ -446,10 +442,9 @@ func (ly *Layer) DecayState(decay float32) {
}

// DecayStatePool decays activation state by given proportion
// in given sub-pool index (0 based).
// in given pool index (sub pools start at 1).
func (ly *Layer) DecayStatePool(pool int, decay float32) {
pi := int32(pool + 1) // 1 based
pl := &ly.Pools[pi]
pl := &ly.Pools[pool]
for ni := pl.StIndex; ni < pl.EdIndex; ni++ {
nrn := &ly.Neurons[ni]
if nrn.IsOff() {
Expand Down
16 changes: 6 additions & 10 deletions leabra/layerbase.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ type Layer struct {
// current serotonin level for this layer
SE float32 `read-only:"+"`

// GateStates (PBWM) is a slice of gating state values for this layer,
// one for each separate gating pool, according to its GateType.
// For MaintOut, it is ordered such that 0:MaintN are Maint and MaintN:n are Out.
GateStates []GateState

// SendTo is a list of layers that this layer sends special signals to,
// which could be dopamine, gating signals, depending on the layer type.
SendTo []string
Expand Down Expand Up @@ -174,7 +169,7 @@ func (ly *Layer) ShouldDisplay(field string) bool {
return ly.Type == RWPredLayer || ly.Type == RWDaLayer
case "TD":
return ly.Type == TDRewIntegLayer || ly.Type == TDDaLayer
case "PBWM", "GateStates":
case "PBWM":
return isPBWM
case "SendTo":
return ly.Type == GPiThalLayer || ly.Type == ClampDaLayer || ly.Type == RWDaLayer || ly.Type == TDDaLayer
Expand All @@ -184,6 +179,8 @@ func (ly *Layer) ShouldDisplay(field string) bool {
return ly.Type == GPiThalLayer
case "PFCGate", "PFCMaint":
return ly.Type == PFCLayer || ly.Type == PFCDeepLayer
case "PFCDyns":
return ly.Type == PFCDeepLayer
default:
return true
}
Expand Down Expand Up @@ -269,7 +266,6 @@ func (ly *Layer) UnitValue1D(varIndex int, idx int, di int) float32 {
nrn := &ly.Neurons[idx]
da := NeuronVarsMap["DA"]
if varIndex >= da {
gs := ly.GateStates[int(nrn.SubPool)-1] // 0-based
switch varIndex - da {
case 0:
return ly.DA
Expand All @@ -278,11 +274,11 @@ func (ly *Layer) UnitValue1D(varIndex int, idx int, di int) float32 {
case 2:
return ly.SE
case 3:
return gs.Act
return ly.Pools[nrn.SubPool].Gate.Act
case 4:
return num.FromBool[float32](gs.Now)
return num.FromBool[float32](ly.Pools[nrn.SubPool].Gate.Now)
case 5:
return float32(gs.Cnt)
return float32(ly.Pools[nrn.SubPool].Gate.Cnt)
}
}
return nrn.VarByIndex(varIndex)
Expand Down
1 change: 1 addition & 0 deletions leabra/networkbase.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ func (nt *Network) Build() error {
var errs []error
for li, ly := range nt.Layers {
ly.Index = li
ly.Network = nt
if ly.Off {
continue
}
Expand Down
2 changes: 2 additions & 0 deletions leabra/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ func (pt *Path) DWt() {
pt.DWtRW()
case pt.Type == TDRewPredPath:
pt.DWtTDRewPred()
case pt.Type == DaHebbPath:
pt.DWtDaHebb()
default:
pt.DWtStd()
}
Expand Down
3 changes: 2 additions & 1 deletion leabra/pathbase.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ func (pt *Path) DefaultsForType() {
pt.RWDefaults()
case MatrixPath:
pt.MatrixDefaults()
case GPiThalPath:
case DaHebbPath:
pt.DaHebbDefaults()
}
}

Expand Down
4 changes: 4 additions & 0 deletions leabra/pathtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ const (
// GPiThalPath accumulates per-path raw conductance that is needed for
// separately weighting NoGo vs. Go inputs.
GPiThalPath

// DaHebbPath does dopamine-modulated Hebbian learning -- i.e., the 3-factor
// learning rule: Da * Recv.Act * Send.Act
DaHebbPath
)
56 changes: 26 additions & 30 deletions leabra/pbwm_layers.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,11 @@ func (ly *Layer) DaAChFromLay(ctx *Context) {
// RecGateAct records the gating activation from current activation, when gating occcurs
// based on GateState.Now
func (ly *Layer) RecGateAct(ctx *Context) {
for gi := range ly.GateStates {
gs := &ly.GateStates[gi]
if !gs.Now { // not gating now
for pi := range ly.Pools {
pl := &ly.Pools[pi]
if !pl.Gate.Now { // not gating now
continue
}
pl := &ly.Pools[1+gi]
for ni := pl.StIndex; ni < pl.EdIndex; ni++ {
nrn := &ly.Neurons[ni]
if nrn.IsOff() {
Expand Down Expand Up @@ -388,31 +387,25 @@ func (gs *GateState) CopyFrom(fm *GateState) {
gs.Now = fm.Now
}

// SetGateState sets the GateState for given pool index (individual pools start at 1) on this layer
func (ly *Layer) SetGateState(poolIndex int, state *GateState) {
gs := &ly.GateStates[poolIndex]
gs.CopyFrom(state)
}

// SetGateStates sets the GateStates from given source states, of given gating type
func (ly *Layer) SetGateStates(states []GateState, typ GateTypes) {
func (ly *Layer) SetGateStates(src *Layer, typ GateTypes) {
myt := ly.PBWM.Type
if myt < MaintOut && typ < MaintOut && myt != typ { // mismatch
return
}
switch {
case myt == typ:
mx := min(len(states), len(ly.GateStates))
for i := 0; i < mx; i++ {
ly.SetGateState(i, &states[i])
mx := min(len(src.Pools), len(ly.Pools))
for i := 1; i < mx; i++ {
ly.Pool(i).Gate.CopyFrom(&src.Pool(i).Gate)
}
default: // typ == MaintOut, myt = Maint or Out
mx := len(ly.GateStates)
for i := 0; i < mx; i++ {
gs := &ly.GateStates[i]
mx := len(ly.Pools)
for i := 1; i < mx; i++ {
gs := &ly.Pool(i).Gate
si := ly.PBWM.FullIndex1D(i, myt)
src := &states[si]
gs.CopyFrom(src)
sgs := &src.Pool(si).Gate
gs.CopyFrom(sgs)
}
}
}
Expand Down Expand Up @@ -451,7 +444,7 @@ func (ly *Layer) GPiGateFromAct(ctx *Context) {
if nrn.IsOff() {
continue
}
gs := ly.GateStates[int(nrn.SubPool)-1]
gs := &ly.Pool(int(nrn.SubPool)).Gate
if ctx.Quarter == 0 && qtrCyc == 0 {
gs.Act = 0 // reset at start
}
Expand Down Expand Up @@ -482,7 +475,7 @@ func (ly *Layer) GPiSendGateStates() {
myt := MaintOut
for _, lnm := range ly.SendTo {
gl := ly.Network.LayerByName(lnm)
gl.SetGateStates(ly.GateStates, myt)
gl.SetGateStates(ly, myt)
}
}

Expand Down Expand Up @@ -690,20 +683,20 @@ func (ly *Layer) PFCDeepGating(ctx *Context) {
}
}

for gi := range ly.GateStates {
gs := &ly.GateStates[gi]
for pi := range ly.Pools {
gs := &ly.Pools[pi].Gate
if !gs.Now { // not gating now
continue
}
if gs.Act > 0 { // use GPiThal threshold, so anything > 0
gs.Cnt = 0 // this is the "just gated" signal
if ly.PFCGate.OutGate { // time to clear out maint
if ly.PFCMaint.OutClearMaint {
ly.ClearMaint(gi)
ly.ClearMaint(pi)
}
} else {
pfcs := ly.SuperPFC()
pfcs.DecayStatePool(gi, ly.PFCMaint.Clear)
pfcs.DecayStatePool(pi, ly.PFCMaint.Clear)
}
}
// test for over-duration maintenance -- allow for active gating to override
Expand All @@ -719,8 +712,8 @@ func (ly *Layer) ClearMaint(pool int) {
if pfcm == nil {
return
}
gs := &pfcm.GateStates[pool] // 0 based
if gs.Cnt >= 1 { // important: only for established maint, not just gated..
gs := &pfcm.Pools[pool].Gate
if gs.Cnt >= 1 { // important: only for established maint, not just gated..
gs.Cnt = -1 // reset
pfcs := pfcm.SuperPFC()
pfcs.DecayStatePool(pool, pfcm.PFCMaint.Clear)
Expand Down Expand Up @@ -758,7 +751,7 @@ func (ly *Layer) DeepMaint(ctx *Context) {
uy := ui / xN
ux := ui % xN

gs := &ly.GateStates[nrn.SubPool-1]
gs := &ly.Pool(int(nrn.SubPool)).Gate
if gs.Cnt < 0 {
nrn.Maint = 0
nrn.MaintGe = 0
Expand All @@ -781,8 +774,11 @@ func (ly *Layer) UpdateGateCnt(ctx *Context) {
if !ly.PFCGate.GateQtr.HasFlag(ctx.Quarter) {
return
}
for gi := range ly.GateStates {
gs := &ly.GateStates[gi]
for pi := range ly.Pools {
if pi == 0 {
continue
}
gs := &ly.Pools[pi].Gate
if gs.Cnt < 0 {
// ly.ClearCtxtPool(gi)
gs.Cnt--
Expand Down
1 change: 0 additions & 1 deletion leabra/pbwm_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,5 @@ func (nt *Network) AddPBWM(prefix string, nY, nMaint, nOut, nNeurBgY, nNeurBgX,
pfcMnt.PlaceAbove(mtxGo)
}
gpi.SendToMatrixPFC(prefix) // sends gating to all these layers
gpi.SendPBWMParams()
return
}
31 changes: 31 additions & 0 deletions leabra/pbwm_paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,34 @@ func (pt *Path) DWtMatrix() {
}
}
}

//////// DaHebbPath

func (pj *Path) DaHebbDefaults() {
pj.Learn.WtSig.Gain = 1
pj.Learn.Norm.On = false
pj.Learn.Momentum.On = false
pj.Learn.WtBal.On = false
}

// DWtDaHebb computes the weight change (learning), for [DaHebbPath].
func (pj *Path) DWtDaHebb() {
slay := pj.Send
rlay := pj.Recv
for si := range slay.Neurons {
sn := &slay.Neurons[si]
nc := int(pj.SConN[si])
st := int(pj.SConIndexSt[si])
syns := pj.Syns[st : st+nc]
scons := pj.SConIndex[st : st+nc]

for ci := range syns {
sy := &syns[ci]
ri := scons[ci]
rn := &rlay.Neurons[ri]
da := rn.DALrn
dwt := da * rn.Act * sn.Act
sy.DWt += pj.Learn.Lrate * dwt
}
}
}
4 changes: 4 additions & 0 deletions leabra/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ type Pool struct {

// running-average activation levels used for netinput scaling and adaptive inhibition
ActAvg ActAvg

// Gate is gating state for PBWM layers
Gate GateState
}

func (pl *Pool) Init() {
pl.Inhib.Init()
pl.Gate.Init()
}

// ActAvg are running-average activation levels used for netinput scaling and adaptive inhibition
Expand Down

0 comments on commit cf74bcf

Please sign in to comment.