codeStrings, const std::string ¶mName,
+ size_t index, P getParamValuesFn) const
+ {
+ // If none of the code strings reference the parameter, return false
+ if(std::none_of(codeStrings.begin(), codeStrings.end(),
+ [¶mName](const std::string &c)
+ {
+ return (c.find("$(" + paramName + ")") != std::string::npos);
+ }))
+ {
+ return false;
+ }
+ // Otherwise check if values are heterogeneous
+ else {
+ return isParamValueHeterogeneous(index, getParamValuesFn);
+ }
+ }
+
private:
//------------------------------------------------------------------------
// Members
@@ -49,15 +101,275 @@ class GroupMerged
};
//----------------------------------------------------------------------------
-// CodeGenerator::NeuronGroupMerged
+// CodeGenerator::NeuronSpikeQueueUpdateGroupMerged
//----------------------------------------------------------------------------
-class GENN_EXPORT NeuronGroupMerged : public GroupMerged
+class GENN_EXPORT NeuronSpikeQueueUpdateGroupMerged : public GroupMerged
{
public:
- NeuronGroupMerged(size_t index, const std::vector> &groups)
+ NeuronSpikeQueueUpdateGroupMerged(size_t index, const std::vector> &groups)
: GroupMerged(index, groups)
{}
+ //------------------------------------------------------------------------
+ // Public API
+ //------------------------------------------------------------------------
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision) const;
+
+ void genMergedGroupSpikeCountReset(CodeStream &os) const;
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::NeuronGroupMergedBase
+//----------------------------------------------------------------------------
+class GENN_EXPORT NeuronGroupMergedBase : public GroupMerged
+{
+public:
+ //------------------------------------------------------------------------
+ // Public API
+ //------------------------------------------------------------------------
+ //! Should the parameter be implemented heterogeneously?
+ bool isParamHeterogeneous(size_t index) const;
+
+ //! Should the derived parameter be implemented heterogeneously?
+ bool isDerivedParamHeterogeneous(size_t index) const;
+
+ //! Should the var init parameter be implemented heterogeneously?
+ bool isVarInitParamHeterogeneous(size_t varIndex, size_t paramIndex) const;
+
+ //! Should the var init derived parameter be implemented heterogeneously?
+ bool isVarInitDerivedParamHeterogeneous(size_t varIndex, size_t paramIndex) const;
+
+ //! Should the current source parameter be implemented heterogeneously?
+ bool isCurrentSourceParamHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ //! Should the current source derived parameter be implemented heterogeneously?
+ bool isCurrentSourceDerivedParamHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ //! Should the current source var init parameter be implemented heterogeneously?
+ bool isCurrentSourceVarInitParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+ //! Should the current source var init derived parameter be implemented heterogeneously?
+ bool isCurrentSourceVarInitDerivedParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+ //! Should the postsynaptic model parameter be implemented heterogeneously?
+ bool isPSMParamHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ //! Should the postsynaptic model derived parameter be implemented heterogeneously?
+ bool isPSMDerivedParamHeterogeneous(size_t childIndex, size_t varIndex) const;
+
+ //! Should the GLOBALG postsynaptic model variable be implemented heterogeneously?
+ bool isPSMGlobalVarHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ //! Should the postsynaptic model var init parameter be implemented heterogeneously?
+ bool isPSMVarInitParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+ //! Should the postsynaptic model var init derived parameter be implemented heterogeneously?
+ bool isPSMVarInitDerivedParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+protected:
+ //------------------------------------------------------------------------
+ // Protected methods
+ //------------------------------------------------------------------------
+ NeuronGroupMergedBase(size_t index, bool init, const std::vector> &groups);
+
+ void generate(MergedStructGenerator &gen, const BackendBase &backend,
+ const std::string &precision, const std::string &timePrecision, bool init) const;
+
+ template
+ void orderNeuronGroupChildren(const std::vector &archetypeChildren,
+ std::vector> &sortedGroupChildren,
+ G getVectorFunc, C isCompatibleFunc) const
+ {
+ // Reserve vector of vectors to hold children for all neuron groups, in archetype order
+ sortedGroupChildren.reserve(archetypeChildren.size());
+
+ // Loop through groups
+ for(const auto &g : getGroups()) {
+ // Make temporary copy of this group's children
+ std::vector tempChildren((g.get().*getVectorFunc)());
+
+ assert(tempChildren.size() == archetypeChildren.size());
+
+ // Reserve vector for this group's children
+ sortedGroupChildren.emplace_back();
+ sortedGroupChildren.back().reserve(tempChildren.size());
+
+ // Loop through archetype group's children
+ for(const auto &archetypeG : archetypeChildren) {
+ // Find compatible child in temporary list
+ const auto otherChild = std::find_if(tempChildren.cbegin(), tempChildren.cend(),
+ [archetypeG, isCompatibleFunc](const T &g)
+ {
+ return isCompatibleFunc(archetypeG, g);
+ });
+ assert(otherChild != tempChildren.cend());
+
+ // Add pointer to vector of compatible merged in syns
+ sortedGroupChildren.back().push_back(*otherChild);
+
+ // Remove from original vector
+ tempChildren.erase(otherChild);
+ }
+ }
+ }
+
+ template
+ void orderNeuronGroupChildren(std::vector> &sortedGroupChildren,
+ G getVectorFunc, C isCompatibleFunc) const
+ {
+ const std::vector &archetypeChildren = (getArchetype().*getVectorFunc)();
+ orderNeuronGroupChildren(archetypeChildren, sortedGroupChildren, getVectorFunc, isCompatibleFunc);
+ }
+
+
+ template
+ bool isChildParamValueHeterogeneous(std::initializer_list codeStrings,
+ const std::string ¶mName, size_t childIndex, size_t paramIndex,
+ const std::vector> &sortedGroupChildren, G getParamValuesFn) const
+ {
+ // If none of the code strings reference the parameter
+ if(std::any_of(codeStrings.begin(), codeStrings.end(),
+ [¶mName](const std::string &c)
+ {
+ return (c.find("$(" + paramName + ")") != std::string::npos);
+ }))
+ {
+ // Get value of archetype derived parameter
+ const double firstValue = getParamValuesFn(sortedGroupChildren[0][childIndex]).at(paramIndex);
+
+ // Loop through groups within merged group
+ for(size_t i = 0; i < sortedGroupChildren.size(); i++) {
+ const auto group = sortedGroupChildren[i][childIndex];
+ if(getParamValuesFn(group).at(paramIndex) != firstValue) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ template
+ void addHeterogeneousChildParams(MergedStructGenerator &gen,
+ const Snippet::Base::StringVec ¶mNames, size_t childIndex,
+ const std::string &prefix,
+ H isChildParamHeterogeneousFn, V getValueFn) const
+ {
+ // Loop through parameters
+ for(size_t p = 0; p < paramNames.size(); p++) {
+ // If parameter is heterogeneous
+ if((static_cast(this)->*isChildParamHeterogeneousFn)(childIndex, p)) {
+ gen.addScalarField(paramNames[p] + prefix + std::to_string(childIndex),
+ [childIndex, p, getValueFn](const NeuronGroupInternal &, size_t groupIndex)
+ {
+ return Utils::writePreciseString(getValueFn(groupIndex, childIndex, p));
+ });
+ }
+ }
+ }
+
+ template
+ void addHeterogeneousChildDerivedParams(MergedStructGenerator &gen,
+ const Snippet::Base::DerivedParamVec &derivedParams, size_t childIndex,
+ const std::string &prefix, H isChildDerivedParamHeterogeneousFn, V getValueFn) const
+ {
+ // Loop through derived parameters
+ for(size_t p = 0; p < derivedParams.size(); p++) {
+ // If parameter is heterogeneous
+ if((static_cast(this)->*isChildDerivedParamHeterogeneousFn)(childIndex, p)) {
+ gen.addScalarField(derivedParams[p].name + prefix + std::to_string(childIndex),
+ [childIndex, p, getValueFn](const NeuronGroupInternal &, size_t groupIndex)
+ {
+ return Utils::writePreciseString(getValueFn(groupIndex, childIndex, p));
+ });
+ }
+ }
+ }
+
+ template
+ void addHeterogeneousChildVarInitParams(MergedStructGenerator &gen,
+ const Snippet::Base::StringVec ¶mNames, size_t childIndex,
+ size_t varIndex, const std::string &prefix,
+ H isChildParamHeterogeneousFn, V getVarInitialiserFn) const
+ {
+ // Loop through parameters
+ for(size_t p = 0; p < paramNames.size(); p++) {
+ // If parameter is heterogeneous
+ if((static_cast(this)->*isChildParamHeterogeneousFn)(childIndex, varIndex, p)) {
+ gen.addScalarField(paramNames[p] + prefix + std::to_string(childIndex),
+ [childIndex, varIndex, p, getVarInitialiserFn](const NeuronGroupInternal &, size_t groupIndex)
+ {
+ const std::vector &varInit = getVarInitialiserFn(groupIndex, childIndex);
+ return Utils::writePreciseString(varInit.at(varIndex).getParams().at(p));
+ });
+ }
+ }
+ }
+
+ template
+ void addHeterogeneousChildVarInitDerivedParams(MergedStructGenerator &gen,
+ const Snippet::Base::DerivedParamVec &derivedParams, size_t childIndex,
+ size_t varIndex, const std::string &prefix,
+ H isChildDerivedParamHeterogeneousFn, V getVarInitialiserFn) const
+ {
+ // Loop through parameters
+ for(size_t p = 0; p < derivedParams.size(); p++) {
+ // If parameter is heterogeneous
+ if((static_cast(this)->*isChildDerivedParamHeterogeneousFn)(childIndex, varIndex, p)) {
+ gen.addScalarField(derivedParams[p].name + prefix + std::to_string(childIndex),
+ [childIndex, varIndex, p, getVarInitialiserFn](const NeuronGroupInternal &, size_t groupIndex)
+ {
+ const std::vector &varInit = getVarInitialiserFn(groupIndex, childIndex);
+ return Utils::writePreciseString(varInit.at(varIndex).getDerivedParams().at(p));
+ });
+ }
+ }
+ }
+
+ template
+ void addChildEGPs(MergedStructGenerator &gen,
+ const std::vector &egps, size_t childIndex,
+ const std::string &arrayPrefix, const std::string &prefix,
+ S getEGPSuffixFn) const
+ {
+ using FieldType = std::remove_reference::type::FieldType;
+ for(const auto &e : egps) {
+ const bool isPointer = Utils::isTypePointer(e.type);
+ const std::string varPrefix = isPointer ? arrayPrefix : "";
+ gen.addField(e.type, e.name + prefix + std::to_string(childIndex),
+ [getEGPSuffixFn, childIndex, e, varPrefix](const NeuronGroupInternal&, size_t groupIndex)
+ {
+ return varPrefix + e.name + getEGPSuffixFn(groupIndex, childIndex);
+ },
+ Utils::isTypePointer(e.type) ? FieldType::PointerEGP : FieldType::ScalarEGP);
+ }
+ }
+
+
+
+ void addMergedInSynPointerField(MergedStructGenerator &gen,
+ const std::string &type, const std::string &name,
+ size_t archetypeIndex, const std::string &prefix) const;
+
+private:
+ //------------------------------------------------------------------------
+ // Members
+ //------------------------------------------------------------------------
+ std::vector>>> m_SortedMergedInSyns;
+ std::vector> m_SortedCurrentSources;
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::NeuronUpdateGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase
+{
+public:
+ NeuronUpdateGroupMerged(size_t index, const std::vector> &groups);
+
//------------------------------------------------------------------------
// Public API
//------------------------------------------------------------------------
@@ -67,18 +379,167 @@ class GENN_EXPORT NeuronGroupMerged : public GroupMerged
//! Get the expression to calculate the queue offset for accessing state of variables in previous timestep
std::string getPrevQueueOffset() const;
+ //! Should the incoming synapse weight update model parameter be implemented heterogeneously?
+ bool isInSynWUMParamHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ //! Should the incoming synapse weight update model derived parameter be implemented heterogeneously?
+ bool isInSynWUMDerivedParamHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ //! Should the outgoing synapse weight update model parameter be implemented heterogeneously?
+ bool isOutSynWUMParamHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ //! Should the outgoing synapse weight update model derived parameter be implemented heterogeneously?
+ bool isOutSynWUMDerivedParamHeterogeneous(size_t childIndex, size_t paramIndex) const;
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision) const;
+
+private:
+ //------------------------------------------------------------------------
+ // Private methods
+ //------------------------------------------------------------------------
+ //! Helper to generate merged struct fields for WU pre and post vars
+ void generateWUVar(MergedStructGenerator &gen, const BackendBase &backend,
+ const std::string &fieldPrefixStem,
+ const std::vector &archetypeSyn,
+ const std::vector> &sortedSyn,
+ Models::Base::VarVec(WeightUpdateModels::Base::*getVars)(void) const,
+ bool(NeuronUpdateGroupMerged::*isParamHeterogeneous)(size_t, size_t) const,
+ bool(NeuronUpdateGroupMerged::*isDerivedParamHeterogeneous)(size_t, size_t) const) const;
+
+ //------------------------------------------------------------------------
+ // Members
+ //------------------------------------------------------------------------
+ std::vector> m_SortedInSynWithPostCode;
+ std::vector> m_SortedOutSynWithPreCode;
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::NeuronInitGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT NeuronInitGroupMerged : public NeuronGroupMergedBase
+{
+public:
+ NeuronInitGroupMerged(size_t index, const std::vector> &groups);
+
+ //! Should the incoming synapse weight update model var init parameter be implemented heterogeneously?
+ bool isInSynWUMVarInitParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+ //! Should the incoming synapse weight update model var init derived parameter be implemented heterogeneously?
+ bool isInSynWUMVarInitDerivedParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+ //! Should the outgoing synapse weight update model var init parameter be implemented heterogeneously?
+ bool isOutSynWUMVarInitParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+ //! Should the outgoing synapse weight update model var init derived parameter be implemented heterogeneously?
+ bool isOutSynWUMVarInitDerivedParamHeterogeneous(size_t childIndex, size_t varIndex, size_t paramIndex) const;
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision) const;
+
+private:
+ //------------------------------------------------------------------------
+ // Private methods
+ //------------------------------------------------------------------------
+ //! Helper to generate merged struct fields for WU pre and post vars
+ void generateWUVar(MergedStructGenerator &gen, const BackendBase &backend,
+ const std::string &fieldPrefixStem,
+ const std::vector &archetypeSyn,
+ const std::vector> &sortedSyn,
+ Models::Base::VarVec(WeightUpdateModels::Base::*getVars)(void) const,
+ const std::vector&(SynapseGroupInternal::*getVarInitialisers)(void) const,
+ bool(NeuronInitGroupMerged::*isParamHeterogeneous)(size_t, size_t, size_t) const,
+ bool(NeuronInitGroupMerged::*isDerivedParamHeterogeneous)(size_t, size_t, size_t) const) const;
+
+
+ //------------------------------------------------------------------------
+ // Members
+ //------------------------------------------------------------------------
+ std::vector> m_SortedInSynWithPostVars;
+ std::vector> m_SortedOutSynWithPreVars;
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::SynapseDendriticDelayUpdateGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT SynapseDendriticDelayUpdateGroupMerged : public GroupMerged
+{
+public:
+ SynapseDendriticDelayUpdateGroupMerged(size_t index, const std::vector> &groups)
+ : GroupMerged(index, groups)
+ {}
+
+ //------------------------------------------------------------------------
+ // Public API
+ //------------------------------------------------------------------------
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision) const;
};
+// ----------------------------------------------------------------------------
+// SynapseConnectivityHostInitGroupMerged
//----------------------------------------------------------------------------
-// CodeGenerator::SynapseGroupMerged
+class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged
+{
+public:
+ SynapseConnectivityHostInitGroupMerged(size_t index, const std::vector> &groups)
+ : GroupMerged(index, groups)
+ {}
+
+ //------------------------------------------------------------------------
+ // Public API
+ //------------------------------------------------------------------------
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision) const;
+
+ //! Should the connectivity initialization parameter be implemented heterogeneously for EGP init?
+ bool isConnectivityInitParamHeterogeneous(size_t paramIndex) const;
+
+ //! Should the connectivity initialization derived parameter be implemented heterogeneously for EGP init?
+ bool isConnectivityInitDerivedParamHeterogeneous(size_t paramIndex) const;
+};
+
+// ----------------------------------------------------------------------------
+// SynapseConnectivityInitGroupMerged
//----------------------------------------------------------------------------
-class GENN_EXPORT SynapseGroupMerged : public GroupMerged
+class GENN_EXPORT SynapseConnectivityInitGroupMerged : public GroupMerged
{
public:
- SynapseGroupMerged(size_t index, const std::vector> &groups)
+ SynapseConnectivityInitGroupMerged(size_t index, const std::vector> &groups)
: GroupMerged(index, groups)
{}
+ //------------------------------------------------------------------------
+ // Public API
+ //------------------------------------------------------------------------
+ //! Should the connectivity initialization parameter be implemented heterogeneously?
+ bool isConnectivityInitParamHeterogeneous(size_t paramIndex) const;
+
+ //! Should the connectivity initialization parameter be implemented heterogeneously?
+ bool isConnectivityInitDerivedParamHeterogeneous(size_t paramIndex) const;
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision) const;
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::SynapseGroupMergedBase
+//----------------------------------------------------------------------------
+class GENN_EXPORT SynapseGroupMergedBase : public GroupMerged
+{
+public:
//------------------------------------------------------------------------
// Public API
//------------------------------------------------------------------------
@@ -92,5 +553,223 @@ class GENN_EXPORT SynapseGroupMerged : public GroupMerged
std::string getDendriticDelayOffset(const std::string &offset = "") const;
+ //! Should the weight update model parameter be implemented heterogeneously?
+ bool isWUParamHeterogeneous(size_t paramIndex) const;
+
+ //! Should the weight update model derived parameter be implemented heterogeneously?
+ bool isWUDerivedParamHeterogeneous(size_t paramIndex) const;
+
+ //! Should the GLOBALG weight update model variable be implemented heterogeneously?
+ bool isWUGlobalVarHeterogeneous(size_t varIndex) const;
+
+ //! Should the weight update model variable initialization parameter be implemented heterogeneously?
+ bool isWUVarInitParamHeterogeneous(size_t varIndex, size_t paramIndex) const;
+
+ //! Should the weight update model variable initialization derived parameter be implemented heterogeneously?
+ bool isWUVarInitDerivedParamHeterogeneous(size_t varIndex, size_t paramIndex) const;
+
+ //! Should the connectivity initialization parameter be implemented heterogeneously?
+ bool isConnectivityInitParamHeterogeneous(size_t paramIndex) const;
+
+ //! Should the connectivity initialization parameter be implemented heterogeneously?
+ bool isConnectivityInitDerivedParamHeterogeneous(size_t paramIndex) const;
+
+ //! Is presynaptic neuron parameter heterogeneous
+ bool isSrcNeuronParamHeterogeneous(size_t paramIndex) const;
+
+ //! Is presynaptic neuron derived parameter heterogeneous
+ bool isSrcNeuronDerivedParamHeterogeneous(size_t paramIndex) const;
+
+ //! Is postsynaptic neuron parameter heterogeneous
+ bool isTrgNeuronParamHeterogeneous(size_t paramIndex) const;
+
+ //! Is postsynaptic neuron derived parameter heterogeneous
+ bool isTrgNeuronDerivedParamHeterogeneous(size_t paramIndex) const;
+
+protected:
+ //----------------------------------------------------------------------------
+ // Enumerations
+ //----------------------------------------------------------------------------
+ enum class Role
+ {
+ PresynapticUpdate,
+ PostsynapticUpdate,
+ SynapseDynamics,
+ DenseInit,
+ SparseInit,
+ };
+
+ SynapseGroupMergedBase(size_t index, const std::vector> &groups)
+ : GroupMerged(index, groups)
+ {}
+
+ //----------------------------------------------------------------------------
+ // Declared virtuals
+ //----------------------------------------------------------------------------
+ virtual std::string getArchetypeCode() const { return ""; }
+
+ //----------------------------------------------------------------------------
+ // Protected methods
+ //----------------------------------------------------------------------------
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision, const std::string &name, Role role) const;
+private:
+ //------------------------------------------------------------------------
+ // Private methods
+ //------------------------------------------------------------------------
+ void addPSPointerField(MergedStructGenerator &gen,
+ const std::string &type, const std::string &name, const std::string &prefix) const;
+ void addSrcPointerField(MergedStructGenerator &gen,
+ const std::string &type, const std::string &name, const std::string &prefix) const;
+ void addTrgPointerField(MergedStructGenerator &gen,
+ const std::string &type, const std::string &name, const std::string &prefix) const;
+ void addWeightSharingPointerField(MergedStructGenerator &gen,
+ const std::string &type, const std::string &name, const std::string &prefix) const;
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::PresynapticUpdateGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT PresynapticUpdateGroupMerged : public SynapseGroupMergedBase
+{
+public:
+ PresynapticUpdateGroupMerged(size_t index, const std::vector> &groups)
+ : SynapseGroupMergedBase(index, groups)
+ {}
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision) const
+ {
+ SynapseGroupMergedBase::generate(backend, definitionsInternal, definitionsInternalFunc,
+ definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc,
+ mergedStructData, precision, timePrecision,
+ "PresynapticUpdate", SynapseGroupMergedBase::Role::PresynapticUpdate);
+ }
+
+protected:
+ //----------------------------------------------------------------------------
+ // SynapseGroupMergedBase virtuals
+ //----------------------------------------------------------------------------
+ virtual std::string getArchetypeCode() const override
+ {
+ // **NOTE** we concatenate sim code, event code and threshold code so all get tested
+ return getArchetype().getWUModel()->getSimCode() + getArchetype().getWUModel()->getEventCode() + getArchetype().getWUModel()->getEventThresholdConditionCode();
+ }
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::PostsynapticUpdateGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT PostsynapticUpdateGroupMerged : public SynapseGroupMergedBase
+{
+public:
+ PostsynapticUpdateGroupMerged(size_t index, const std::vector> &groups)
+ : SynapseGroupMergedBase(index, groups)
+ {}
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision) const
+ {
+ SynapseGroupMergedBase::generate(backend, definitionsInternal, definitionsInternalFunc,
+ definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc,
+ mergedStructData, precision, timePrecision,
+ "PostsynapticUpdate", SynapseGroupMergedBase::Role::PostsynapticUpdate);
+ }
+
+protected:
+ //----------------------------------------------------------------------------
+ // SynapseGroupMergedBase virtuals
+ //----------------------------------------------------------------------------
+ virtual std::string getArchetypeCode() const override
+ {
+ return getArchetype().getWUModel()->getLearnPostCode();
+ }
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::SynapseDynamicsGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT SynapseDynamicsGroupMerged : public SynapseGroupMergedBase
+{
+public:
+ SynapseDynamicsGroupMerged(size_t index, const std::vector> &groups)
+ : SynapseGroupMergedBase(index, groups)
+ {}
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision) const
+ {
+ SynapseGroupMergedBase::generate(backend, definitionsInternal, definitionsInternalFunc,
+ definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc,
+ mergedStructData, precision, timePrecision,
+ "SynapseDynamics", SynapseGroupMergedBase::Role::SynapseDynamics);
+ }
+
+protected:
+ //----------------------------------------------------------------------------
+ // SynapseGroupMergedBase virtuals
+ //----------------------------------------------------------------------------
+ virtual std::string getArchetypeCode() const override
+ {
+ return getArchetype().getWUModel()->getSynapseDynamicsCode();
+ }
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::SynapseDenseInitGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT SynapseDenseInitGroupMerged : public SynapseGroupMergedBase
+{
+public:
+ SynapseDenseInitGroupMerged(size_t index, const std::vector> &groups)
+ : SynapseGroupMergedBase(index, groups)
+ {}
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision) const
+ {
+ SynapseGroupMergedBase::generate(backend, definitionsInternal, definitionsInternalFunc,
+ definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc,
+ mergedStructData, precision, timePrecision,
+ "SynapseDenseInit", SynapseGroupMergedBase::Role::DenseInit);
+ }
+};
+
+//----------------------------------------------------------------------------
+// CodeGenerator::SynapseSparseInitGroupMerged
+//----------------------------------------------------------------------------
+class GENN_EXPORT SynapseSparseInitGroupMerged : public SynapseGroupMergedBase
+{
+public:
+ SynapseSparseInitGroupMerged(size_t index, const std::vector> &groups)
+ : SynapseGroupMergedBase(index, groups)
+ {}
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &precision,
+ const std::string &timePrecision) const
+ {
+ SynapseGroupMergedBase::generate(backend, definitionsInternal, definitionsInternalFunc,
+ definitionsInternalVar, runnerVarDecl, runnerMergedStructAlloc,
+ mergedStructData, precision, timePrecision,
+ "SynapseSparseInit", SynapseGroupMergedBase::Role::SparseInit);
+ }
};
} // namespace CodeGenerator
diff --git a/include/genn/genn/code_generator/mergedStructGenerator.h b/include/genn/genn/code_generator/mergedStructGenerator.h
index b756bcc39f..4d75457833 100644
--- a/include/genn/genn/code_generator/mergedStructGenerator.h
+++ b/include/genn/genn/code_generator/mergedStructGenerator.h
@@ -9,6 +9,7 @@
#include "gennUtils.h"
// GeNN code generator includes
+#include "code_generator/codeGenUtils.h"
#include "code_generator/codeStream.h"
#include "code_generator/groupMerged.h"
@@ -35,8 +36,8 @@ class MergedStructGenerator
// Typedefines
//------------------------------------------------------------------------
typedef std::function GetFieldValueFunc;
-
- MergedStructGenerator(const T &mergedGroup) : m_MergedGroup(mergedGroup)
+
+ MergedStructGenerator(const T &mergedGroup, const std::string &precision) : m_MergedGroup(mergedGroup), m_LiteralSuffix((precision == "float") ? "f" : "")
{
}
@@ -48,78 +49,205 @@ class MergedStructGenerator
m_Fields.emplace_back(type, name, getFieldValue, fieldType);
}
+ void addScalarField(const std::string &name, GetFieldValueFunc getFieldValue, FieldType fieldType = FieldType::Standard)
+ {
+ addField("scalar", name,
+ [getFieldValue, this](const typename T::GroupInternal &g, size_t i)
+ {
+ return getFieldValue(g, i) + m_LiteralSuffix;
+ },
+ fieldType);
+ }
+
void addPointerField(const std::string &type, const std::string &name, const std::string &prefix)
{
assert(!Utils::isTypePointer(type));
addField(type + "*", name, [prefix](const typename T::GroupInternal &g, size_t){ return prefix + g.getName(); });
}
- void addVars(const std::vector &vars, const std::string &prefix)
+
+ void addVars(const Models::Base::VarVec &vars, const std::string &arrayPrefix)
{
+ // Loop through variables
for(const auto &v : vars) {
- addPointerField(v.type, v.name, prefix + v.name);
+ addPointerField(v.type, v.name, arrayPrefix + v.name);
+
}
}
- void addEGPs(const std::vector &egps)
+ void addEGPs(const Snippet::Base::EGPVec &egps, const std::string &arrayPrefix, const std::string &varName = "")
{
for(const auto &e : egps) {
- addField(e.type, e.name,
- [e](const typename T::GroupInternal &g, size_t){ return e.name + g.getName(); },
- Utils::isTypePointer(e.type) ? FieldType::PointerEGP : FieldType::ScalarEGP);
+ const bool isPointer = Utils::isTypePointer(e.type);
+ const std::string prefix = isPointer ? arrayPrefix : "";
+ addField(e.type, e.name + varName,
+ [e, prefix, varName](const typename T::GroupInternal &g, size_t){ return prefix + e.name + varName + g.getName(); },
+ isPointer ? FieldType::PointerEGP : FieldType::ScalarEGP);
+ }
+ }
+
+ template
+ void addHeterogeneousParams(const Snippet::Base::StringVec ¶mNames,
+ G getParamValues, H isHeterogeneous)
+ {
+ // Loop through params
+ for(size_t p = 0; p < paramNames.size(); p++) {
+ // If parameters is heterogeneous
+ if((getMergedGroup().*isHeterogeneous)(p)) {
+ // Add field
+ addScalarField(paramNames[p],
+ [p, getParamValues](const typename T::GroupInternal &g, size_t)
+ {
+ const auto &values = getParamValues(g);
+ return Utils::writePreciseString(values.at(p));
+ });
+ }
+ }
+ }
+
+ template
+ void addHeterogeneousDerivedParams(const Snippet::Base::DerivedParamVec &derivedParams,
+ G getDerivedParamValues, H isHeterogeneous)
+ {
+ // Loop through derived params
+ for(size_t p = 0; p < derivedParams.size(); p++) {
+ // If parameters isn't homogeneous
+ if((getMergedGroup().*isHeterogeneous)(p)) {
+ // Add field
+ addScalarField(derivedParams[p].name,
+ [p, getDerivedParamValues](const typename T::GroupInternal &g, size_t)
+ {
+ const auto &values = getDerivedParamValues(g);
+ return Utils::writePreciseString(values.at(p));
+ });
+ }
}
}
- void generate(CodeGenerator::CodeStream &definitionsInternal, CodeGenerator::CodeStream &definitionsInternalFunc,
- CodeGenerator::CodeStream &runnerVarAlloc, CodeGenerator::MergedEGPMap &mergedEGPs, const std::string &name)
+ template
+ void addHeterogeneousVarInitParams(const Models::Base::VarVec &vars, V getVarInitialisers, H isHeterogeneous)
{
- const size_t index = getMergedGroup().getIndex();
+ // Loop through weight update model variables
+ const std::vector &archetypeVarInitialisers = (getMergedGroup().getArchetype().*getVarInitialisers)();
+ for(size_t v = 0; v < archetypeVarInitialisers.size(); v++) {
+ // Loop through parameters
+ const Models::VarInit &varInit = archetypeVarInitialisers[v];
+ for(size_t p = 0; p < varInit.getParams().size(); p++) {
+ if((getMergedGroup().*isHeterogeneous)(v, p)) {
+ addScalarField(varInit.getSnippet()->getParamNames()[p] + vars[v].name,
+ [p, v, getVarInitialisers](const typename T::GroupInternal &g, size_t)
+ {
+ const auto &values = (g.*getVarInitialisers)()[v].getParams();
+ return Utils::writePreciseString(values.at(p));
+ });
+ }
+ }
+ }
+ }
+
+ template
+ void addHeterogeneousVarInitDerivedParams(const Models::Base::VarVec &vars, V getVarInitialisers, H isHeterogeneous)
+ {
+ // Loop through weight update model variables
+ const std::vector &archetypeVarInitialisers = (getMergedGroup().getArchetype().*getVarInitialisers)();
+ for(size_t v = 0; v < archetypeVarInitialisers.size(); v++) {
+ // Loop through parameters
+ const Models::VarInit &varInit = archetypeVarInitialisers[v];
+ for(size_t d = 0; d < varInit.getDerivedParams().size(); d++) {
+ if((getMergedGroup().*isHeterogeneous)(v, d)) {
+ addScalarField(varInit.getSnippet()->getDerivedParams()[d].name + vars[v].name,
+ [d, v, getVarInitialisers](const typename T::GroupInternal &g, size_t)
+ {
+ const auto &values = (g.*getVarInitialisers)()[v].getDerivedParams();
+ return Utils::writePreciseString(values.at(d));
+ });
+ }
+ }
+ }
+ }
+
+ void generate(const BackendBase &backend, CodeStream &definitionsInternal,
+ CodeStream &definitionsInternalFunc, CodeStream &definitionsInternalVar,
+ CodeStream &runnerVarDecl, CodeStream &runnerMergedStructAlloc,
+ MergedStructData &mergedStructData, const std::string &name, bool host = false) const
+ {
+ const size_t mergedGroupIndex = getMergedGroup().getIndex();
+
+ // Make a copy of fields and sort so largest come first. This should mean that due
+ // to structure packing rules, significant memory is saved and estimate is more precise
+ auto sortedFields = m_Fields;
+ std::sort(sortedFields.begin(), sortedFields.end(),
+ [&backend](const Field &a, const Field &b)
+ {
+ return (backend.getSize(std::get<0>(a)) > backend.getSize(std::get<0>(b)));
+ });
// Write struct declation to top of definitions internal
- definitionsInternal << "struct Merged" << name << "Group" << index << std::endl;
+ size_t structSize = 0;
+ size_t largestFieldSize = 0;
+ definitionsInternal << "struct Merged" << name << "Group" << mergedGroupIndex << std::endl;
{
- CodeGenerator::CodeStream::Scope b(definitionsInternal);
- for(const auto &f : m_Fields) {
+ CodeStream::Scope b(definitionsInternal);
+ for(const auto &f : sortedFields) {
+ // Add field to structure
definitionsInternal << std::get<0>(f) << " " << std::get<1>(f) << ";" << std::endl;
+
+ // Add size of field to total
+ const size_t fieldSize = backend.getSize(std::get<0>(f));
+ structSize += fieldSize;
+
+ // Update largest field size
+ largestFieldSize = std::max(fieldSize, largestFieldSize);
+
+ // If this field is for a pointer EGP, also declare function to push it
+ if(std::get<3>(f) == FieldType::PointerEGP) {
+ definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << mergedGroupIndex << std::get<1>(f) << "ToDevice(unsigned int idx, " << std::get<0>(f) << " value);" << std::endl;
+ }
}
definitionsInternal << std::endl;
}
definitionsInternal << ";" << std::endl;
- // Write local array of these structs containing individual neuron group pointers etc
- // **NOTE** scope will hopefully reduce stack usage
- {
- CodeStream::Scope b(runnerVarAlloc);
- runnerVarAlloc << "Merged" << name << "Group" << index << " merged" << name << "Group" << index << "[] = ";
- {
- CodeGenerator::CodeStream::Scope b(runnerVarAlloc);
- for(size_t i = 0; i < getMergedGroup().getGroups().size(); i++) {
- const auto &g = getMergedGroup().getGroups()[i];
-
- // Add all fields to merged group array
- runnerVarAlloc << "{";
- for(const auto &f : m_Fields) {
- const std::string fieldInitVal = std::get<2>(f)(g, i);
- runnerVarAlloc << fieldInitVal << ", ";
-
- // If field is an EGP, add record to merged EGPS
- if(std::get<3>(f) != FieldType::Standard) {
- mergedEGPs[fieldInitVal].emplace(
- std::piecewise_construct, std::forward_as_tuple(name),
- std::forward_as_tuple(index, i, (std::get<3>(f) == FieldType::PointerEGP), std::get<1>(f)));
- }
- }
- runnerVarAlloc << "}," << std::endl;
+ // Add total size of array of merged structures to merged struct data
+ // **NOTE** to match standard struct packing rules we pad to a multiple of the largest field size
+ const size_t arraySize = padSize(structSize, largestFieldSize) * getMergedGroup().getGroups().size();
+ mergedStructData.addMergedGroupSize(name, mergedGroupIndex, arraySize);
+
+ // Declare array of these structs containing individual neuron group pointers etc
+ runnerVarDecl << "Merged" << name << "Group" << mergedGroupIndex << " merged" << name << "Group" << mergedGroupIndex << "[" << getMergedGroup().getGroups().size() << "];" << std::endl;
+
+ for(size_t groupIndex = 0; groupIndex < getMergedGroup().getGroups().size(); groupIndex++) {
+ const auto &g = getMergedGroup().getGroups()[groupIndex];
+
+ // Set all fields in array of structs
+ runnerMergedStructAlloc << "merged" << name << "Group" << mergedGroupIndex << "[" << groupIndex << "] = {";
+ for(const auto &f : sortedFields) {
+ const std::string fieldInitVal = std::get<2>(f)(g, groupIndex);
+ runnerMergedStructAlloc << fieldInitVal << ", ";
+
+ // If field is an EGP, add record to merged EGPS
+ if(std::get<3>(f) != FieldType::Standard) {
+ mergedStructData.addMergedEGP(fieldInitVal, name, mergedGroupIndex, groupIndex,
+ std::get<0>(f), std::get<1>(f));
}
}
- runnerVarAlloc << ";" << std::endl;
+ runnerMergedStructAlloc << "};" << std::endl;
+ }
+
+ // If this is a host merged struct, export the variable
+ if(host) {
+ definitionsInternalVar << "EXPORT_VAR Merged" << name << "Group" << mergedGroupIndex << " merged" << name << "Group" << mergedGroupIndex << "[" << getMergedGroup().getGroups().size() << "]; " << std::endl;
+ }
+ // Otherwise
+ else {
// Then generate call to function to copy local array to device
- runnerVarAlloc << "pushMerged" << name << "Group" << index << "ToDevice(merged" << name << "Group" << index << ");" << std::endl;
+ runnerMergedStructAlloc << "pushMerged" << name << "Group" << mergedGroupIndex << "ToDevice(merged" << name << "Group" << mergedGroupIndex << ");" << std::endl;
+
+ // Finally add declaration to function to definitions internal
+ definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << "Group" << mergedGroupIndex << "ToDevice(const Merged" << name << "Group" << mergedGroupIndex << " *group);" << std::endl;
}
- // Finally add declaration to function to definitions internal
- definitionsInternalFunc << "EXPORT_FUNC void pushMerged" << name << "Group" << index << "ToDevice(const Merged" << name << "Group" << index << " *group);" << std::endl;
}
protected:
@@ -130,106 +258,15 @@ class MergedStructGenerator
private:
//------------------------------------------------------------------------
- // Members
- //------------------------------------------------------------------------
- const T &m_MergedGroup;
- std::vector> m_Fields;
-};
-
-//--------------------------------------------------------------------------
-// CodeGenerator::MergedNeuronStructGenerator
-//--------------------------------------------------------------------------
-class MergedNeuronStructGenerator : public MergedStructGenerator
-{
-public:
- MergedNeuronStructGenerator(const CodeGenerator::NeuronGroupMerged &mergedGroup)
- : MergedStructGenerator(mergedGroup)
- {
- }
-
- //------------------------------------------------------------------------
- // Public API
+ // Typedefines
//------------------------------------------------------------------------
- void addMergedInSynPointerField(const std::string &type, const std::string &name, size_t archetypeIndex, const std::string &prefix,
- const std::vector>>> &sortedMergedInSyns)
- {
- assert(!Utils::isTypePointer(type));
- addField(type + "*", name + std::to_string(archetypeIndex),
- [prefix, &sortedMergedInSyns, archetypeIndex](const NeuronGroupInternal&, size_t groupIndex)
- {
- return prefix + sortedMergedInSyns.at(groupIndex).at(archetypeIndex).first->getPSModelTargetName();
- });
- }
-
- void addCurrentSourcePointerField(const std::string &type, const std::string &name, size_t archetypeIndex, const std::string &prefix,
- const std::vector> &sortedCurrentSources)
- {
- assert(!Utils::isTypePointer(type));
- addField(type + "*", name + std::to_string(archetypeIndex),
- [prefix, &sortedCurrentSources, archetypeIndex](const NeuronGroupInternal&, size_t groupIndex)
- {
- return prefix + sortedCurrentSources.at(groupIndex).at(archetypeIndex)->getName();
- });
- }
-
- void addSynPointerField(const std::string &type, const std::string &name, size_t archetypeIndex, const std::string &prefix,
- const std::vector> &sortedSyn)
- {
- assert(!Utils::isTypePointer(type));
- addField(type + "*", name + std::to_string(archetypeIndex),
- [prefix, &sortedSyn, archetypeIndex](const NeuronGroupInternal&, size_t groupIndex)
- {
- return prefix + sortedSyn.at(groupIndex).at(archetypeIndex)->getName();
- });
-
- }
-};
-//--------------------------------------------------------------------------
-// CodeGenerator::MergedSynapseStructGenerator
-//--------------------------------------------------------------------------
-class MergedSynapseStructGenerator : public MergedStructGenerator
-{
-public:
- MergedSynapseStructGenerator(const CodeGenerator::SynapseGroupMerged &mergedGroup)
- : MergedStructGenerator(mergedGroup)
- {
- }
+ typedef std::tuple Field;
//------------------------------------------------------------------------
- // Public API
+ // Members
//------------------------------------------------------------------------
- void addPSPointerField(const std::string &type, const std::string &name, const std::string &prefix)
- {
- assert(!Utils::isTypePointer(type));
- addField(type + "*", name, [prefix](const SynapseGroupInternal &sg, size_t){ return prefix + sg.getPSModelTargetName(); });
- }
-
- void addSrcPointerField(const std::string &type, const std::string &name, const std::string &prefix)
- {
- assert(!Utils::isTypePointer(type));
- addField(type + "*", name, [prefix](const SynapseGroupInternal &sg, size_t){ return prefix + sg.getSrcNeuronGroup()->getName(); });
- }
-
- void addTrgPointerField(const std::string &type, const std::string &name, const std::string &prefix)
- {
- assert(!Utils::isTypePointer(type));
- addField(type + "*", name, [prefix](const SynapseGroupInternal &sg, size_t){ return prefix + sg.getTrgNeuronGroup()->getName(); });
- }
-
- void addSrcEGPField(const Snippet::Base::EGP &egp)
- {
- addField(egp.type, egp.name + "Pre",
- [egp](const SynapseGroupInternal &sg, size_t){ return egp.name + sg.getSrcNeuronGroup()->getName(); },
- Utils::isTypePointer(egp.type) ? FieldType::PointerEGP : FieldType::ScalarEGP);
- }
-
- void addTrgEGPField(const Snippet::Base::EGP &egp)
- {
- addField(egp.type, egp.name + "Post",
- [egp](const SynapseGroupInternal &sg, size_t){ return egp.name + sg.getTrgNeuronGroup()->getName(); },
- Utils::isTypePointer(egp.type) ? FieldType::PointerEGP : FieldType::ScalarEGP);
- }
-
+ const T &m_MergedGroup;
+ const std::string m_LiteralSuffix;
+ std::vector m_Fields;
};
-
} // namespace CodeGenerator
diff --git a/include/genn/genn/code_generator/modelSpecMerged.h b/include/genn/genn/code_generator/modelSpecMerged.h
index 89f5ac1a05..2c9519fb60 100644
--- a/include/genn/genn/code_generator/modelSpecMerged.h
+++ b/include/genn/genn/code_generator/modelSpecMerged.h
@@ -33,34 +33,37 @@ class ModelSpecMerged
const ModelSpecInternal &getModel() const{ return m_Model; }
//! Get merged neuron groups which require updating
- const std::vector &getMergedNeuronUpdateGroups() const{ return m_MergedNeuronUpdateGroups; }
+ const std::vector &getMergedNeuronUpdateGroups() const{ return m_MergedNeuronUpdateGroups; }
//! Get merged synapse groups which require presynaptic updates
- const std::vector &getMergedPresynapticUpdateGroups() const{ return m_MergedPresynapticUpdateGroups; }
+ const std::vector &getMergedPresynapticUpdateGroups() const{ return m_MergedPresynapticUpdateGroups; }
//! Get merged synapse groups which require postsynaptic updates
- const std::vector &getMergedPostsynapticUpdateGroups() const{ return m_MergedPostsynapticUpdateGroups; }
+ const std::vector &getMergedPostsynapticUpdateGroups() const{ return m_MergedPostsynapticUpdateGroups; }
//! Get merged synapse groups which require synapse dynamics
- const std::vector &getMergedSynapseDynamicsGroups() const{ return m_MergedSynapseDynamicsGroups; }
+ const std::vector &getMergedSynapseDynamicsGroups() const{ return m_MergedSynapseDynamicsGroups; }
//! Get merged neuron groups which require initialisation
- const std::vector &getMergedNeuronInitGroups() const{ return m_MergedNeuronInitGroups; }
+ const std::vector &getMergedNeuronInitGroups() const{ return m_MergedNeuronInitGroups; }
//! Get merged synapse groups with dense connectivity which require initialisation
- const std::vector &getMergedSynapseDenseInitGroups() const{ return m_MergedSynapseDenseInitGroups; }
+ const std::vector &getMergedSynapseDenseInitGroups() const{ return m_MergedSynapseDenseInitGroups; }
//! Get merged synapse groups which require connectivity initialisation
- const std::vector &getMergedSynapseConnectivityInitGroups() const{ return m_MergedSynapseConnectivityInitGroups; }
+ const std::vector &getMergedSynapseConnectivityInitGroups() const{ return m_MergedSynapseConnectivityInitGroups; }
//! Get merged synapse groups with sparse connectivity which require initialisation
- const std::vector &getMergedSynapseSparseInitGroups() const{ return m_MergedSynapseSparseInitGroups; }
+ const std::vector &getMergedSynapseSparseInitGroups() const{ return m_MergedSynapseSparseInitGroups; }
//! Get merged neuron groups which require their spike queues updating
- const std::vector &getMergedNeuronSpikeQueueUpdateGroups() const { return m_MergedNeuronSpikeQueueUpdateGroups; }
+ const std::vector &getMergedNeuronSpikeQueueUpdateGroups() const { return m_MergedNeuronSpikeQueueUpdateGroups; }
//! Get merged synapse groups which require their dendritic delay updating
- const std::vector &getMergedSynapseDendriticDelayUpdateGroups() const { return m_MergedSynapseDendriticDelayUpdateGroups; }
+ const std::vector &getMergedSynapseDendriticDelayUpdateGroups() const { return m_MergedSynapseDendriticDelayUpdateGroups; }
+
+ //! Merged synapse groups which require host code to initialise their synaptic connectivity
+ const std::vector &getMergedSynapseConnectivityHostInitGroups() const{ return m_MergedSynapseConnectivityHostInitGroups; }
void genNeuronUpdateGroupSupportCode(CodeStream &os) const{ m_NeuronUpdateSupportCode.gen(os, getModel().getPrecision()); }
@@ -90,34 +93,37 @@ class ModelSpecMerged
const ModelSpecInternal &m_Model;
//! Merged neuron groups which require updating
- std::vector m_MergedNeuronUpdateGroups;
+ std::vector m_MergedNeuronUpdateGroups;
//! Merged synapse groups which require presynaptic updates
- std::vector m_MergedPresynapticUpdateGroups;
+ std::vector m_MergedPresynapticUpdateGroups;
//! Merged synapse groups which require postsynaptic updates
- std::vector m_MergedPostsynapticUpdateGroups;
+ std::vector m_MergedPostsynapticUpdateGroups;
//! Merged synapse groups which require synapse dynamics update
- std::vector m_MergedSynapseDynamicsGroups;
+ std::vector m_MergedSynapseDynamicsGroups;
//! Merged neuron groups which require initialisation
- std::vector m_MergedNeuronInitGroups;
+ std::vector m_MergedNeuronInitGroups;
//! Merged synapse groups with dense connectivity which require initialisation
- std::vector m_MergedSynapseDenseInitGroups;
+ std::vector m_MergedSynapseDenseInitGroups;
//! Merged synapse groups which require connectivity initialisation
- std::vector m_MergedSynapseConnectivityInitGroups;
+ std::vector m_MergedSynapseConnectivityInitGroups;
//! Merged synapse groups with sparse connectivity which require initialisation
- std::vector m_MergedSynapseSparseInitGroups;
+ std::vector m_MergedSynapseSparseInitGroups;
//! Merged neuron groups which require their spike queues updating
- std::vector m_MergedNeuronSpikeQueueUpdateGroups;
+ std::vector m_MergedNeuronSpikeQueueUpdateGroups;
//! Merged synapse groups which require their dendritic delay updating
- std::vector m_MergedSynapseDendriticDelayUpdateGroups;
+ std::vector m_MergedSynapseDendriticDelayUpdateGroups;
+
+ //! Merged synapse groups which require host code to initialise their synaptic connectivity
+ std::vector m_MergedSynapseConnectivityHostInitGroups;
//! Unique support code strings for neuron update
SupportCodeMerged m_NeuronUpdateSupportCode;
diff --git a/include/genn/genn/code_generator/substitutions.h b/include/genn/genn/code_generator/substitutions.h
index 67f11788d3..e6a8e39295 100644
--- a/include/genn/genn/code_generator/substitutions.h
+++ b/include/genn/genn/code_generator/substitutions.h
@@ -9,22 +9,42 @@
#include
// GeNN includes
-#include "logging.h"
-
-// GeNN code generator includes
-#include "codeGenUtils.h"
-
-// GeNN includes
+#include "gennExport.h"
#include "gennUtils.h"
+#include "logging.h"
//--------------------------------------------------------------------------
// Substitutions
//--------------------------------------------------------------------------
namespace CodeGenerator
{
-class Substitutions
+class GENN_EXPORT Substitutions
{
public:
+ //! Immutable structure for specifying how to implement
+ //! a generic function e.g. gennrand_uniform
+ /*! **NOTE** for the sake of easy initialisation first two parameters of GenericFunction are repeated (C++17 fixes) */
+ struct FunctionTemplate
+ {
+ // **HACK** while GCC and CLang automatically generate this fine/don't require it, VS2013 seems to need it
+ FunctionTemplate operator = (const FunctionTemplate &o)
+ {
+ return FunctionTemplate{o.genericName, o.numArguments, o.doublePrecisionTemplate, o.singlePrecisionTemplate};
+ }
+
+ //! Generic name used to refer to function in user code
+ const std::string genericName;
+
+ //! Number of function arguments
+ const unsigned int numArguments;
+
+ //! The function template (for use with ::functionSubstitute) used when model uses double precision
+ const std::string doublePrecisionTemplate;
+
+ //! The function template (for use with ::functionSubstitute) used when model uses single precision
+ const std::string singlePrecisionTemplate;
+ };
+
Substitutions(const Substitutions *parent = nullptr) : m_Parent(parent)
{
assert(m_Parent != this);
@@ -52,15 +72,6 @@ class Substitutions
}
}
- void addParamNameSubstitution(const std::vector ¶mNames, const std::string &sourceSuffix = "",
- const std::string &destPrefix = "", const std::string &destSuffix = "")
- {
- for(const auto &p : paramNames) {
- addVarSubstitution(p + sourceSuffix,
- destPrefix + p + destSuffix);
- }
- }
-
template
void addVarValueSubstitution(const std::vector &variables, const std::vector &values,
const std::string &sourceSuffix = "")
@@ -78,71 +89,56 @@ class Substitutions
}
void addParamValueSubstitution(const std::vector ¶mNames, const std::vector &values,
- const std::string &sourceSuffix = "")
+ const std::string &sourceSuffix = "");
+
+ template
+ void addParamValueSubstitution(const std::vector ¶mNames, const std::vector &values, G isHeterogeneousFn,
+ const std::string &sourceSuffix = "", const std::string &destPrefix = "", const std::string &destSuffix = "")
{
if(paramNames.size() != values.size()) {
throw std::runtime_error("Number of parameters does not match number of values");
}
- auto param = paramNames.cbegin();
- auto val = values.cbegin();
- for (;param != paramNames.cend() && val != values.cend(); param++, val++) {
- addVarSubstitution(*param + sourceSuffix,
- "(" + Utils::writePreciseString(*val) + ")");
+ for(size_t i = 0; i < paramNames.size(); i++) {
+ if(isHeterogeneousFn(i)) {
+ addVarSubstitution(paramNames[i] + sourceSuffix,
+ destPrefix + paramNames[i] + destSuffix);
+ }
+ else {
+ addVarSubstitution(paramNames[i] + sourceSuffix,
+ "(" + Utils::writePreciseString(values[i]) + ")");
+ }
}
-
}
- void addVarSubstitution(const std::string &source, const std::string &destionation, bool allowOverride = false)
+ template
+ void addVarValueSubstitution(const std::vector &variables, const std::vector &values, G isHeterogeneousFn,
+ const std::string &sourceSuffix = "", const std::string &destPrefix = "", const std::string &destSuffix = "")
{
- auto res = m_VarSubstitutions.emplace(source, destionation);
- if(!allowOverride && !res.second) {
- throw std::runtime_error("'" + source + "' already has a variable substitution");
+ if(variables.size() != values.size()) {
+ throw std::runtime_error("Number of variables does not match number of values");
}
- }
- void addFuncSubstitution(const std::string &source, unsigned int numArguments, const std::string &funcTemplate, bool allowOverride = false)
- {
- auto res = m_FuncSubstitutions.emplace(std::piecewise_construct,
- std::forward_as_tuple(source),
- std::forward_as_tuple(numArguments, funcTemplate));
- if(!allowOverride && !res.second) {
- throw std::runtime_error("'" + source + "' already has a function substitution");
+ for(size_t i = 0; i < variables.size(); i++) {
+ if(isHeterogeneousFn(i)) {
+ addVarSubstitution(variables[i].name + sourceSuffix,
+ destPrefix + variables[i].name + destSuffix);
+ }
+ else {
+ addVarSubstitution(variables[i].name + sourceSuffix,
+ "(" + Utils::writePreciseString(values[i]) + ")");
+ }
}
}
- bool hasVarSubstitution(const std::string &source) const
- {
- return (m_VarSubstitutions.find(source) != m_VarSubstitutions.end());
- }
-
- const std::string &getVarSubstitution(const std::string &source) const
- {
- auto var = m_VarSubstitutions.find(source);
- if(var != m_VarSubstitutions.end()) {
- return var->second;
- }
- else if(m_Parent) {
- return m_Parent->getVarSubstitution(source);
- }
- else {
- throw std::runtime_error("Nothing to substitute for '" + source + "'");
- }
- }
+ void addVarSubstitution(const std::string &source, const std::string &destionation, bool allowOverride = false);
+ void addFuncSubstitution(const std::string &source, unsigned int numArguments, const std::string &funcTemplate, bool allowOverride = false);
+ bool hasVarSubstitution(const std::string &source) const;
- void apply(std::string &code) const
- {
- // Apply function and variable substitutions
- // **NOTE** functions may contain variables so evaluate ALL functions first
- applyFuncs(code);
- applyVars(code);
- }
+ const std::string &getVarSubstitution(const std::string &source) const;
- void applyCheckUnreplaced(std::string &code, const std::string &context) const
- {
- apply(code);
- checkUnreplacedVariables(code, context);
- }
+ void apply(std::string &code) const;
+ void applyCheckUnreplaced(std::string &code, const std::string &context) const;
//--------------------------------------------------------------------------
// Public API
@@ -156,32 +152,8 @@ class Substitutions
//--------------------------------------------------------------------------
// Private API
//--------------------------------------------------------------------------
- void applyFuncs(std::string &code) const
- {
- // Apply function substitutions
- for(const auto &f : m_FuncSubstitutions) {
- functionSubstitute(code, f.first, f.second.first, f.second.second);
- }
-
- // If we have a parent, apply their function substitutions too
- if(m_Parent) {
- m_Parent->applyFuncs(code);
- }
- }
-
- void applyVars(std::string &code) const
- {
- // Apply variable substitutions
- for(const auto &v : m_VarSubstitutions) {
- LOGD_CODE_GEN << "Substituting '$(" << v.first << ")' for '" << v.second << "'";
- substitute(code, "$(" + v.first + ")", v.second);
- }
-
- // If we have a parent, apply their variable substitutions too
- if(m_Parent) {
- m_Parent->applyVars(code);
- }
- }
+ void applyFuncs(std::string &code) const;
+ void applyVars(std::string &code) const;
//--------------------------------------------------------------------------
// Members
diff --git a/include/genn/genn/gennUtils.h b/include/genn/genn/gennUtils.h
index 144d32c22d..61ed256734 100644
--- a/include/genn/genn/gennUtils.h
+++ b/include/genn/genn/gennUtils.h
@@ -36,6 +36,11 @@ GENN_EXPORT bool isRNGRequired(const std::vector &varInitialise
//--------------------------------------------------------------------------
GENN_EXPORT bool isTypePointer(const std::string &type);
+//--------------------------------------------------------------------------
+//! \brief Function to determine whether a string containing a type is a pointer to a pointer
+//--------------------------------------------------------------------------
+GENN_EXPORT bool isTypePointerToPointer(const std::string &type);
+
//--------------------------------------------------------------------------
//! \brief Assuming type is a string containing a pointer type, function to return the underlying type
//--------------------------------------------------------------------------
@@ -45,7 +50,7 @@ GENN_EXPORT std::string getUnderlyingType(const std::string &type);
//! \brief This function writes a floating point value to a stream -setting the precision so no digits are lost
//--------------------------------------------------------------------------
template::value>::type * = nullptr>
-GENN_EXPORT void writePreciseString(std::ostream &os, T value)
+void writePreciseString(std::ostream &os, T value)
{
// Cache previous precision
const std::streamsize previousPrecision = os.precision();
@@ -72,7 +77,7 @@ GENN_EXPORT void writePreciseString(std::ostream &os, T value)
//! \brief This function writes a floating point value to a string - setting the precision so no digits are lost
//--------------------------------------------------------------------------
template::value>::type * = nullptr>
-GENN_EXPORT std::string writePreciseString(T value)
+std::string writePreciseString(T value)
{
std::stringstream s;
writePreciseString(s, value);
diff --git a/include/genn/genn/initSparseConnectivitySnippet.h b/include/genn/genn/initSparseConnectivitySnippet.h
index 2858a0bb3a..1a7b093f7c 100644
--- a/include/genn/genn/initSparseConnectivitySnippet.h
+++ b/include/genn/genn/initSparseConnectivitySnippet.h
@@ -18,14 +18,14 @@
#define SET_ROW_BUILD_CODE(CODE) virtual std::string getRowBuildCode() const override{ return CODE; }
#define SET_ROW_BUILD_STATE_VARS(...) virtual ParamValVec getRowBuildStateVars() const override{ return __VA_ARGS__; }
+#define SET_HOST_INIT_CODE(CODE) virtual std::string getHostInitCode() const override{ return CODE; }
+
#define SET_CALC_MAX_ROW_LENGTH_FUNC(FUNC) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return FUNC; }
#define SET_CALC_MAX_COL_LENGTH_FUNC(FUNC) virtual CalcMaxLengthFunc getCalcMaxColLengthFunc() const override{ return FUNC; }
#define SET_MAX_ROW_LENGTH(MAX_ROW_LENGTH) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return [](unsigned int, unsigned int, const std::vector &){ return MAX_ROW_LENGTH; }; }
#define SET_MAX_COL_LENGTH(MAX_COL_LENGTH) virtual CalcMaxLengthFunc getCalcMaxColLengthFunc() const override{ return [](unsigned int, unsigned int, const std::vector &){ return MAX_COL_LENGTH; }; }
-#define SET_EXTRA_GLOBAL_PARAMS(...) virtual EGPVec getExtraGlobalParams() const override{ return __VA_ARGS__; }
-
//----------------------------------------------------------------------------
// InitSparseConnectivitySnippet::Base
//----------------------------------------------------------------------------
@@ -46,25 +46,17 @@ class GENN_EXPORT Base : public Snippet::Base
virtual std::string getRowBuildCode() const{ return ""; }
virtual ParamValVec getRowBuildStateVars() const{ return {}; }
+ virtual std::string getHostInitCode() const{ return ""; }
+
//! Get function to calculate the maximum row length of this connector based on the parameters and the size of the pre and postsynaptic population
virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const{ return CalcMaxLengthFunc(); }
//! Get function to calculate the maximum column length of this connector based on the parameters and the size of the pre and postsynaptic population
virtual CalcMaxLengthFunc getCalcMaxColLengthFunc() const{ return CalcMaxLengthFunc(); }
- //! Gets names and types (as strings) of additional
- //! per-population parameters for the connection initialisation snippet
- virtual EGPVec getExtraGlobalParams() const{ return {}; }
-
//------------------------------------------------------------------------
// Public methods
//------------------------------------------------------------------------
- //! Find the index of a named extra global parameter
- size_t getExtraGlobalParamIndex(const std::string ¶mName) const
- {
- return getNamedVecIndex(paramName, getExtraGlobalParams());
- }
-
//! Can this neuron model be merged with other? i.e. can they be simulated using same generated code
bool canBeMerged(const Base *other) const;
};
@@ -79,11 +71,6 @@ class Init : public Snippet::Init
: Snippet::Init(snippet, params)
{
}
-
- bool canBeMerged(const Init &other) const
- {
- return Snippet::Init::canBeMerged(other, getSnippet()->getRowBuildCode());
- }
};
//----------------------------------------------------------------------------
@@ -287,7 +274,46 @@ class FixedNumberTotalWithReplacement : public Base
SET_ROW_BUILD_STATE_VARS({{"x", "scalar", 0.0},{"c", "unsigned int", "$(preCalcRowLength)[($(id_pre) * $(num_threads)) + $(id_thread)]"}});
SET_PARAM_NAMES({"total"});
- SET_EXTRA_GLOBAL_PARAMS({{"preCalcRowLength", "unsigned int*"}})
+ SET_EXTRA_GLOBAL_PARAMS({{"preCalcRowLength", "uint16_t*"}})
+
+ SET_HOST_INIT_CODE(
+ "// Allocate pre-calculated row length array\n"
+ "$(allocatepreCalcRowLength, $(num_pre) * $(num_threads));\n"
+ "// Calculate row lengths\n"
+ "const size_t numPostPerThread = ($(num_post) + $(num_threads) - 1) / $(num_threads);\n"
+ "const size_t leftOverNeurons = $(num_post) % numPostPerThread;\n"
+ "size_t remainingConnections = $(total);\n"
+ "size_t matrixSize = (size_t)$(num_pre) * (size_t)$(num_post);\n"
+ "uint16_t *subRowLengths = $(preCalcRowLength);\n"
+ "// Loop through rows\n"
+ "for(size_t i = 0; i < $(num_pre); i++) {\n"
+ " const bool lastPre = (i == ($(num_pre) - 1));\n"
+ " // Loop through subrows\n"
+ " for(size_t j = 0; j < $(num_threads); j++) {\n"
+ " const bool lastSubRow = (j == ($(num_threads) - 1));\n"
+ " // If this isn't the last sub-row of the matrix\n"
+ " if(!lastPre || ! lastSubRow) {\n"
+ " // Get length of this subrow\n"
+ " const unsigned int numSubRowNeurons = (leftOverNeurons != 0 && lastSubRow) ? leftOverNeurons : numPostPerThread;\n"
+ " // Calculate probability\n"
+ " const double probability = (double)numSubRowNeurons / (double)matrixSize;\n"
+ " // Create distribution to sample row length\n"
+ " std::binomial_distribution rowLengthDist(remainingConnections, probability);\n"
+ " // Sample row length;\n"
+ " const size_t subRowLength = rowLengthDist($(rng));\n"
+ " // Update counters\n"
+ " remainingConnections -= subRowLength;\n"
+ " matrixSize -= numSubRowNeurons;\n"
+ " // Add row length to array\n"
+ " assert(subRowLength < std::numeric_limits::max());\n"
+ " *subRowLengths++ = (uint16_t)subRowLength;\n"
+ " }\n"
+ " }\n"
+ "}\n"
+ "// Insert remaining connections into last sub-row\n"
+ "*subRowLengths = (uint16_t)remainingConnections;\n"
+ "// Push populated row length array\n"
+ "$(pushpreCalcRowLength, $(num_pre) * $(num_threads));\n");
SET_CALC_MAX_ROW_LENGTH_FUNC(
[](unsigned int numPre, unsigned int numPost, const std::vector &pars)
@@ -298,7 +324,7 @@ class FixedNumberTotalWithReplacement : public Base
// There are numConnections connections amongst the numPre*numPost possible connections.
// Each of the numConnections connections has an independent p=float(numPost)/(numPre*numPost)
// probability of being selected and the number of synapses in the sub-row is binomially distributed
- return binomialInverseCDF(quantile, pars[0], (double)numPost / ((double)numPre * (double)numPost));
+ return binomialInverseCDF(quantile, (unsigned int)pars[0], (double)numPost / ((double)numPre * (double)numPost));
});
SET_CALC_MAX_COL_LENGTH_FUNC(
@@ -310,7 +336,7 @@ class FixedNumberTotalWithReplacement : public Base
// There are numConnections connections amongst the numPre*numPost possible connections.
// Each of the numConnections connections has an independent p=float(numPre)/(numPre*numPost)
// probability of being selected and the number of synapses in the sub-row is binomially distributed
- return binomialInverseCDF(quantile, pars[0], (double)numPre / ((double)numPre * (double)numPost));
+ return binomialInverseCDF(quantile, (unsigned int)pars[0], (double)numPre / ((double)numPre * (double)numPost));
});
};
} // namespace InitVarSnippet
diff --git a/include/genn/genn/modelSpec.h b/include/genn/genn/modelSpec.h
index 54da07bbb6..2b2f8b0424 100644
--- a/include/genn/genn/modelSpec.h
+++ b/include/genn/genn/modelSpec.h
@@ -272,7 +272,7 @@ class GENN_EXPORT ModelSpec
auto result = m_LocalSynapseGroups.emplace(
std::piecewise_construct,
std::forward_as_tuple(name),
- std::forward_as_tuple(name, mtype, delaySteps,
+ std::forward_as_tuple(name, nullptr, mtype, delaySteps,
wum, weightParamValues.getValues(), weightVarInitialisers.getInitialisers(), weightPreVarInitialisers.getInitialisers(), weightPostVarInitialisers.getInitialisers(),
psm, postsynapticParamValues.getValues(), postsynapticVarInitialisers.getInitialisers(),
srcNeuronGrp, trgNeuronGrp,
@@ -350,6 +350,70 @@ class GENN_EXPORT ModelSpec
}
+
+ template
+ SynapseGroup *addSlaveSynapsePopulation(const std::string &name, const std::string &weightSharingMasterName, unsigned int delaySteps, const std::string &src, const std::string &trg,
+ const PostsynapticModel *psm, const typename PostsynapticModel::ParamValues &postsynapticParamValues, const typename PostsynapticModel::VarValues &postsynapticVarInitialisers)
+ {
+ // Get source and target neuron groups
+ auto srcNeuronGrp = findNeuronGroupInternal(src);
+ auto trgNeuronGrp = findNeuronGroupInternal(trg);
+
+ // Find weight sharing master group
+ auto masterGrp = findSynapseGroupInternal(weightSharingMasterName);
+ const auto *wum = masterGrp->getWUModel();
+
+ // If the weight sharing master has individuak weights and any are read-write, give error
+ const auto wumVars = wum->getVars();
+ if((masterGrp->getMatrixType() & SynapseMatrixWeight::INDIVIDUAL) &&
+ std::any_of(wumVars.cbegin(), wumVars.cend(),
+ [](const Models::Base::Var &v)
+ {
+ return (v.access == VarAccess::READ_WRITE);
+ }))
+ {
+ throw std::runtime_error("Individual synapse variables can only be shared if they are read-only");
+ }
+
+ // Check that population sizes match
+ if ((srcNeuronGrp->getNumNeurons() != masterGrp->getSrcNeuronGroup()->getNumNeurons())
+ || (trgNeuronGrp->getNumNeurons() != masterGrp->getTrgNeuronGroup()->getNumNeurons()))
+ {
+ throw std::runtime_error("Size of populations connected by shared weights must match");
+ }
+
+ // If weight update model has any pre or postsynaptic variables, give error
+ // **THINK** this could be supported but quite what the semantics are is ambiguous
+ if(!wum->getPreVars().empty() || !wum->getPostVars().empty()) {
+ throw std::runtime_error("Synapse groups with pre and postsynpatic variables cannot be shared");
+ }
+
+ // Add synapse group to map
+ auto result = m_LocalSynapseGroups.emplace(
+ std::piecewise_construct,
+ std::forward_as_tuple(name),
+ std::forward_as_tuple(name, masterGrp, masterGrp->getMatrixType(), delaySteps,
+ wum, masterGrp->getWUParams(), masterGrp->getWUVarInitialisers(), masterGrp->getWUPreVarInitialisers(), masterGrp->getWUPostVarInitialisers(),
+ psm, postsynapticParamValues.getValues(), postsynapticVarInitialisers.getInitialisers(),
+ srcNeuronGrp, trgNeuronGrp, masterGrp->getConnectivityInitialiser(),
+ m_DefaultVarLocation, m_DefaultExtraGlobalParamLocation, m_DefaultSparseConnectivityLocation, m_DefaultNarrowSparseIndEnabled));
+
+ if(!result.second) {
+ throw std::runtime_error("Cannot add a synapse population with duplicate name:" + name);
+ }
+ else {
+ return &result.first->second;
+ }
+ }
+
+ template
+ SynapseGroup *addSlaveSynapsePopulation(const std::string &name, const std::string &weightSharingMasterName, unsigned int delaySteps, const std::string &src, const std::string &trg,
+ const typename PostsynapticModel::ParamValues &postsynapticParamValues, const typename PostsynapticModel::VarValues &postsynapticVarInitialisers)
+ {
+ return addSlaveSynapsePopulation(name, weightSharingMasterName, delaySteps, src, trg,
+ PostsynapticModel::getInstance(), postsynapticParamValues, postsynapticVarInitialisers);
+ }
+
// PUBLIC CURRENT SOURCE FUNCTIONS
//================================
//! Find a current source by name
@@ -436,6 +500,9 @@ class GENN_EXPORT ModelSpec
//! Find a neuron group by name
NeuronGroupInternal *findNeuronGroupInternal(const std::string &name);
+ //! Find a synapse group by name
+ SynapseGroupInternal *findSynapseGroupInternal(const std::string &name);
+
//--------------------------------------------------------------------------
// Private members
//--------------------------------------------------------------------------
diff --git a/include/genn/genn/models.h b/include/genn/genn/models.h
index 0e015e2b4e..f89ca10fa0 100644
--- a/include/genn/genn/models.h
+++ b/include/genn/genn/models.h
@@ -21,7 +21,6 @@
#define IMPLEMENT_MODEL(TYPE) IMPLEMENT_SNIPPET(TYPE)
#define SET_VARS(...) virtual VarVec getVars() const override{ return __VA_ARGS__; }
-#define SET_EXTRA_GLOBAL_PARAMS(...) virtual EGPVec getExtraGlobalParams() const override{ return __VA_ARGS__; }
//----------------------------------------------------------------------------
// VarAccess
@@ -53,11 +52,6 @@ class VarInit : public Snippet::Init
: Snippet::Init(InitVarSnippet::Constant::getInstance(), {constant})
{
}
-
- bool canBeMerged(const VarInit &other) const
- {
- return Snippet::Init::canBeMerged(other, getSnippet()->getCode());
- }
};
//----------------------------------------------------------------------------
@@ -182,10 +176,6 @@ class GENN_EXPORT Base : public Snippet::Base
//! Gets names and types (as strings) of model variables
virtual VarVec getVars() const{ return {}; }
- //! Gets names and types (as strings) of additional
- //! per-population parameters for the weight update model.
- virtual EGPVec getExtraGlobalParams() const{ return {}; }
-
//------------------------------------------------------------------------
// Public methods
//------------------------------------------------------------------------
@@ -195,12 +185,6 @@ class GENN_EXPORT Base : public Snippet::Base
return getNamedVecIndex(varName, getVars());
}
- //! Find the index of a named extra global parameter
- size_t getExtraGlobalParamIndex(const std::string ¶mName) const
- {
- return getNamedVecIndex(paramName, getExtraGlobalParams());
- }
-
protected:
//------------------------------------------------------------------------
// Protected methods
@@ -209,8 +193,7 @@ class GENN_EXPORT Base : public Snippet::Base
{
// Return true if vars and egps match
return (Snippet::Base::canBeMerged(other)
- && (getVars() == other->getVars())
- && (getExtraGlobalParams() == other->getExtraGlobalParams()));
+ && (getVars() == other->getVars()));
}
};
} // Models
diff --git a/include/genn/genn/neuronGroup.h b/include/genn/genn/neuronGroup.h
index 7d58c04957..df95cd204a 100644
--- a/include/genn/genn/neuronGroup.h
+++ b/include/genn/genn/neuronGroup.h
@@ -216,6 +216,12 @@ class GENN_EXPORT NeuronGroup
//! Helper to get vector of outgoing synapse groups which have presynaptic update code
std::vector getOutSynWithPreCode() const;
+ //! Helper to get vector of incoming synapse groups which have postsynaptic variables
+ std::vector getInSynWithPostVars() const;
+
+ //! Helper to get vector of outgoing synapse groups which have presynaptic variables
+ std::vector getOutSynWithPreVars() const;
+
bool isVarQueueRequired(const std::string &var) const;
bool isVarQueueRequired(size_t index) const{ return m_VarQueueRequired[index]; }
diff --git a/include/genn/genn/neuronGroupInternal.h b/include/genn/genn/neuronGroupInternal.h
index 74d2fcd768..69f767732b 100644
--- a/include/genn/genn/neuronGroupInternal.h
+++ b/include/genn/genn/neuronGroupInternal.h
@@ -34,6 +34,8 @@ class NeuronGroupInternal : public NeuronGroup
using NeuronGroup::getSpikeEventCondition;
using NeuronGroup::getInSynWithPostCode;
using NeuronGroup::getOutSynWithPreCode;
+ using NeuronGroup::getInSynWithPostVars;
+ using NeuronGroup::getOutSynWithPreVars;
using NeuronGroup::isVarQueueRequired;
using NeuronGroup::canBeMerged;
using NeuronGroup::canInitBeMerged;
diff --git a/include/genn/genn/snippet.h b/include/genn/genn/snippet.h
index 500737f0e8..fd9dda7c9f 100644
--- a/include/genn/genn/snippet.h
+++ b/include/genn/genn/snippet.h
@@ -35,6 +35,7 @@ public: \
#define SET_PARAM_NAMES(...) virtual StringVec getParamNames() const override{ return __VA_ARGS__; }
#define SET_DERIVED_PARAMS(...) virtual DerivedParamVec getDerivedParams() const override{ return __VA_ARGS__; }
+#define SET_EXTRA_GLOBAL_PARAMS(...) virtual EGPVec getExtraGlobalParams() const override{ return __VA_ARGS__; }
//----------------------------------------------------------------------------
// Snippet::ValueBase
@@ -183,6 +184,18 @@ class GENN_EXPORT Base
//! Calculate their value from a vector of model parameter values
virtual DerivedParamVec getDerivedParams() const{ return {}; }
+ //! Gets names and types (as strings) of additional
+ //! per-population parameters for the snippet
+ virtual EGPVec getExtraGlobalParams() const { return {}; }
+
+ //------------------------------------------------------------------------
+ // Public methods
+ //------------------------------------------------------------------------
+ //! Find the index of a named extra global parameter
+ size_t getExtraGlobalParamIndex(const std::string ¶mName) const
+ {
+ return getNamedVecIndex(paramName, getExtraGlobalParams());
+ }
protected:
//------------------------------------------------------------------------
@@ -191,7 +204,9 @@ class GENN_EXPORT Base
bool canBeMerged(const Base *other) const
{
// Return true if parameters names and derived parameter names match
- return ((getParamNames() == other->getParamNames()) && (getDerivedParams() == other->getDerivedParams()));
+ return ((getParamNames() == other->getParamNames())
+ && (getDerivedParams() == other->getDerivedParams())
+ && (getExtraGlobalParams() == other->getExtraGlobalParams()));
}
//------------------------------------------------------------------------
@@ -244,40 +259,9 @@ class Init
}
}
-protected:
- bool canBeMerged(const Init &other, const std::string &codeString) const
+ bool canBeMerged(const Init &other) const
{
- // If snippets can be merged
- if(getSnippet()->canBeMerged(other.getSnippet())) {
- // Loop through parameters
- const auto paramNames = getSnippet()->getParamNames();
- for(size_t i = 0; i < paramNames.size(); i++) {
- // If parameter is referenced in code string
- if(codeString.find("$(" + paramNames[i] + ")") != std::string::npos) {
- // If parameter values don't match, return true
- if(getParams()[i] != other.getParams()[i]) {
- return false;
- }
- }
- }
-
- // Loop through derived parameters
- const auto derivedParams = getSnippet()->getDerivedParams();
- assert(derivedParams.size() == getDerivedParams().size());
- assert(derivedParams.size() == other.getDerivedParams().size());
- for(size_t i = 0; i < derivedParams.size(); i++) {
- // If derived parameter is referenced in code string
- if(codeString.find("$(" + derivedParams[i].name + ")") != std::string::npos) {
- // If derived parameter values don't match, return true
- if(getDerivedParams()[i] != other.getDerivedParams()[i]) {
- return false;
- }
- }
- }
- return true;
- }
-
- return false;
+ return getSnippet()->canBeMerged(other.getSnippet());
}
private:
diff --git a/include/genn/genn/synapseGroup.h b/include/genn/genn/synapseGroup.h
index 96a4a58f4d..6c9a56d27a 100644
--- a/include/genn/genn/synapseGroup.h
+++ b/include/genn/genn/synapseGroup.h
@@ -16,6 +16,7 @@
// Forward declarations
class NeuronGroupInternal;
+class SynapseGroupInternal;
//------------------------------------------------------------------------
// SynapseGroup
@@ -75,7 +76,7 @@ class GENN_EXPORT SynapseGroup
//! Set variable mode used for sparse connectivity
/*! This is ignored for simulations on hardware with a single memory space */
- void setSparseConnectivityLocation(VarLocation loc){ m_SparseConnectivityLocation = loc; }
+ void setSparseConnectivityLocation(VarLocation loc);
//! Set variable mode used for this synapse group's dendritic delay buffers
void setDendriticDelayLocation(VarLocation loc) { m_DendriticDelayLocation = loc; }
@@ -114,8 +115,8 @@ class GENN_EXPORT SynapseGroup
unsigned int getNumThreadsPerSpike() const{ return m_NumThreadsPerSpike; }
unsigned int getDelaySteps() const{ return m_DelaySteps; }
unsigned int getBackPropDelaySteps() const{ return m_BackPropDelaySteps; }
- unsigned int getMaxConnections() const{ return m_MaxConnections; }
- unsigned int getMaxSourceConnections() const{ return m_MaxSourceConnections; }
+ unsigned int getMaxConnections() const;
+ unsigned int getMaxSourceConnections() const;
unsigned int getMaxDendriticDelayTimesteps() const{ return m_MaxDendriticDelayTimesteps; }
SynapseMatrixType getMatrixType() const{ return m_MatrixType; }
@@ -123,7 +124,7 @@ class GENN_EXPORT SynapseGroup
VarLocation getInSynLocation() const { return m_InSynLocation; }
//! Get variable mode used for sparse connectivity
- VarLocation getSparseConnectivityLocation() const{ return m_SparseConnectivityLocation; }
+ VarLocation getSparseConnectivityLocation() const;
//! Get variable mode used for this synapse group's dendritic delay buffers
VarLocation getDendriticDelayLocation() const{ return m_DendriticDelayLocation; }
@@ -134,6 +135,9 @@ class GENN_EXPORT SynapseGroup
//! Does synapse group need to handle spike-like events
bool isSpikeEventRequired() const;
+ //! Is this synapse group a weight-sharing slave
+ bool isWeightSharingSlave() const { return (getWeightSharingMaster() != nullptr); }
+
const WeightUpdateModels::Base *getWUModel() const{ return m_WUModel; }
const std::vector &getWUParams() const{ return m_WUParams; }
@@ -156,7 +160,7 @@ class GENN_EXPORT SynapseGroup
VarLocation getWUVarLocation(const std::string &var) const;
//! Get location of weight update model per-synapse state variable by index
- VarLocation getWUVarLocation(size_t index) const{ return m_WUVarLocation.at(index); }
+ VarLocation getWUVarLocation(size_t index) const;
//! Get location of weight update model presynaptic state variable by name
VarLocation getWUPreVarLocation(const std::string &var) const;
@@ -198,7 +202,7 @@ class GENN_EXPORT SynapseGroup
//! Get location of sparse connectivity initialiser extra global parameter by index
/*! This is only used by extra global parameters which are pointers*/
- VarLocation getSparseConnectivityExtraGlobalParamLocation(size_t index) const{ return m_ConnectivityExtraGlobalParamLocation.at(index); }
+ VarLocation getSparseConnectivityExtraGlobalParamLocation(size_t index) const;
//! Does this synapse group require dendritic delay?
bool isDendriticDelayRequired() const;
@@ -212,6 +216,15 @@ class GENN_EXPORT SynapseGroup
//! Does this synapse group require an RNG for it's weight update init code?
bool isWUInitRNGRequired() const;
+ //! Does this synapse group require an RNG for it's weight update presynaptic variable init code?
+ bool isWUPreInitRNGRequired() const;
+
+ //! Does this synapse group require an RNG for it's weight update postsynaptic variable init code?
+ bool isWUPostInitRNGRequired() const;
+
+ //! Does this synapse group require a RNG for any sort of initialization
+ bool isHostInitRNGRequired() const;
+
//! Is var init code required for any variables in this synapse group's weight update model?
bool isWUVarInitRequired() const;
@@ -219,10 +232,10 @@ class GENN_EXPORT SynapseGroup
bool isSparseConnectivityInitRequired() const;
protected:
- SynapseGroup(const std::string name, SynapseMatrixType matrixType, unsigned int delaySteps,
+ SynapseGroup(const std::string &name, SynapseMatrixType matrixType, unsigned int delaySteps,
const WeightUpdateModels::Base *wu, const std::vector &wuParams, const std::vector &wuVarInitialisers, const std::vector &wuPreVarInitialisers, const std::vector &wuPostVarInitialisers,
const PostsynapticModels::Base *ps, const std::vector &psParams, const std::vector &psVarInitialisers,
- NeuronGroupInternal *srcNeuronGroup, NeuronGroupInternal *trgNeuronGroup,
+ NeuronGroupInternal *srcNeuronGroup, NeuronGroupInternal *trgNeuronGroup, const SynapseGroupInternal *weightSharingMaster,
const InitSparseConnectivitySnippet::Init &connectivityInitialiser,
VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation,
VarLocation defaultSparseConnectivityLocation, bool defaultNarrowSparseIndEnabled);
@@ -251,6 +264,8 @@ class GENN_EXPORT SynapseGroup
const std::vector &getWUDerivedParams() const{ return m_WUDerivedParams; }
const std::vector &getPSDerivedParams() const{ return m_PSDerivedParams; }
+ const SynapseGroupInternal *getWeightSharingMaster() const { return m_WeightSharingMaster; }
+
//!< Does the event threshold needs to be retested in the synapse kernel?
/*! This is required when the pre-synaptic neuron population's outgoing synapse groups require different event threshold */
bool isEventThresholdReTestRequired() const{ return m_EventThresholdReTestRequired; }
@@ -258,7 +273,6 @@ class GENN_EXPORT SynapseGroup
const std::string &getPSModelTargetName() const{ return m_PSModelTargetName; }
bool isPSModelMerged() const{ return m_PSModelTargetName != getName(); }
-
//! Get the type to use for sparse connectivity indices for synapse group
std::string getSparseIndType() const;
@@ -302,6 +316,10 @@ class GENN_EXPORT SynapseGroup
/*! NOTE: this can only be called after model is finalized */
bool canConnectivityInitBeMerged(const SynapseGroup &other) const;
+ //! Can connectivity host initialisation for this synapse group be merged with other? i.e. can they be performed using same generated code
+ /*! NOTE: this can only be called after model is finalized */
+ bool canConnectivityHostInitBeMerged(const SynapseGroup &other) const;
+
private:
//------------------------------------------------------------------------
// Members
@@ -339,6 +357,9 @@ class GENN_EXPORT SynapseGroup
//! Pointer to postsynaptic neuron group
NeuronGroupInternal * const m_TrgNeuronGroup;
+ //! Pointer to 'master' weight sharing group if this is a slave
+ const SynapseGroupInternal *m_WeightSharingMaster;
+
//! Does the event threshold needs to be retested in the synapse kernel?
/*! This is required when the pre-synaptic neuron population's outgoing synapse groups require different event threshold */
bool m_EventThresholdReTestRequired;
diff --git a/include/genn/genn/synapseGroupInternal.h b/include/genn/genn/synapseGroupInternal.h
index 8f4517a08c..702806a0b2 100644
--- a/include/genn/genn/synapseGroupInternal.h
+++ b/include/genn/genn/synapseGroupInternal.h
@@ -9,7 +9,7 @@
class SynapseGroupInternal : public SynapseGroup
{
public:
- SynapseGroupInternal(const std::string name, SynapseMatrixType matrixType, unsigned int delaySteps,
+ SynapseGroupInternal(const std::string &name, const SynapseGroupInternal *weightSharingMaster, SynapseMatrixType matrixType, unsigned int delaySteps,
const WeightUpdateModels::Base *wu, const std::vector &wuParams, const std::vector &wuVarInitialisers, const std::vector &wuPreVarInitialisers, const std::vector &wuPostVarInitialisers,
const PostsynapticModels::Base *ps, const std::vector &psParams, const std::vector &psVarInitialisers,
NeuronGroupInternal *srcNeuronGroup, NeuronGroupInternal *trgNeuronGroup,
@@ -17,7 +17,7 @@ class SynapseGroupInternal : public SynapseGroup
VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation,
VarLocation defaultSparseConnectivityLocation, bool defaultNarrowSparseIndEnabled)
: SynapseGroup(name, matrixType, delaySteps, wu, wuParams, wuVarInitialisers, wuPreVarInitialisers, wuPostVarInitialisers,
- ps, psParams, psVarInitialisers, srcNeuronGroup, trgNeuronGroup,
+ ps, psParams, psVarInitialisers, srcNeuronGroup, trgNeuronGroup, weightSharingMaster,
connectivityInitialiser, defaultVarLocation, defaultExtraGlobalParamLocation,
defaultSparseConnectivityLocation, defaultNarrowSparseIndEnabled)
{
@@ -28,6 +28,7 @@ class SynapseGroupInternal : public SynapseGroup
using SynapseGroup::getSrcNeuronGroup;
using SynapseGroup::getTrgNeuronGroup;
+ using SynapseGroup::getWeightSharingMaster;
using SynapseGroup::getWUDerivedParams;
using SynapseGroup::getPSDerivedParams;
using SynapseGroup::setEventThresholdReTestRequired;
@@ -47,4 +48,5 @@ class SynapseGroupInternal : public SynapseGroup
using SynapseGroup::canWUPostInitBeMerged;
using SynapseGroup::canPSInitBeMerged;
using SynapseGroup::canConnectivityInitBeMerged;
+ using SynapseGroup::canConnectivityHostInitBeMerged;
};
diff --git a/include/genn/genn/synapseMatrixType.h b/include/genn/genn/synapseMatrixType.h
index 496881fa73..8d566fd6bb 100644
--- a/include/genn/genn/synapseMatrixType.h
+++ b/include/genn/genn/synapseMatrixType.h
@@ -27,6 +27,7 @@ enum class SynapseMatrixType : unsigned int
DENSE_GLOBALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::GLOBAL),
DENSE_GLOBALG_INDIVIDUAL_PSM = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::GLOBAL) | static_cast(SynapseMatrixWeight::INDIVIDUAL_PSM),
DENSE_INDIVIDUALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::INDIVIDUAL) | static_cast(SynapseMatrixWeight::INDIVIDUAL_PSM),
+ DENSE_PROCEDURALG = static_cast(SynapseMatrixConnectivity::DENSE) | static_cast(SynapseMatrixWeight::PROCEDURAL) | static_cast(SynapseMatrixWeight::INDIVIDUAL_PSM),
BITMASK_GLOBALG = static_cast(SynapseMatrixConnectivity::BITMASK) | static_cast(SynapseMatrixWeight::GLOBAL),
BITMASK_GLOBALG_INDIVIDUAL_PSM = static_cast(SynapseMatrixConnectivity::BITMASK) | static_cast(SynapseMatrixWeight::GLOBAL) | static_cast(SynapseMatrixWeight::INDIVIDUAL_PSM),
SPARSE_GLOBALG = static_cast(SynapseMatrixConnectivity::SPARSE) | static_cast(SynapseMatrixWeight::GLOBAL),
diff --git a/include/genn/genn/weightUpdateModels.h b/include/genn/genn/weightUpdateModels.h
index 7fc3711f67..1fc383dd7e 100644
--- a/include/genn/genn/weightUpdateModels.h
+++ b/include/genn/genn/weightUpdateModels.h
@@ -210,7 +210,7 @@ class StaticGraded : public Base
SET_PARAM_NAMES({"Epre", "Vslope"});
SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}});
- SET_EVENT_CODE("$(addToInSyn, max(0.0, $(g) * tanh(($(V_pre) - $(Epre)) / $(Vslope))* DT));\n");
+ SET_EVENT_CODE("$(addToInSyn, fmax(0.0, $(g) * tanh(($(V_pre) - $(Epre)) / $(Vslope))* DT));\n");
SET_EVENT_THRESHOLD_CONDITION_CODE("$(V_pre) > $(Epre)");
};
diff --git a/include/genn/third_party/plog/Appenders/AndroidAppender.h b/include/genn/third_party/plog/Appenders/AndroidAppender.h
index 92f6d33c9f..8443a3260d 100644
--- a/include/genn/third_party/plog/Appenders/AndroidAppender.h
+++ b/include/genn/third_party/plog/Appenders/AndroidAppender.h
@@ -1,47 +1,47 @@
-#pragma once
-#include