diff --git a/src/lava/proc/plateau/models.py b/src/lava/proc/plateau/models.py new file mode 100644 index 000000000..684ff29e5 --- /dev/null +++ b/src/lava/proc/plateau/models.py @@ -0,0 +1,149 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + + +import numpy as np +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.model.py.ports import PyInPort, PyOutPort +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.resources import CPU +from lava.magma.core.decorator import implements, requires, tag +from lava.magma.core.model.py.model import PyLoihiProcessModel +from lava.proc.plateau.process import Plateau + + +@implements(proc=Plateau, protocol=LoihiProtocol) +@requires(CPU) +@tag("fixed_pt") +class PyPlateauModelFixed(PyLoihiProcessModel): + """ Implementation of Plateau neuron process in fixed point precision. + + Precisions of state variables + + - dv_dend : unsigned 12-bit integer (0 to 4095) + - dv_soma : unsigned 12-bit integer (0 to 4095) + - vth_dend : unsigned 17-bit integer (0 to 131071) + - vth_soma : unsigned 17-bit integer (0 to 131071) + - up_dur : unsigned 8-bit integer (0 to 255) + """ + + a_dend_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16) + a_soma_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16) + s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24) + v_dend: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24) + v_soma: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24) + dv_dend: int = LavaPyType(int, np.uint16, precision=12) + dv_soma: int = LavaPyType(int, np.uint16, precision=12) + vth_dend: int = LavaPyType(int, np.int32, precision=17) + vth_soma: int = LavaPyType(int, np.int32, precision=17) + up_dur: int = LavaPyType(int, np.uint16, precision=8) + up_state: int = LavaPyType(np.ndarray, np.uint16, precision=8) + + def __init__(self, proc_params): + super(PyPlateauModelFixed, self).__init__(proc_params) + self._validate_inputs(proc_params) + self.uv_bitwidth = 24 + self.max_uv_val = 2 ** (self.uv_bitwidth - 1) + self.decay_shift = 12 + self.decay_unity = 2 ** self.decay_shift - 1 + self.vth_shift = 6 + self.act_shift = 6 + self.isthrscaled = False + self.effective_vth_dend = None + self.effective_vth_soma = None + self.s_out_buff = None + + def _validate_var(self, var, var_type, min_val, max_val, var_name): + if not isinstance(var, var_type): + raise ValueError(f"'{var_name}' must have type {var_type}") + if var < min_val or var > max_val: + raise ValueError( + f"'{var_name}' must be in range [{min_val}, {max_val}]" + ) + + def _validate_inputs(self, proc_params): + self._validate_var(proc_params['dv_dend'], int, 0, 4095, 'dv_dend') + self._validate_var(proc_params['dv_soma'], int, 0, 4095, 'dv_soma') + self._validate_var(proc_params['vth_dend'], int, 0, 131071, 'vth_dend') + self._validate_var(proc_params['vth_soma'], int, 0, 131071, 'vth_soma') + self._validate_var(proc_params['up_dur'], int, 0, 255, 'up_dur') + + def scale_threshold(self): + self.effective_vth_dend = np.left_shift(self.vth_dend, self.vth_shift) + self.effective_vth_soma = np.left_shift(self.vth_soma, self.vth_shift) + self.isthrscaled = True + + def subthr_dynamics( + self, + activation_dend_in: np.ndarray, + activation_soma_in: np.ndarray + ): + """Run the sub-threshold dynamics for both the dendrite and soma of the + neuron. Both use 'leaky integration'. + """ + for v, dv, a_in in [ + (self.v_dend, self.dv_dend, activation_dend_in), + (self.v_soma, self.dv_soma, activation_soma_in), + ]: + decayed_volt = np.int64(v) * (self.decay_unity - dv) + decayed_volt = np.sign(decayed_volt) * np.right_shift( + np.abs(decayed_volt), 12 + ) + decayed_volt = np.int32(decayed_volt) + updated_volt = decayed_volt + np.left_shift(a_in, self.act_shift) + + neg_voltage_limit = -np.int32(self.max_uv_val) + 1 + pos_voltage_limit = np.int32(self.max_uv_val) - 1 + + v[:] = np.clip( + updated_volt, neg_voltage_limit, pos_voltage_limit + ) + + def update_up_state(self): + """Decrements the up state (if necessary) and checks v_dend to see if + up state needs to be (re)set. If up state is (re)set, then v_dend is + reset to 0. + """ + self.up_state[self.up_state > 0] -= 1 + self.up_state[self.v_dend > self.effective_vth_dend] = self.up_dur + self.v_dend[self.v_dend > self.effective_vth_dend] = 0 + + def soma_spike_and_reset(self): + """Check the spiking conditions for the plateau soma. Checks if: + v_soma > v_th_soma + up_state > 0 + + For any neurons n that satisfy both conditions, sets: + s_out_buff[n] = True + v_soma = 0 + """ + s_out_buff = np.logical_and( + self.v_soma > self.effective_vth_soma, + self.up_state > 0 + ) + self.v_soma[s_out_buff] = 0 + + return s_out_buff + + def run_spk(self): + """The run function that performs the actual computation during + execution orchestrated by a PyLoihiProcessModel using the + LoihiProtocol. + """ + + # Receive synaptic input + a_dend_in_data = self.a_dend_in.recv() + a_soma_in_data = self.a_soma_in.recv() + + # Check threshold scaling + if not self.isthrscaled: + self.scale_threshold() + + self.subthr_dynamics(a_dend_in_data, a_soma_in_data) + + self.update_up_state() + + self.s_out_buff = self.soma_spike_and_reset() + + self.s_out.send(self.s_out_buff) diff --git a/src/lava/proc/plateau/process.py b/src/lava/proc/plateau/process.py new file mode 100644 index 000000000..83fa04fa8 --- /dev/null +++ b/src/lava/proc/plateau/process.py @@ -0,0 +1,70 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + + +import typing as ty +from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.process.variable import Var +from lava.magma.core.process.ports.ports import InPort, OutPort + + +class Plateau(AbstractProcess): + """Plateau Neuron Process. + + Couples two modified LIF dynamics. The neuron posesses two voltages, + v_dend and v_soma. Both follow sub-threshold LIF dynamics. When v_dend + crosses v_th_dend, it resets and sets the up_state to the value up_dur. + The supra-threshold behavior of v_soma depends on up_state: + if up_state == 0: + v_soma follows sub-threshold dynamics + if up_state > 0: + v_soma resets and the neuron sends out a spike + + Parameters + ---------- + shape : tuple(int) + Number and topology of Plateau neurons. + dv_dend : int + Inverse of the decay time-constant for the dendrite voltage. + dv_soma : int + Inverse of the decay time-constant for the soma voltage. + vth_dend : int + Dendrite threshold voltage, exceeding which, the neuron will enter the + UP state. + vth_soma : int + Soma threshold voltage, exceeding which, the neuron will spike if it is + also in the UP state. + up_dur : int + The duration, in timesteps, of the UP state. + """ + def __init__( + self, + shape: ty.Tuple[int, ...], + dv_dend: int, + dv_soma: int, + vth_dend: int, + vth_soma: int, + up_dur: int, + name: ty.Optional[str] = None, + ): + super().__init__( + shape=shape, + dv_dend=dv_dend, + dv_soma=dv_soma, + name=name, + up_dur=up_dur, + vth_dend=vth_dend, + vth_soma=vth_soma + ) + self.a_dend_in = InPort(shape=shape) + self.a_soma_in = InPort(shape=shape) + self.s_out = OutPort(shape=shape) + self.v_dend = Var(shape=shape, init=0) + self.v_soma = Var(shape=shape, init=0) + self.dv_dend = Var(shape=(1,), init=dv_dend) + self.dv_soma = Var(shape=(1,), init=dv_soma) + self.vth_dend = Var(shape=(1,), init=vth_dend) + self.vth_soma = Var(shape=(1,), init=vth_soma) + self.up_dur = Var(shape=(1,), init=up_dur) + self.up_state = Var(shape=shape, init=0) diff --git a/tests/lava/proc/plateau/test_models.py b/tests/lava/proc/plateau/test_models.py new file mode 100644 index 000000000..b5635ac1f --- /dev/null +++ b/tests/lava/proc/plateau/test_models.py @@ -0,0 +1,137 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + + +import unittest +import numpy as np +from lava.proc.plateau.process import Plateau +from lava.proc.dense.process import Dense +from lava.proc.io.source import RingBuffer as Source +from lava.magma.core.run_configs import Loihi2SimCfg +from lava.magma.core.run_conditions import RunSteps +from lava.tests.lava.proc.lif.test_models import VecRecvProcess + + +def create_spike_source(spike_list, n_indices, n_timesteps): + """Use list of spikes [(idx, timestep), ...] to create a RingBuffer source + with data shape (n_indices, n_timesteps) and spikes at all specified points + in the spike_list. + """ + data = np.zeros(shape=(n_indices, n_timesteps)) + for idx, timestep in spike_list: + data[idx, timestep - 1] = 1 + return Source(data=data) + + +class TestPlateauProcessModelsFixed(unittest.TestCase): + """Tests for the fixed point Plateau process models.""" + def test_fixed_max_decay(self): + """ + Tests fixed point Plateau with max voltage decays. + """ + shape = (3,) + num_steps = 20 + spikes_in_dend = [(0, 5), (1, 5), (2, 5)] + spikes_in_soma = [(0, 3), (1, 10), (2, 17)] + sg_dend = create_spike_source(spikes_in_dend, shape[0], num_steps) + sg_soma = create_spike_source(spikes_in_soma, shape[0], num_steps) + dense_dend = Dense(weights=2 * np.diag(np.ones(shape=shape))) + dense_soma = Dense(weights=2 * np.diag(np.ones(shape=shape))) + plat = Plateau( + shape=shape, + dv_dend=4095, + dv_soma=4095, + vth_soma=1, + vth_dend=1, + up_dur=10 + ) + vr = VecRecvProcess(shape=(num_steps, shape[0])) + sg_dend.s_out.connect(dense_dend.s_in) + sg_soma.s_out.connect(dense_soma.s_in) + dense_dend.a_out.connect(plat.a_dend_in) + dense_soma.a_out.connect(plat.a_soma_in) + plat.s_out.connect(vr.s_in) + # run model + plat.run(RunSteps(num_steps), Loihi2SimCfg(select_tag='fixed_pt')) + test_spk_data = vr.spk_data.get() + plat.stop() + # Gold standard for the test + expected_spk_data = np.zeros((num_steps, shape[0])) + # Neuron 2 should spike when receiving soma input + expected_spk_data[10, 1] = 1 + self.assertTrue(np.all(expected_spk_data == test_spk_data)) + + def test_up_dur(self): + """ + Tests that the UP state lasts for the time specified by the model. + Checks that up_state decreases by one each time step after activation. + """ + shape = (1,) + num_steps = 10 + spikes_in_dend = [(0, 3)] + sg_dend = create_spike_source(spikes_in_dend, shape[0], num_steps) + dense_dend = Dense(weights=2 * (np.diag(np.ones(shape=shape)))) + plat = Plateau( + shape=shape, + dv_dend=4095, + dv_soma=4095, + vth_soma=1, + vth_dend=1, + up_dur=5 + ) + sg_dend.s_out.connect(dense_dend.s_in) + dense_dend.a_out.connect(plat.a_dend_in) + # run model + test_up_state = [] + for _ in range(num_steps): + plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt')) + test_up_state.append(plat.up_state.get().astype(int)[0]) + plat.stop() + # Gold standard for the test + # UP state active time steps 4 - 9 (5 timesteps) + # this is delayed by one b.c. of the Dense process + expected_up_state = [0, 0, 0, 5, 4, 3, 2, 1, 0, 0] + self.assertListEqual(expected_up_state, test_up_state) + + def test_fixed_dvs(self): + """ + Tests fixed point Plateau voltage decays. + """ + shape = (1,) + num_steps = 10 + spikes_in = [(0, 1)] + sg_dend = create_spike_source(spikes_in, shape[0], num_steps) + sg_soma = create_spike_source(spikes_in, shape[0], num_steps) + dense_dend = Dense(weights=100 * np.diag(np.ones(shape=shape))) + dense_soma = Dense(weights=100 * np.diag(np.ones(shape=shape))) + plat = Plateau( + shape=shape, + dv_dend=2048, + dv_soma=1024, + vth_soma=100, + vth_dend=100, + up_dur=10 + ) + sg_dend.s_out.connect(dense_dend.s_in) + sg_soma.s_out.connect(dense_soma.s_in) + dense_dend.a_out.connect(plat.a_dend_in) + dense_soma.a_out.connect(plat.a_soma_in) + # run model + test_v_dend = [] + test_v_soma = [] + for _ in range(num_steps): + plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt')) + test_v_dend.append(plat.v_dend.get().astype(int)[0]) + test_v_soma.append(plat.v_soma.get().astype(int)[0]) + plat.stop() + # Gold standard for the test + # 100<<6 = 6400 -- initial value at time step 2 + expected_v_dend = [ + 0, 6400, 3198, 1598, 798, 398, 198, 98, 48, 23 + ] + expected_v_soma = [ + 0, 6400, 4798, 3597, 2696, 2021, 1515, 1135, 850, 637 + ] + self.assertListEqual(expected_v_dend, test_v_dend) + self.assertListEqual(expected_v_soma, test_v_soma) diff --git a/tests/lava/proc/plateau/test_process.py b/tests/lava/proc/plateau/test_process.py new file mode 100644 index 000000000..57872471c --- /dev/null +++ b/tests/lava/proc/plateau/test_process.py @@ -0,0 +1,29 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + + +import unittest +from lava.proc.plateau.process import Plateau + + +class TestPlateauProcess(unittest.TestCase): + """Tests for Plateau class""" + def test_init(self): + """Tests instantiation of Plateau""" + plat = Plateau( + shape=(100,), + dv_dend=100, + dv_soma=1, + vth_dend=10, + vth_soma=1, + up_dur=10, + name="Plat" + ) + + self.assertEqual(plat.name, "Plat") + self.assertEqual(plat.dv_dend.init, 100) + self.assertEqual(plat.dv_soma.init, 1) + self.assertEqual(plat.vth_dend.init, 10) + self.assertEqual(plat.vth_soma.init, 1) + self.assertEqual(plat.up_dur.init, 10)