diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 17601cb51..ab771dfb0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -49,6 +49,7 @@ set(SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/entity.cpp ${CMAKE_CURRENT_SOURCE_DIR}/enums.cpp ${CMAKE_CURRENT_SOURCE_DIR}/generator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/generatorinterpreter.cpp ${CMAKE_CURRENT_SOURCE_DIR}/generatorprofile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/generatorprofiletools.cpp ${CMAKE_CURRENT_SOURCE_DIR}/importedentity.cpp @@ -133,6 +134,8 @@ set(GIT_HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/debug.h ${CMAKE_CURRENT_SOURCE_DIR}/entity_p.h ${CMAKE_CURRENT_SOURCE_DIR}/generator_p.h + ${CMAKE_CURRENT_SOURCE_DIR}/generatorinterpreter_p.h + ${CMAKE_CURRENT_SOURCE_DIR}/generatorinterpreter.h ${CMAKE_CURRENT_SOURCE_DIR}/generatorprofile_p.h ${CMAKE_CURRENT_SOURCE_DIR}/generatorprofilesha1values.h ${CMAKE_CURRENT_SOURCE_DIR}/generatorprofiletools.h diff --git a/src/debug.cpp b/src/debug.cpp index 2c9dbcd8a..ff462fec1 100644 --- a/src/debug.cpp +++ b/src/debug.cpp @@ -26,7 +26,7 @@ limitations under the License. #include "libcellml/model.h" #include "libcellml/variable.h" -#include "generator_p.h" +#include "generatorinterpreter_p.h" #include "utilities.h" namespace libcellml { diff --git a/src/generator.cpp b/src/generator.cpp index 1a8b0a2c7..88839555e 100644 --- a/src/generator.cpp +++ b/src/generator.cpp @@ -28,6 +28,7 @@ limitations under the License. #include "libcellml/version.h" #include "generator_p.h" +#include "generatorinterpreter_p.h" #include "generatorprofilesha1values.h" #include "generatorprofiletools.h" #include "utilities.h" @@ -41,186 +42,6 @@ void Generator::GeneratorImpl::reset() mCode = {}; } -bool Generator::GeneratorImpl::modelHasOdes() const -{ - switch (mModel->type()) { - case AnalyserModel::Type::ODE: - case AnalyserModel::Type::DAE: - return true; - default: - return false; - } -} - -bool Generator::GeneratorImpl::modelHasNlas() const -{ - switch (mModel->type()) { - case AnalyserModel::Type::NLA: - case AnalyserModel::Type::DAE: - return true; - default: - return false; - } -} - -AnalyserVariablePtr analyserVariable(const AnalyserModelPtr &model, const VariablePtr &variable) -{ - // Find and return the analyser variable associated with the given variable. - - AnalyserVariablePtr res; - auto modelVoi = model->voi(); - auto modelVoiVariable = (modelVoi != nullptr) ? modelVoi->variable() : nullptr; - - if ((modelVoiVariable != nullptr) - && model->areEquivalentVariables(variable, modelVoiVariable)) { - res = modelVoi; - } else { - // Normally, we would have something like: - // - // for (const auto &modelVariable : variables(model)) { - // if (model->areEquivalentVariables(variable, modelVariable->variable())) { - // res = modelVariable; - // - // break; - // } - // } - // - // but we always have variables, so llvm-cov will complain that the false branch of our for loop is never - // reached. The below code is a bit more verbose but at least it makes llvm-cov happy. - - auto modelVariables = variables(model); - auto modelVariable = modelVariables.begin(); - - do { - if (model->areEquivalentVariables(variable, (*modelVariable)->variable())) { - res = *modelVariable; - } else { - ++modelVariable; - } - } while (res == nullptr); - } - - return res; -} - -double Generator::GeneratorImpl::scalingFactor(const VariablePtr &variable) const -{ - // Return the scaling factor for the given variable, accounting for the fact that a constant may be initialised by - // another variable which initial value may be defined in a different component. - - auto analyserVariable = libcellml::analyserVariable(mModel, variable); - - if ((analyserVariable->type() == AnalyserVariable::Type::CONSTANT) - && !isCellMLReal(variable->initialValue())) { - auto initialValueVariable = owningComponent(variable)->variable(variable->initialValue()); - auto initialValueAnalyserVariable = libcellml::analyserVariable(mModel, initialValueVariable); - - if (owningComponent(variable) != owningComponent(initialValueAnalyserVariable->variable())) { - return Units::scalingFactor(initialValueVariable->units(), variable->units()); - } - } - - return Units::scalingFactor(analyserVariable->variable()->units(), variable->units()); -} - -bool Generator::GeneratorImpl::isNegativeNumber(const AnalyserEquationAstPtr &ast) const -{ - if (ast->type() == AnalyserEquationAst::Type::CN) { - double doubleValue; - - convertToDouble(ast->value(), doubleValue); - - return doubleValue < 0.0; - } - - return false; -} - -bool Generator::GeneratorImpl::isRelationalOperator(const AnalyserEquationAstPtr &ast) const -{ - switch (ast->type()) { - case AnalyserEquationAst::Type::EQ: - return mProfile->hasEqOperator(); - case AnalyserEquationAst::Type::NEQ: - return mProfile->hasNeqOperator(); - case AnalyserEquationAst::Type::LT: - return mProfile->hasLtOperator(); - case AnalyserEquationAst::Type::LEQ: - return mProfile->hasLeqOperator(); - case AnalyserEquationAst::Type::GT: - return mProfile->hasGtOperator(); - case AnalyserEquationAst::Type::GEQ: - return mProfile->hasGeqOperator(); - default: - return false; - } -} - -bool Generator::GeneratorImpl::isAndOperator(const AnalyserEquationAstPtr &ast) const -{ - return (ast->type() == AnalyserEquationAst::Type::AND) - && mProfile->hasAndOperator(); -} - -bool Generator::GeneratorImpl::isOrOperator(const AnalyserEquationAstPtr &ast) const -{ - return (ast->type() == AnalyserEquationAst::Type::OR) - && mProfile->hasOrOperator(); -} - -bool Generator::GeneratorImpl::isXorOperator(const AnalyserEquationAstPtr &ast) const -{ - return (ast->type() == AnalyserEquationAst::Type::XOR) - && mProfile->hasXorOperator(); -} - -bool Generator::GeneratorImpl::isLogicalOperator(const AnalyserEquationAstPtr &ast) const -{ - // Note: AnalyserEquationAst::Type::NOT is a unary logical operator, hence - // we don't include it here since this method is only used to - // determine whether parentheses should be added around some code. - - return isAndOperator(ast) || isOrOperator(ast) || isXorOperator(ast); -} - -bool Generator::GeneratorImpl::isPlusOperator(const AnalyserEquationAstPtr &ast) const -{ - return ast->type() == AnalyserEquationAst::Type::PLUS; -} - -bool Generator::GeneratorImpl::isMinusOperator(const AnalyserEquationAstPtr &ast) const -{ - return ast->type() == AnalyserEquationAst::Type::MINUS; -} - -bool Generator::GeneratorImpl::isTimesOperator(const AnalyserEquationAstPtr &ast) const -{ - return ast->type() == AnalyserEquationAst::Type::TIMES; -} - -bool Generator::GeneratorImpl::isDivideOperator(const AnalyserEquationAstPtr &ast) const -{ - return ast->type() == AnalyserEquationAst::Type::DIVIDE; -} - -bool Generator::GeneratorImpl::isPowerOperator(const AnalyserEquationAstPtr &ast) const -{ - return (ast->type() == AnalyserEquationAst::Type::POWER) - && mProfile->hasPowerOperator(); -} - -bool Generator::GeneratorImpl::isRootOperator(const AnalyserEquationAstPtr &ast) const -{ - return (ast->type() == AnalyserEquationAst::Type::ROOT) - && mProfile->hasPowerOperator(); -} - -bool Generator::GeneratorImpl::isPiecewiseStatement(const AnalyserEquationAstPtr &ast) const -{ - return (ast->type() == AnalyserEquationAst::Type::PIECEWISE) - && mProfile->hasConditionalOperator(); -} - void Generator::GeneratorImpl::updateVariableInfoSizes(size_t &componentSize, size_t &nameSize, size_t &unitsSize, @@ -246,11 +67,6 @@ bool Generator::GeneratorImpl::modifiedProfile() const sha1(profileContents) != PYTHON_GENERATOR_PROFILE_SHA1; } -std::string Generator::GeneratorImpl::newLineIfNeeded() -{ - return mCode.empty() ? "" : "\n"; -} - void Generator::GeneratorImpl::addOriginCommentCode() { if (!mProfile->commentString().empty() @@ -264,7 +80,7 @@ void Generator::GeneratorImpl::addOriginCommentCode() "Python"; profileInformation += " profile of"; - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + replace(mProfile->commentString(), "[CODE]", replace(replace(mProfile->originCommentString(), "[PROFILE_INFORMATION]", profileInformation), "[LIBCELLML_VERSION]", versionString())); } @@ -273,7 +89,7 @@ void Generator::GeneratorImpl::addOriginCommentCode() void Generator::GeneratorImpl::addInterfaceHeaderCode() { if (!mProfile->interfaceHeaderString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->interfaceHeaderString(); } } @@ -287,7 +103,7 @@ void Generator::GeneratorImpl::addImplementationHeaderCode() if (!mProfile->implementationHeaderString().empty() && ((hasInterfaceFileName && !mProfile->interfaceFileNameString().empty()) || !hasInterfaceFileName)) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + replace(mProfile->implementationHeaderString(), "[INTERFACE_FILE_NAME]", mProfile->interfaceFileNameString()); } @@ -321,7 +137,7 @@ void Generator::GeneratorImpl::addVersionAndLibcellmlVersionCode(bool interface) } if (!code.empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + code; } } @@ -330,7 +146,7 @@ void Generator::GeneratorImpl::addStateAndVariableCountCode(bool interface) { std::string code; - if (modelHasOdes() + if (mModelHasOdes && ((interface && !mProfile->interfaceStateCountString().empty()) || (!interface && !mProfile->implementationStateCountString().empty()))) { code += interface ? @@ -373,7 +189,7 @@ void Generator::GeneratorImpl::addStateAndVariableCountCode(bool interface) } if (!code.empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + code; } } @@ -397,7 +213,7 @@ std::string Generator::GeneratorImpl::generateVariableInfoObjectCode(const std:: void Generator::GeneratorImpl::addVariableInfoObjectCode() { if (!mProfile->variableInfoObjectString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + generateVariableInfoObjectCode(mProfile->variableInfoObjectString()); } } @@ -416,12 +232,12 @@ void Generator::GeneratorImpl::addInterfaceVariableInfoCode() { std::string code; - if (modelHasOdes() + if (mModelHasOdes && !mProfile->interfaceVoiInfoString().empty()) { code += mProfile->interfaceVoiInfoString(); } - if (modelHasOdes() + if (mModelHasOdes && !mProfile->interfaceStateInfoString().empty()) { code += mProfile->interfaceStateInfoString(); } @@ -444,7 +260,7 @@ void Generator::GeneratorImpl::addInterfaceVariableInfoCode() } if (!code.empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + code; } } @@ -475,18 +291,18 @@ void Generator::GeneratorImpl::doAddImplementationVariableInfoCode(const std::st infoElementsCode += "\n"; } - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + replace(variableInfoString, "[CODE]", infoElementsCode); } } void Generator::GeneratorImpl::addImplementationVariableInfoCode() { - if (modelHasOdes()) { + if (mModelHasOdes) { doAddImplementationVariableInfoCode(mProfile->implementationVoiInfoString(), {mModel->voi()}, true); } - if (modelHasOdes()) { + if (mModelHasOdes) { doAddImplementationVariableInfoCode(mProfile->implementationStateInfoString(), mModel->states(), false); } @@ -503,73 +319,73 @@ void Generator::GeneratorImpl::addArithmeticFunctionsCode() { if (mModel->needEqFunction() && !mProfile->hasEqOperator() && !mProfile->eqFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->eqFunctionString(); } if (mModel->needNeqFunction() && !mProfile->hasNeqOperator() && !mProfile->neqFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->neqFunctionString(); } if (mModel->needLtFunction() && !mProfile->hasLtOperator() && !mProfile->ltFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->ltFunctionString(); } if (mModel->needLeqFunction() && !mProfile->hasLeqOperator() && !mProfile->leqFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->leqFunctionString(); } if (mModel->needGtFunction() && !mProfile->hasGtOperator() && !mProfile->gtFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->gtFunctionString(); } if (mModel->needGeqFunction() && !mProfile->hasGeqOperator() && !mProfile->geqFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->geqFunctionString(); } if (mModel->needAndFunction() && !mProfile->hasAndOperator() && !mProfile->andFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->andFunctionString(); } if (mModel->needOrFunction() && !mProfile->hasOrOperator() && !mProfile->orFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->orFunctionString(); } if (mModel->needXorFunction() && !mProfile->hasXorOperator() && !mProfile->xorFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->xorFunctionString(); } if (mModel->needNotFunction() && !mProfile->hasNotOperator() && !mProfile->notFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->notFunctionString(); } if (mModel->needMinFunction() && !mProfile->minFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->minFunctionString(); } if (mModel->needMaxFunction() && !mProfile->maxFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->maxFunctionString(); } } @@ -578,73 +394,73 @@ void Generator::GeneratorImpl::addTrigonometricFunctionsCode() { if (mModel->needSecFunction() && !mProfile->secFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->secFunctionString(); } if (mModel->needCscFunction() && !mProfile->cscFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->cscFunctionString(); } if (mModel->needCotFunction() && !mProfile->cotFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->cotFunctionString(); } if (mModel->needSechFunction() && !mProfile->sechFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->sechFunctionString(); } if (mModel->needCschFunction() && !mProfile->cschFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->cschFunctionString(); } if (mModel->needCothFunction() && !mProfile->cothFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->cothFunctionString(); } if (mModel->needAsecFunction() && !mProfile->asecFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->asecFunctionString(); } if (mModel->needAcscFunction() && !mProfile->acscFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->acscFunctionString(); } if (mModel->needAcotFunction() && !mProfile->acotFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->acotFunctionString(); } if (mModel->needAsechFunction() && !mProfile->asechFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->asechFunctionString(); } if (mModel->needAcschFunction() && !mProfile->acschFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->acschFunctionString(); } if (mModel->needAcothFunction() && !mProfile->acothFunctionString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->acothFunctionString(); } } @@ -653,7 +469,7 @@ void Generator::GeneratorImpl::addInterfaceCreateDeleteArrayMethodsCode() { std::string code; - if (modelHasOdes() + if (mModelHasOdes && !mProfile->interfaceCreateStatesArrayMethodString().empty()) { code += mProfile->interfaceCreateStatesArrayMethodString(); } @@ -681,42 +497,42 @@ void Generator::GeneratorImpl::addInterfaceCreateDeleteArrayMethodsCode() } if (!code.empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + code; } } void Generator::GeneratorImpl::addImplementationCreateDeleteArrayMethodsCode() { - if (modelHasOdes() + if (mModelHasOdes && !mProfile->implementationCreateStatesArrayMethodString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->implementationCreateStatesArrayMethodString(); } if (!mProfile->implementationCreateConstantsArrayMethodString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->implementationCreateConstantsArrayMethodString(); } if (!mProfile->implementationCreateComputedConstantsArrayMethodString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->implementationCreateComputedConstantsArrayMethodString(); } if (!mProfile->implementationCreateAlgebraicArrayMethodString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->implementationCreateAlgebraicArrayMethodString(); } if (mModel->hasExternalVariables() && !mProfile->implementationCreateExternalsArrayMethodString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->implementationCreateExternalsArrayMethodString(); } if (!mProfile->implementationDeleteArrayMethodString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->implementationDeleteArrayMethodString(); } } @@ -724,10 +540,10 @@ void Generator::GeneratorImpl::addImplementationCreateDeleteArrayMethodsCode() void Generator::GeneratorImpl::addExternalVariableMethodTypeDefinitionCode() { if (mModel->hasExternalVariables()) { - auto externalVariableMethodTypeDefinitionString = mProfile->externalVariableMethodTypeDefinitionString(modelHasOdes()); + auto externalVariableMethodTypeDefinitionString = mProfile->externalVariableMethodTypeDefinitionString(mModelHasOdes); if (!externalVariableMethodTypeDefinitionString.empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + externalVariableMethodTypeDefinitionString; } } @@ -735,28 +551,28 @@ void Generator::GeneratorImpl::addExternalVariableMethodTypeDefinitionCode() void Generator::GeneratorImpl::addRootFindingInfoObjectCode() { - if (modelHasNlas() - && !mProfile->rootFindingInfoObjectString(modelHasOdes(), mModel->hasExternalVariables()).empty()) { - mCode += newLineIfNeeded() - + mProfile->rootFindingInfoObjectString(modelHasOdes(), mModel->hasExternalVariables()); + if (mModelHasNlas + && !mProfile->rootFindingInfoObjectString(mModelHasOdes, mModel->hasExternalVariables()).empty()) { + mCode += newLineIfNeeded(mCode) + + mProfile->rootFindingInfoObjectString(mModelHasOdes, mModel->hasExternalVariables()); } } void Generator::GeneratorImpl::addExternNlaSolveMethodCode() { - if (modelHasNlas() + if (mModelHasNlas && !mProfile->externNlaSolveMethodString().empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + mProfile->externNlaSolveMethodString(); } } void Generator::GeneratorImpl::addNlaSystemsCode() { - if (modelHasNlas() - && !mProfile->objectiveFunctionMethodString(modelHasOdes(), mModel->hasExternalVariables()).empty() - && !mProfile->findRootMethodString(modelHasOdes(), mModel->hasExternalVariables()).empty() - && !mProfile->nlaSolveCallString(modelHasOdes(), mModel->hasExternalVariables()).empty()) { + if (mModelHasNlas + && !mProfile->objectiveFunctionMethodString(mModelHasOdes, mModel->hasExternalVariables()).empty() + && !mProfile->findRootMethodString(mModelHasOdes, mModel->hasExternalVariables()).empty() + && !mProfile->nlaSolveCallString(mModelHasOdes, mModel->hasExternalVariables()).empty()) { // Note: only states and algebraic variables can be computed through an NLA system. Constants, computed // constants, and external variables cannot, by definition, be computed through an NLA system. @@ -781,7 +597,7 @@ void Generator::GeneratorImpl::addNlaSystemsCode() + mProfile->commandSeparatorString() + "\n"; } - methodBody += newLineIfNeeded(); + methodBody += newLineIfNeeded(mCode); i = MAX_SIZE_T; @@ -803,8 +619,8 @@ void Generator::GeneratorImpl::addNlaSystemsCode() handledNlaEquations.push_back(nlaSibling); } - mCode += newLineIfNeeded() - + replace(replace(mProfile->objectiveFunctionMethodString(modelHasOdes(), mModel->hasExternalVariables()), + mCode += newLineIfNeeded(mCode) + + replace(replace(mProfile->objectiveFunctionMethodString(mModelHasOdes, mModel->hasExternalVariables()), "[INDEX]", convertToString(equation->nlaSystemIndex())), "[CODE]", generateMethodBodyCode(methodBody)); @@ -826,13 +642,13 @@ void Generator::GeneratorImpl::addNlaSystemsCode() auto variablesCount = variables.size(); - methodBody += newLineIfNeeded() + methodBody += newLineIfNeeded(mCode) + mProfile->indentString() - + replace(replace(mProfile->nlaSolveCallString(modelHasOdes(), mModel->hasExternalVariables()), + + replace(replace(mProfile->nlaSolveCallString(mModelHasOdes, mModel->hasExternalVariables()), "[INDEX]", convertToString(equation->nlaSystemIndex())), "[SIZE]", convertToString(variablesCount)); - methodBody += newLineIfNeeded(); + methodBody += newLineIfNeeded(mCode); i = MAX_SIZE_T; @@ -848,8 +664,8 @@ void Generator::GeneratorImpl::addNlaSystemsCode() + mProfile->commandSeparatorString() + "\n"; } - mCode += newLineIfNeeded() - + replace(replace(replace(mProfile->findRootMethodString(modelHasOdes(), mModel->hasExternalVariables()), + mCode += newLineIfNeeded(mCode) + + replace(replace(replace(mProfile->findRootMethodString(mModelHasOdes, mModel->hasExternalVariables()), "[INDEX]", convertToString(equation->nlaSystemIndex())), "[SIZE]", convertToString(variablesCount)), "[CODE]", generateMethodBodyCode(methodBody)); @@ -1744,16 +1560,16 @@ std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserEquatio res += mProfile->indentString() + generateVariableNameCode(variable->variable()) + mProfile->equalityString() - + replace(mProfile->externalVariableMethodCallString(modelHasOdes()), + + replace(mProfile->externalVariableMethodCallString(mModelHasOdes), "[INDEX]", convertToString(variable->index())) + mProfile->commandSeparatorString() + "\n"; } break; case AnalyserEquation::Type::NLA: - if (!mProfile->findRootCallString(modelHasOdes(), mModel->hasExternalVariables()).empty()) { + if (!mProfile->findRootCallString(mModelHasOdes, mModel->hasExternalVariables()).empty()) { res += mProfile->indentString() - + replace(mProfile->findRootCallString(modelHasOdes(), mModel->hasExternalVariables()), + + replace(mProfile->findRootCallString(mModelHasOdes, mModel->hasExternalVariables()), "[INDEX]", convertToString(equation->nlaSystemIndex())); } @@ -1778,7 +1594,7 @@ std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserEquatio void Generator::GeneratorImpl::addInterfaceComputeModelMethodsCode() { - auto interfaceInitialiseVariablesMethodString = mProfile->interfaceInitialiseVariablesMethodString(modelHasOdes()); + auto interfaceInitialiseVariablesMethodString = mProfile->interfaceInitialiseVariablesMethodString(mModelHasOdes); std::string code; if (!interfaceInitialiseVariablesMethodString.empty()) { @@ -1791,12 +1607,12 @@ void Generator::GeneratorImpl::addInterfaceComputeModelMethodsCode() auto interfaceComputeRatesMethodString = mProfile->interfaceComputeRatesMethodString(mModel->hasExternalVariables()); - if (modelHasOdes() + if (mModelHasOdes && !interfaceComputeRatesMethodString.empty()) { code += interfaceComputeRatesMethodString; } - auto interfaceComputeVariablesMethodString = mProfile->interfaceComputeVariablesMethodString(modelHasOdes(), + auto interfaceComputeVariablesMethodString = mProfile->interfaceComputeVariablesMethodString(mModelHasOdes, mModel->hasExternalVariables()); if (!interfaceComputeVariablesMethodString.empty()) { @@ -1804,7 +1620,7 @@ void Generator::GeneratorImpl::addInterfaceComputeModelMethodsCode() } if (!code.empty()) { - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + code; } } @@ -1836,7 +1652,7 @@ std::string Generator::GeneratorImpl::generateConstantInitialisationCode(const s void Generator::GeneratorImpl::addImplementationInitialiseVariablesMethodCode(std::vector &remainingEquations) { - auto implementationInitialiseVariablesMethodString = mProfile->implementationInitialiseVariablesMethodString(modelHasOdes()); + auto implementationInitialiseVariablesMethodString = mProfile->implementationInitialiseVariablesMethodString(mModelHasOdes); if (!implementationInitialiseVariablesMethodString.empty()) { // Initialise our states (after, if needed, initialising the constant on which it depends). @@ -1904,7 +1720,7 @@ void Generator::GeneratorImpl::addImplementationInitialiseVariablesMethodCode(st } } - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + replace(implementationInitialiseVariablesMethodString, "[CODE]", generateMethodBodyCode(methodBody)); } @@ -1921,7 +1737,7 @@ void Generator::GeneratorImpl::addImplementationComputeComputedConstantsMethodCo } } - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + replace(mProfile->implementationComputeComputedConstantsMethodString(), "[CODE]", generateMethodBodyCode(methodBody)); } @@ -1931,7 +1747,7 @@ void Generator::GeneratorImpl::addImplementationComputeRatesMethodCode(std::vect { auto implementationComputeRatesMethodString = mProfile->implementationComputeRatesMethodString(mModel->hasExternalVariables()); - if (modelHasOdes() + if (mModelHasOdes && !implementationComputeRatesMethodString.empty()) { std::string methodBody; @@ -1950,7 +1766,7 @@ void Generator::GeneratorImpl::addImplementationComputeRatesMethodCode(std::vect } } - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + replace(implementationComputeRatesMethodString, "[CODE]", generateMethodBodyCode(methodBody)); } @@ -1958,7 +1774,7 @@ void Generator::GeneratorImpl::addImplementationComputeRatesMethodCode(std::vect void Generator::GeneratorImpl::addImplementationComputeVariablesMethodCode(std::vector &remainingEquations) { - auto implementationComputeVariablesMethodString = mProfile->implementationComputeVariablesMethodString(modelHasOdes(), + auto implementationComputeVariablesMethodString = mProfile->implementationComputeVariablesMethodString(mModelHasOdes, mModel->hasExternalVariables()); if (!implementationComputeVariablesMethodString.empty()) { @@ -1973,7 +1789,7 @@ void Generator::GeneratorImpl::addImplementationComputeVariablesMethodCode(std:: } } - mCode += newLineIfNeeded() + mCode += newLineIfNeeded(mCode) + replace(implementationComputeVariablesMethodString, "[CODE]", generateMethodBodyCode(methodBody)); } @@ -2012,6 +1828,8 @@ AnalyserModelPtr Generator::model() void Generator::setModel(const AnalyserModelPtr &model) { mPimpl->mModel = model; + mPimpl->mModelHasOdes = modelHasOdes(model); + mPimpl->mModelHasNlas = modelHasNlas(model); } std::string Generator::interfaceCode() const @@ -2153,13 +1971,9 @@ std::string Generator::implementationCode() const std::string Generator::equationCode(const AnalyserEquationAstPtr &ast, const GeneratorProfilePtr &profile) { - GeneratorPtr generator = libcellml::Generator::create(); - - if (profile != nullptr) { - generator->setProfile(profile); - } + auto generatorInterpreter = GeneratorInterpreter::create(ast, profile); - return generator->mPimpl->generateCode(ast); + return generatorInterpreter->code(); } std::string Generator::equationCode(const AnalyserEquationAstPtr &ast) diff --git a/src/generator_p.h b/src/generator_p.h index fce8d2e7d..50cb6dbb9 100644 --- a/src/generator_p.h +++ b/src/generator_p.h @@ -26,8 +26,6 @@ namespace libcellml { std::string generateDoubleCode(const std::string &value); -AnalyserVariablePtr analyserVariable(const AnalyserModelPtr &model, const VariablePtr &variable); - /** * @brief The Generator::GeneratorImpl struct. * @@ -37,15 +35,15 @@ struct Generator::GeneratorImpl { AnalyserModelPtr mModel; + bool mModelHasOdes = false; + bool mModelHasNlas = false; + std::string mCode; GeneratorProfilePtr mProfile = GeneratorProfile::create(); void reset(); - bool modelHasOdes() const; - bool modelHasNlas() const; - double scalingFactor(const VariablePtr &variable) const; bool isNegativeNumber(const AnalyserEquationAstPtr &ast) const; @@ -69,8 +67,6 @@ struct Generator::GeneratorImpl bool modifiedProfile() const; - std::string newLineIfNeeded(); - void addOriginCommentCode(); void addInterfaceHeaderCode(); diff --git a/src/generatorinterpreter.cpp b/src/generatorinterpreter.cpp new file mode 100644 index 000000000..9bb9e1ca2 --- /dev/null +++ b/src/generatorinterpreter.cpp @@ -0,0 +1,1495 @@ +/* +Copyright libCellML Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "generatorinterpreter_p.h" + +#include "libcellml/analyserequationast.h" +#include "libcellml/analysermodel.h" +#include "libcellml/analyservariable.h" + +#include "utilities.h" + +namespace libcellml { + +GeneratorInterpreter::GeneratorInterpreterImpl::GeneratorInterpreterImpl(const AnalyserEquationAstPtr &ast, + const GeneratorProfilePtr &profile) +{ + if (profile != nullptr) { + mProfile = profile; + } + + mCode = generateCode(ast); +} + +void GeneratorInterpreter::GeneratorInterpreterImpl::initialise(const AnalyserModelPtr &model, + const GeneratorProfilePtr &profile, + const std::string &code) +{ + mModel = model; + mModelHasOdes = modelHasOdes(model); + mModelHasNlas = modelHasNlas(model); + mProfile = profile; + mCode = code; + + // Add code for solving the NLA systems. + + addNlaSystemsCode(); + + // Add code for the implementation to initialise our variables. + + auto equations = mModel->equations(); + std::vector remainingEquations {std::begin(equations), std::end(equations)}; + + initialiseVariables(remainingEquations); + + // Add code for the implementation to compute our computed constants. + + computeComputedConstants(remainingEquations); + + // Add code for the implementation to compute our rates (and any variables on which they depend). + + computeRates(remainingEquations); + + // Add code for the implementation to compute our variables. + // Note: this method computes the remaining variables, i.e. the ones not needed to compute our rates, but also the + // variables that depend on the value of some states/rates and all the external variables. This method is + // typically called after having integrated a model, thus ensuring that variables that rely on the value of + // some states/rates are up to date. + + computeVariables(remainingEquations); +} + +bool modelHasOdes(const AnalyserModelPtr &model) +{ + switch (model->type()) { + case AnalyserModel::Type::ODE: + case AnalyserModel::Type::DAE: + return true; + default: + return false; + } +} + +bool modelHasNlas(const AnalyserModelPtr &model) +{ + switch (model->type()) { + case AnalyserModel::Type::NLA: + case AnalyserModel::Type::DAE: + return true; + default: + return false; + } +} + +AnalyserVariablePtr analyserVariable(const AnalyserModelPtr &model, const VariablePtr &variable) +{ + // Find and return the analyser variable associated with the given variable. + + AnalyserVariablePtr res; + auto modelVoi = model->voi(); + auto modelVoiVariable = (modelVoi != nullptr) ? modelVoi->variable() : nullptr; + + if ((modelVoiVariable != nullptr) + && model->areEquivalentVariables(variable, modelVoiVariable)) { + res = modelVoi; + } else { + // Normally, we would have something like: + // + // for (const auto &modelVariable : variables(model)) { + // if (model->areEquivalentVariables(variable, modelVariable->variable())) { + // res = modelVariable; + // + // break; + // } + // } + // + // but we always have variables, so llvm-cov will complain that the false branch of our for loop is never + // reached. The below code is a bit more verbose but at least it makes llvm-cov happy. + + auto modelVariables = variables(model); + auto modelVariable = modelVariables.begin(); + + do { + if (model->areEquivalentVariables(variable, (*modelVariable)->variable())) { + res = *modelVariable; + } else { + ++modelVariable; + } + } while (res == nullptr); + } + + return res; +} + +double GeneratorInterpreter::GeneratorInterpreterImpl::scalingFactor(const VariablePtr &variable) const +{ + // Return the scaling factor for the given variable, accounting for the fact that a constant may be initialised by + // another variable which initial value may be defined in a different component. + + auto analyserVariable = libcellml::analyserVariable(mModel, variable); + + if ((analyserVariable->type() == AnalyserVariable::Type::CONSTANT) + && !isCellMLReal(variable->initialValue())) { + auto initialValueVariable = owningComponent(variable)->variable(variable->initialValue()); + auto initialValueAnalyserVariable = libcellml::analyserVariable(mModel, initialValueVariable); + + if (owningComponent(variable) != owningComponent(initialValueAnalyserVariable->variable())) { + return Units::scalingFactor(initialValueVariable->units(), variable->units()); + } + } + + return Units::scalingFactor(analyserVariable->variable()->units(), variable->units()); +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isNegativeNumber(const AnalyserEquationAstPtr &ast) const +{ + if (ast->type() == AnalyserEquationAst::Type::CN) { + double doubleValue; + + convertToDouble(ast->value(), doubleValue); + + return doubleValue < 0.0; + } + + return false; +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isRelationalOperator(const AnalyserEquationAstPtr &ast) const +{ + switch (ast->type()) { + case AnalyserEquationAst::Type::EQ: + return mProfile->hasEqOperator(); + case AnalyserEquationAst::Type::NEQ: + return mProfile->hasNeqOperator(); + case AnalyserEquationAst::Type::LT: + return mProfile->hasLtOperator(); + case AnalyserEquationAst::Type::LEQ: + return mProfile->hasLeqOperator(); + case AnalyserEquationAst::Type::GT: + return mProfile->hasGtOperator(); + case AnalyserEquationAst::Type::GEQ: + return mProfile->hasGeqOperator(); + default: + return false; + } +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isAndOperator(const AnalyserEquationAstPtr &ast) const +{ + return (ast->type() == AnalyserEquationAst::Type::AND) + && mProfile->hasAndOperator(); +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isOrOperator(const AnalyserEquationAstPtr &ast) const +{ + return (ast->type() == AnalyserEquationAst::Type::OR) + && mProfile->hasOrOperator(); +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isXorOperator(const AnalyserEquationAstPtr &ast) const +{ + return (ast->type() == AnalyserEquationAst::Type::XOR) + && mProfile->hasXorOperator(); +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isLogicalOperator(const AnalyserEquationAstPtr &ast) const +{ + // Note: AnalyserEquationAst::Type::NOT is a unary logical operator, hence we don't include it here since this + // method is only used to determine whether parentheses should be added around some code. + + return isAndOperator(ast) || isOrOperator(ast) || isXorOperator(ast); +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isPlusOperator(const AnalyserEquationAstPtr &ast) const +{ + return ast->type() == AnalyserEquationAst::Type::PLUS; +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isMinusOperator(const AnalyserEquationAstPtr &ast) const +{ + return ast->type() == AnalyserEquationAst::Type::MINUS; +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isTimesOperator(const AnalyserEquationAstPtr &ast) const +{ + return ast->type() == AnalyserEquationAst::Type::TIMES; +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isDivideOperator(const AnalyserEquationAstPtr &ast) const +{ + return ast->type() == AnalyserEquationAst::Type::DIVIDE; +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isPowerOperator(const AnalyserEquationAstPtr &ast) const +{ + return (ast->type() == AnalyserEquationAst::Type::POWER) + && mProfile->hasPowerOperator(); +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isRootOperator(const AnalyserEquationAstPtr &ast) const +{ + return (ast->type() == AnalyserEquationAst::Type::ROOT) + && mProfile->hasPowerOperator(); +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isPiecewiseStatement(const AnalyserEquationAstPtr &ast) const +{ + return (ast->type() == AnalyserEquationAst::Type::PIECEWISE) + && mProfile->hasConditionalOperator(); +} + +std::string newLineIfNeeded(const std::string &code) +{ + return code.empty() ? "" : "\n"; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateMethodBodyCode(const std::string &methodBody) const +{ + return methodBody.empty() ? + (mProfile->emptyMethodString().empty() ? + "" : + mProfile->indentString() + mProfile->emptyMethodString()) : + methodBody; +} + +std::string generateDoubleCode(const std::string &value) +{ + if (value.find('.') != std::string::npos) { + return value; + } + + auto ePos = value.find('e'); + + if (ePos == std::string::npos) { + return value + ".0"; + } + + return value.substr(0, ePos) + ".0" + value.substr(ePos); +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateVariableNameCode(const VariablePtr &variable, + bool rate) const +{ + // Generate some code for a variable name, but only if we have a model. If we don't have a model, it means that we + // are using the generator from the analyser, in which case we just want to return the original name of the + // variable. + + if (mModel == nullptr) { + return variable->name(); + } + + auto analyserVariable = libcellml::analyserVariable(mModel, variable); + + if (analyserVariable->type() == AnalyserVariable::Type::VARIABLE_OF_INTEGRATION) { + return mProfile->voiString(); + } + + std::string arrayName; + + if (analyserVariable->type() == AnalyserVariable::Type::STATE) { + arrayName = rate ? + mProfile->ratesArrayString() : + mProfile->statesArrayString(); + } else if (analyserVariable->type() == AnalyserVariable::Type::CONSTANT) { + arrayName = mProfile->constantsArrayString(); + } else if (analyserVariable->type() == AnalyserVariable::Type::COMPUTED_CONSTANT) { + arrayName = mProfile->computedConstantsArrayString(); + } else if (analyserVariable->type() == AnalyserVariable::Type::ALGEBRAIC) { + arrayName = mProfile->algebraicArrayString(); + } else { + arrayName = mProfile->externalArrayString(); + } + + return arrayName + mProfile->openArrayString() + convertToString(analyserVariable->index()) + mProfile->closeArrayString(); +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateOperatorCode(const std::string &op, + const AnalyserEquationAstPtr &ast) const +{ + // Generate the code for the left and right branches of the given AST. + + std::string res; + auto astLeftChild = ast->leftChild(); + auto astRightChild = ast->rightChild(); + auto leftCode = generateCode(astLeftChild); + auto rightCode = generateCode(astRightChild); + + // Determine whether parentheses should be added around the left and/or right piece of code, and this based on the + // precedence of the operators used in CellML, which are listed below from higher to lower precedence: + // 1. Parentheses [Left to right] + // 2. POWER (as an operator, not as a function, i.e. [Left to right] + // as in Matlab and not in C, for example) + // 3. Unary PLUS, Unary MINUS, NOT [Right to left] + // 4. TIMES, DIVIDE [Left to right] + // 5. PLUS, MINUS [Left to right] + // 6. LT, LEQ, GT, GEQ [Left to right] + // 7. EQ, NEQ [Left to right] + // 8. XOR (bitwise) [Left to right] + // 9. AND (logical) [Left to right] + // 10. OR (logical) [Left to right] + // 11. PIECEWISE (as an operator) [Right to left] + + if (isPlusOperator(ast)) { + if (isRelationalOperator(astLeftChild) + || isLogicalOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } + + if (isRelationalOperator(astRightChild) + || isLogicalOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } + } else if (isMinusOperator(ast)) { + if (isRelationalOperator(astLeftChild) + || isLogicalOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } + + if (isNegativeNumber(astRightChild) + || isRelationalOperator(astRightChild) + || isLogicalOperator(astRightChild) + || isMinusOperator(astRightChild) + || isPiecewiseStatement(astRightChild) + || (rightCode.rfind(mProfile->minusString(), 0) == 0)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } + } else if (isTimesOperator(ast)) { + if (isRelationalOperator(astLeftChild) + || isLogicalOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isPlusOperator(astLeftChild) + || isMinusOperator(astLeftChild)) { + if (astLeftChild->rightChild() != nullptr) { + leftCode = "(" + leftCode + ")"; + } + } + + if (isRelationalOperator(astRightChild) + || isLogicalOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild) + || isMinusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } + } else if (isDivideOperator(ast)) { + if (isRelationalOperator(astLeftChild) + || isLogicalOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isPlusOperator(astLeftChild) + || isMinusOperator(astLeftChild)) { + if (astLeftChild->rightChild() != nullptr) { + leftCode = "(" + leftCode + ")"; + } + } + + if (isRelationalOperator(astRightChild) + || isLogicalOperator(astRightChild) + || isTimesOperator(astRightChild) + || isDivideOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild) + || isMinusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } + } else if (isAndOperator(ast)) { + // Note: according to the precedence rules above, we only need to add parentheses around OR and PIECEWISE. + // However, it looks better/clearer to have some around some other operators (agreed, this is somewhat + // subjective). + + if (isRelationalOperator(astLeftChild) + || isOrOperator(astLeftChild) + || isXorOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isPlusOperator(astLeftChild) + || isMinusOperator(astLeftChild)) { + if (astLeftChild->rightChild() != nullptr) { + leftCode = "(" + leftCode + ")"; + } + } else if (isPowerOperator(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isRootOperator(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } + + if (isRelationalOperator(astRightChild) + || isOrOperator(astRightChild) + || isXorOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild) + || isMinusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } else if (isPowerOperator(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isRootOperator(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } + } else if (isOrOperator(ast)) { + // Note: according to the precedence rules above, we only need to add parentheses around PIECEWISE. However, it + // looks better/clearer to have some around some other operators (agreed, this is somewhat subjective). + + if (isRelationalOperator(astLeftChild) + || isAndOperator(astLeftChild) + || isXorOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isPlusOperator(astLeftChild) + || isMinusOperator(astLeftChild)) { + if (astLeftChild->rightChild() != nullptr) { + leftCode = "(" + leftCode + ")"; + } + } else if (isPowerOperator(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isRootOperator(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } + + if (isRelationalOperator(astRightChild) + || isAndOperator(astRightChild) + || isXorOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild) + || isMinusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } else if (isPowerOperator(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isRootOperator(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } + } else if (isXorOperator(ast)) { + // Note: according to the precedence rules above, we only need to add parentheses around AND, OR and PIECEWISE. + // However, it looks better/clearer to have some around some other operators (agreed, this is somewhat + // subjective). + + if (isRelationalOperator(astLeftChild) + || isAndOperator(astLeftChild) + || isOrOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isPlusOperator(astLeftChild) + || isMinusOperator(astLeftChild)) { + if (astLeftChild->rightChild() != nullptr) { + leftCode = "(" + leftCode + ")"; + } + } else if (isPowerOperator(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isRootOperator(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } + + if (isRelationalOperator(astRightChild) + || isAndOperator(astRightChild) + || isOrOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild) + || isMinusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } else if (isPowerOperator(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isRootOperator(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } + } else if (isPowerOperator(ast)) { + if (isRelationalOperator(astLeftChild) + || isLogicalOperator(astLeftChild) + || isMinusOperator(astLeftChild) + || isTimesOperator(astLeftChild) + || isDivideOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isPlusOperator(astLeftChild)) { + if (astLeftChild->rightChild() != nullptr) { + leftCode = "(" + leftCode + ")"; + } + } + + if (isRelationalOperator(astRightChild) + || isLogicalOperator(astRightChild) + || isMinusOperator(astLeftChild) + || isTimesOperator(astRightChild) + || isDivideOperator(astRightChild) + || isPowerOperator(astRightChild) + || isRootOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } + } else if (isRootOperator(ast)) { + if (isRelationalOperator(astRightChild) + || isLogicalOperator(astRightChild) + || isMinusOperator(astRightChild) + || isTimesOperator(astRightChild) + || isDivideOperator(astRightChild) + || isPiecewiseStatement(astRightChild)) { + rightCode = "(" + rightCode + ")"; + } else if (isPlusOperator(astRightChild)) { + if (astRightChild->rightChild() != nullptr) { + rightCode = "(" + rightCode + ")"; + } + } + + auto astLeftChildLeftChild = astLeftChild->leftChild(); + + if (isRelationalOperator(astLeftChildLeftChild) + || isLogicalOperator(astLeftChildLeftChild) + || isMinusOperator(astLeftChildLeftChild) + || isTimesOperator(astLeftChildLeftChild) + || isDivideOperator(astLeftChildLeftChild) + || isPowerOperator(astLeftChildLeftChild) + || isRootOperator(astLeftChildLeftChild) + || isPiecewiseStatement(astLeftChildLeftChild)) { + leftCode = "(" + leftCode + ")"; + } else if (isPlusOperator(astLeftChildLeftChild)) { + if (astLeftChildLeftChild->rightChild() != nullptr) { + leftCode = "(" + leftCode + ")"; + } + } + + return rightCode + op + "(1.0/" + leftCode + ")"; + } + + return leftCode + op + rightCode; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateMinusUnaryCode(const AnalyserEquationAstPtr &ast) const +{ + // Generate the code for the left branch of the given AST. + + auto astLeftChild = ast->leftChild(); + auto leftCode = generateCode(astLeftChild); + + // Determine whether parentheses should be added around the left code. + + if (isRelationalOperator(astLeftChild) + || isLogicalOperator(astLeftChild) + || isPlusOperator(astLeftChild) + || isMinusOperator(astLeftChild) + || isPiecewiseStatement(astLeftChild)) { + leftCode = "(" + leftCode + ")"; + } + + return mProfile->minusString() + leftCode; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateOneParameterFunctionCode(const std::string &function, + const AnalyserEquationAstPtr &ast) const +{ + auto leftCode = generateCode(ast->leftChild()); + + return function + "(" + leftCode + ")"; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateTwoParameterFunctionCode(const std::string &function, + const AnalyserEquationAstPtr &ast) const +{ + auto leftCode = generateCode(ast->leftChild()); + auto rightCode = generateCode(ast->rightChild()); + + return function + "(" + leftCode + ", " + rightCode + ")"; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generatePiecewiseIfCode(const std::string &condition, + const std::string &value) const +{ + return replace(replace(mProfile->hasConditionalOperator() ? + mProfile->conditionalOperatorIfString() : + mProfile->piecewiseIfString(), + "[CONDITION]", condition), + "[IF_STATEMENT]", value); +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generatePiecewiseElseCode(const std::string &value) const +{ + return replace(mProfile->hasConditionalOperator() ? + mProfile->conditionalOperatorElseString() : + mProfile->piecewiseElseString(), + "[ELSE_STATEMENT]", value); +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateCode(const AnalyserEquationAstPtr &ast) const +{ + // Make sure that we have an AST to work on. + + if (ast == nullptr) { + return {}; + } + + // Generate the code for the given AST. + // Note: AnalyserEquationAst::Type::BVAR is only relevant when there is no model (in which case we want to generate + // something like dx/dt, as is in the case of the analyser when we want to mention an equation) since + // otherwise we don't need to generate any code for it (since we will, instead, want to generate something + // like rates[0]). + + std::string code; + + switch (ast->type()) { + case AnalyserEquationAst::Type::EQUALITY: + code = generateOperatorCode(mProfile->equalityString(), ast); + + break; + case AnalyserEquationAst::Type::EQ: + if (mProfile->hasEqOperator()) { + code = generateOperatorCode(mProfile->eqString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->eqString(), ast); + } + + break; + case AnalyserEquationAst::Type::NEQ: + if (mProfile->hasNeqOperator()) { + code = generateOperatorCode(mProfile->neqString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->neqString(), ast); + } + + break; + case AnalyserEquationAst::Type::LT: + if (mProfile->hasLtOperator()) { + code = generateOperatorCode(mProfile->ltString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->ltString(), ast); + } + + break; + case AnalyserEquationAst::Type::LEQ: + if (mProfile->hasLeqOperator()) { + code = generateOperatorCode(mProfile->leqString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->leqString(), ast); + } + + break; + case AnalyserEquationAst::Type::GT: + if (mProfile->hasGtOperator()) { + code = generateOperatorCode(mProfile->gtString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->gtString(), ast); + } + + break; + case AnalyserEquationAst::Type::GEQ: + if (mProfile->hasGeqOperator()) { + code = generateOperatorCode(mProfile->geqString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->geqString(), ast); + } + + break; + case AnalyserEquationAst::Type::AND: + if (mProfile->hasAndOperator()) { + code = generateOperatorCode(mProfile->andString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->andString(), ast); + } + + break; + case AnalyserEquationAst::Type::OR: + if (mProfile->hasOrOperator()) { + code = generateOperatorCode(mProfile->orString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->orString(), ast); + } + + break; + case AnalyserEquationAst::Type::XOR: + if (mProfile->hasXorOperator()) { + code = generateOperatorCode(mProfile->xorString(), ast); + } else { + code = generateTwoParameterFunctionCode(mProfile->xorString(), ast); + } + + break; + case AnalyserEquationAst::Type::NOT: + if (mProfile->hasNotOperator()) { + code = mProfile->notString() + generateCode(ast->leftChild()); + } else { + code = generateOneParameterFunctionCode(mProfile->notString(), ast); + } + + break; + case AnalyserEquationAst::Type::PLUS: + if (ast->rightChild() != nullptr) { + code = generateOperatorCode(mProfile->plusString(), ast); + } else { + code = generateCode(ast->leftChild()); + } + + break; + case AnalyserEquationAst::Type::MINUS: + code = (ast->rightChild() != nullptr) ? + generateOperatorCode(mProfile->minusString(), ast) : + generateMinusUnaryCode(ast); + + break; + case AnalyserEquationAst::Type::TIMES: + code = generateOperatorCode(mProfile->timesString(), ast); + + break; + case AnalyserEquationAst::Type::DIVIDE: + code = generateOperatorCode(mProfile->divideString(), ast); + + break; + case AnalyserEquationAst::Type::POWER: { + auto rightCode = generateCode(ast->rightChild()); + double doubleValue; + auto validConversion = convertToDouble(rightCode, doubleValue); + + if (validConversion && areEqual(doubleValue, 0.5)) { + code = mProfile->squareRootString() + "(" + generateCode(ast->leftChild()) + ")"; + } else if (validConversion && areEqual(doubleValue, 2.0) + && !mProfile->squareString().empty()) { + code = mProfile->squareString() + "(" + generateCode(ast->leftChild()) + ")"; + } else if (mProfile->hasPowerOperator()) { + code = generateOperatorCode(mProfile->powerString(), ast); + } else { + code = mProfile->powerString() + "(" + generateCode(ast->leftChild()) + ", " + rightCode + ")"; + } + } break; + case AnalyserEquationAst::Type::ROOT: { + auto astRightChild = ast->rightChild(); + + if (astRightChild != nullptr) { + auto astLeftChild = ast->leftChild(); + auto leftCode = generateCode(astLeftChild); + double doubleValue; + auto validConversion = convertToDouble(leftCode, doubleValue); + + if (validConversion && areEqual(doubleValue, 2.0)) { + code = mProfile->squareRootString() + "(" + generateCode(astRightChild) + ")"; + } else if (validConversion && areEqual(doubleValue, 0.5) + && !mProfile->squareString().empty()) { + code = mProfile->squareString() + "(" + generateCode(astRightChild) + ")"; + } else if (mProfile->hasPowerOperator()) { + code = generateOperatorCode(mProfile->powerString(), ast); + } else { + auto inverseValueAst = AnalyserEquationAst::create(); + + inverseValueAst->setType(AnalyserEquationAst::Type::DIVIDE); + inverseValueAst->setParent(ast); + + auto inverseValueAstLeftChild = AnalyserEquationAst::create(); + + inverseValueAstLeftChild->setType(AnalyserEquationAst::Type::CN); + inverseValueAstLeftChild->setValue("1.0"); + inverseValueAstLeftChild->setParent(inverseValueAst); + + inverseValueAst->setLeftChild(inverseValueAstLeftChild); + inverseValueAst->setRightChild(astLeftChild->leftChild()); + + code = mProfile->powerString() + "(" + generateCode(astRightChild) + ", " + generateOperatorCode(mProfile->divideString(), inverseValueAst) + ")"; + } + } else { + code = mProfile->squareRootString() + "(" + generateCode(ast->leftChild()) + ")"; + } + } break; + case AnalyserEquationAst::Type::ABS: + code = generateOneParameterFunctionCode(mProfile->absoluteValueString(), ast); + + break; + case AnalyserEquationAst::Type::EXP: + code = generateOneParameterFunctionCode(mProfile->exponentialString(), ast); + + break; + case AnalyserEquationAst::Type::LN: + code = generateOneParameterFunctionCode(mProfile->naturalLogarithmString(), ast); + + break; + case AnalyserEquationAst::Type::LOG: { + auto astRightChild = ast->rightChild(); + + if (astRightChild != nullptr) { + auto leftCode = generateCode(ast->leftChild()); + double doubleValue; + auto rightCode = generateCode(astRightChild); + + if (convertToDouble(leftCode, doubleValue) + && areEqual(doubleValue, 10.0)) { + code = mProfile->commonLogarithmString() + "(" + rightCode + ")"; + } else { + code = mProfile->naturalLogarithmString() + "(" + rightCode + ")/" + mProfile->naturalLogarithmString() + "(" + leftCode + ")"; + } + } else { + code = generateOneParameterFunctionCode(mProfile->commonLogarithmString(), ast); + } + } break; + case AnalyserEquationAst::Type::CEILING: + code = generateOneParameterFunctionCode(mProfile->ceilingString(), ast); + + break; + case AnalyserEquationAst::Type::FLOOR: + code = generateOneParameterFunctionCode(mProfile->floorString(), ast); + + break; + case AnalyserEquationAst::Type::MIN: + code = generateTwoParameterFunctionCode(mProfile->minString(), ast); + + break; + case AnalyserEquationAst::Type::MAX: + code = generateTwoParameterFunctionCode(mProfile->maxString(), ast); + + break; + case AnalyserEquationAst::Type::REM: + code = generateTwoParameterFunctionCode(mProfile->remString(), ast); + + break; + case AnalyserEquationAst::Type::DIFF: + if (mModel != nullptr) { + code = generateCode(ast->rightChild()); + } else { + code = "d" + generateCode(ast->rightChild()) + "/d" + generateCode(ast->leftChild()); + } + + break; + case AnalyserEquationAst::Type::SIN: + code = generateOneParameterFunctionCode(mProfile->sinString(), ast); + + break; + case AnalyserEquationAst::Type::COS: + code = generateOneParameterFunctionCode(mProfile->cosString(), ast); + + break; + case AnalyserEquationAst::Type::TAN: + code = generateOneParameterFunctionCode(mProfile->tanString(), ast); + + break; + case AnalyserEquationAst::Type::SEC: + code = generateOneParameterFunctionCode(mProfile->secString(), ast); + + break; + case AnalyserEquationAst::Type::CSC: + code = generateOneParameterFunctionCode(mProfile->cscString(), ast); + + break; + case AnalyserEquationAst::Type::COT: + code = generateOneParameterFunctionCode(mProfile->cotString(), ast); + + break; + case AnalyserEquationAst::Type::SINH: + code = generateOneParameterFunctionCode(mProfile->sinhString(), ast); + + break; + case AnalyserEquationAst::Type::COSH: + code = generateOneParameterFunctionCode(mProfile->coshString(), ast); + + break; + case AnalyserEquationAst::Type::TANH: + code = generateOneParameterFunctionCode(mProfile->tanhString(), ast); + + break; + case AnalyserEquationAst::Type::SECH: + code = generateOneParameterFunctionCode(mProfile->sechString(), ast); + + break; + case AnalyserEquationAst::Type::CSCH: + code = generateOneParameterFunctionCode(mProfile->cschString(), ast); + + break; + case AnalyserEquationAst::Type::COTH: + code = generateOneParameterFunctionCode(mProfile->cothString(), ast); + + break; + case AnalyserEquationAst::Type::ASIN: + code = generateOneParameterFunctionCode(mProfile->asinString(), ast); + + break; + case AnalyserEquationAst::Type::ACOS: + code = generateOneParameterFunctionCode(mProfile->acosString(), ast); + + break; + case AnalyserEquationAst::Type::ATAN: + code = generateOneParameterFunctionCode(mProfile->atanString(), ast); + + break; + case AnalyserEquationAst::Type::ASEC: + code = generateOneParameterFunctionCode(mProfile->asecString(), ast); + + break; + case AnalyserEquationAst::Type::ACSC: + code = generateOneParameterFunctionCode(mProfile->acscString(), ast); + + break; + case AnalyserEquationAst::Type::ACOT: + code = generateOneParameterFunctionCode(mProfile->acotString(), ast); + + break; + case AnalyserEquationAst::Type::ASINH: + code = generateOneParameterFunctionCode(mProfile->asinhString(), ast); + + break; + case AnalyserEquationAst::Type::ACOSH: + code = generateOneParameterFunctionCode(mProfile->acoshString(), ast); + + break; + case AnalyserEquationAst::Type::ATANH: + code = generateOneParameterFunctionCode(mProfile->atanhString(), ast); + + break; + case AnalyserEquationAst::Type::ASECH: + code = generateOneParameterFunctionCode(mProfile->asechString(), ast); + + break; + case AnalyserEquationAst::Type::ACSCH: + code = generateOneParameterFunctionCode(mProfile->acschString(), ast); + + break; + case AnalyserEquationAst::Type::ACOTH: + code = generateOneParameterFunctionCode(mProfile->acothString(), ast); + + break; + case AnalyserEquationAst::Type::PIECEWISE: { + auto astLeftChild = ast->leftChild(); + auto astRightChild = ast->rightChild(); + auto leftCode = generateCode(astLeftChild); + auto rightCode = generateCode(astRightChild); + + if (astRightChild != nullptr) { + if (astRightChild->type() == AnalyserEquationAst::Type::PIECE) { + code = leftCode + generatePiecewiseElseCode(rightCode + generatePiecewiseElseCode(mProfile->nanString())); + } else { + code = leftCode + generatePiecewiseElseCode(rightCode); + } + } else if (astLeftChild != nullptr) { + if (astLeftChild->type() == AnalyserEquationAst::Type::PIECE) { + code = leftCode + generatePiecewiseElseCode(mProfile->nanString()); + } else { + code = leftCode; + } + } else { + code = mProfile->nanString(); + } + } break; + case AnalyserEquationAst::Type::PIECE: { + auto leftCode = generateCode(ast->leftChild()); + auto rightCode = generateCode(ast->rightChild()); + + code = generatePiecewiseIfCode(rightCode, leftCode); + } break; + case AnalyserEquationAst::Type::OTHERWISE: { + code = generateCode(ast->leftChild()); + } break; + case AnalyserEquationAst::Type::CI: { + auto variable = ast->variable(); + bool rate = ast->parent()->type() == AnalyserEquationAst::Type::DIFF; + + code = generateVariableNameCode(variable, rate); + } break; + case AnalyserEquationAst::Type::CN: { + double doubleValue; + + convertToDouble(ast->value(), doubleValue); + + code = generateDoubleCode(ast->value()); + } break; + case AnalyserEquationAst::Type::DEGREE: + case AnalyserEquationAst::Type::LOGBASE: { + code = generateCode(ast->leftChild()); + } break; + case AnalyserEquationAst::Type::BVAR: { + code = generateCode(ast->leftChild()); + } break; + case AnalyserEquationAst::Type::TRUE: + code = mProfile->trueString(); + + break; + case AnalyserEquationAst::Type::FALSE: + code = mProfile->falseString(); + + break; + case AnalyserEquationAst::Type::E: + code = mProfile->eString(); + + break; + case AnalyserEquationAst::Type::PI: + code = mProfile->piString(); + + break; + case AnalyserEquationAst::Type::INF: + code = mProfile->infString(); + + break; + default: // AnalyserEquationAst::Type::NAN. + code = mProfile->nanString(); + + break; + } + + return code; +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isToBeComputedAgain(const AnalyserEquationPtr &equation) const +{ + // NLA and algebraic equations that are state/rate-based and external equations are to be computed again (in the + // computeVariables() method). + + switch (equation->type()) { + case AnalyserEquation::Type::NLA: + case AnalyserEquation::Type::ALGEBRAIC: + return equation->isStateRateBased(); + case AnalyserEquation::Type::EXTERNAL: + return true; + default: + return false; + } +} + +bool GeneratorInterpreter::GeneratorInterpreterImpl::isSomeConstant(const AnalyserEquationPtr &equation, + bool includeComputedConstants) const +{ + auto type = equation->type(); + + return (type == AnalyserEquation::Type::TRUE_CONSTANT) + || (!includeComputedConstants && (type == AnalyserEquation::Type::VARIABLE_BASED_CONSTANT)); +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateZeroInitialisationCode(const AnalyserVariablePtr &variable) +{ + bool rate = variable->type() == AnalyserVariable::Type::STATE; + + return mProfile->indentString() + + generateVariableNameCode(variable->variable(), rate) + + mProfile->equalityString() + + "0.0" + + mProfile->commandSeparatorString() + "\n"; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateInitialisationCode(const AnalyserVariablePtr &variable) +{ + // Determine whether the initialising variable has an initial value per se or if it is initialised by another + // variable. + + auto initialisingVariable = variable->initialisingVariable(); + std::string initialValueCode; + + if (isCellMLReal(initialisingVariable->initialValue())) { + initialValueCode = generateDoubleCode(initialisingVariable->initialValue()); + } else { + auto initialValueVariable = owningComponent(initialisingVariable)->variable(initialisingVariable->initialValue()); + auto analyserInitialValueVariable = analyserVariable(mModel, initialValueVariable); + + initialValueCode = mProfile->constantsArrayString() + mProfile->openArrayString() + convertToString(analyserInitialValueVariable->index()) + mProfile->closeArrayString(); + } + + // Determine the scaling factor, if any. + + auto scalingFactor = GeneratorInterpreter::GeneratorInterpreterImpl::scalingFactor(initialisingVariable); + + if (!areNearlyEqual(scalingFactor, 1.0)) { + initialValueCode = generateDoubleCode(convertToString(scalingFactor)) + mProfile->timesString() + initialValueCode; + } + + return mProfile->indentString() + + generateVariableNameCode(variable->variable()) + + mProfile->equalityString() + + initialValueCode + + mProfile->commandSeparatorString() + "\n"; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateEquationCode(const AnalyserEquationPtr &equation, + std::vector &remainingEquations, + std::vector &equationsForDependencies, + bool includeComputedConstants) +{ + std::string res; + + if (std::find(remainingEquations.begin(), remainingEquations.end(), equation) != remainingEquations.end()) { + // Stop tracking the equation and its NLA siblings, if any. + // Note: we need to do this as soon as possible to avoid recursive calls, something that would happen if we were + // to do this at the end of this if statement. + + remainingEquations.erase(std::find(remainingEquations.begin(), remainingEquations.end(), equation)); + + for (const auto &nlaSibling : equation->nlaSiblings()) { + remainingEquations.erase(std::find(remainingEquations.begin(), remainingEquations.end(), nlaSibling)); + } + + // Generate any dependency that this equation may have. + + if (!isSomeConstant(equation, includeComputedConstants)) { + for (const auto &dependency : equation->dependencies()) { + if ((dependency->type() != AnalyserEquation::Type::ODE) + && !isSomeConstant(dependency, includeComputedConstants) + && (equationsForDependencies.empty() + || isToBeComputedAgain(dependency) + || (std::find(equationsForDependencies.begin(), equationsForDependencies.end(), dependency) != equationsForDependencies.end()))) { + res += generateEquationCode(dependency, remainingEquations, equationsForDependencies, includeComputedConstants); + } + } + } + + // Generate the equation code itself, based on the equation type. + + switch (equation->type()) { + case AnalyserEquation::Type::EXTERNAL: + for (const auto &variable : variables(equation)) { + res += mProfile->indentString() + + generateVariableNameCode(variable->variable()) + + mProfile->equalityString() + + replace(mProfile->externalVariableMethodCallString(mModelHasOdes), + "[INDEX]", convertToString(variable->index())) + + mProfile->commandSeparatorString() + "\n"; + } + + break; + case AnalyserEquation::Type::NLA: + if (!mProfile->findRootCallString(mModelHasOdes, mModel->hasExternalVariables()).empty()) { + res += mProfile->indentString() + + replace(mProfile->findRootCallString(mModelHasOdes, mModel->hasExternalVariables()), + "[INDEX]", convertToString(equation->nlaSystemIndex())); + } + + break; + default: + res += mProfile->indentString() + generateCode(equation->ast()) + mProfile->commandSeparatorString() + "\n"; + + break; + } + } + + return res; +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateEquationCode(const AnalyserEquationPtr &equation, + std::vector &remainingEquations) +{ + std::vector dummyEquationsForComputeVariables; + + return generateEquationCode(equation, remainingEquations, dummyEquationsForComputeVariables, true); +} + +void GeneratorInterpreter::GeneratorInterpreterImpl::addNlaSystemsCode() +{ + if (mModelHasNlas + && !mProfile->objectiveFunctionMethodString(mModelHasOdes, mModel->hasExternalVariables()).empty() + && !mProfile->findRootMethodString(mModelHasOdes, mModel->hasExternalVariables()).empty() + && !mProfile->nlaSolveCallString(mModelHasOdes, mModel->hasExternalVariables()).empty()) { + // Note: only states and algebraic variables can be computed through an NLA system. Constants, computed + // constants, and external variables cannot, by definition, be computed through an NLA system. + + std::vector handledNlaEquations; + + for (const auto &equation : mModel->equations()) { + if ((equation->type() == AnalyserEquation::Type::NLA) + && (std::find(handledNlaEquations.begin(), handledNlaEquations.end(), equation) == handledNlaEquations.end())) { + std::string methodBody; + auto i = MAX_SIZE_T; + auto variables = libcellml::variables(equation); + + for (const auto &variable : variables) { + auto arrayString = (variable->type() == AnalyserVariable::Type::STATE) ? + mProfile->ratesArrayString() : + mProfile->algebraicArrayString(); + + methodBody += mProfile->indentString() + + arrayString + mProfile->openArrayString() + convertToString(variable->index()) + mProfile->closeArrayString() + + mProfile->equalityString() + + mProfile->uArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString() + + mProfile->commandSeparatorString() + "\n"; + } + + methodBody += newLineIfNeeded(mCode); + + i = MAX_SIZE_T; + + methodBody += mProfile->indentString() + + mProfile->fArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString() + + mProfile->equalityString() + + generateCode(equation->ast()) + + mProfile->commandSeparatorString() + "\n"; + + handledNlaEquations.push_back(equation); + + for (const auto &nlaSibling : equation->nlaSiblings()) { + methodBody += mProfile->indentString() + + mProfile->fArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString() + + mProfile->equalityString() + + generateCode(nlaSibling->ast()) + + mProfile->commandSeparatorString() + "\n"; + + handledNlaEquations.push_back(nlaSibling); + } + + mCode += newLineIfNeeded(mCode) + + replace(replace(mProfile->objectiveFunctionMethodString(mModelHasOdes, mModel->hasExternalVariables()), + "[INDEX]", convertToString(equation->nlaSystemIndex())), + "[CODE]", generateMethodBodyCode(methodBody)); + + methodBody = {}; + + i = MAX_SIZE_T; + + for (const auto &variable : variables) { + auto arrayString = (variable->type() == AnalyserVariable::Type::STATE) ? + mProfile->ratesArrayString() : + mProfile->algebraicArrayString(); + + methodBody += mProfile->indentString() + + mProfile->uArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString() + + mProfile->equalityString() + + arrayString + mProfile->openArrayString() + convertToString(variable->index()) + mProfile->closeArrayString() + + mProfile->commandSeparatorString() + "\n"; + } + + auto variablesCount = variables.size(); + + methodBody += newLineIfNeeded(mCode) + + mProfile->indentString() + + replace(replace(mProfile->nlaSolveCallString(mModelHasOdes, mModel->hasExternalVariables()), + "[INDEX]", convertToString(equation->nlaSystemIndex())), + "[SIZE]", convertToString(variablesCount)); + + methodBody += newLineIfNeeded(mCode); + + i = MAX_SIZE_T; + + for (const auto &variable : variables) { + auto arrayString = (variable->type() == AnalyserVariable::Type::STATE) ? + mProfile->ratesArrayString() : + mProfile->algebraicArrayString(); + + methodBody += mProfile->indentString() + + arrayString + mProfile->openArrayString() + convertToString(variable->index()) + mProfile->closeArrayString() + + mProfile->equalityString() + + mProfile->uArrayString() + mProfile->openArrayString() + convertToString(++i) + mProfile->closeArrayString() + + mProfile->commandSeparatorString() + "\n"; + } + + mCode += newLineIfNeeded(mCode) + + replace(replace(replace(mProfile->findRootMethodString(mModelHasOdes, mModel->hasExternalVariables()), + "[INDEX]", convertToString(equation->nlaSystemIndex())), + "[SIZE]", convertToString(variablesCount)), + "[CODE]", generateMethodBodyCode(methodBody)); + } + } + } +} + +std::string GeneratorInterpreter::GeneratorInterpreterImpl::generateConstantInitialisationCode(const std::vector::iterator constant, + std::vector &remainingConstants) +{ + auto initialisingVariable = (*constant)->initialisingVariable(); + auto initialValue = initialisingVariable->initialValue(); + + if (!isCellMLReal(initialValue)) { + auto initialisingComponent = owningComponent(initialisingVariable); + auto crtConstant = std::find_if(remainingConstants.begin(), remainingConstants.end(), + [=](const AnalyserVariablePtr &av) -> bool { + return initialisingComponent->variable(initialValue) == av->variable(); + }); + + if (crtConstant != remainingConstants.end()) { + return generateConstantInitialisationCode(crtConstant, remainingConstants); + } + } + + auto code = generateInitialisationCode(*constant); + + remainingConstants.erase(constant); + + return code; +} + +void GeneratorInterpreter::GeneratorInterpreterImpl::initialiseVariables(std::vector &remainingEquations) +{ + auto implementationInitialiseVariablesMethodString = mProfile->implementationInitialiseVariablesMethodString(mModelHasOdes); + + if (!implementationInitialiseVariablesMethodString.empty()) { + // Initialise our states (after, if needed, initialising the constant on which it depends). + + std::string methodBody; + auto constants = mModel->constants(); + + for (const auto &state : mModel->states()) { + auto initialisingVariable = state->initialisingVariable(); + auto initialValue = initialisingVariable->initialValue(); + + if (!isCellMLReal(initialValue)) { + // The initial value references a constant. + + auto initialisingComponent = owningComponent(initialisingVariable); + auto constant = std::find_if(constants.begin(), constants.end(), + [=](const AnalyserVariablePtr &av) -> bool { + return initialisingComponent->variable(initialValue)->hasEquivalentVariable(av->variable()); + }); + + methodBody += generateConstantInitialisationCode(constant, constants); + } + + methodBody += generateInitialisationCode(state); + } + + // Use an initial guess of zero for rates computed using an NLA system (see the note below). + + for (const auto &state : mModel->states()) { + if (state->equation(0)->type() == AnalyserEquation::Type::NLA) { + methodBody += generateZeroInitialisationCode(state); + } + } + + // Initialise our (remaining) constants. + + while (!constants.empty()) { + methodBody += generateConstantInitialisationCode(constants.begin(), constants); + } + + // Initialise our computed constants that are initialised using an equation (e.g., x = 3 rather than x with an + // initial value of 3). + + auto equations = mModel->equations(); + + for (const auto &equation : equations) { + if (equation->type() == AnalyserEquation::Type::TRUE_CONSTANT) { + methodBody += generateEquationCode(equation, remainingEquations); + } + } + + // Initialise our algebraic variables that have an initial value. Also use an initial guess of zero for + // algebraic variables computed using an NLA system. + // Note: a variable which is the only unknown in an equation, but which is not on its own on either the LHS or + // RHS of that equation (e.g., x = y+z with x and y known and z unknown) is (currently) to be computed + // using an NLA system for which we need an initial guess. We use an initial guess of zero, which is fine + // since such an NLA system has only one solution. + + for (const auto &algebraic : mModel->algebraic()) { + if (algebraic->initialisingVariable() != nullptr) { + methodBody += generateInitialisationCode(algebraic); + } else if (algebraic->equation(0)->type() == AnalyserEquation::Type::NLA) { + methodBody += generateZeroInitialisationCode(algebraic); + } + } + + mCode += newLineIfNeeded(mCode) + + replace(implementationInitialiseVariablesMethodString, + "[CODE]", generateMethodBodyCode(methodBody)); + } +} + +void GeneratorInterpreter::GeneratorInterpreterImpl::computeComputedConstants(std::vector &remainingEquations) +{ + if (!mProfile->implementationComputeComputedConstantsMethodString().empty()) { + std::string methodBody; + + for (const auto &equation : mModel->equations()) { + if (equation->type() == AnalyserEquation::Type::VARIABLE_BASED_CONSTANT) { + methodBody += generateEquationCode(equation, remainingEquations); + } + } + + mCode += newLineIfNeeded(mCode) + + replace(mProfile->implementationComputeComputedConstantsMethodString(), + "[CODE]", generateMethodBodyCode(methodBody)); + } +} + +void GeneratorInterpreter::GeneratorInterpreterImpl::computeRates(std::vector &remainingEquations) +{ + auto implementationComputeRatesMethodString = mProfile->implementationComputeRatesMethodString(mModel->hasExternalVariables()); + + if (mModelHasOdes + && !implementationComputeRatesMethodString.empty()) { + std::string methodBody; + + for (const auto &equation : mModel->equations()) { + // A rate is computed either through an ODE equation or through an NLA equation in case the rate is not on + // its own on either the LHS or RHS of the equation. + + auto variables = libcellml::variables(equation); + + if ((equation->type() == AnalyserEquation::Type::ODE) + || ((equation->type() == AnalyserEquation::Type::NLA) + && (variables.size() == 1) + && (variables[0]->type() == AnalyserVariable::Type::STATE))) { + methodBody += generateEquationCode(equation, remainingEquations); + } + } + + mCode += newLineIfNeeded(mCode) + + replace(implementationComputeRatesMethodString, + "[CODE]", generateMethodBodyCode(methodBody)); + } +} + +void GeneratorInterpreter::GeneratorInterpreterImpl::computeVariables(std::vector &remainingEquations) +{ + auto implementationComputeVariablesMethodString = mProfile->implementationComputeVariablesMethodString(mModelHasOdes, + mModel->hasExternalVariables()); + + if (!implementationComputeVariablesMethodString.empty()) { + std::string methodBody; + auto equations = mModel->equations(); + std::vector newRemainingEquations {std::begin(equations), std::end(equations)}; + + for (const auto &equation : equations) { + if ((std::find(remainingEquations.begin(), remainingEquations.end(), equation) != remainingEquations.end()) + || isToBeComputedAgain(equation)) { + methodBody += generateEquationCode(equation, newRemainingEquations, remainingEquations, false); + } + } + + mCode += newLineIfNeeded(mCode) + + replace(implementationComputeVariablesMethodString, + "[CODE]", generateMethodBodyCode(methodBody)); + } +} + +GeneratorInterpreter::GeneratorInterpreter(const AnalyserEquationAstPtr &ast, const GeneratorProfilePtr &profile) + : mPimpl(new GeneratorInterpreterImpl(ast, profile)) +{ +} + +GeneratorInterpreter::~GeneratorInterpreter() +{ + delete mPimpl; +} + +GeneratorInterpreterPtr GeneratorInterpreter::create(const AnalyserEquationAstPtr &ast, + const GeneratorProfilePtr &profile) noexcept +{ + return std::shared_ptr {new GeneratorInterpreter {ast, profile}}; +} + +std::string GeneratorInterpreter::code() const +{ + return mPimpl->mCode; +} + +} // namespace libcellml diff --git a/src/generatorinterpreter.h b/src/generatorinterpreter.h new file mode 100644 index 000000000..f6aa33424 --- /dev/null +++ b/src/generatorinterpreter.h @@ -0,0 +1,73 @@ +/* +Copyright libCellML Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include "libcellml/analyserequation.h" + +namespace libcellml { + +class GeneratorInterpreter; /**< Forward declaration of GeneratorInterpreter class. */ +using GeneratorInterpreterPtr = std::shared_ptr; /**< Type definition for shared GeneratorInterpreter pointer. */ + +/** + * @brief The GeneratorInterpreter class. + * + * The GeneratorInterpreter class is used by the Generator and Interpreter classes to generate the code to compute a + * model. + */ +class GeneratorInterpreter +{ +public: + ~GeneratorInterpreter(); /**< Destructor, @private. */ + GeneratorInterpreter(const GeneratorInterpreter &rhs) = delete; /**< Copy constructor, @private. */ + GeneratorInterpreter(GeneratorInterpreter &&rhs) noexcept = delete; /**< Move constructor, @private. */ + GeneratorInterpreter &operator=(GeneratorInterpreter rhs) = delete; /**< Assignment operator, @private. */ + + /** + * @brief Create an @ref GeneratorInterpreter object. + * + * Factory method to create an @ref GeneratorInterpreter. Create a generator-interpreter with:: + * + * @code + * auto interpreterStatement = libcellml::GeneratorInterpreter::create(profile); + * @endcode + * + * @param ast The AST for which we want to generate some code. + * @param profile The profile to be used to generate some code. + * + * @return A smart pointer to an @ref GeneratorInterpreter object. + */ + static GeneratorInterpreterPtr create(const AnalyserEquationAstPtr &ast, + const GeneratorProfilePtr &profile) noexcept; + + /** + * @brief Get the code to compute the model. + * + * Get the @c std::string code to compute the model. + * + * @return The @c std::string code to compute the model. + */ + std::string code() const; + +private: + GeneratorInterpreter(const AnalyserEquationAstPtr &ast, const GeneratorProfilePtr &profile); /**< Constructor, @private. */ + + struct GeneratorInterpreterImpl; + GeneratorInterpreterImpl *mPimpl; /**< Private member to implementation pointer, @private. */ +}; + +} // namespace libcellml diff --git a/src/generatorinterpreter_p.h b/src/generatorinterpreter_p.h new file mode 100644 index 000000000..5efee23d2 --- /dev/null +++ b/src/generatorinterpreter_p.h @@ -0,0 +1,104 @@ +/* +Copyright libCellML Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include "generatorinterpreter.h" + +#include "libcellml/generatorprofile.h" + +namespace libcellml { + +bool modelHasOdes(const AnalyserModelPtr &model); +bool modelHasNlas(const AnalyserModelPtr &model); + +AnalyserVariablePtr analyserVariable(const AnalyserModelPtr &model, const VariablePtr &variable); + +std::string newLineIfNeeded(const std::string &code); + +/** + * @brief The GeneratorInterpreter::GeneratorInterpreterImpl struct. + * + * The private implementation for the GeneratorInterpreter class. + */ +struct GeneratorInterpreter::GeneratorInterpreterImpl +{ + AnalyserModelPtr mModel; + bool mModelHasOdes = false; + bool mModelHasNlas = false; + + GeneratorProfilePtr mProfile = GeneratorProfile::create(); + std::string mCode; + + explicit GeneratorInterpreterImpl(const AnalyserEquationAstPtr &ast, const GeneratorProfilePtr &profile); + + void initialise(const AnalyserModelPtr &model, const GeneratorProfilePtr &profile, const std::string &code); + + double scalingFactor(const VariablePtr &variable) const; + + bool isNegativeNumber(const AnalyserEquationAstPtr &ast) const; + + bool isRelationalOperator(const AnalyserEquationAstPtr &ast) const; + bool isAndOperator(const AnalyserEquationAstPtr &ast) const; + bool isOrOperator(const AnalyserEquationAstPtr &ast) const; + bool isXorOperator(const AnalyserEquationAstPtr &ast) const; + bool isLogicalOperator(const AnalyserEquationAstPtr &ast) const; + bool isPlusOperator(const AnalyserEquationAstPtr &ast) const; + bool isMinusOperator(const AnalyserEquationAstPtr &ast) const; + bool isTimesOperator(const AnalyserEquationAstPtr &ast) const; + bool isDivideOperator(const AnalyserEquationAstPtr &ast) const; + bool isPowerOperator(const AnalyserEquationAstPtr &ast) const; + bool isRootOperator(const AnalyserEquationAstPtr &ast) const; + bool isPiecewiseStatement(const AnalyserEquationAstPtr &ast) const; + + std::string generateMethodBodyCode(const std::string &methodBody) const; + + std::string generateVariableNameCode(const VariablePtr &variable, + bool rate = false) const; + + std::string generateOperatorCode(const std::string &op, const AnalyserEquationAstPtr &ast) const; + std::string generateMinusUnaryCode(const AnalyserEquationAstPtr &ast) const; + std::string generateOneParameterFunctionCode(const std::string &function, const AnalyserEquationAstPtr &ast) const; + std::string generateTwoParameterFunctionCode(const std::string &function, const AnalyserEquationAstPtr &ast) const; + std::string generatePiecewiseIfCode(const std::string &condition, + const std::string &value) const; + std::string generatePiecewiseElseCode(const std::string &value) const; + std::string generateCode(const AnalyserEquationAstPtr &ast) const; + + bool isToBeComputedAgain(const AnalyserEquationPtr &equation) const; + bool isSomeConstant(const AnalyserEquationPtr &equation, + bool includeComputedConstants) const; + + std::string generateZeroInitialisationCode(const AnalyserVariablePtr &variable); + std::string generateInitialisationCode(const AnalyserVariablePtr &variable); + std::string generateEquationCode(const AnalyserEquationPtr &equation, + std::vector &remainingEquations, + std::vector &equationsForDependencies, + bool includeComputedConstants); + std::string generateEquationCode(const AnalyserEquationPtr &equation, + std::vector &remainingEquations); + + void addNlaSystemsCode(); + + std::string generateConstantInitialisationCode(const std::vector::iterator constant, + std::vector &remainingConstants); + void initialiseVariables(std::vector &remainingEquations); + void computeComputedConstants(std::vector &remainingEquations); + void computeRates(std::vector &remainingEquations); + void computeVariables(std::vector &remainingEquations); +}; + +} // namespace libcellml