diff --git a/pygenn/genn_groups.py b/pygenn/genn_groups.py index dedc2ae553..041f5677cf 100644 --- a/pygenn/genn_groups.py +++ b/pygenn/genn_groups.py @@ -359,7 +359,7 @@ def load(self): # Load neuron extra global params self._load_egp() - + def load_init_egps(self): # Load any egps used for variable initialisation self._load_var_init_egps() @@ -763,36 +763,44 @@ def load(self): # which requires initialising manually if not self.is_dense and self.weight_sharing_master is None: if self.is_ragged: - # Get pointers to ragged data structure members - ind = self._assign_ext_ptr_array("ind", - self.weight_update_var_size, - "unsigned int") - row_length = self._assign_ext_ptr_array("rowLength", - self.src.size, - "unsigned int") - # add pointers to the object - self._ind = ind - self._row_lengths = row_length - - # If data is available - if self.connections_set: - - # Copy in row length - row_length[:] = self.row_lengths - - # Create (x)range containing the index where each row starts in ind - row_start_idx = xrange(0, self.weight_update_var_size, - self.max_row_length) - - # Loop through ragged matrix rows - syn = 0 - for i, r in zip(row_start_idx, self.row_lengths): - # Copy row from non-padded indices into correct location - ind[i:i + r] = self.ind[syn:syn + r] - syn += r - elif self.connectivity_initialiser is None: - raise Exception("For sparse projections, the connections" - "must be set before loading a model") + # If connectivity is located on host + conn_loc = self.pop.get_sparse_connectivity_location() + if (conn_loc & VarLocation_HOST) != 0: + # Get pointers to ragged data structure members + ind = self._assign_ext_ptr_array("ind", + self.weight_update_var_size, + "unsigned int") + row_length = self._assign_ext_ptr_array("rowLength", + self.src.size, + "unsigned int") + # add pointers to the object + self._ind = ind + self._row_lengths = row_length + + # If data is available + if self.connections_set: + # Copy in row length + row_length[:] = self.row_lengths + + # Create (x)range containing the index where each row starts in ind + row_start_idx = xrange(0, self.weight_update_var_size, + self.max_row_length) + + # Loop through ragged matrix rows + syn = 0 + for i, r in zip(row_start_idx, self.row_lengths): + # Copy row from non-padded indices into correct location + ind[i:i + r] = self.ind[syn:syn + r] + syn += r + elif self.connectivity_initialiser is None: + raise Exception("For sparse projections, the connections" + "must be set before loading a model") + # Otherwise, if connectivity isn't located on host, + # give error if user tries to manually configure it + elif self.connections_set: + raise Exception("If sparse connectivity is only located " + "on device, it cannot be set with " + "set_sparse_connections") else: raise Exception("Matrix format not supported")