Skip to content

Commit

Permalink
nrnlinz.cpp: use eigen instead of sparse13
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Cornu committed Nov 30, 2023
1 parent 170a3ac commit f71b138
Showing 1 changed file with 43 additions and 44 deletions.
87 changes: 43 additions & 44 deletions src/nrniv/nonlinz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "cspmatrix.h"
#include "membfunc.h"

#include <Eigen/Sparse>

extern void v_setup_vectors();
extern int nrndae_extra_eqn_count();
extern Symlist* hoc_built_in_symlist;
Expand All @@ -29,14 +31,14 @@ class NonLinImpRep {
void dsds();
int gapsolve();

char* m_;
Eigen::SparseMatrix<std::complex<double>, Eigen::RowMajor> m_{};
int scnt_; // structure_change
int n_v_, n_ext_, n_lin_, n_ode_, neq_v_, neq_;
std::vector<neuron::container::data_handle<double>> pv_, pvdot_;
int* v_index_;
double* rv_;
double* jv_;
double** diag_;
std::vector<std::complex<double>*> diag_;
double* deltavec_; // just like cvode.atol*cvode.atolscale for ode's
double delta_; // slightly more efficient and easier for v.
void current(int, Memb_list*, int);
Expand Down Expand Up @@ -152,28 +154,14 @@ void NonLinImp::compute(double omega, double deltafac, int maxiter) {
rep_->omega_ = 1000. * omega;
rep_->delta(deltafac);
// fill matrix
cmplx_spClear(rep_->m_);
rep_->m_.setZero();
rep_->didv();
rep_->dsds();
#if 1 // when 0 equivalent to standard method
rep_->dids();
rep_->dsdv();
#endif

// cmplx_spPrint(rep_->m_, 0, 1, 1);
// for (int i=0; i < rep_->neq_; ++i) {
// printf("i=%d %g %g\n", i, rep_->diag_[i][0], rep_->diag_[i][1]);
// }
int e = cmplx_spFactor(rep_->m_);
switch (e) {
case spZERO_DIAG:
hoc_execerror("cmplx_spFactor error:", "Zero Diagonal");
case spNO_MEMORY:
hoc_execerror("cmplx_spFactor error:", "No Memory");
case spSINGULAR:
hoc_execerror("cmplx_spFactor error:", "Singular");
}

rep_->iloc_ = -2;
}

Expand All @@ -196,8 +184,20 @@ int NonLinImp::solve(int curloc) {
if (nrnthread_v_transfer_) {
rval = rep_->gapsolve();
} else {
assert(rep_->m_);
cmplx_spSolve(rep_->m_, rep_->rv_ - 1, rep_->rv_ - 1, rep_->jv_ - 1, rep_->jv_ - 1);
std::vector<std::complex<double>> rjv{};
rjv.reserve(rep_->neq_);
for (size_t i = 0; i < rep_->neq_; ++i) {
rjv.emplace_back(rep_->rv_[i], rep_->jv_[i]);
}
auto lu = Eigen::SparseLU<decltype(rep_->m_)>(rep_->m_);
lu.analyzePattern(rep_->m_);
lu.factorize(rep_->m_);
auto rjv_ = Eigen::Map<Eigen::Vector<std::complex<double>, Eigen::Dynamic>>(rjv.data(), rjv.size());
rjv_ = lu.solve(rjv_);
for (size_t i = 0; i < rep_->neq_; ++i) {
rep_->rv_[i] = rjv[i].real();
rep_->jv_[i] = rjv[i].imag();
}
}
}
return rval;
Expand All @@ -211,7 +211,6 @@ NonLinImpRep::NonLinImpRep() {
int i, j, ieq, cnt;
NrnThread* _nt = nrn_threads;
maxiter_ = 500;
m_ = NULL;

vsymtol_ = NULL;
Symbol* vsym = hoc_table_lookup("v", hoc_built_in_symlist);
Expand Down Expand Up @@ -245,16 +244,15 @@ NonLinImpRep::NonLinImpRep() {
if (neq_ == 0) {
return;
}
m_ = cmplx_spCreate(neq_, 1, &err);
assert(err == spOKAY);
m_ = Eigen::SparseMatrix<std::complex<double>, Eigen::RowMajor>(neq_, neq_);
pv_.resize(neq_);
pvdot_.resize(neq_);
v_index_ = new int[n_v_];
rv_ = new double[neq_ + 1];
rv_ += 1;
jv_ = new double[neq_ + 1];
jv_ += 1;
diag_ = new double*[neq_];
diag_.resize(neq_);
deltavec_ = new double[neq_];

for (i = 0; i < n_v_; ++i) {
Expand All @@ -265,23 +263,18 @@ NonLinImpRep::NonLinImpRep() {
v_index_[i] = i + 1;
}
for (i = 0; i < n_v_; ++i) {
diag_[i] = cmplx_spGetElement(m_, v_index_[i], v_index_[i]);
diag_[i] = &m_.coeffRef(v_index_[i], v_index_[i]);
}
for (i = neq_v_; i < neq_; ++i) {
diag_[i] = cmplx_spGetElement(m_, i + 1, i + 1);
diag_[i] = &m_.coeffRef(i, i);
}
scnt_ = structure_change_cnt;
}

NonLinImpRep::~NonLinImpRep() {
if (!m_) {
return;
}
cmplx_spDestroy(m_);
delete[] v_index_;
delete[](rv_ - 1);
delete[](jv_ - 1);
delete[] diag_;
delete[] deltavec_;
}

Expand Down Expand Up @@ -317,10 +310,8 @@ void NonLinImpRep::didv() {
for (i = _nt->ncell; i < n_v_; ++i) {
nd = _nt->_v_node[i];
ip = _nt->_v_parent[i]->v_node_index;
double* a = cmplx_spGetElement(m_, v_index_[ip], v_index_[i]);
double* b = cmplx_spGetElement(m_, v_index_[i], v_index_[ip]);
*a += NODEA(nd);
*b += NODEB(nd);
m_.coeffRef(v_index_[ip], v_index_[i]) += NODEA(nd);
m_.coeffRef(v_index_[i], v_index_[ip]) += NODEB(nd);
*diag_[i] -= NODEB(nd);
*diag_[ip] -= NODEA(nd);
}
Expand Down Expand Up @@ -412,9 +403,7 @@ void NonLinImpRep::dids() {
*pv_[is] = x1[is]; // restore s
double g = (NODERHS(nd) - x2[in]) / deltavec_[is];
if (g != 0.) {
double* elm =
cmplx_spGetElement(m_, v_index_[nd->v_node_index], is + 1);
elm[0] = -g;
m_.coeffRef(v_index_[nd->v_node_index], is + 1) = -g;
}
}
// don't know if this is necessary but make sure last
Expand Down Expand Up @@ -477,9 +466,7 @@ void NonLinImpRep::dsdv() {
for (is = ieq + in * cnt, iis = 0; iis < cnt; ++iis, ++is) {
double ds = (x2[is] - *pvdot_[is]) / delta_;
if (ds != 0.) {
double* elm =
cmplx_spGetElement(m_, is + 1, v_index_[nd->v_node_index]);
elm[0] = -ds;
m_.coeffRef(is + 1, v_index_[nd->v_node_index]) = -ds;
}
}
}
Expand All @@ -494,7 +481,7 @@ void NonLinImpRep::dsds() {
NrnThread* nt = nrn_threads;
// jw term
for (i = neq_v_; i < neq_; ++i) {
diag_[i][1] += omega_;
*diag_[i] += std::complex<double>(0, omega_);
}
ieq = neq_v_;
for (NrnThreadMembList* tml = nt->tml; tml; tml = tml->next) {
Expand Down Expand Up @@ -540,8 +527,7 @@ void NonLinImpRep::dsds() {
for (is = ieq + in * cnt, iis = 0; iis < cnt; ++iis, ++is) {
double ds = (*pvdot_[is] - x2[is]) / deltavec_[is];
if (ds != 0.) {
double* elm = cmplx_spGetElement(m_, is + 1, ks + 1);
elm[0] = -ds;
m_.coeffRef(is, ks) = -ds;
}
*pv_[ks] = x1[ks];
}
Expand Down Expand Up @@ -620,7 +606,20 @@ int NonLinImpRep::gapsolve() {

for (iter = 1; iter <= maxiter_; ++iter) {
if (neq_) {
cmplx_spSolve(m_, rb - 1, rx1 - 1, jb - 1, jx1 - 1);
std::vector<std::complex<double>> rjv{};
rjv.reserve(neq_);
for (size_t i = 0; i < neq_; ++i) {
rjv.emplace_back(rb[i], rx1[i]);
}
auto lu = Eigen::SparseLU<decltype(m_)>(m_);
lu.analyzePattern(m_);
lu.factorize(m_);
auto rjv_ = Eigen::Map<Eigen::Vector<std::complex<double>, Eigen::Dynamic>>(rjv.data(), rjv.size());
rjv_ = lu.solve(rjv_);
for (size_t i = 0; i < neq_; ++i) {
jb[i] = rjv[i].real();
jx1[i] = rjv[i].imag();
}
}

// if any change in x > tol, then do another iteration.
Expand Down

0 comments on commit f71b138

Please sign in to comment.