Skip to content

Commit

Permalink
minus phase fully on GPU. not getting any state back until plus phase…
Browse files Browse the repository at this point in the history
… now.
  • Loading branch information
rcoreilly committed Nov 25, 2024
1 parent 176abe6 commit f98952d
Show file tree
Hide file tree
Showing 13 changed files with 294 additions and 224 deletions.
43 changes: 32 additions & 11 deletions axon/act-layer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 32 additions & 11 deletions axon/act-layer.goal
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@ func (ly *LayerParams) CyclePost(ctx *Context, di uint32) {
lpi := ly.PoolIndex(0)
ly.CyclePostLayer(ctx, lpi, di)
switch ly.Type {
case MatrixLayer, BGThalLayer:
ly.GatedFromSpkMax(ctx, di)
case CeMLayer:
ly.CyclePostCeMLayer(ctx, lpi, di)
case VSPatchLayer:
Expand Down Expand Up @@ -849,6 +851,24 @@ func (ly *LayerParams) CyclePostVSPatchLayer(ctx *Context, pi, di uint32, spi in

//////// Phase timescale

// DecayStateNeuronsAll decays neural activation state by given proportion
// (default decay values are ly.Params.Acts.Decay.Act, Glong, AHP)
// for all data parallel indexes. Does not decay pool or layer state.
// This is used for minus phase of Pulvinar layers to clear state in prep
// for driver plus phase.
func (ly *LayerParams) DecayStateNeuronsAll(ctx *Context, decay, glong, ahp float32) {
nn := ly.Indexes.NNeurons
for lni := uint32(0); lni < nn; lni++ {
ni := ly.Indexes.NeurSt + lni
if NeuronIsOff(ni) {
continue
}
for di := uint32(0); di < ctx.NData; di++ {
ly.Acts.DecayState(ctx, ni, di, decay, glong, ahp)
}
}
}

// NewStateLayer does NewState at the layer level, called
func (ly *LayerParams) NewStateLayer(ctx *Context) {
actMinusAvg := float32(0)
Expand Down Expand Up @@ -950,6 +970,17 @@ func (ly *LayerParams) MinusPhaseNeuron(ctx *Context, ni, di uint32) {
Neurons[ni, di, CaSpkPM] = Neurons[ni, di, CaSpkP]
}

// MinusPhasePost does special algorithm processing at end of minus
func (ly *LayerParams) MinusPhasePost(ctx *Context) {
switch ly.Type {
case MatrixLayer:
ly.MatrixGated(ctx) // need gated state for decisions about action processing, so do in minus too
case PulvinarLayer:
ly.DecayStateNeuronsAll(ctx, 1, 1, 0)
default:
}
}

// PlusPhaseStartNeuron does neuron level plus-phase start:
// applies Target inputs as External inputs.
func (ly *LayerParams) PlusPhaseStartNeuron(ctx *Context, ni, di uint32) {
Expand Down Expand Up @@ -1195,16 +1226,6 @@ func (ly *Layer) UpdateExtFlags(ctx *Context) {
}
}

// MinusPhasePost does special algorithm processing at end of minus
func (ly *Layer) MinusPhasePost(ctx *Context) {
switch ly.Type {
case MatrixLayer:
ly.MatrixGated(ctx) // need gated state for decisions about action processing, so do in minus too
case PulvinarLayer:
ly.DecayStateNeuronsAll(ctx, 1, 1, 0)
}
}

// PlusPhasePost does special algorithm processing at end of plus
func (ly *Layer) PlusPhasePost(ctx *Context) {
ly.PlusPhaseActAvg(ctx)
Expand Down Expand Up @@ -1235,7 +1256,7 @@ func (ly *Layer) PlusPhasePost(ctx *Context) {
}
switch ly.Type {
case MatrixLayer:
ly.MatrixGated(ctx)
ly.Params.MatrixGated(ctx)
}
}

Expand Down
27 changes: 7 additions & 20 deletions axon/act-net.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 7 additions & 20 deletions axon/act-net.goal
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ func (nt *Network) Cycle(ncyc int, getNeurons bool) {

if getNeurons {
RunDoneLayersNeurons()
} else {
RunDoneLayers()
}

// todo: fix this:
Expand Down Expand Up @@ -99,7 +97,6 @@ func (nt *Network) ApplyExts() {
ctx := nt.Context()
nd := int(nix.NNeurons * ctx.NData)
RunApplyExtsNeuron(nd)
// note: not completed until cycle is run
}

// MinusPhase does updating after end of minus phase.
Expand All @@ -110,22 +107,7 @@ func (nt *Network) MinusPhase() {
pd := int(nix.NPools * ctx.NData)
RunMinusPhasePool(pd)
RunMinusPhaseNeuron(nd)
RunDoneLayersNeurons()
nt.MinusPhasePost()
ToGPULayersNeurons()
// todo:
// nt.GPU.SyncStateToGPU()
}

// MinusPhasePost does special CPU post processing.
func (nt *Network) MinusPhasePost() {
ctx := nt.Context()
for _, ly := range nt.Layers {
if ly.Off {
continue
}
ly.MinusPhasePost(ctx)
}
RunMinusPhasePost(int(nix.NLayers))
}

// PlusPhaseStart does updating at the start of the plus phase:
Expand All @@ -135,7 +117,6 @@ func (nt *Network) PlusPhaseStart() {
ctx := nt.Context()
nd := int(nix.NNeurons * ctx.NData)
RunPlusPhaseStartNeuron(nd)
RunDone()
}

// PlusPhase does updating after end of plus phase
Expand Down Expand Up @@ -351,6 +332,12 @@ func MinusPhaseNeuron(i uint32) { //gosl:kernel
Layers[li].MinusPhaseNeuron(ctx, ni, di)
}

// MinusPhasePost does special algorithm post processing.
func MinusPhasePost(li uint32) { //gosl:kernel
ctx := GetCtx(0)
Layers[li].MinusPhasePost(ctx)
}

// PlusPhaseStartNeuron is the kernel over Neurons * Data to
// do neuron-level updating at start of plus phase.
func PlusPhaseStartNeuron(i uint32) { //gosl:kernel
Expand Down
43 changes: 43 additions & 0 deletions axon/gosl.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 0 additions & 18 deletions axon/init-layer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 0 additions & 18 deletions axon/init-layer.goal
Original file line number Diff line number Diff line change
Expand Up @@ -343,22 +343,4 @@ func (ly *Layer) DecayStatePool(ctx *Context, pool int, decay, glong, ahp float3
}
}

// DecayStateNeuronsAll decays neural activation state by given proportion
// (default decay values are ly.Params.Acts.Decay.Act, Glong, AHP)
// for all data parallel indexes. Does not decay pool or layer state.
// This is used for minus phase of Pulvinar layers to clear state in prep
// for driver plus phase.
func (ly *Layer) DecayStateNeuronsAll(ctx *Context, decay, glong, ahp float32) {
nn := ly.NNeurons
for lni := uint32(0); lni < nn; lni++ {
ni := ly.NeurStIndex + lni
if NeuronIsOff(ni) {
continue
}
for di := uint32(0); di < ctx.NData; di++ {
ly.Params.Acts.DecayState(ctx, ni, di, decay, glong, ahp)
}
}
}


Loading

0 comments on commit f98952d

Please sign in to comment.