Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/randomness #7

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions bsb_nest/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from neo import SpikeTrain
from tqdm import tqdm

from .exceptions import NestConnectError, NestModelError, NestModuleError
from .exceptions import KernelWarning, NestConnectError, NestModelError, NestModuleError

if typing.TYPE_CHECKING:
from bsb import Simulation
from .simulation import NestSimulation


class NestResult(SimulationResult):
Expand Down Expand Up @@ -167,10 +167,12 @@ def create_devices(self, simulation):
for device_model in simulation.devices.values():
device_model.implement(self, simulation, simdata)

def set_settings(self, simulation: "Simulation"):
def set_settings(self, simulation: "NestSimulation"):
nest.set_verbosity(simulation.verbosity)
nest.resolution = simulation.resolution
nest.overwrite_files = True
if simulation.seed is not None:
nest.rng_seed = simulation.seed

def check_comm(self):
if nest.NumProcesses() != MPI.get_size():
Expand Down
13 changes: 10 additions & 3 deletions bsb_nest/cell.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import nest
from bsb import CellModel, config, types
from bsb import CellModel, config

from .distributions import NestRandomDistribution, nest_parameter


@config.node
class NestCell(CellModel):
model = config.attr(type=str, default="iaf_psc_alpha")
constants = config.dict(type=types.any_())
constants = config.dict(type=nest_parameter())

def create_population(self, simdata):
n = len(simdata.placement[self])
Expand All @@ -15,7 +17,12 @@ def create_population(self, simdata):
return population

def set_constants(self, population):
population.set(self.constants)
population.set(
{
k: (v() if isinstance(v, NestRandomDistribution) else v)
for k, v in self.constants.items()
}
)

def set_parameters(self, population, simdata):
ps = simdata.placement[self]
Expand Down
3 changes: 2 additions & 1 deletion bsb_nest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bsb import MPI, ConnectionModel, compose_nodes, config, types
from tqdm import tqdm

from .distributions import nest_parameter
from .exceptions import NestConnectError


Expand All @@ -16,7 +17,7 @@ class NestSynapseSettings:
weight = config.attr(type=float, required=True)
delay = config.attr(type=float, required=True)
receptor_type = config.attr(type=int)
constants = config.catch_all(type=types.any_())
constants = config.catch_all(type=nest_parameter())


@config.node
Expand Down
61 changes: 61 additions & 0 deletions bsb_nest/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import builtins
import typing

import errr
import nest.random.hl_api_random as _distributions
from bsb import DistributionCastError, TypeHandler, config, types
Helveg marked this conversation as resolved.
Show resolved Hide resolved

if typing.TYPE_CHECKING:
from bsb import Scaffold

_available_distributions = [d for d in _distributions.__all__]


@config.node
class NestRandomDistribution:
"""
Class to handle NEST random distributions.
"""

scaffold: "Scaffold"
distribution: str = config.attr(
type=types.in_(_available_distributions), required=True
)
"""Distribution name. Should correspond to a function of nest.random.hl_api_random"""
parameters: dict[str, typing.Any] = config.catch_all(type=types.any_())
"""Dictionary of parameters to assign to the distribution. Should correspond to NEST's"""

def __init__(self, **kwargs):
try:
self._distr = getattr(_distributions, self.distribution)(**self.parameters)
except Exception as e:
errr.wrap(
DistributionCastError, e, prepend=f"Can't cast to '{self.distribution}': "
)

def __call__(self):
return self._distr

def __getattr__(self, attr):
if "_distr" not in self.__dict__:
drodarie marked this conversation as resolved.
Show resolved Hide resolved
raise AttributeError("No underlying _distr found for distribution node.")
return getattr(self._distr, attr)


class nest_parameter(TypeHandler):
"""
Type validator. Type casts the value or node to a Nest parameter, that can be either a value or
a NestRandomDistribution.
"""

