Skip to content

Commit

Permalink
Merge pull request #344 from genn-team/device_only_connectivity_fix
Browse files Browse the repository at this point in the history
Device only connectivity fix
  • Loading branch information
neworderofjamie authored Jul 8, 2020
2 parents 101a428 + 961dc03 commit 38fa54e
Showing 1 changed file with 39 additions and 31 deletions.
70 changes: 39 additions & 31 deletions pygenn/genn_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def load(self):

# Load neuron extra global params
self._load_egp()

def load_init_egps(self):
# Load any egps used for variable initialisation
self._load_var_init_egps()
Expand Down Expand Up @@ -763,36 +763,44 @@ def load(self):
# which requires initialising manually
if not self.is_dense and self.weight_sharing_master is None:
if self.is_ragged:
# Get pointers to ragged data structure members
ind = self._assign_ext_ptr_array("ind",
self.weight_update_var_size,
"unsigned int")
row_length = self._assign_ext_ptr_array("rowLength",
self.src.size,
"unsigned int")
# add pointers to the object
self._ind = ind
self._row_lengths = row_length

# If data is available
if self.connections_set:

# Copy in row length
row_length[:] = self.row_lengths

# Create (x)range containing the index where each row starts in ind
row_start_idx = xrange(0, self.weight_update_var_size,
self.max_row_length)

# Loop through ragged matrix rows
syn = 0
for i, r in zip(row_start_idx, self.row_lengths):
# Copy row from non-padded indices into correct location
ind[i:i + r] = self.ind[syn:syn + r]
syn += r
elif self.connectivity_initialiser is None:
raise Exception("For sparse projections, the connections"
"must be set before loading a model")
# If connectivity is located on host
conn_loc = self.pop.get_sparse_connectivity_location()
if (conn_loc & VarLocation_HOST) != 0:
# Get pointers to ragged data structure members
ind = self._assign_ext_ptr_array("ind",
self.weight_update_var_size,
"unsigned int")
row_length = self._assign_ext_ptr_array("rowLength",
self.src.size,
"unsigned int")
# add pointers to the object
self._ind = ind
self._row_lengths = row_length

# If data is available
if self.connections_set:
# Copy in row length
row_length[:] = self.row_lengths

# Create (x)range containing the index where each row starts in ind
row_start_idx = xrange(0, self.weight_update_var_size,
self.max_row_length)

# Loop through ragged matrix rows
syn = 0
for i, r in zip(row_start_idx, self.row_lengths):
# Copy row from non-padded indices into correct location
ind[i:i + r] = self.ind[syn:syn + r]
syn += r
elif self.connectivity_initialiser is None:
raise Exception("For sparse projections, the connections"
"must be set before loading a model")
# Otherwise, if connectivity isn't located on host,
# give error if user tries to manually configure it
elif self.connections_set:
raise Exception("If sparse connectivity is only located "
"on device, it cannot be set with "
"set_sparse_connections")

else:
raise Exception("Matrix format not supported")
Expand Down

0 comments on commit 38fa54e

Please sign in to comment.