Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update simulator to return sparse coo tensors #31

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions spikeometric/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,29 @@ def simulate(self, data, n_steps: int, verbose: bool = True, equilibration_steps
device = edge_index.device

# If verbose is True, a progress bar is shown
pbar = tqdm(range(T, n_steps + T), colour="#3E5641") if verbose else range(T, n_steps + T)
pbar = tqdm(range(n_steps), colour="#3E5641") if verbose else range(n_steps)

# Simulate the network
x = torch.zeros((n_neurons, n_steps + T), device=device, dtype=store_as_dtype)
inital_state = torch.randint(0, 2, device=device, size=(n_neurons,), generator=self._rng, dtype=store_as_dtype)
x[:, :T] = self.equilibrate(edge_index, W, inital_state, n_steps=equilibration_steps, store_as_dtype=store_as_dtype)
x = self.equilibrate(edge_index, W, inital_state, n_steps=equilibration_steps, store_as_dtype=store_as_dtype)
indices = [[],[]]
for t in pbar:
x[:, t] = self(edge_index=edge_index, W=W, state=x[:, t-T:t], t=t-T)

state = self(edge_index=edge_index, W=W, state=x, t=t)
x = x.roll(-1, 1)
x[:,-1] = state
sparse = torch.where(state)[0]
indices[1] += [t]*len(sparse)
indices[0] += sparse.tolist()

result = torch.sparse_coo_tensor(
indices, torch.ones(len(indices[0]), device=device), (n_neurons, n_steps), dtype=store_as_dtype
)
# If the stimulus is batched, we increment the batch in preparation for the next batch
if isinstance(self.stimulus, BaseStimulus) and self.stimulus.n_batches > 1:
self.stimulus.next_batch()

# Return the state of the network at each time step
return x[:, T:]
return result

def tune(
self,
Expand Down
17 changes: 12 additions & 5 deletions spikeometric/models/sa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,28 @@ def simulate(self, data: Data, n_steps: int, verbose: bool =True, equilibration_
pbar = tqdm(range(n_steps), colour="#3E5641") if verbose else range(n_steps)

# Initialize the state of the network
x = torch.zeros(n_neurons, n_steps, device=device, dtype=store_as_dtype)
initial_activation = torch.rand((n_neurons, T), device=device)
activation = self.equilibrate(edge_index, W, initial_activation, equilibration_steps, store_as_dtype=store_as_dtype)

spikes = torch.zeros((n_neurons, T), device=device)
# Simulate the network
indices = [[],[]]
for t in pbar:
x[:, t] = self(edge_index=edge_index, W=W, state=activation, t=t)
activation = self.update_activation(spikes=x[:, t:t+T], activation=activation)
spikes[:, 0] = self(edge_index=edge_index, W=W, state=activation, t=t)
activation = self.update_activation(spikes=spikes, activation=activation)
sparse = torch.where(spikes[:,0])[0]
indices[1] += [t]*len(sparse)
indices[0] += sparse.tolist()

result = torch.sparse_coo_tensor(
indices, torch.ones(len(indices[0]), device=device), (n_neurons, n_steps), dtype=store_as_dtype
)

# If the stimulus is batched, we increment the batch in preparation for the next batch
if isinstance(self.stimulus, BaseStimulus) and self.stimulus.n_batches > 1:
self.stimulus.next_batch()

# Return the state of the network at each time step
return x
return result

def tune(
self,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ def test_no_grad(bernoulli_glm, example_data):
assert not bernoulli_glm.alpha.requires_grad

def test_consistent_output_after_ten_steps(expected_output_after_ten_steps, bernoulli_glm, example_data):
X = bernoulli_glm.simulate(example_data, n_steps=10, verbose=False, equilibration_steps=0)
X = bernoulli_glm.simulate(example_data, n_steps=10, verbose=False, equilibration_steps=0).to_dense()
assert_close(X, expected_output_after_ten_steps)

def test_simulation_statistics(bernoulli_glm, saved_glorot_dataset):
n_steps = 1000
expected_firing_rate = 7.2
for example_data in saved_glorot_dataset:
X = bernoulli_glm.simulate(example_data, n_steps=n_steps, verbose=False)
X = bernoulli_glm.simulate(example_data, n_steps=n_steps, verbose=False).to_dense()
fr = (X.float().mean() / bernoulli_glm.dt) * 1000
pytest.approx(fr, expected_firing_rate)

def test_uniform_simulation(threshold_sam, generated_uniform_data):
from torch import tensor
example_uniform_data = generated_uniform_data[0]
X = threshold_sam.simulate(example_uniform_data, n_steps=1000, verbose=False)
X = threshold_sam.simulate(example_uniform_data, n_steps=1000, verbose=False).to_dense()
fr = (X.float().mean() / threshold_sam.dt) * 1000
assert_close(fr, tensor(33.6364), atol=0.001, rtol=0.1)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def test_only_model_tunable(bernoulli_glm, sin_stimulus, saved_glorot_dataset):
assert initial_parameters[parameter] == bernoulli_glm.tunable_parameters[parameter]

def test_tune_rectified_sa_model(rectified_sam, rectified_sam_network):
initial_spikes = rectified_sam.simulate(rectified_sam_network, 100)
initial_spikes = rectified_sam.simulate(rectified_sam_network, 100).to_dense()
rectified_sam.tune(rectified_sam_network, 10, n_steps=100, n_epochs=1, lr=0.01, verbose=False)
final_spikes = rectified_sam.simulate(rectified_sam_network, 100)
final_spikes = rectified_sam.simulate(rectified_sam_network, 100).to_dense()
assert not initial_spikes.float().mean() == final_spikes.float().mean()

def test_all_parameters_tunable(bernoulli_glm, saved_glorot_dataset):
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_tuning_improves_firing_rate(bernoulli_glm, example_data):
firing_rate = 62.5
initial_firing_rate = 7.2
bernoulli_glm.tune(example_data, firing_rate, tunable_parameters, n_steps=10, n_epochs=1, lr=0.1, verbose=False)
X = bernoulli_glm.simulate(example_data, n_steps=1000, verbose=False)
X = bernoulli_glm.simulate(example_data, n_steps=1000, verbose=False).to_dense()
final_firing_rate = X.float().mean() / bernoulli_glm.dt * 1000
assert final_firing_rate > initial_firing_rate

Expand Down