def __call__(self, value, _key=None, _parent=None):
if isinstance(value, builtins.dict) and "distribution" in value.keys():
return NestRandomDistribution(**value, _key=_key, _parent=_parent)
return value

@property
def __name__(self): # pragma: nocover
return "nest_parameter"
drodarie marked this conversation as resolved.
Show resolved Hide resolved

def __inv__(self, value):
return value
6 changes: 6 additions & 0 deletions bsb_nest/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ class NestSimulation(Simulation):
"""

modules = config.list(type=str)
"""List of NEST modules to load at the beginning of the simulation"""
threads = config.attr(type=types.int(min=1), default=1)
"""Number of threads to use during simulation"""
resolution = config.attr(type=types.float(min=0.0), required=True)
"""Simulation time step size in milliseconds"""
verbosity = config.attr(type=str, default="M_ERROR")
"""NEST verbosity level"""
seed = config.attr(type=int, default=None)
"""Random seed for the simulations"""

cell_models = config.dict(type=NestCell, required=True)
connection_models = config.dict(type=NestConnection, required=True)
Expand Down
88 changes: 88 additions & 0 deletions tests/test_nest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import nest
import numpy as np
from bsb import CastError
from bsb.config import Configuration
from bsb.core import Scaffold
from bsb.services import MPI
Expand Down Expand Up @@ -210,3 +211,90 @@ def test_iaf_cond_alpha(self):
results = netw.run_simulation("test")
spike_times_bsb = results.spiketrains[0]
self.assertClose(np.array(spike_times_nest), np.array(spike_times_bsb))

def test_nest_randomness(self):
nest.ResetKernel()
nest.resolution = 0.1
nest.rng_seed = 1234
# gif_cond_exp implements a random spiking process.
# So it's perfect to test the seed
A = nest.Create(
"gif_cond_exp",
1,
params={"I_e": 200.0, "V_m": nest.random.normal(mean=-70, std=20.0)},
)
spikeA = nest.Create("spike_recorder")
nest.Connect(A, spikeA)
nest.Simulate(1000.0)
spike_times_nest = spikeA.get("events")["times"]
print(spike_times_nest)

conf = {
"name": "test",
"storage": {"engine": "hdf5"},
"network": {"x": 1, "y": 1, "z": 1},
"partitions": {"B": {"type": "layer", "thickness": 1}},
"cell_types": {"A": {"spatial": {"radius": 1, "count": 1}}},
"placement": {
"placement_A": {
"strategy": "bsb.placement.strategy.FixedPositions",
"cell_types": ["A"],
"partitions": ["B"],
"positions": [[1, 1, 1]],
}
},
"connectivity": {},
"after_connectivity": {},
"simulations": {
"test": {
"simulator": "nest",
"duration": 1000,
"resolution": 0.1,
"seed": 1234,
"cell_models": {
"A": {
"model": "gif_cond_exp",
"constants": {
"I_e": 200.0,
"V_m": {
"distribution": "normal",
"mean": -70,
"std": 20.0,
},
},
}
},
"connection_models": {},
"devices": {
"record_A_spikes": {
"device": "spike_recorder",
"delay": 0.5,
"targetting": {
"strategy": "cell_model",
"cell_models": ["A"],
},
}
},
}
},
}
cfg = Configuration(conf)
netw = Scaffold(cfg, self.storage)
netw.compile()
results = netw.run_simulation("test")
spike_times_bsb = results.spiketrains[0]
self.assertClose(np.array(spike_times_nest), np.array(spike_times_bsb))
self.assertEqual(
cfg.__tree__()["simulations"]["test"]["cell_models"]["A"]["constants"]["V_m"],
{
"distribution": "normal",
"mean": -70,
"std": 20.0,
},
)
# Test with an unknown distribution
conf["simulations"]["test"]["cell_models"]["A"]["constants"]["V_m"][
"distribution"
] = "bean"
with self.assertRaises(CastError):
_ = Configuration(conf)
drodarie marked this conversation as resolved.
Show resolved Hide resolved