diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 617e1e6c6..13ef1e819 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -13,6 +13,7 @@ #include "utils/logger.hpp" #include "visitors/visitor_utils.hpp" #include +#include #include namespace pywrap = nmodl::pybind_wrappers; @@ -35,6 +36,25 @@ static void remove_conserve_statements(ast::StatementBlock& node) { } } +// remove units from CVODE block so sympy can parse it properly +static void remove_units(ast::BinaryExpression& node) { + // matches either an int or a float, followed by any (including zero) + // number of spaces, followed by an expression in parentheses, that only + // has letters of the alphabet + std::regex unit_pattern(R"((\d+\.?\d*|\.\d+)\s*\([a-zA-Z]+\))"); + auto rhs_string = to_nmodl(node.get_rhs()); + auto rhs_string_no_units = fmt::format("{} = {}", + to_nmodl(node.get_lhs()), + std::regex_replace(rhs_string, unit_pattern, "$1")); + logger->debug("CvodeVisitor :: removing units from statement {}", to_nmodl(node)); + logger->debug("CvodeVisitor :: result: {}", rhs_string_no_units); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(rhs_string_no_units)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); +} + static std::pair> parse_independent_var( std::shared_ptr node) { auto variable = std::make_pair(node->get_node_name(), std::optional()); @@ -152,7 +172,10 @@ class StiffVisitor: public CvodeHelperVisitor { program_symtab->insert(symbol); } + remove_units(node); + auto rhs = node.get_rhs(); + // all indexed variables (need special treatment in SymPy) auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; diff --git a/test/usecases/cvode/derivative.mod b/test/usecases/cvode/derivative.mod index d3715352f..2a8ba6ca6 100644 --- a/test/usecases/cvode/derivative.mod +++ b/test/usecases/cvode/derivative.mod @@ -2,6 +2,10 @@ NEURON { SUFFIX scalar } +UNITS { + (um) = (micron) +} + PARAMETER { freq = 10 a = 5 @@ -14,7 +18,7 @@ PARAMETER { k = 0.2 } -STATE {var1 var2 var3} +STATE {var1 var2 var3 var4} INITIAL { var1 = v1 @@ -34,4 +38,6 @@ DERIVATIVE equation { var2' = -var2 * a : logistic ODE var3' = r * var3 * (1 - var3 / k) + : ODE with some units + var4' = 1(um) * var4 + a * .1(um) + r * 1.(um) + 1.0 (um) }