Skip to content

Commit

Permalink
complete rewrite of compute code to use GPU indexing consistently on …
Browse files Browse the repository at this point in the history
…CPU -- same code now.
  • Loading branch information
rcoreilly committed Oct 10, 2024
1 parent 64ab03d commit df9554b
Show file tree
Hide file tree
Showing 33 changed files with 1,476 additions and 1,740 deletions.
File renamed without changes.
File renamed without changes.
12 changes: 4 additions & 8 deletions axon/act.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 4 additions & 8 deletions axon/act.goal
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
// act.go contains the activation params and functions for axon

//gosl:start
//gosl:import chans
//gosl:import "github.com/emer/axon/v2/chans"

//////////////////////////////////////////////////////////////////////////////////////
// SpikeParams
Expand Down Expand Up @@ -212,8 +212,6 @@ func (ai *ActInitParams) Defaults() {
ai.GiVar = 0
}

//gosl:end

// GeBase returns the baseline Ge value: Ge + rand(GeVar) > 0
func (ai *ActInitParams) GetGeBase(rnd randx.Rand) float32 {
ge := ai.GeBase
Expand Down Expand Up @@ -458,7 +456,7 @@ func (an *SpikeNoiseParams) ShouldDisplay(field string) bool {
// PGe updates the GeNoiseP probability, multiplying a uniform random number [0-1]
// and returns Ge from spiking if a spike is triggered
func (an *SpikeNoiseParams) PGe(ctx *Context, p *float32, ni, di uint32) float32 {
ndi := di*ctx.NetIndexes.NNeurons + ni
ndi := di*NetIxs().NNeurons + ni
*p *= GetRandomNumber(ndi, ctx.RandCtr, RandFunActPGe)
if *p <= an.GeExpInt {
*p = 1
Expand All @@ -470,7 +468,7 @@ func (an *SpikeNoiseParams) PGe(ctx *Context, p *float32, ni, di uint32) float32
// PGi updates the GiNoiseP probability, multiplying a uniform random number [0-1]
// and returns Gi from spiking if a spike is triggered
func (an *SpikeNoiseParams) PGi(ctx *Context, p *float32, ni, di uint32) float32 {
ndi := di*ctx.NetIndexes.NNeurons + ni
ndi := di*NetIxs().NNeurons + ni
*p *= GetRandomNumber(ndi, ctx.RandCtr, RandFunActPGi)
if *p <= an.GiExpInt {
*p = 1
Expand Down Expand Up @@ -937,8 +935,6 @@ func (ac *ActParams) DecayState(ctx *Context, ni, di uint32, decay, glong, ahp f
Neurons[CtxtGeOrig, ni, di] -= glong * Neurons[CtxtGeOrig, ni, di]
}

//gosl:end

// InitActs initializes activation state in neuron -- called during InitWeights but otherwise not
// automatically called (DecayState is used instead)
func (ac *ActParams) InitActs(ctx *Context, ni, di uint32) {
Expand Down Expand Up @@ -1074,7 +1070,7 @@ func (ac *ActParams) SMaintFromISI(ctx *Context, ni, di uint32) {
if isi < ac.SMaint.ISI.Min || isi > ac.SMaint.ISI.Max {
return
}
ndi := di*ctx.NetIndexes.NNeurons + ni
ndi := di*NetIxs().NNeurons + ni
smp := Neurons[SMaintP, ni, di]
smp *= GetRandomNumber(ndi, ctx.RandCtr, RandFunActSMaintP)
trg := ac.SMaint.ExpInt(isi)
Expand Down
6 changes: 3 additions & 3 deletions axon/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1416,23 +1416,23 @@ func TestGlobalIndexes(t *testing.T) {
for vv := GvCost; vv <= GvCostRaw; vv++ {
for ui := uint32(0); ui < pv.NCosts; ui++ {
for di := uint32(0); di < nData; di++ {
SetGlbCostV(ctx, di, vv, ui, val)
GlobalVectors[vv, ui, di] = val
val += 1
}
}
}
for vv := GvUSneg; vv <= GvUSnegRaw; vv++ {
for ui := uint32(0); ui < pv.NNegUSs; ui++ {
for di := uint32(0); di < nData; di++ {
SetGlbUSnegV(ctx, di, vv, ui, val)
GlobalVectors[vv, ui, di] = val
val += 1
}
}
}
for vv := GvDrives; vv < GlobalVarsN; vv++ {
for ui := uint32(0); ui < pv.NPosUSs; ui++ {
for di := uint32(0); di < nData; di++ {
SetGlbUSposV(ctx, di, vv, ui, val)
GlobalVectors[vv, ui, di] = val
val += 1
}
}
Expand Down
108 changes: 5 additions & 103 deletions axon/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,112 +9,18 @@ import (
"github.com/emer/emergent/v2/etime"
)

// CopyNetStridesFrom copies strides and NetIndexes for accessing
// variables on a Network -- these must be set properly for
// the Network in question (from its Ctx field) before calling
// any compute methods with the context. See SetCtxStrides on Network.
func (ctx *Context) CopyNetStridesFrom(srcCtx *Context) {
ctx.NetIndexes = srcCtx.NetIndexes
}

//gosl:start

// NetIndexes are indexes and sizes for processing network
type NetIndexes struct {

// number of data parallel items to process currently
NData uint32 `min:"1"`

// network index in global Networks list of networks
NetIndex uint32 `edit:"-"`

// maximum amount of data parallel
MaxData uint32 `edit:"-"`

// number of layers in the network
NLayers uint32 `edit:"-"`

// total number of neurons
NNeurons uint32 `edit:"-"`

// total number of pools excluding * MaxData factor
NPools uint32 `edit:"-"`

// total number of synapses
NSyns uint32 `edit:"-"`

// maximum size in float32 (4 bytes) of a GPU buffer -- needed for GPU access
GPUMaxBuffFloats uint32 `edit:"-"`

// total number of SynCa banks of GPUMaxBufferBytes arrays in GPU
GPUSynCaBanks uint32 `edit:"-"`

// total number of .Rubicon Drives / positive USs
RubiconNPosUSs uint32 `edit:"-"`

// total number of .Rubicon Costs
RubiconNCosts uint32 `edit:"-"`

// total number of .Rubicon Negative USs
RubiconNNegUSs uint32 `edit:"-"`
}

// ValuesIndex returns the global network index for LayerValues
// with given layer index and data parallel index.
func (ctx *NetIndexes) ValuesIndex(li, di uint32) uint32 {
return li*ctx.MaxData + di
}

// ItemIndex returns the main item index from an overall index over NItems * MaxData
// (items = layers, neurons, synapeses)
func (ctx *NetIndexes) ItemIndex(idx uint32) uint32 {
return idx / ctx.MaxData
}

// DataIndex returns the data index from an overall index over N * MaxData
func (ctx *NetIndexes) DataIndex(idx uint32) uint32 {
return idx % ctx.MaxData
}

// DataIndexIsValid returns true if the data index is valid (< NData)
func (ctx *NetIndexes) DataIndexIsValid(li uint32) bool {
return (li < ctx.NData)
}

// LayerIndexIsValid returns true if the layer index is valid (< NLayers)
func (ctx *NetIndexes) LayerIndexIsValid(li uint32) bool {
return (li < ctx.NLayers)
}

// NeurIndexIsValid returns true if the neuron index is valid (< NNeurons)
func (ctx *NetIndexes) NeurIndexIsValid(ni uint32) bool {
return (ni < ctx.NNeurons)
}

// PoolIndexIsValid returns true if the pool index is valid (< NPools)
func (ctx *NetIndexes) PoolIndexIsValid(pi uint32) bool {
return (pi < ctx.NPools)
}

// PoolDataIndexIsValid returns true if the pool*data index is valid (< NPools*MaxData)
func (ctx *NetIndexes) PoolDataIndexIsValid(pi uint32) bool {
return (pi < ctx.NPools*ctx.MaxData)
}

// SynIndexIsValid returns true if the synapse index is valid (< NSyns)
func (ctx *NetIndexes) SynIndexIsValid(si uint32) bool {
return (si < ctx.NSyns)
}

// Context contains all of the global context state info
// that is shared across every step of the computation.
// It is passed around to all relevant computational functions,
// and is updated on the CPU and synced to the GPU after every cycle.
// It is the *only* mechanism for communication from CPU to GPU.
// It contains timing, Testing vs. Training mode, random number context,
// global neuromodulation, etc.
// It contains timing, Testing vs. Training mode, random number context, etc.
type Context struct {

// number of data parallel items to process currently.
NData uint32 `min:"1"`

// current evaluation mode, e.g., Train, Test, etc
Mode etime.Modes

Expand Down Expand Up @@ -162,14 +68,10 @@ type Context struct {
// many are actually used. This is shared across all layers so must
// encompass all possible param settings.
RandCtr uint64

// indexes and sizes of current network
NetIndexes NetIndexes `display:"inline"`
}

// Defaults sets default values
func (ctx *Context) Defaults() {
ctx.NetIndexes.NData = 1
ctx.TimePerCycle = 0.001
ctx.ThetaCycles = 200
ctx.SlowInterval = 100
Expand All @@ -195,7 +97,7 @@ func (ctx *Context) CycleInc() {
// SlowInc increments the Slow counter and returns true if time
// to perform SlowAdapt functions (associated with sleep).
func (ctx *Context) SlowInc() bool {
ctx.SlowCtr += int32(ctx.NetIndexes.NData)
ctx.SlowCtr += int32(ctx.NData)
if ctx.SlowCtr < ctx.SlowInterval {
return false
}
Expand Down
Loading

0 comments on commit df9554b

Please sign in to comment.