From 06ce2d368b053cdd71cc5a3dc9227c32d7a0f20b Mon Sep 17 00:00:00 2001 From: Theodore Chang Date: Sat, 16 Nov 2024 18:56:15 +0100 Subject: [PATCH] Update `armadillo` version 14.2.0 --- .../armadillo/armadillo_bits/Base_bones.hpp | 8 +- .../armadillo_bits/CubeToMatOp_bones.hpp | 3 + .../armadillo/armadillo_bits/Cube_bones.hpp | 9 +- .../armadillo/armadillo_bits/Gen_bones.hpp | 3 + .../armadillo/armadillo_bits/Glue_bones.hpp | 3 + .../armadillo/armadillo_bits/Glue_meat.hpp | 13 + .../armadillo/armadillo_bits/Mat_bones.hpp | 11 +- Include/armadillo/armadillo_bits/Mat_meat.hpp | 13 + Include/armadillo/armadillo_bits/Op_bones.hpp | 3 + Include/armadillo/armadillo_bits/Op_meat.hpp | 13 + Include/armadillo/armadillo_bits/Proxy.hpp | 16 +- .../armadillo/armadillo_bits/SpBase_bones.hpp | 8 +- .../armadillo_bits/SpToDGlue_bones.hpp | 3 + .../armadillo_bits/SpToDOp_bones.hpp | 3 + .../armadillo/armadillo_bits/arma_forward.hpp | 2 + .../armadillo/armadillo_bits/arma_version.hpp | 6 +- .../armadillo/armadillo_bits/auxlib_bones.hpp | 36 +- .../armadillo/armadillo_bits/auxlib_meat.hpp | 777 ++++++++++++++++-- .../armadillo_bits/compiler_check.hpp | 6 +- .../armadillo_bits/compiler_setup.hpp | 8 +- Include/armadillo/armadillo_bits/config.hpp | 25 +- .../armadillo/armadillo_bits/config.hpp.cmake | 25 +- .../armadillo/armadillo_bits/def_lapack.hpp | 112 +++ .../armadillo_bits/diagview_bones.hpp | 3 + .../armadillo_bits/diagview_meat.hpp | 53 +- .../armadillo/armadillo_bits/eGlue_bones.hpp | 3 + .../armadillo/armadillo_bits/eGlue_meat.hpp | 11 + .../armadillo/armadillo_bits/eOp_bones.hpp | 3 + Include/armadillo/armadillo_bits/eOp_meat.hpp | 13 + Include/armadillo/armadillo_bits/fn_misc.hpp | 90 +- .../armadillo_bits/glue_solve_bones.hpp | 3 + .../armadillo_bits/glue_solve_meat.hpp | 93 ++- .../armadillo_bits/glue_times_meat.hpp | 14 +- .../armadillo_bits/gmm_diag_meat.hpp | 4 +- .../armadillo_bits/gmm_full_meat.hpp | 4 +- .../armadillo/armadillo_bits/mtGlue_bones.hpp | 3 + .../armadillo/armadillo_bits/mtGlue_meat.hpp | 13 + .../armadillo/armadillo_bits/mtOp_bones.hpp | 5 +- .../armadillo/armadillo_bits/mtOp_meat.hpp | 13 + .../armadillo/armadillo_bits/op_cond_meat.hpp | 11 +- .../armadillo_bits/op_expmat_meat.hpp | 18 +- .../armadillo_bits/op_inv_gen_meat.hpp | 48 +- .../armadillo_bits/op_inv_spd_meat.hpp | 8 +- .../armadillo_bits/op_log_det_meat.hpp | 2 +- .../armadillo_bits/op_logmat_meat.hpp | 6 +- .../armadillo/armadillo_bits/op_pinv_meat.hpp | 12 +- .../armadillo_bits/op_powmat_meat.hpp | 43 +- .../armadillo/armadillo_bits/op_rank_meat.hpp | 12 +- .../armadillo_bits/op_rcond_meat.hpp | 25 +- .../armadillo_bits/op_sqrtmat_meat.hpp | 6 +- .../armadillo_bits/operator_minus.hpp | 102 --- .../armadillo_bits/operator_plus.hpp | 103 --- .../armadillo_bits/subview_bones.hpp | 7 +- .../armadillo_bits/subview_elem1_bones.hpp | 3 + .../armadillo_bits/subview_elem1_meat.hpp | 82 +- .../armadillo_bits/subview_elem2_bones.hpp | 3 + .../armadillo_bits/subview_elem2_meat.hpp | 13 + .../armadillo/armadillo_bits/subview_meat.hpp | 13 + .../armadillo/armadillo_bits/sym_helper.hpp | 146 ++-- Include/armadillo/armadillo_bits/traits.hpp | 24 + .../armadillo_bits/translate_lapack.hpp | 144 ++++ .../armadillo/armadillo_bits/typedef_elem.hpp | 5 +- Include/armadillo/armadillo_bits/unwrap.hpp | 12 +- README.md | 2 +- 64 files changed, 1607 insertions(+), 677 deletions(-) diff --git a/Include/armadillo/armadillo_bits/Base_bones.hpp b/Include/armadillo/armadillo_bits/Base_bones.hpp index ac947856b..9c7e9ca58 100644 --- a/Include/armadillo/armadillo_bits/Base_bones.hpp +++ b/Include/armadillo/armadillo_bits/Base_bones.hpp @@ -127,11 +127,11 @@ struct Base arma_warn_unused inline elem_type min() const; arma_warn_unused inline elem_type max() const; - inline elem_type min(uword& index_of_min_val) const; - inline elem_type max(uword& index_of_max_val) const; + arma_frown("use .index_min() instead") inline elem_type min(uword& index_of_min_val) const; + arma_frown("use .index_max() instead") inline elem_type max(uword& index_of_max_val) const; - inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; - inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; + arma_deprecated inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; + arma_deprecated inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; arma_warn_unused inline uword index_min() const; arma_warn_unused inline uword index_max() const; diff --git a/Include/armadillo/armadillo_bits/CubeToMatOp_bones.hpp b/Include/armadillo/armadillo_bits/CubeToMatOp_bones.hpp index cd2ba5997..a53dd2d09 100644 --- a/Include/armadillo/armadillo_bits/CubeToMatOp_bones.hpp +++ b/Include/armadillo/armadillo_bits/CubeToMatOp_bones.hpp @@ -36,6 +36,9 @@ class CubeToMatOp : public Base< typename T1::elem_type, CubeToMatOp + constexpr bool is_alias(const Mat&) const { return false; } + static constexpr bool is_row = op_type::template traits::is_row; static constexpr bool is_col = op_type::template traits::is_col; static constexpr bool is_xvec = op_type::template traits::is_xvec; diff --git a/Include/armadillo/armadillo_bits/Cube_bones.hpp b/Include/armadillo/armadillo_bits/Cube_bones.hpp index 91449e185..6227ce9c4 100644 --- a/Include/armadillo/armadillo_bits/Cube_bones.hpp +++ b/Include/armadillo/armadillo_bits/Cube_bones.hpp @@ -387,12 +387,11 @@ class Cube : public BaseCube< eT, Cube > arma_warn_unused inline eT min() const; arma_warn_unused inline eT max() const; - inline eT min(uword& index_of_min_val) const; - inline eT max(uword& index_of_max_val) const; - - inline eT min(uword& row_of_min_val, uword& col_of_min_val, uword& slice_of_min_val) const; - inline eT max(uword& row_of_max_val, uword& col_of_max_val, uword& slice_of_max_val) const; + arma_frown("use .index_min() instead") inline eT min(uword& index_of_min_val) const; + arma_frown("use .index_max() instead") inline eT max(uword& index_of_max_val) const; + arma_deprecated inline eT min(uword& row_of_min_val, uword& col_of_min_val, uword& slice_of_min_val) const; + arma_deprecated inline eT max(uword& row_of_max_val, uword& col_of_max_val, uword& slice_of_max_val) const; arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; arma_cold inline bool save(const hdf5_name& spec, const file_type type = hdf5_binary) const; diff --git a/Include/armadillo/armadillo_bits/Gen_bones.hpp b/Include/armadillo/armadillo_bits/Gen_bones.hpp index 172e5b9c2..352bfcdf6 100644 --- a/Include/armadillo/armadillo_bits/Gen_bones.hpp +++ b/Include/armadillo/armadillo_bits/Gen_bones.hpp @@ -54,6 +54,9 @@ class Gen inline void apply_inplace_div (Mat& out) const; inline void apply(subview& out) const; + + template + constexpr bool is_alias(const Mat&) const { return false; } }; diff --git a/Include/armadillo/armadillo_bits/Glue_bones.hpp b/Include/armadillo/armadillo_bits/Glue_bones.hpp index 197ae7464..0b4c73f6b 100644 --- a/Include/armadillo/armadillo_bits/Glue_bones.hpp +++ b/Include/armadillo/armadillo_bits/Glue_bones.hpp @@ -56,6 +56,9 @@ class Glue inline Glue(const T1& in_A, const T2& in_B, const uword in_aux_uword); inline ~Glue(); + template + inline bool is_alias(const Mat& X) const; + const T1& A; //!< first operand; must be derived from Base const T2& B; //!< second operand; must be derived from Base uword aux_uword; //!< storage of auxiliary data, uword format diff --git a/Include/armadillo/armadillo_bits/Glue_meat.hpp b/Include/armadillo/armadillo_bits/Glue_meat.hpp index cf4cfc68f..66834a17d 100644 --- a/Include/armadillo/armadillo_bits/Glue_meat.hpp +++ b/Include/armadillo/armadillo_bits/Glue_meat.hpp @@ -53,4 +53,17 @@ Glue::~Glue() +template +template +inline +bool +Glue::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (A.is_alias(X) || B.is_alias(X)); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/Mat_bones.hpp b/Include/armadillo/armadillo_bits/Mat_bones.hpp index 079c9f1ac..69d92cf8d 100644 --- a/Include/armadillo/armadillo_bits/Mat_bones.hpp +++ b/Include/armadillo/armadillo_bits/Mat_bones.hpp @@ -541,11 +541,11 @@ class Mat : public Base< eT, Mat > arma_warn_unused inline eT min() const; arma_warn_unused inline eT max() const; - inline eT min(uword& index_of_min_val) const; - inline eT max(uword& index_of_max_val) const; + arma_frown("use .index_min() instead") inline eT min(uword& index_of_min_val) const; + arma_frown("use .index_max() instead") inline eT max(uword& index_of_max_val) const; - inline eT min(uword& row_of_min_val, uword& col_of_min_val) const; - inline eT max(uword& row_of_max_val, uword& col_of_max_val) const; + arma_deprecated inline eT min(uword& row_of_min_val, uword& col_of_min_val) const; + arma_deprecated inline eT max(uword& row_of_max_val, uword& col_of_max_val) const; arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; @@ -771,6 +771,9 @@ class Mat : public Base< eT, Mat > inline void steal_mem_col(Mat& X, const uword max_n_rows); + template + arma_inline bool is_alias(const Mat& X) const; //!< don't use this unless you're writing code internal to Armadillo + template class fixed; diff --git a/Include/armadillo/armadillo_bits/Mat_meat.hpp b/Include/armadillo/armadillo_bits/Mat_meat.hpp index 0f785d5c6..51524abdf 100644 --- a/Include/armadillo/armadillo_bits/Mat_meat.hpp +++ b/Include/armadillo/armadillo_bits/Mat_meat.hpp @@ -1324,6 +1324,19 @@ Mat::steal_mem_col(Mat& x, const uword max_n_rows) +template +template +arma_inline +bool +Mat::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (is_same_type::yes) && (void_ptr(this) == void_ptr(&X)); + } + + + //! construct a matrix from a given auxiliary array of eTs. //! if copy_aux_mem is true, new memory is allocated and the array is copied. //! if copy_aux_mem is false, the auxiliary array is used directly (without allocating memory and copying). diff --git a/Include/armadillo/armadillo_bits/Op_bones.hpp b/Include/armadillo/armadillo_bits/Op_bones.hpp index fa8c3efd8..7fc4088f2 100644 --- a/Include/armadillo/armadillo_bits/Op_bones.hpp +++ b/Include/armadillo/armadillo_bits/Op_bones.hpp @@ -58,6 +58,9 @@ class Op inline Op(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); inline ~Op(); + template + inline bool is_alias(const Mat& X) const; + arma_aligned const T1& m; //!< the operand; must be derived from Base arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 arma_aligned uword aux_uword_a; //!< auxiliary data, uword format diff --git a/Include/armadillo/armadillo_bits/Op_meat.hpp b/Include/armadillo/armadillo_bits/Op_meat.hpp index 66fbaba6b..879f56826 100644 --- a/Include/armadillo/armadillo_bits/Op_meat.hpp +++ b/Include/armadillo/armadillo_bits/Op_meat.hpp @@ -76,4 +76,17 @@ Op::~Op() +template +template +inline +bool +Op::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/Proxy.hpp b/Include/armadillo/armadillo_bits/Proxy.hpp index a51580ddc..441f68ab1 100644 --- a/Include/armadillo/armadillo_bits/Proxy.hpp +++ b/Include/armadillo/armadillo_bits/Proxy.hpp @@ -188,7 +188,7 @@ struct Proxy< Mat > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&Q) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } @@ -235,7 +235,7 @@ struct Proxy< Col > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&Q) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } @@ -282,7 +282,7 @@ struct Proxy< Row > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&Q) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } @@ -1013,7 +1013,7 @@ struct Proxy< subview > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } @@ -1060,7 +1060,7 @@ struct Proxy< subview_col > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } @@ -1109,7 +1109,7 @@ struct Proxy< subview_cols > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(sv.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return sv.check_overlap(X); } @@ -1156,7 +1156,7 @@ struct Proxy< subview_row > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } @@ -1304,7 +1304,7 @@ struct Proxy< diagview > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } diff --git a/Include/armadillo/armadillo_bits/SpBase_bones.hpp b/Include/armadillo/armadillo_bits/SpBase_bones.hpp index cc53e93b9..74ff3b463 100644 --- a/Include/armadillo/armadillo_bits/SpBase_bones.hpp +++ b/Include/armadillo/armadillo_bits/SpBase_bones.hpp @@ -76,11 +76,11 @@ struct SpBase arma_warn_unused inline elem_type min() const; arma_warn_unused inline elem_type max() const; - inline elem_type min(uword& index_of_min_val) const; - inline elem_type max(uword& index_of_max_val) const; + arma_frown("use .index_min() instead") inline elem_type min(uword& index_of_min_val) const; + arma_frown("use .index_max() instead") inline elem_type max(uword& index_of_max_val) const; - inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; - inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; + arma_deprecated inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; + arma_deprecated inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; arma_warn_unused inline uword index_min() const; arma_warn_unused inline uword index_max() const; diff --git a/Include/armadillo/armadillo_bits/SpToDGlue_bones.hpp b/Include/armadillo/armadillo_bits/SpToDGlue_bones.hpp index 158dd6b89..e21b31270 100644 --- a/Include/armadillo/armadillo_bits/SpToDGlue_bones.hpp +++ b/Include/armadillo/armadillo_bits/SpToDGlue_bones.hpp @@ -36,6 +36,9 @@ class SpToDGlue : public Base< typename T1::elem_type, SpToDGlue + constexpr bool is_alias(const Mat&) const { return false; } + const T1& A; //!< first operand; must be derived from Base or SpBase const T2& B; //!< second operand; must be derived from Base or SpBase }; diff --git a/Include/armadillo/armadillo_bits/SpToDOp_bones.hpp b/Include/armadillo/armadillo_bits/SpToDOp_bones.hpp index 44fa78950..2215c97a5 100644 --- a/Include/armadillo/armadillo_bits/SpToDOp_bones.hpp +++ b/Include/armadillo/armadillo_bits/SpToDOp_bones.hpp @@ -39,6 +39,9 @@ class SpToDOp : public Base< typename T1::elem_type, SpToDOp > inline SpToDOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); inline ~SpToDOp(); + template + constexpr bool is_alias(const Mat&) const { return false; } + arma_aligned const T1& m; //!< the operand; must be derived from SpBase arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 arma_aligned uword aux_uword_a; //!< auxiliary data, uword format diff --git a/Include/armadillo/armadillo_bits/arma_forward.hpp b/Include/armadillo/armadillo_bits/arma_forward.hpp index b35d64ef4..b56d4e4d0 100644 --- a/Include/armadillo/armadillo_bits/arma_forward.hpp +++ b/Include/armadillo/armadillo_bits/arma_forward.hpp @@ -91,6 +91,8 @@ class op_diagmat; class op_trimat; class op_vectorise_row; class op_vectorise_col; +class op_symmatu; +class op_symmatl; class op_row_as_mat; class op_col_as_mat; diff --git a/Include/armadillo/armadillo_bits/arma_version.hpp b/Include/armadillo/armadillo_bits/arma_version.hpp index 227bf826c..d47f0c8c4 100644 --- a/Include/armadillo/armadillo_bits/arma_version.hpp +++ b/Include/armadillo/armadillo_bits/arma_version.hpp @@ -22,9 +22,9 @@ #define ARMA_VERSION_MAJOR 14 -#define ARMA_VERSION_MINOR 0 -#define ARMA_VERSION_PATCH 3 -#define ARMA_VERSION_NAME "Stochastic Parrot" +#define ARMA_VERSION_MINOR 2 +#define ARMA_VERSION_PATCH 0 +#define ARMA_VERSION_NAME "Smooth Caffeine" diff --git a/Include/armadillo/armadillo_bits/auxlib_bones.hpp b/Include/armadillo/armadillo_bits/auxlib_bones.hpp index 63292dd9e..f68ba1f5f 100644 --- a/Include/armadillo/armadillo_bits/auxlib_bones.hpp +++ b/Include/armadillo/armadillo_bits/auxlib_bones.hpp @@ -43,6 +43,18 @@ class auxlib template inline static bool inv_tr_rcond(Mat& A, typename get_pod_type::result& out_rcond, const uword layout); + template + inline static bool inv_sym(Mat& A); + + template + inline static bool inv_sym(Mat< std::complex >& A); + + template + inline static bool inv_sym_rcond(Mat& A, eT& out_rcond); + + template + inline static bool inv_sym_rcond(Mat< std::complex >& A, T& out_rcond); + template inline static bool inv_sympd(Mat& A, bool& out_sympd_state); @@ -50,10 +62,10 @@ class auxlib inline static bool inv_sympd(Mat& out, const Mat& X); template - inline static bool inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond); + inline static bool inv_sympd_rcond(Mat& A, eT& out_rcond); template - inline static bool inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out_rcond); + inline static bool inv_sympd_rcond(Mat< std::complex >& A, T& out_rcond); // @@ -269,6 +281,20 @@ class auxlib // + template + inline static bool solve_sym_fast(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_sym_fast(Mat< std::complex >& out, Mat< std::complex >& A, const Base< std::complex, T1 >& B_expr); + + template + inline static bool solve_sym_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); + + template + inline static bool solve_sym_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr); + + // + template inline static bool solve_sympd_fast(Mat& out, Mat& A, const Base& B_expr); @@ -388,10 +414,10 @@ class auxlib inline static T rcond(Mat< std::complex >& A); template - inline static eT rcond_sympd(Mat& A, bool& calc_ok); + inline static eT rcond_sym(Mat& A); - template - inline static T rcond_sympd(Mat< std::complex >& A, bool& calc_ok); + template + inline static T rcond_sym(Mat< std::complex >& A); template inline static eT rcond_trimat(const Mat& A, const uword layout); diff --git a/Include/armadillo/armadillo_bits/auxlib_meat.hpp b/Include/armadillo/armadillo_bits/auxlib_meat.hpp index bc70fdadc..7546b641b 100644 --- a/Include/armadillo/armadillo_bits/auxlib_meat.hpp +++ b/Include/armadillo/armadillo_bits/auxlib_meat.hpp @@ -28,6 +28,10 @@ auxlib::inv(Mat& A) { arma_debug_sigprint(); + // NOTE: given a matrix with NaN values, lapack::getrf() and lapack::getri() do not necessarily fail, + // NOTE: and can produce matrices with NaN values. + // NOTE: we're not checking for non-finite values to avoid breaking existing user code. + if(A.is_empty()) { return true; } #if defined(ARMA_USE_LAPACK) @@ -46,7 +50,7 @@ auxlib::inv(Mat& A) if(info != 0) { return false; } - if(n > 16) + if(n > blas_int(podarray_prealloc_n_elem::val)) { eT work_query[2] = {}; blas_int lwork_query = -1; @@ -130,7 +134,7 @@ auxlib::inv_rcond(Mat& A, typename get_pod_type::result& out_rcond) out_rcond = auxlib::lu_rcond(A, norm_val); - if(n > 16) + if(n > blas_int(podarray_prealloc_n_elem::val)) { eT work_query[2] = {}; blas_int lwork_query = -1; @@ -242,6 +246,316 @@ auxlib::inv_tr_rcond(Mat& A, typename get_pod_type::result& out_rcond, c +template +inline +bool +auxlib::inv_sym(Mat& A) + { + arma_debug_sigprint(); + + if(A.is_empty()) { return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_conform_assert_blas_size(A); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + + podarray ipiv(A.n_rows); + + if(n > blas_int(podarray_prealloc_n_elem::val)) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sytri()"); + lapack::sytri(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &info); + + if(info != 0) { return false; } + + A = symmatl(A); + + return true; + } + #else + { + arma_ignore(A); + arma_stop_logic_error("inv_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_sym(Mat< std::complex >& A) + { + arma_debug_sigprint(); + + // NOTE: the function name is required for overloading, but is a misnomer: it processes complex hermitian matrices + + if(A.is_empty()) { return true; } + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_debug_print("auxlib::inv_sym(): redirecting to auxlib::inv() due to crippled LAPACK"); + + return auxlib::inv(A); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_conform_assert_blas_size(A); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + + podarray ipiv(A.n_rows); + + if(n > blas_int(podarray_prealloc_n_elem::val)) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hetri()"); + lapack::hetri(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &info); + + if(info != 0) { return false; } + + A = symmatl(A); + + return true; + } + #else + { + arma_ignore(A); + arma_stop_logic_error("inv_sym(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_sym_rcond(Mat& A, eT& out_rcond) + { + arma_debug_sigprint(); + + out_rcond = eT(0); + + if(A.is_empty()) { return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_conform_assert_blas_size(A); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::sycon() requirements + blas_int info = 0; + eT norm_val = eT(0); + eT tmp_rcond = eT(0); + + podarray ipiv(A.n_rows); + podarray iwork(A.n_rows); + + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::lansy()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &lda, work.memptr()); + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sycon()"); + lapack::sycon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &tmp_rcond, work.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + out_rcond = tmp_rcond; + + if(arma_isnan(out_rcond)) { return false; } + + arma_debug_print("lapack::sytri()"); + lapack::sytri(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &info); + + if(info != 0) { return false; } + + A = symmatl(A); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(out_rcond); + arma_stop_logic_error("inv_sym_rcond(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_sym_rcond(Mat< std::complex >& A, T& out_rcond) + { + arma_debug_sigprint(); + + // NOTE: the function name is required for overloading, but is a misnomer: it processes complex hermitian matrices + + out_rcond = T(0); + + if(A.is_empty()) { return true; } + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_debug_print("auxlib::inv_sym_rcond(): redirecting to auxlib::inv_rcond() due to crippled LAPACK"); + + return auxlib::inv_rcond(A, out_rcond); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename std::complex eT; + + arma_conform_assert_blas_size(A); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::hecon() requirements + blas_int info = 0; + T norm_val = T(0); + T tmp_rcond = T(0); + + podarray ipiv(A.n_rows); + podarray lanhe_work(A.n_rows); + + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::lanhe()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, lanhe_work.memptr()); + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hecon()"); + lapack::hecon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &tmp_rcond, work.memptr(), &info); + + if(info != 0) { return false; } + + out_rcond = tmp_rcond; + + if(arma_isnan(out_rcond)) { return false; } + + arma_debug_print("lapack::hetri()"); + lapack::hetri(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &info); + + if(info != 0) { return false; } + + A = symmatl(A); + + return true; + } + #else + { + arma_ignore(A); + arma_ignore(out_rcond); + arma_stop_logic_error("inv_sym_rcond(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + template inline bool @@ -310,12 +624,10 @@ auxlib::inv_sympd(Mat& out, const Mat& X) template inline bool -auxlib::inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond) +auxlib::inv_sympd_rcond(Mat& A, eT& out_rcond) { arma_debug_sigprint(); - out_sympd_state = false; - if(A.is_empty()) { return true; } #if defined(ARMA_USE_LAPACK) @@ -340,8 +652,6 @@ auxlib::inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond) if(info != 0) { out_rcond = eT(0); return false; } - out_sympd_state = true; - out_rcond = auxlib::lu_rcond_sympd(A, norm_val); if(arma_isnan(out_rcond)) { return false; } @@ -358,7 +668,6 @@ auxlib::inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond) #else { arma_ignore(A); - arma_ignore(out_sympd_state); arma_ignore(out_rcond); arma_stop_logic_error("inv_sympd_rcond(): use LAPACK must be enabled"); return false; @@ -371,18 +680,15 @@ auxlib::inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond) template inline bool -auxlib::inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out_rcond) +auxlib::inv_sympd_rcond(Mat< std::complex >& A, T& out_rcond) { arma_debug_sigprint(); - out_sympd_state = false; - if(A.is_empty()) { return true; } #if defined(ARMA_CRIPPLED_LAPACK) { arma_ignore(A); - arma_ignore(out_sympd_state); arma_ignore(out_rcond); return false; } @@ -406,8 +712,6 @@ auxlib::inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out if(info != 0) { out_rcond = T(0); return false; } - out_sympd_state = true; - out_rcond = auxlib::lu_rcond_sympd(A, norm_val); if(arma_isnan(out_rcond)) { return false; } @@ -424,7 +728,6 @@ auxlib::inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out #else { arma_ignore(A); - arma_ignore(out_sympd_state); arma_ignore(out_rcond); arma_stop_logic_error("inv_sympd_rcond(): use LAPACK must be enabled"); return false; @@ -4262,6 +4565,328 @@ auxlib::solve_square_refine(Mat< std::complex >& out, typ +template +inline +bool +auxlib::solve_sym_fast(Mat& out, Mat& A, const Base& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + arma_conform_assert_blas_size(A,out); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + + podarray ipiv(A.n_rows); + + if(n > blas_int(podarray_prealloc_n_elem::val)) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sytrs()"); + lapack::sytrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_sym_fast(Mat< std::complex >& out, Mat< std::complex >& A, const Base< std::complex, T1 >& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_debug_print("auxlib::solve_sym_fast(): redirecting to auxlib::solve_square_fast() due to crippled LAPACK"); + + return auxlib::solve_square_fast(out, A, B_expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef std::complex eT; + + arma_conform_assert_blas_size(A,out); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + + podarray ipiv(A.n_rows); + + if(n > blas_int(podarray_prealloc_n_elem::val)) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hetrs()"); + lapack::hetrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_sym_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_debug_print("auxlib::solve_sym_rcond(): redirecting to auxlib::solve_square_rcond() due to crippled LAPACK"); + + return auxlib::solve_square_rcond(out, out_rcond, A, B_expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + out_rcond = eT(0); + + arma_conform_assert_blas_size(A,out); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::sycon() requirements + blas_int info = 0; + eT norm_val = eT(0); + eT tmp_rcond = eT(0); + + podarray ipiv(A.n_rows); + podarray iwork(A.n_rows); + + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::lansy()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sytrs()"); + lapack::sytrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sycon()"); + lapack::sycon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &tmp_rcond, work.memptr(), iwork.memptr(), &info); + + out_rcond = tmp_rcond; + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_sym_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + out_rcond = T(0); + + arma_conform_assert_blas_size(A,out); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::hecon() requirements + blas_int info = 0; + T norm_val = T(0); + T tmp_rcond = T(0); + + podarray ipiv(A.n_rows); + podarray lanhe_work(A.n_rows); + + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::lanhe()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, lanhe_work.memptr()); + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hetrs()"); + lapack::hetrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hecon()"); + lapack::hecon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &tmp_rcond, work.memptr(), &info); + + out_rcond = tmp_rcond; + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + template inline bool @@ -6189,47 +6814,60 @@ auxlib::rcond(Mat< std::complex >& A) template inline eT -auxlib::rcond_sympd(Mat& A, bool& calc_ok) +auxlib::rcond_sym(Mat& A) { #if defined(ARMA_USE_LAPACK) { arma_conform_assert_blas_size(A); - calc_ok = false; + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::sycon() requirements + blas_int info = 0; + eT norm_val = eT(0); + eT out_rcond = eT(0); - char norm_id = '1'; - char uplo = 'L'; - blas_int n = blas_int(A.n_rows); // assuming square matrix - blas_int lda = blas_int(A.n_rows); - eT norm_val = eT(0); - eT rcond = eT(0); - blas_int info = blas_int(0); + podarray ipiv(A.n_rows); + podarray iwork(A.n_rows); - podarray work(3*A.n_rows); - podarray iwork( A.n_rows); + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return eT(0); } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); arma_debug_print("lapack::lansy()"); norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &lda, work.memptr()); - arma_debug_print("lapack::potrf()"); - lapack::potrf(&uplo, &n, A.memptr(), &lda, &info); + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); - if(info != blas_int(0)) { return eT(0); } + if(info != 0) { return eT(0); } - arma_debug_print("lapack::pocon()"); - lapack::pocon(&uplo, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), iwork.memptr(), &info); - - if(info != blas_int(0)) { return eT(0); } + arma_debug_print("lapack::sycon()"); + lapack::sycon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &out_rcond, work.memptr(), iwork.memptr(), &info); - calc_ok = true; + if(info != 0) { return eT(0); } - return rcond; + return out_rcond; } #else { arma_ignore(A); - calc_ok = false; - arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + arma_stop_logic_error("rcond_sym(): use of LAPACK must be enabled"); return eT(0); } #endif @@ -6240,13 +6878,13 @@ auxlib::rcond_sympd(Mat& A, bool& calc_ok) template inline T -auxlib::rcond_sympd(Mat< std::complex >& A, bool& calc_ok) +auxlib::rcond_sym(Mat< std::complex >& A) { + // NOTE: the function name is required for overloading, but is a misnomer: it processes complex hermitian matrices + #if defined(ARMA_CRIPPLED_LAPACK) { - arma_debug_print("auxlib::rcond_sympd(): redirecting to auxlib::rcond() due to crippled LAPACK"); - - calc_ok = true; + arma_debug_print("auxlib::rcond_sym(): redirecting to auxlib::rcond() due to crippled LAPACK"); return auxlib::rcond(A); } @@ -6256,41 +6894,54 @@ auxlib::rcond_sympd(Mat< std::complex >& A, bool& calc_ok) arma_conform_assert_blas_size(A); - calc_ok = false; + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::hecon() requirements + blas_int info = 0; + T norm_val = T(0); + T out_rcond = T(0); - char norm_id = '1'; - char uplo = 'L'; - blas_int n = blas_int(A.n_rows); // assuming square matrix - blas_int lda = blas_int(A.n_rows); - T norm_val = T(0); - T rcond = T(0); - blas_int info = blas_int(0); + podarray ipiv(A.n_rows); + podarray lanhe_work(A.n_rows); - podarray work(2*A.n_rows); - podarray< T> rwork( A.n_rows); + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return T(0); } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); arma_debug_print("lapack::lanhe()"); - norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, rwork.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, lanhe_work.memptr()); - arma_debug_print("lapack::potrf()"); - lapack::potrf(&uplo, &n, A.memptr(), &lda, &info); + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); - if(info != blas_int(0)) { return T(0); } + if(info != 0) { return T(0); } - arma_debug_print("lapack::cx_pocon()"); - lapack::cx_pocon(&uplo, &n, A.memptr(), &lda, &norm_val, &rcond, work.memptr(), rwork.memptr(), &info); + arma_debug_print("lapack::hecon()"); + lapack::hecon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &out_rcond, work.memptr(), &info); - if(info != blas_int(0)) { return T(0); } - - calc_ok = true; + if(info != 0) { return T(0); } - return rcond; + return out_rcond; } #else { arma_ignore(A); - calc_ok = false; - arma_stop_logic_error("rcond(): use of LAPACK must be enabled"); + arma_stop_logic_error("rcond_sym(): use of LAPACK must be enabled"); return T(0); } #endif diff --git a/Include/armadillo/armadillo_bits/compiler_check.hpp b/Include/armadillo/armadillo_bits/compiler_check.hpp index 18ea59a19..fcf651ab6 100644 --- a/Include/armadillo/armadillo_bits/compiler_check.hpp +++ b/Include/armadillo/armadillo_bits/compiler_check.hpp @@ -83,10 +83,8 @@ #endif -#if (!defined(ARMA_HAVE_CXX14)) - #if (!defined(ARMA_IGNORE_DEPRECATED_MARKER)) || defined(ARMA_DONT_IGNORE_DEPRECATED_MARKER) || defined(ARMA_DEBUG) - #pragma message ("INFO: support for C++11 is deprecated") - #endif +#if (!defined(ARMA_HAVE_CXX14)) && (!defined(ARMA_IGNORE_DEPRECATED_MARKER)) + #pragma message ("INFO: support for C++11 is deprecated; minimum recommended standard is C++14") #endif diff --git a/Include/armadillo/armadillo_bits/compiler_setup.hpp b/Include/armadillo/armadillo_bits/compiler_setup.hpp index 3978911c1..3176a50c4 100644 --- a/Include/armadillo/armadillo_bits/compiler_setup.hpp +++ b/Include/armadillo/armadillo_bits/compiler_setup.hpp @@ -168,6 +168,10 @@ // gcc 6.1 has proper C++14 support and fixes an OpenMP related bug: // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=57580 + #if (ARMA_GCC_VERSION < 80100) + #pragma message("INFO: support for GCC versions older than 8.1 is deprecated" + #endif + #define ARMA_GOOD_COMPILER #undef arma_hot @@ -310,7 +314,7 @@ #if defined(_MSC_VER) - #if (_MSC_VER < 1900) + #if (_MSC_VER < 1910) #error "*** newer compiler required ***" #endif @@ -477,7 +481,7 @@ // NOTE: option 'ARMA_IGNORE_DEPRECATED_MARKER' will be removed // NOTE: disabling deprecation messages is counter-productive -#if defined(ARMA_IGNORE_DEPRECATED_MARKER) && (!defined(ARMA_DONT_IGNORE_DEPRECATED_MARKER)) && (!defined(ARMA_DEBUG)) +#if defined(ARMA_IGNORE_DEPRECATED_MARKER) #undef arma_deprecated #define arma_deprecated diff --git a/Include/armadillo/armadillo_bits/config.hpp b/Include/armadillo/armadillo_bits/config.hpp index 21cad8791..fb1c51279 100644 --- a/Include/armadillo/armadillo_bits/config.hpp +++ b/Include/armadillo/armadillo_bits/config.hpp @@ -102,11 +102,8 @@ //// Uncomment the above line if your BLAS and LAPACK libraries have function names with a trailing underscore. //// Conversely, comment it out if the function names don't have a trailing underscore. -// #define ARMA_BLAS_LONG -//// Uncomment the above line if your BLAS and LAPACK libraries use "long" instead of "int" - // #define ARMA_BLAS_LONG_LONG -//// Uncomment the above line if your BLAS and LAPACK libraries use "long long" instead of "int" +//// Uncomment the above line if your BLAS and LAPACK libraries use 64 bit integers, ie. "long long" instead of "int" // #define ARMA_BLAS_NOEXCEPT //// Uncomment the above line if you require BLAS functions to have the 'noexcept' specification @@ -131,7 +128,7 @@ //// Uncomment the above line to use Intel MKL types for complex numbers. //// You will need to include appropriate MKL headers before the Armadillo header. //// You may also need to enable or disable the following options: -//// ARMA_BLAS_LONG, ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS +//// ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS #if !defined(ARMA_USE_OPENMP) // #define ARMA_USE_OPENMP @@ -306,6 +303,12 @@ #undef ARMA_64BIT_WORD #endif +// for compatibility with earlier versions of Armadillo +#if defined(ARMA_BLAS_LONG) || defined(ARMA_BLAS_LONG_LONG) + #undef ARMA_BLAS_64BIT_INT + #define ARMA_BLAS_64BIT_INT +#endif + #if defined(ARMA_DONT_OPTIMISE_BAND) || defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) #undef ARMA_OPTIMISE_BAND #endif @@ -330,6 +333,10 @@ #undef ARMA_CHECK_NONFINITE #endif +#if defined(ARMA_DONT_IGNORE_DEPRECATED_MARKER) + #undef ARMA_IGNORE_DEPRECATED_MARKER +#endif + #if defined(ARMA_NO_DEBUG) #undef ARMA_DEBUG #undef ARMA_EXTRA_DEBUG @@ -347,6 +354,8 @@ #undef ARMA_WARN_LEVEL #define ARMA_WARN_LEVEL 3 + + #undef ARMA_IGNORE_DEPRECATED_MARKER #endif #if defined(ARMA_DONT_PRINT_EXCEPTIONS) @@ -357,6 +366,12 @@ #undef ARMA_CRIPPLED_LAPACK #endif +#if defined(ARMA_CRIPPLED_LAPACK) + #if (!defined(ARMA_IGNORE_DEPRECATED_MARKER)) + #pragma message ("option ARMA_CRIPPLED_LAPACK is deprecated and will be removed") + #endif +#endif + // if Armadillo was installed on this system via CMake and ARMA_USE_WRAPPER is not defined, // ARMA_AUX_LIBS lists the libraries required by Armadillo on this system, and diff --git a/Include/armadillo/armadillo_bits/config.hpp.cmake b/Include/armadillo/armadillo_bits/config.hpp.cmake index 6d7e33de3..82603d1b3 100644 --- a/Include/armadillo/armadillo_bits/config.hpp.cmake +++ b/Include/armadillo/armadillo_bits/config.hpp.cmake @@ -102,11 +102,8 @@ //// Uncomment the above line if your BLAS and LAPACK libraries have function names with a trailing underscore. //// Conversely, comment it out if the function names don't have a trailing underscore. -// #define ARMA_BLAS_LONG -//// Uncomment the above line if your BLAS and LAPACK libraries use "long" instead of "int" - // #define ARMA_BLAS_LONG_LONG -//// Uncomment the above line if your BLAS and LAPACK libraries use "long long" instead of "int" +//// Uncomment the above line if your BLAS and LAPACK libraries use 64 bit integers, ie. "long long" instead of "int" // #define ARMA_BLAS_NOEXCEPT //// Uncomment the above line if you require BLAS functions to have the 'noexcept' specification @@ -131,7 +128,7 @@ //// Uncomment the above line to use Intel MKL types for complex numbers. //// You will need to include appropriate MKL headers before the Armadillo header. //// You may also need to enable or disable the following options: -//// ARMA_BLAS_LONG, ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS +//// ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS #if !defined(ARMA_USE_OPENMP) // #define ARMA_USE_OPENMP @@ -306,6 +303,12 @@ #undef ARMA_64BIT_WORD #endif +// for compatibility with earlier versions of Armadillo +#if defined(ARMA_BLAS_LONG) || defined(ARMA_BLAS_LONG_LONG) + #undef ARMA_BLAS_64BIT_INT + #define ARMA_BLAS_64BIT_INT +#endif + #if defined(ARMA_DONT_OPTIMISE_BAND) || defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) #undef ARMA_OPTIMISE_BAND #endif @@ -330,6 +333,10 @@ #undef ARMA_CHECK_NONFINITE #endif +#if defined(ARMA_DONT_IGNORE_DEPRECATED_MARKER) + #undef ARMA_IGNORE_DEPRECATED_MARKER +#endif + #if defined(ARMA_NO_DEBUG) #undef ARMA_DEBUG #undef ARMA_EXTRA_DEBUG @@ -347,6 +354,8 @@ #undef ARMA_WARN_LEVEL #define ARMA_WARN_LEVEL 3 + + #undef ARMA_IGNORE_DEPRECATED_MARKER #endif #if defined(ARMA_DONT_PRINT_EXCEPTIONS) @@ -357,6 +366,12 @@ #undef ARMA_CRIPPLED_LAPACK #endif +#if defined(ARMA_CRIPPLED_LAPACK) + #if (!defined(ARMA_IGNORE_DEPRECATED_MARKER)) + #pragma message ("option ARMA_CRIPPLED_LAPACK is deprecated and will be removed") + #endif +#endif + // if Armadillo was installed on this system via CMake and ARMA_USE_WRAPPER is not defined, // ARMA_AUX_LIBS lists the libraries required by Armadillo on this system, and diff --git a/Include/armadillo/armadillo_bits/def_lapack.hpp b/Include/armadillo/armadillo_bits/def_lapack.hpp index 00854ab09..ad4e50c2c 100644 --- a/Include/armadillo/armadillo_bits/def_lapack.hpp +++ b/Include/armadillo/armadillo_bits/def_lapack.hpp @@ -269,6 +269,30 @@ #define arma_cpstrf cpstrf #define arma_zpstrf zpstrf + #define arma_ssytrf ssytrf + #define arma_dsytrf dsytrf + + #define arma_chetrf chetrf + #define arma_zhetrf zhetrf + + #define arma_ssytrs ssytrs + #define arma_dsytrs dsytrs + + #define arma_chetrs chetrs + #define arma_zhetrs zhetrs + + #define arma_ssytri ssytri + #define arma_dsytri dsytri + + #define arma_chetri chetri + #define arma_zhetri zhetri + + #define arma_ssycon ssycon + #define arma_dsycon dsycon + + #define arma_checon checon + #define arma_zhecon zhecon + #else #define arma_sgetrf SGETRF @@ -505,6 +529,30 @@ #define arma_cpstrf CPSTRF #define arma_zpstrf ZPSTRF + #define arma_ssytrf SSYTRF + #define arma_dsytrf DSYTRF + + #define arma_chetrf CHETRF + #define arma_zhetrf ZHETRF + + #define arma_ssytrs SSYTRS + #define arma_dsytrs DSYTRS + + #define arma_chetrs CHETRS + #define arma_zhetrs ZHETRS + + #define arma_ssytri SSYTRI + #define arma_dsytri DSYTRI + + #define arma_chetri CHETRI + #define arma_zhetri ZHETRI + + #define arma_ssycon SSYCON + #define arma_dsycon DSYCON + + #define arma_checon CHECON + #define arma_zhecon ZHECON + #endif @@ -846,6 +894,38 @@ extern "C" void arma_fortran(arma_cpstrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; void arma_fortran(arma_zpstrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + // factorisation of symmetric matrix (real) + void arma_fortran(arma_ssytrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // factorisation of hermitian matrix (complex) + void arma_fortran(arma_chetrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zhetrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // solve system using pre-computed factorisation (real) + void arma_fortran(arma_ssytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // solve system using pre-computed factorisation (complex) + void arma_fortran(arma_chetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zhetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // inverse of symmetric matrix using pre-computed factorisation (real) + void arma_fortran(arma_ssytri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // inverse of hermitian matrix using pre-computed factorisation (complex) + void arma_fortran(arma_chetri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zhetri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // rcond of symmetric matrix using pre-computed factorisation (real) + void arma_fortran(arma_ssycon)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsycon)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // rcond of hermitian matrix using pre-computed factorisation (complex) + void arma_fortran(arma_checon)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zhecon)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + #else // prototypes without hidden arguments @@ -1170,6 +1250,38 @@ extern "C" void arma_fortran(arma_cpstrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info) ARMA_NOEXCEPT; void arma_fortran(arma_zpstrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info) ARMA_NOEXCEPT; + // factorisation of symmetric matrix (real) + void arma_fortran(arma_ssytrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // factorisation of hermitian matrix (complex) + void arma_fortran(arma_chetrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zhetrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // solve system using pre-computed factorisation (real) + void arma_fortran(arma_ssytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system using pre-computed factorisation (complex) + void arma_fortran(arma_zhetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_chetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // inverse of symmetric matrix using pre-computed factorisation (real) + void arma_fortran(arma_ssytri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, blas_int* info) ARMA_NOEXCEPT; + + // inverse of hermitian matrix using pre-computed factorisation (complex) + void arma_fortran(arma_chetri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zhetri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, blas_int* info) ARMA_NOEXCEPT; + + // rcond of symmetric matrix using pre-computed factorisation (real) + void arma_fortran(arma_ssycon)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsycon)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + + // rcond of hermitian matrix using pre-computed factorisation (complex) + void arma_fortran(arma_checon)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zhecon)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, blas_int* info) ARMA_NOEXCEPT; + #endif } diff --git a/Include/armadillo/armadillo_bits/diagview_bones.hpp b/Include/armadillo/armadillo_bits/diagview_bones.hpp index 5aa4bcee7..4117aa45e 100644 --- a/Include/armadillo/armadillo_bits/diagview_bones.hpp +++ b/Include/armadillo/armadillo_bits/diagview_bones.hpp @@ -108,6 +108,9 @@ class diagview : public Base< eT, diagview > inline static void schur_inplace(Mat& out, const diagview& in); inline static void div_inplace(Mat& out, const diagview& in); + template + inline bool is_alias(const Mat& X) const; + friend class Mat; friend class subview; diff --git a/Include/armadillo/armadillo_bits/diagview_meat.hpp b/Include/armadillo/armadillo_bits/diagview_meat.hpp index c8d741ab7..a33b2ecb2 100644 --- a/Include/armadillo/armadillo_bits/diagview_meat.hpp +++ b/Include/armadillo/armadillo_bits/diagview_meat.hpp @@ -236,13 +236,13 @@ diagview::operator= (const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -309,13 +309,13 @@ diagview::operator+=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -382,13 +382,13 @@ diagview::operator-=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -455,13 +455,13 @@ diagview::operator%=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -528,13 +528,13 @@ diagview::operator/=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -1022,4 +1022,17 @@ diagview::randn() +template +template +inline +bool +diagview::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/eGlue_bones.hpp b/Include/armadillo/armadillo_bits/eGlue_bones.hpp index 097dc6cb8..a86377a98 100644 --- a/Include/armadillo/armadillo_bits/eGlue_bones.hpp +++ b/Include/armadillo/armadillo_bits/eGlue_bones.hpp @@ -44,6 +44,9 @@ class eGlue : public Base< typename T1::elem_type, eGlue > arma_inline ~eGlue(); arma_inline eGlue(const T1& in_A, const T2& in_B); + template + inline bool is_alias(const Mat& X) const; + arma_inline uword get_n_rows() const; arma_inline uword get_n_cols() const; arma_inline uword get_n_elem() const; diff --git a/Include/armadillo/armadillo_bits/eGlue_meat.hpp b/Include/armadillo/armadillo_bits/eGlue_meat.hpp index 04eb6ba23..4d55bc788 100644 --- a/Include/armadillo/armadillo_bits/eGlue_meat.hpp +++ b/Include/armadillo/armadillo_bits/eGlue_meat.hpp @@ -49,6 +49,17 @@ eGlue::eGlue(const T1& in_A, const T2& in_B) +template +template +inline +bool +eGlue::is_alias(const Mat& X) const + { + return (P1.is_alias(X) || P2.is_alias(X)); + } + + + template arma_inline uword diff --git a/Include/armadillo/armadillo_bits/eOp_bones.hpp b/Include/armadillo/armadillo_bits/eOp_bones.hpp index d32abddbd..a200a6bb6 100644 --- a/Include/armadillo/armadillo_bits/eOp_bones.hpp +++ b/Include/armadillo/armadillo_bits/eOp_bones.hpp @@ -50,6 +50,9 @@ class eOp : public Base< typename T1::elem_type, eOp > inline eOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); inline eOp(const T1& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b); + template + inline bool is_alias(const Mat& X) const; + arma_inline uword get_n_rows() const; arma_inline uword get_n_cols() const; arma_inline uword get_n_elem() const; diff --git a/Include/armadillo/armadillo_bits/eOp_meat.hpp b/Include/armadillo/armadillo_bits/eOp_meat.hpp index 75dfec022..57473fe78 100644 --- a/Include/armadillo/armadillo_bits/eOp_meat.hpp +++ b/Include/armadillo/armadillo_bits/eOp_meat.hpp @@ -74,7 +74,20 @@ eOp::~eOp() arma_debug_sigprint(); } + + +template +template +inline +bool +eOp::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + return P.is_alias(X); + } + + template arma_inline diff --git a/Include/armadillo/armadillo_bits/fn_misc.hpp b/Include/armadillo/armadillo_bits/fn_misc.hpp index d441b26da..52d85ac25 100644 --- a/Include/armadillo/armadillo_bits/fn_misc.hpp +++ b/Include/armadillo/armadillo_bits/fn_misc.hpp @@ -147,46 +147,6 @@ logspace(const double A, const double B, const uword N = 50u) -// -// log_exp_add - -template -arma_warn_unused -inline -typename arma_real_only::result -log_add_exp(eT log_a, eT log_b) - { - if(log_a < log_b) - { - std::swap(log_a, log_b); - } - - const eT negdelta = log_b - log_a; - - if( (negdelta < Datum::log_min) || (arma_isfinite(negdelta) == false) ) - { - return log_a; - } - else - { - return (log_a + std::log1p(std::exp(negdelta))); - } - } - - - -// for compatibility with earlier versions -template -arma_warn_unused -inline -typename arma_real_only::result -log_add(eT log_a, eT log_b) - { - return log_add_exp(log_a, log_b); - } - - - //! kept for compatibility with old user code template arma_warn_unused @@ -584,4 +544,54 @@ affmul(const T1& A, const T2& B) +namespace priv + { + // internal use only + template + arma_warn_unused + inline + typename arma_real_only::result + internal_log_add_exp(eT log_a, eT log_b) + { + if(log_a < log_b) { std::swap(log_a, log_b); } + + const eT negdelta = log_b - log_a; + + if( (negdelta < Datum::log_min) || (arma_isfinite(negdelta) == false) ) + { + return log_a; + } + else + { + return (log_a + std::log1p(std::exp(negdelta))); + } + } + } + + + +// DO NOT USE; kept only for compatibility with old user code +template +arma_deprecated +inline +typename arma_real_only::result +log_add_exp(eT log_a, eT log_b) + { + return priv::internal_log_add_exp(log_a, log_b); + } + + + +// DO NOT USE; kept only for compatibility with old user code +template +arma_deprecated +inline +typename arma_real_only::result +log_add(eT log_a, eT log_b) + { + return priv::internal_log_add_exp(log_a, log_b); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/glue_solve_bones.hpp b/Include/armadillo/armadillo_bits/glue_solve_bones.hpp index 20c016591..c04b17a6b 100644 --- a/Include/armadillo/armadillo_bits/glue_solve_bones.hpp +++ b/Include/armadillo/armadillo_bits/glue_solve_bones.hpp @@ -140,6 +140,7 @@ namespace solve_opts static constexpr uword flag_refine = uword(1u << 9); static constexpr uword flag_no_trimat = uword(1u << 10); static constexpr uword flag_force_approx = uword(1u << 11); + static constexpr uword flag_force_sym = uword(1u << 12); struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; struct opts_fast : public opts { inline constexpr opts_fast() : opts(flag_fast ) {} }; @@ -154,6 +155,7 @@ namespace solve_opts struct opts_refine : public opts { inline constexpr opts_refine() : opts(flag_refine ) {} }; struct opts_no_trimat : public opts { inline constexpr opts_no_trimat() : opts(flag_no_trimat ) {} }; struct opts_force_approx : public opts { inline constexpr opts_force_approx() : opts(flag_force_approx) {} }; + struct opts_force_sym : public opts { inline constexpr opts_force_sym() : opts(flag_force_sym ) {} }; static constexpr opts_none none; static constexpr opts_fast fast; @@ -168,6 +170,7 @@ namespace solve_opts static constexpr opts_refine refine; static constexpr opts_no_trimat no_trimat; static constexpr opts_force_approx force_approx; + static constexpr opts_force_sym force_sym; } diff --git a/Include/armadillo/armadillo_bits/glue_solve_meat.hpp b/Include/armadillo/armadillo_bits/glue_solve_meat.hpp index f50bcc401..6aeb67298 100644 --- a/Include/armadillo/armadillo_bits/glue_solve_meat.hpp +++ b/Include/armadillo/armadillo_bits/glue_solve_meat.hpp @@ -99,6 +99,7 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const const bool refine = has_user_flags && bool(flags & solve_opts::flag_refine ); const bool no_trimat = has_user_flags && bool(flags & solve_opts::flag_no_trimat ); const bool force_approx = has_user_flags && bool(flags & solve_opts::flag_force_approx); + const bool force_sym = has_user_flags && bool(flags & solve_opts::flag_force_sym ); if(has_user_flags) { @@ -114,10 +115,11 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const if(refine ) { arma_debug_print("refine"); } if(no_trimat ) { arma_debug_print("no_trimat"); } if(force_approx) { arma_debug_print("force_approx"); } + if(force_sym ) { arma_debug_print("force_sym"); } - arma_conform_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" ); - arma_conform_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" ); - arma_conform_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" ); + arma_conform_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" ); + arma_conform_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" ); + arma_conform_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" ); } Mat A = A_expr.get_ref(); @@ -128,26 +130,34 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const arma_conform_check( no_approx, "solve(): options 'no_approx' and 'force_approx' are mutually exclusive" ); - if(fast) { arma_warn(2, "solve(): option 'fast' ignored for forced approximate solution" ); } - if(equilibrate) { arma_warn(2, "solve(): option 'equilibrate' ignored for forced approximate solution" ); } - if(refine) { arma_warn(2, "solve(): option 'refine' ignored for forced approximate solution" ); } - if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for forced approximate solution" ); } + if(fast) { arma_warn(2, "solve(): option 'fast' ignored for forced approximate solution" ); } + if(equilibrate) { arma_warn(2, "solve(): option 'equilibrate' ignored for forced approximate solution" ); } + if(refine) { arma_warn(2, "solve(): option 'refine' ignored for forced approximate solution" ); } + if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for forced approximate solution" ); } + if(force_sym) { arma_warn(2, "solve(): option 'force_sym' ignored for forced approximate solution" ); } return auxlib::solve_approx_svd(actual_out, A, B_expr.get_ref()); // A is overwritten } + if(force_sym) + { + if((arma_config::check_conform) && (auxlib::rudimentary_sym_check(A) == false)) + { + if(is_cx::no ) { arma_warn(1, "solve(): option 'force_sym' enabled, but given matrix is not symmetric"); } + if(is_cx::yes) { arma_warn(1, "solve(): option 'force_sym' enabled, but given matrix is not hermitian"); } + } + + if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for forced symmetric solver" ); } + if(equilibrate) { arma_warn(2, "solve(): option 'force_sym' ignored as option 'equilibrate' is enabled (combination not implemented yet)" ); } + if(refine) { arma_warn(2, "solve(): option 'force_sym' ignored as option 'refine' is enabled (combination not implemented yet)" ); } + } + // A_expr and B_expr can be used more than once (sympd optimisation fails or approximate solution required), // so ensure they are not overwritten in case we have aliasing - bool is_alias = true; // assume we have aliasing until we can prove otherwise + const bool is_alias = A_expr.get_ref().is_alias(actual_out) || B_expr.get_ref().is_alias(actual_out); - if(is_Mat::value && is_Mat::value) - { - const quasi_unwrap UA( A_expr.get_ref() ); - const quasi_unwrap UB( B_expr.get_ref() ); - - is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); - } + if(is_alias) { arma_debug_print("glue_solve_gen_full::apply(): aliasing detected"); } Mat tmp; Mat& out = (is_alias) ? tmp : actual_out; @@ -162,12 +172,13 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const uword KL = 0; uword KU = 0; - const bool is_band = arma_config::optimise_band && ((no_band || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32))); + const bool is_band = arma_config::optimise_band && ( (no_band || force_sym || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32)) ); - const bool is_triu = (no_trimat || refine || equilibrate || likely_sympd || is_band ) ? false : trimat_helper::is_triu(A); - const bool is_tril = (no_trimat || refine || equilibrate || likely_sympd || is_band || is_triu) ? false : trimat_helper::is_tril(A); + const bool is_triu = (no_trimat || refine || equilibrate || likely_sympd || force_sym || is_band ) ? false : trimat_helper::is_triu(A); + const bool is_tril = (no_trimat || refine || equilibrate || likely_sympd || force_sym || is_band || is_triu) ? false : trimat_helper::is_tril(A); - const bool try_sympd = arma_config::optimise_sym && ((no_sympd || auxlib::crippled_lapack(A) || is_band || is_triu || is_tril) ? false : (likely_sympd ? true : sym_helper::guess_sympd(A, uword(16)))); + const bool is_sym = arma_config::optimise_sym && ( (refine || equilibrate || likely_sympd || force_sym || is_band || is_triu || is_tril || auxlib::crippled_lapack(A)) ? false : is_sym_expr::eval(A_expr.get_ref()) ); + const bool try_sympd = arma_config::optimise_sym && ( ( no_sympd || is_sym || force_sym || is_band || is_triu || is_tril || auxlib::crippled_lapack(A)) ? false : (likely_sympd ? true : sym_helper::guess_sympd(A, uword(16))) ); if(fast) { @@ -201,6 +212,13 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const status = auxlib::solve_trimat_fast(out, A, B_expr.get_ref(), layout); } else + if(force_sym || is_sym) + { + arma_debug_print("glue_solve_gen_full::apply(): fast + sym"); + + status = auxlib::solve_sym_fast(out, A, B_expr.get_ref()); // A is overwritten + } + else if(try_sympd) { arma_debug_print("glue_solve_gen_full::apply(): fast + try_sympd"); @@ -238,6 +256,10 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const status = auxlib::solve_band_refine(out, rcond, A, KL, KU, B_expr, equilibrate); } + // else + // if(force_sym || is_sym) // TODO: implement auxlib::solve_sym_refine() + // { + // } else if(try_sympd) { @@ -287,6 +309,13 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); } else + if(force_sym || is_sym) + { + arma_debug_print("glue_solve_gen_full::apply(): rcond + sym"); + + status = auxlib::solve_sym_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten + } + else if(try_sympd) { bool sympd_state = false; @@ -315,6 +344,7 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const if(equilibrate) { arma_warn(2, "solve(): option 'equilibrate' ignored for non-square matrix" ); } if(refine) { arma_warn(2, "solve(): option 'refine' ignored for non-square matrix" ); } if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for non-square matrix" ); } + if(force_sym) { arma_warn(2, "solve(): option 'force_sym' ignored for non-square matrix" ); } if(fast) { @@ -406,14 +436,9 @@ glue_solve_tri_default::apply(Mat& actual_out, const Base& A_expr, co const uword layout = (triu) ? uword(0) : uword(1); - bool is_alias = true; + const bool is_alias = A_expr.get_ref().is_alias(actual_out) || B_expr.get_ref().is_alias(actual_out); - if(is_Mat::value) - { - const quasi_unwrap UB(B_expr.get_ref()); - - is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); - } + if(is_alias) { arma_debug_print("glue_solve_tri_default::apply(): aliasing detected"); } T rcond = T(0); bool status = false; @@ -497,6 +522,7 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const const bool refine = bool(flags & solve_opts::flag_refine ); const bool no_trimat = bool(flags & solve_opts::flag_no_trimat ); const bool force_approx = bool(flags & solve_opts::flag_force_approx); + const bool force_sym = bool(flags & solve_opts::flag_force_sym ); arma_debug_print("glue_solve_tri_full::apply(): enabled flags:"); @@ -510,6 +536,10 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const if(refine ) { arma_debug_print("refine"); } if(no_trimat ) { arma_debug_print("no_trimat"); } if(force_approx) { arma_debug_print("force_approx"); } + if(force_sym ) { arma_debug_print("force_sym"); } + + arma_conform_check( (likely_sympd), "solve(): option 'likely_sympd' not applicable to triangular matrix" ); + arma_conform_check( (force_sym ), "solve(): option 'force_sym' not applicable to triangular matrix" ); if(no_trimat || equilibrate || refine || force_approx) { @@ -518,8 +548,6 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const return glue_solve_gen_full::apply(actual_out, ((triu) ? trimatu(A_expr.get_ref()) : trimatl(A_expr.get_ref())), B_expr, (flags & mask)); } - if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for triangular matrix"); } - const quasi_unwrap UA(A_expr.get_ref()); const Mat& A = UA.M; @@ -527,14 +555,9 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const const uword layout = (triu) ? uword(0) : uword(1); - bool is_alias = true; + const bool is_alias = A_expr.get_ref().is_alias(actual_out) || B_expr.get_ref().is_alias(actual_out); - if(is_Mat::value) - { - const quasi_unwrap UB(B_expr.get_ref()); - - is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); - } + if(is_alias) { arma_debug_print("glue_solve_tri_full::apply(): aliasing detected"); } T rcond = T(0); bool status = false; diff --git a/Include/armadillo/armadillo_bits/glue_times_meat.hpp b/Include/armadillo/armadillo_bits/glue_times_meat.hpp index 9b4020faa..92d695996 100644 --- a/Include/armadillo/armadillo_bits/glue_times_meat.hpp +++ b/Include/armadillo/armadillo_bits/glue_times_meat.hpp @@ -119,7 +119,9 @@ glue_times_redirect2_helper::apply(Mat& out, const arma_conform_assert_mul_size(A, B, "matrix multiplication"); - const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, B) : auxlib::solve_square_fast(out, A, B); + const bool is_sym = (strip_inv::do_inv_spd) ? false : ( arma_config::optimise_sym && (auxlib::crippled_lapack(A) == false) && (is_sym_expr::eval(X.A) || sym_helper::is_approx_sym(A, uword(100))) ); + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, B) : ( (is_sym) ? auxlib::solve_sym_fast(out, A, B) : auxlib::solve_square_fast(out, A, B) ); if(status == false) { @@ -278,7 +280,9 @@ glue_times_redirect3_helper::apply(Mat& out, const if(is_cx::yes) { arma_warn(1, "inv_sympd(): given matrix is not hermitian"); } } - const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC); + const bool is_sym = (strip_inv::do_inv_spd) ? false : ( arma_config::optimise_sym && (auxlib::crippled_lapack(A) == false) && (is_sym_expr::eval(X.A.A) || sym_helper::is_approx_sym(A, uword(100))) ); + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, BC) : ( (is_sym) ? auxlib::solve_sym_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC) ); if(status == false) { @@ -315,7 +319,9 @@ glue_times_redirect3_helper::apply(Mat& out, const Mat solve_result; - const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C); + const bool is_sym = (strip_inv::do_inv_spd) ? false : ( arma_config::optimise_sym && (auxlib::crippled_lapack(B) == false) && (is_sym_expr::eval(X.A.B) || sym_helper::is_approx_sym(B, uword(100))) ); + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(solve_result, B, C) : ( (is_sym) ? auxlib::solve_sym_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C) ); if(status == false) { @@ -552,7 +558,7 @@ glue_times::apply_inplace_plus(Mat& out, const Glue::do_times || partial_unwrap_check::do_times || (sign < sword(0)); - const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0); + const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0); arma_conform_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); diff --git a/Include/armadillo/armadillo_bits/gmm_diag_meat.hpp b/Include/armadillo/armadillo_bits/gmm_diag_meat.hpp index 185f892a3..37048c6d4 100644 --- a/Include/armadillo/armadillo_bits/gmm_diag_meat.hpp +++ b/Include/armadillo/armadillo_bits/gmm_diag_meat.hpp @@ -1070,7 +1070,7 @@ gmm_diag::internal_scalar_log_p(const eT* x) const { const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g]; - log_sum = log_add_exp(log_sum, tmp); + log_sum = priv::internal_log_add_exp(log_sum, tmp); } return log_sum; @@ -2556,7 +2556,7 @@ gmm_diag::em_generate_acc for(uword g=1; g < N_gaus; ++g) { - log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); + log_lhood_sum = priv::internal_log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); } progress_log_lhood += log_lhood_sum; diff --git a/Include/armadillo/armadillo_bits/gmm_full_meat.hpp b/Include/armadillo/armadillo_bits/gmm_full_meat.hpp index 6e9d1f7d1..b81c823c4 100644 --- a/Include/armadillo/armadillo_bits/gmm_full_meat.hpp +++ b/Include/armadillo/armadillo_bits/gmm_full_meat.hpp @@ -1087,7 +1087,7 @@ gmm_full::internal_scalar_log_p(const eT* x) const { const eT log_val = internal_scalar_log_p(x, g) + log_hefts_mem[g]; - log_sum = log_add_exp(log_sum, log_val); + log_sum = priv::internal_log_add_exp(log_sum, log_val); } return log_sum; @@ -2618,7 +2618,7 @@ gmm_full::em_generate_acc for(uword g=1; g < N_gaus; ++g) { - log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); + log_lhood_sum = priv::internal_log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); } progress_log_lhood += log_lhood_sum; diff --git a/Include/armadillo/armadillo_bits/mtGlue_bones.hpp b/Include/armadillo/armadillo_bits/mtGlue_bones.hpp index 5937d89f7..fc15a6f92 100644 --- a/Include/armadillo/armadillo_bits/mtGlue_bones.hpp +++ b/Include/armadillo/armadillo_bits/mtGlue_bones.hpp @@ -37,6 +37,9 @@ class mtGlue : public Base< out_eT, mtGlue > arma_inline mtGlue(const T1& in_A, const T2& in_B, const uword in_aux_uword); arma_inline ~mtGlue(); + template + inline bool is_alias(const Mat& X) const; + arma_aligned const T1& A; //!< first operand; must be derived from Base arma_aligned const T2& B; //!< second operand; must be derived from Base arma_aligned uword aux_uword; //!< storage of auxiliary data, uword format diff --git a/Include/armadillo/armadillo_bits/mtGlue_meat.hpp b/Include/armadillo/armadillo_bits/mtGlue_meat.hpp index 85cc9a219..4a64840c8 100644 --- a/Include/armadillo/armadillo_bits/mtGlue_meat.hpp +++ b/Include/armadillo/armadillo_bits/mtGlue_meat.hpp @@ -53,4 +53,17 @@ mtGlue::~mtGlue() +template +template +inline +bool +mtGlue::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (A.is_alias(X) || B.is_alias(X)); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/mtOp_bones.hpp b/Include/armadillo/armadillo_bits/mtOp_bones.hpp index ff0e4c38f..87e74573b 100644 --- a/Include/armadillo/armadillo_bits/mtOp_bones.hpp +++ b/Include/armadillo/armadillo_bits/mtOp_bones.hpp @@ -47,7 +47,10 @@ class mtOp : public Base< out_eT, mtOp > inline mtOp(const mtOp_dual_aux_indicator&, const T1& in_m, const in_eT in_aux_a, const out_eT in_aux_b); inline ~mtOp(); - + + template + inline bool is_alias(const Mat& X) const; + arma_aligned const T1& m; //!< the operand; must be derived from Base arma_aligned in_eT aux; //!< auxiliary data, using the element type as used by T1 diff --git a/Include/armadillo/armadillo_bits/mtOp_meat.hpp b/Include/armadillo/armadillo_bits/mtOp_meat.hpp index 623660c4d..f53816dfa 100644 --- a/Include/armadillo/armadillo_bits/mtOp_meat.hpp +++ b/Include/armadillo/armadillo_bits/mtOp_meat.hpp @@ -101,4 +101,17 @@ mtOp::~mtOp() +template +template +inline +bool +mtOp::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/op_cond_meat.hpp b/Include/armadillo/armadillo_bits/op_cond_meat.hpp index f0ac58298..be1ef0788 100644 --- a/Include/armadillo/armadillo_bits/op_cond_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_cond_meat.hpp @@ -37,19 +37,12 @@ op_cond::apply(const Base& X) if(is_op_diagmat::value || A.is_diagmat()) { - arma_debug_print("op_cond::apply(): detected diagonal matrix"); + arma_debug_print("op_cond::apply(): diag optimisation"); return op_cond::apply_diag(A); } - bool is_approx_sym = false; - bool is_approx_sympd = false; - - sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); - - const bool do_sym = (is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd); - - if(do_sym) + if(sym_helper::is_approx_sym(A)) { arma_debug_print("op_cond: symmetric/hermitian optimisation"); diff --git a/Include/armadillo/armadillo_bits/op_expmat_meat.hpp b/Include/armadillo/armadillo_bits/op_expmat_meat.hpp index 7dd076b86..7278cabca 100644 --- a/Include/armadillo/armadillo_bits/op_expmat_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_expmat_meat.hpp @@ -75,7 +75,7 @@ op_expmat::apply_direct(Mat& out, const Base& out, const Base::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd)); - } - - if(do_sym) + if( (arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) && sym_helper::is_approx_sym(A) ) { arma_debug_print("op_expmat: symmetric/hermitian optimisation"); @@ -207,7 +195,7 @@ op_expmat_sym::apply_direct(Mat& out, const Base::value || X.is_diagmat()) { - arma_debug_print("op_expmat_sym: detected diagonal matrix"); + arma_debug_print("op_expmat_sym: diag optimisation"); out = X; diff --git a/Include/armadillo/armadillo_bits/op_inv_gen_meat.hpp b/Include/armadillo/armadillo_bits/op_inv_gen_meat.hpp index c08ed33e1..b127eeffe 100644 --- a/Include/armadillo/armadillo_bits/op_inv_gen_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_inv_gen_meat.hpp @@ -185,7 +185,7 @@ op_inv_gen_full::apply_direct(Mat& out, const Base::value || out.is_diagmat()) { - arma_debug_print("op_inv_gen_full: detected diagonal matrix"); + arma_debug_print("op_inv_gen_full: diag optimisation"); eT* colmem = out.memptr(); @@ -216,28 +216,16 @@ op_inv_gen_full::apply_direct(Mat& out, const Base::eval(expr.get_ref()) || sym_helper::is_approx_sym(out, uword(100)) ) ) { - arma_debug_print("op_inv_gen_full: attempting sympd optimisation"); - - Mat tmp = out; - - bool sympd_state = false; - - const bool status = auxlib::inv_sympd(tmp, sympd_state); + arma_debug_print("op_inv_gen_full: symmetric/hermitian optimisation"); - if(status) { out.steal_mem(tmp); return true; } - - if((status == false) && (sympd_state == true)) { return false; } - - arma_debug_print("op_inv_gen_full: sympd optimisation failed"); - - // fallthrough if optimisation failed + return auxlib::inv_sym(out); } return auxlib::inv(out); @@ -348,7 +336,7 @@ op_inv_gen_rcond::apply_direct(Mat& out, op_inv_gen_stat if(is_op_diagmat::value || out.is_diagmat()) { - arma_debug_print("op_inv_gen_rcond: detected diagonal matrix"); + arma_debug_print("op_inv_gen_rcond: diag optimisation"); out_state.is_diag = true; @@ -394,30 +382,18 @@ op_inv_gen_rcond::apply_direct(Mat& out, op_inv_gen_stat if(is_triu_expr || is_tril_expr || is_triu_mat || is_tril_mat) { + arma_debug_print("op_inv_gen_rcond: tri optimisation"); + return auxlib::inv_tr_rcond(out, out_state.rcond, ((is_triu_expr || is_triu_mat) ? uword(0) : uword(1))); } - const bool try_sympd = arma_config::optimise_sym && ((auxlib::crippled_lapack(out)) ? false : sym_helper::guess_sympd(out)); - - if(try_sympd) + if( (arma_config::optimise_sym) && (auxlib::crippled_lapack(out) == false) && ( is_sym_expr::eval(expr.get_ref()) || sym_helper::is_approx_sym(out, uword(100)) ) ) { - arma_debug_print("op_inv_gen_rcond: attempting sympd optimisation"); + arma_debug_print("op_inv_gen_rcond: symmetric/hermitian optimisation"); out_state.is_sym = true; - Mat tmp = out; - - bool sympd_state = false; - - const bool status = auxlib::inv_sympd_rcond(tmp, sympd_state, out_state.rcond); - - if(status) { out.steal_mem(tmp); return true; } - - if((status == false) && (sympd_state == true)) { return false; } - - arma_debug_print("op_inv_gen_rcond: sympd optimisation failed"); - - // fallthrough if optimisation failed + return auxlib::inv_sym_rcond(out, out_state.rcond); } return auxlib::inv_rcond(out, out_state.rcond); diff --git a/Include/armadillo/armadillo_bits/op_inv_spd_meat.hpp b/Include/armadillo/armadillo_bits/op_inv_spd_meat.hpp index 13994e38a..73ed5c791 100644 --- a/Include/armadillo/armadillo_bits/op_inv_spd_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_inv_spd_meat.hpp @@ -191,7 +191,7 @@ op_inv_spd_full::apply_direct(Mat& out, const Base::value || out.is_diagmat()) { - arma_debug_print("op_inv_spd_full: detected diagonal matrix"); + arma_debug_print("op_inv_spd_full: diag optimisation"); eT* colmem = out.memptr(); @@ -300,7 +300,7 @@ op_inv_spd_rcond::apply_direct(Mat& out, op_inv_spd_stat if(is_op_diagmat::value || out.is_diagmat()) { - arma_debug_print("op_inv_spd_rcond: detected diagonal matrix"); + arma_debug_print("op_inv_spd_rcond: diag optimisation"); out_state.is_diag = true; @@ -355,9 +355,7 @@ op_inv_spd_rcond::apply_direct(Mat& out, op_inv_spd_stat return true; } - bool is_sympd_junk = false; - - return auxlib::inv_sympd_rcond(out, is_sympd_junk, out_state.rcond); + return auxlib::inv_sympd_rcond(out, out_state.rcond); } diff --git a/Include/armadillo/armadillo_bits/op_log_det_meat.hpp b/Include/armadillo/armadillo_bits/op_log_det_meat.hpp index 5601597d6..39ef514de 100644 --- a/Include/armadillo/armadillo_bits/op_log_det_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_log_det_meat.hpp @@ -202,7 +202,7 @@ op_log_det_sympd::apply_direct(typename T1::pod_type& out_val, const Base::value || A.is_diagmat()) { - arma_debug_print("op_log_det_sympd: detected diagonal matrix"); + arma_debug_print("op_log_det_sympd: diag optimisation"); eT* colmem = A.memptr(); diff --git a/Include/armadillo/armadillo_bits/op_logmat_meat.hpp b/Include/armadillo/armadillo_bits/op_logmat_meat.hpp index 671366021..43cd9ac53 100644 --- a/Include/armadillo/armadillo_bits/op_logmat_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_logmat_meat.hpp @@ -111,7 +111,7 @@ op_logmat::apply_direct(Mat< std::complex >& out, const if(A.is_diagmat()) { - arma_debug_print("op_logmat: detected diagonal matrix"); + arma_debug_print("op_logmat: diag optimisation"); const uword N = A.n_rows; @@ -292,7 +292,7 @@ op_logmat_cx::apply_direct(Mat& out, const Base& out, const Base::value || X.is_diagmat()) { - arma_debug_print("op_logmat_sympd: detected diagonal matrix"); + arma_debug_print("op_logmat_sympd: diag optimisation"); out = X; diff --git a/Include/armadillo/armadillo_bits/op_pinv_meat.hpp b/Include/armadillo/armadillo_bits/op_pinv_meat.hpp index 9012bc3a0..f88ca05bc 100644 --- a/Include/armadillo/armadillo_bits/op_pinv_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_pinv_meat.hpp @@ -106,7 +106,7 @@ op_pinv::apply_direct(Mat& out, const Base::value || A.is_diagmat()) { - arma_debug_print("op_pinv: detected diagonal matrix"); + arma_debug_print("op_pinv: diag optimisation"); return op_pinv::apply_diag(out, A, tol); } @@ -119,15 +119,7 @@ op_pinv::apply_direct(Mat& out, const Base::eval(expr.get_ref()); - if(do_sym == false) - { - bool is_approx_sym = false; - bool is_approx_sympd = false; - - sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); - - do_sym = ((is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd)); - } + if(do_sym == false) { do_sym = sym_helper::is_approx_sym(A); } } if(do_sym) diff --git a/Include/armadillo/armadillo_bits/op_powmat_meat.hpp b/Include/armadillo/armadillo_bits/op_powmat_meat.hpp index bbd104cf4..c95b68602 100644 --- a/Include/armadillo/armadillo_bits/op_powmat_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_powmat_meat.hpp @@ -96,7 +96,7 @@ op_powmat::apply_direct_positive(Mat& out, const Mat& X, const uword y) if(X.is_diagmat()) { - arma_debug_print("op_powmat: detected diagonal matrix"); + arma_debug_print("op_powmat: diag optimisation"); podarray tmp(N); // use temporary array in case we have aliasing @@ -194,11 +194,11 @@ op_powmat_cx::apply_direct(Mat< std::complex >& out, cons if(A.is_diagmat()) { - arma_debug_print("op_powmat_cx: detected diagonal matrix"); + arma_debug_print("op_powmat_cx: diag optimisation"); podarray tmp(N); // use temporary array in case we have aliasing - for(uword i=0; i(A.at(i,i)), y) ; } + for(uword i=0; i(A.at(i,i)), y ); } out.zeros(N,N); @@ -207,11 +207,11 @@ op_powmat_cx::apply_direct(Mat< std::complex >& out, cons return true; } - const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); + const bool try_sym = arma_config::optimise_sym && sym_helper::is_approx_sym(A); - if(try_sympd) + if(try_sym) { - arma_debug_print("op_powmat_cx: attempting sympd optimisation"); + arma_debug_print("op_powmat_cx: symmetric/hermitian optimisation"); Col eigval; Mat eigvec; @@ -220,16 +220,39 @@ op_powmat_cx::apply_direct(Mat< std::complex >& out, cons if(eig_status) { - eigval = pow(eigval, y); + bool all_pos = true; - const Mat tmp = diagmat(eigval) * eigvec.t(); + for(uword i=0; i >::from(eigvec * tmp); + if(all_pos) + { + arma_debug_print("op_powmat_cx: all_pos = true"); + + eigval = pow(eigval, y); + + const Mat tmp = eigvec * diagmat(eigval); + + out = conv_to< Mat >::from(tmp * eigvec.t()); + } + else + { + arma_debug_print("op_powmat_cx: all_pos = false"); + + Col cx_eigval_pow(N, arma_nozeros_indicator()); + + for(uword i=0; i(eigval[i]), y ); } + + const Mat cx_eigvec = conv_to< Mat >::from(eigvec); + + const Mat tmp = cx_eigvec * diagmat(cx_eigval_pow); + + out = tmp * cx_eigvec.t(); + } return true; } - arma_debug_print("op_powmat_cx: sympd optimisation failed"); + arma_debug_print("op_powmat_cx: symmetric/hermitian optimisation failed"); // fallthrough if optimisation failed } diff --git a/Include/armadillo/armadillo_bits/op_rank_meat.hpp b/Include/armadillo/armadillo_bits/op_rank_meat.hpp index 0a32c6b2c..c7840e575 100644 --- a/Include/armadillo/armadillo_bits/op_rank_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_rank_meat.hpp @@ -37,7 +37,7 @@ op_rank::apply(uword& out, const Base& expr, const ty if(is_op_diagmat::value || A.is_diagmat()) { - arma_debug_print("op_rank::apply(): detected diagonal matrix"); + arma_debug_print("op_rank::apply(): diag optimisation"); return op_rank::apply_diag(out, A, tol); } @@ -50,15 +50,7 @@ op_rank::apply(uword& out, const Base& expr, const ty { do_sym = is_sym_expr::eval(expr.get_ref()); - if(do_sym == false) - { - bool is_approx_sym = false; - bool is_approx_sympd = false; - - sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); - - do_sym = (is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd); - } + if(do_sym == false) { do_sym = sym_helper::is_approx_sym(A); } } if(do_sym) diff --git a/Include/armadillo/armadillo_bits/op_rcond_meat.hpp b/Include/armadillo/armadillo_bits/op_rcond_meat.hpp index ae3aeaa9f..e994e6922 100644 --- a/Include/armadillo/armadillo_bits/op_rcond_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_rcond_meat.hpp @@ -33,6 +33,8 @@ op_rcond::apply(const Base& X) if(strip_trimat::do_trimat) { + arma_debug_print("op_rcond::apply(): tri optimisation"); + const strip_trimat S(X.get_ref()); const quasi_unwrap::stored_type> U(S.M); @@ -52,7 +54,7 @@ op_rcond::apply(const Base& X) if(is_op_diagmat::value || A.is_diagmat()) { - arma_debug_print("op_rcond::apply(): detected diagonal matrix"); + arma_debug_print("op_rcond::apply(): diag optimisation"); const eT* colmem = A.memptr(); const uword N = A.n_rows; @@ -80,29 +82,18 @@ op_rcond::apply(const Base& X) if(is_triu || is_tril) { + arma_debug_print("op_rcond::apply(): tri optimisation"); + const uword layout = (is_triu) ? uword(0) : uword(1); return auxlib::rcond_trimat(A, layout); } - const bool try_sympd = arma_config::optimise_sym && (auxlib::crippled_lapack(A) ? false : sym_helper::guess_sympd(A)); - - if(try_sympd) + if( (arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) && ( is_sym_expr::eval(X.get_ref()) || sym_helper::is_approx_sym(A, uword(100)) ) ) { - arma_debug_print("op_rcond::apply(): attempting sympd optimisation"); - - bool calc_ok = false; - - const T out_val = auxlib::rcond_sympd(A, calc_ok); - - if(calc_ok) { return out_val; } - - arma_debug_print("op_rcond::apply(): sympd optimisation failed"); + arma_debug_print("op_rcond::apply(): symmetric/hermitian optimisation"); - // auxlib::rcond_sympd() may have failed because A isn't really sympd - // restore A, as auxlib::rcond_sympd() may have destroyed it - A = X.get_ref(); - // fallthrough to the next return statement + return auxlib::rcond_sym(A); } return auxlib::rcond(A); diff --git a/Include/armadillo/armadillo_bits/op_sqrtmat_meat.hpp b/Include/armadillo/armadillo_bits/op_sqrtmat_meat.hpp index 44017455e..3c2f07d4d 100644 --- a/Include/armadillo/armadillo_bits/op_sqrtmat_meat.hpp +++ b/Include/armadillo/armadillo_bits/op_sqrtmat_meat.hpp @@ -116,7 +116,7 @@ op_sqrtmat::apply_direct(Mat< std::complex >& out, const if(A.is_diagmat()) { - arma_debug_print("op_sqrtmat: detected diagonal matrix"); + arma_debug_print("op_sqrtmat: diag optimisation"); const uword N = A.n_rows; @@ -325,7 +325,7 @@ op_sqrtmat_cx::apply_direct(Mat& out, const Base& out, const Base::value || X.is_diagmat()) { - arma_debug_print("op_sqrtmat_sympd: detected diagonal matrix"); + arma_debug_print("op_sqrtmat_sympd: diag optimisation"); out = X; diff --git a/Include/armadillo/armadillo_bits/operator_minus.hpp b/Include/armadillo/armadillo_bits/operator_minus.hpp index 3cc1bba2f..f8ddf4f43 100644 --- a/Include/armadillo/armadillo_bits/operator_minus.hpp +++ b/Include/armadillo/armadillo_bits/operator_minus.hpp @@ -438,108 +438,6 @@ operator- -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - (is_same_type::value || - is_same_type::value)), - const SpToDOp - >::result -operator- - ( - const SpToDOp& x, - const typename T1::elem_type k - ) - { - arma_debug_sigprint(); - - const typename T1::elem_type aux = (is_same_type::value) ? -x.aux : x.aux; - - return SpToDOp(x.m, aux + k); - } - - - -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - (is_same_type::value || - is_same_type::value)), - const SpToDOp - >::result -operator- - ( - const typename T1::elem_type k, - const SpToDOp& x - ) - { - arma_debug_sigprint(); - - const typename T1::elem_type aux = (is_same_type::value) ? -x.aux : x.aux; - - return SpToDOp(x.m, k + aux); - } - - - -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - is_same_type::value), - const SpToDOp - >::result -operator- - ( - const SpToDOp& x, - const typename T1::elem_type k - ) - { - arma_debug_sigprint(); - - return SpToDOp(x.m, x.aux - k); - } - - - -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - is_same_type::value), - const SpToDOp - >::result -operator- - ( - const typename T1::elem_type k, - const SpToDOp& x - ) - { - arma_debug_sigprint(); - - return SpToDOp(x.m, k - x.aux); - } - - - template arma_inline Mat diff --git a/Include/armadillo/armadillo_bits/operator_plus.hpp b/Include/armadillo/armadillo_bits/operator_plus.hpp index fd0603d06..ddf6417de 100644 --- a/Include/armadillo/armadillo_bits/operator_plus.hpp +++ b/Include/armadillo/armadillo_bits/operator_plus.hpp @@ -370,109 +370,6 @@ operator+ -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - (is_same_type::value || - is_same_type::value)), - const SpToDOp - >::result -operator+ - ( - const SpToDOp& x, - const typename T1::elem_type k - ) - { - arma_debug_sigprint(); - - const typename T1::elem_type aux = (is_same_type::value) ? x.aux : -x.aux; - - return SpToDOp(x.m, aux + k); - } - - - -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - is_same_type::value), - const SpToDOp - >::result -operator+ - ( - const SpToDOp& x, - const typename T1::elem_type k - ) - { - arma_debug_sigprint(); - - return SpToDOp(x.m, x.aux + k); - } - - - -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - (is_same_type::value || - is_same_type::value)), - const SpToDOp - >::result -operator+ - ( - const typename T1::elem_type k, - const SpToDOp& x - ) - { - arma_debug_sigprint(); - - const typename T1::elem_type aux = (is_same_type::value) ? x.aux : -x.aux; - - return SpToDOp(x.m, aux + k); - } - - - -// TODO: this is an uncommon use case; remove? -//! multiple applications of add/subtract scalars can be condensed -template -inline -typename -enable_if2 - < - (is_arma_sparse_type::value && - is_same_type::value), - const SpToDOp - >::result -operator+ - ( - const typename T1::elem_type k, - const SpToDOp& x - ) - { - arma_debug_sigprint(); - - return SpToDOp(x.m, x.aux + k); - } - - - - template arma_inline Mat diff --git a/Include/armadillo/armadillo_bits/subview_bones.hpp b/Include/armadillo/armadillo_bits/subview_bones.hpp index 95553a10e..53b4a4211 100644 --- a/Include/armadillo/armadillo_bits/subview_bones.hpp +++ b/Include/armadillo/armadillo_bits/subview_bones.hpp @@ -201,6 +201,9 @@ class subview : public Base< eT, subview > inline void swap_rows(const uword in_row1, const uword in_row2); inline void swap_cols(const uword in_col1, const uword in_col2); + template + inline bool is_alias(const Mat& X) const; + class const_iterator; @@ -428,8 +431,8 @@ class subview_col : public subview arma_warn_unused inline eT min() const; arma_warn_unused inline eT max() const; - inline eT min(uword& index_of_min_val) const; - inline eT max(uword& index_of_max_val) const; + arma_frown("use .index_min() instead") inline eT min(uword& index_of_min_val) const; + arma_frown("use .index_max() instead") inline eT max(uword& index_of_max_val) const; arma_warn_unused inline uword index_min() const; arma_warn_unused inline uword index_max() const; diff --git a/Include/armadillo/armadillo_bits/subview_elem1_bones.hpp b/Include/armadillo/armadillo_bits/subview_elem1_bones.hpp index 2ac3cdad6..37b5d725b 100644 --- a/Include/armadillo/armadillo_bits/subview_elem1_bones.hpp +++ b/Include/armadillo/armadillo_bits/subview_elem1_bones.hpp @@ -99,6 +99,9 @@ class subview_elem1 : public Base< eT, subview_elem1 > inline static void schur_inplace(Mat& out, const subview_elem1& in); inline static void div_inplace(Mat& out, const subview_elem1& in); + template + inline bool is_alias(const Mat& X) const; + friend class Mat; friend class Cube; diff --git a/Include/armadillo/armadillo_bits/subview_elem1_meat.hpp b/Include/armadillo/armadillo_bits/subview_elem1_meat.hpp index 32a06caa1..3f62f104d 100644 --- a/Include/armadillo/armadillo_bits/subview_elem1_meat.hpp +++ b/Include/armadillo/armadillo_bits/subview_elem1_meat.hpp @@ -67,11 +67,10 @@ subview_elem1::inplace_op(const eT val) const unwrap_check_mixed tmp(a.get_ref(), m_local); const umat& aa = tmp.M; - arma_conform_check - ( - ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object must be a vector" - ); + if(resolves_to_vector::no) + { + arma_conform_check( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), "Mat::elem(): given object must be a vector" ); + } const uword* aa_mem = aa.memptr(); const uword aa_n_elem = aa.n_elem; @@ -220,11 +219,10 @@ subview_elem1::inplace_op(const Base& x) const unwrap_check_mixed aa_tmp(a.get_ref(), m_local); const umat& aa = aa_tmp.M; - arma_conform_check - ( - ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object must be a vector" - ); + if(resolves_to_vector::no) + { + arma_conform_check( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), "Mat::elem(): given object must be a vector" ); + } const uword* aa_mem = aa.memptr(); const uword aa_n_elem = aa.n_elem; @@ -233,9 +231,9 @@ subview_elem1::inplace_op(const Base& x) arma_conform_check( (aa_n_elem != P.get_n_elem()), "Mat::elem(): size mismatch" ); - const bool is_alias = P.is_alias(m); + const bool have_alias = P.is_alias(m); - if( (is_alias == false) && (Proxy::use_at == false) ) + if( (have_alias == false) && (Proxy::use_at == false) ) { typename Proxy::ea_type X = P.get_ea(); @@ -271,7 +269,7 @@ subview_elem1::inplace_op(const Base& x) { arma_debug_print("subview_elem1::inplace_op(): aliasing or use_at detected"); - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& M = tmp.M; const eT* X = M.memptr(); @@ -358,11 +356,10 @@ subview_elem1::replace(const eT old_val, const eT new_val) const unwrap_check_mixed tmp(a.get_ref(), m_local); const umat& aa = tmp.M; - arma_conform_check - ( - ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object must be a vector" - ); + if(resolves_to_vector::no) + { + arma_conform_check( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), "Mat::elem(): given object must be a vector" ); + } const uword* aa_mem = aa.memptr(); const uword aa_n_elem = aa.n_elem; @@ -480,11 +477,10 @@ subview_elem1::randu() const unwrap_check_mixed tmp(a.get_ref(), m_local); const umat& aa = tmp.M; - arma_conform_check - ( - ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object must be a vector" - ); + if(resolves_to_vector::no) + { + arma_conform_check( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), "Mat::elem(): given object must be a vector" ); + } const uword* aa_mem = aa.memptr(); const uword aa_n_elem = aa.n_elem; @@ -531,11 +527,10 @@ subview_elem1::randn() const unwrap_check_mixed tmp(a.get_ref(), m_local); const umat& aa = tmp.M; - arma_conform_check - ( - ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object must be a vector" - ); + if(resolves_to_vector::no) + { + arma_conform_check( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), "Mat::elem(): given object must be a vector" ); + } const uword* aa_mem = aa.memptr(); const uword aa_n_elem = aa.n_elem; @@ -788,11 +783,10 @@ subview_elem1::extract(Mat& actual_out, const subview_elem1& i const unwrap_check_mixed tmp1(in.a.get_ref(), actual_out); const umat& aa = tmp1.M; - arma_conform_check - ( - ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object must be a vector" - ); + if(resolves_to_vector::no) + { + arma_conform_check( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), "Mat::elem(): given object must be a vector" ); + } const uword* aa_mem = aa.memptr(); const uword aa_n_elem = aa.n_elem; @@ -854,11 +848,10 @@ subview_elem1::mat_inplace_op(Mat& out, const subview_elem1& in) const unwrap tmp1(in.a.get_ref()); const umat& aa = tmp1.M; - arma_conform_check - ( - ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object must be a vector" - ); + if(resolves_to_vector::no) + { + arma_conform_check( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), "Mat::elem(): given object must be a vector" ); + } const uword* aa_mem = aa.memptr(); const uword aa_n_elem = aa.n_elem; @@ -950,4 +943,17 @@ subview_elem1::div_inplace(Mat& out, const subview_elem1& in) +template +template +inline +bool +subview_elem1::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (m.is_alias(X) || a.get_ref().is_alias(X)); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/subview_elem2_bones.hpp b/Include/armadillo/armadillo_bits/subview_elem2_bones.hpp index d4c4cbe6c..bec561d77 100644 --- a/Include/armadillo/armadillo_bits/subview_elem2_bones.hpp +++ b/Include/armadillo/armadillo_bits/subview_elem2_bones.hpp @@ -103,6 +103,9 @@ class subview_elem2 : public Base< eT, subview_elem2 > inline static void schur_inplace(Mat& out, const subview_elem2& in); inline static void div_inplace(Mat& out, const subview_elem2& in); + template + inline bool is_alias(const Mat& X) const; + friend class Mat; }; diff --git a/Include/armadillo/armadillo_bits/subview_elem2_meat.hpp b/Include/armadillo/armadillo_bits/subview_elem2_meat.hpp index a6e9dc23b..4a160642b 100644 --- a/Include/armadillo/armadillo_bits/subview_elem2_meat.hpp +++ b/Include/armadillo/armadillo_bits/subview_elem2_meat.hpp @@ -870,4 +870,17 @@ subview_elem2::div_inplace(Mat& out, const subview_elem2& in) +template +template +inline +bool +subview_elem2::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (m.is_alias(X) || base_ri.get_ref().is_alias(X) || base_ci.get_ref().is_alias(X)); + } + + + //! @} diff --git a/Include/armadillo/armadillo_bits/subview_meat.hpp b/Include/armadillo/armadillo_bits/subview_meat.hpp index 81cc39f98..021f7cb0c 100644 --- a/Include/armadillo/armadillo_bits/subview_meat.hpp +++ b/Include/armadillo/armadillo_bits/subview_meat.hpp @@ -2617,6 +2617,19 @@ subview::swap_cols(const uword in_col1, const uword in_col2) +template +template +inline +bool +subview::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + template inline typename subview::iterator diff --git a/Include/armadillo/armadillo_bits/sym_helper.hpp b/Include/armadillo/armadillo_bits/sym_helper.hpp index fab7d5532..fa2b2b3ae 100644 --- a/Include/armadillo/armadillo_bits/sym_helper.hpp +++ b/Include/armadillo/armadillo_bits/sym_helper.hpp @@ -117,6 +117,8 @@ guess_sympd_worker(const Mat& A) // NOTE: assuming A is square-sized + // NOTE: the function name is required for overloading, but is a misnomer: it processes complex hermitian matrices + typedef typename get_pod_type::result T; const T tol = T(100) * std::numeric_limits::epsilon(); // allow some leeway @@ -131,12 +133,16 @@ guess_sympd_worker(const Mat& A) for(uword j=0; j < N; ++j) { const eT& A_jj = A_col[j]; - const T A_jj_real = std::real(A_jj); - const T A_jj_imag = std::imag(A_jj); + const T A_jj_r = std::real(A_jj ); + const T A_jj_i = std::imag(A_jj ); + const T A_jj_rabs = std::abs(A_jj_r); + const T A_jj_iabs = std::abs(A_jj_i); - if( (A_jj_real <= T(0)) || (std::abs(A_jj_imag) > tol) ) { return false; } + if(A_jj_r <= T(0) ) { return false; } // real should be positive + if(A_jj_iabs > tol ) { return false; } // imag should be approx zero + if(A_jj_iabs > A_jj_rabs) { return false; } // corner case: real and imag are close to zero, and imag is dominant - max_diag = (A_jj_real > max_diag) ? A_jj_real : max_diag; + max_diag = (A_jj_r > max_diag) ? A_jj_r : max_diag; A_col += N; } @@ -246,14 +252,11 @@ guess_sympd(const Mat& A, const uword min_n_rows) template inline -typename enable_if2::no, void>::result -analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) +typename enable_if2::no, bool>::result +is_approx_sym_worker(const Mat& A) { arma_debug_sigprint(); - is_approx_sym = true; - is_approx_sympd = true; - const eT tol = eT(100) * std::numeric_limits::epsilon(); // allow some leeway const uword N = A.n_rows; @@ -261,15 +264,11 @@ analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& const eT* A_mem = A.memptr(); const eT* A_col = A_mem; - eT max_diag = eT(0); - for(uword j=0; j < N; ++j) { - const eT A_jj = A_col[j]; + const eT& A_jj = A_col[j]; - if(A_jj <= eT(0)) { is_approx_sympd = false; } - - max_diag = (A_jj > max_diag) ? A_jj : max_diag; + if(arma_isfinite(A_jj) == false) { return false; } A_col += N; } @@ -277,15 +276,11 @@ analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A_col = A_mem; const uword Nm1 = N-1; - const uword Np1 = N+1; for(uword j=0; j < Nm1; ++j) { - const eT A_jj = A_col[j]; - const uword jp1 = j+1; - const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); - const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); for(uword i=jp1; i < N; ++i) { @@ -298,39 +293,29 @@ analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& const eT A_delta = (std::abs)(A_ij - A_ji); const eT A_abs_max = (std::max)(A_ij_abs, A_ji_abs); - if( (A_delta > tol) && (A_delta > (A_abs_max*tol)) ) { is_approx_sym = false; return; } - - if(is_approx_sympd) - { - // if( (A_ij_abs >= max_diag) || (A_ji_abs >= max_diag) ) { is_approx_sympd = false; } - if(A_ij_abs >= max_diag) { is_approx_sympd = false; } - - const eT A_ii = (*A_ii_ptr); - - if( (A_ij_abs + A_ij_abs) >= (A_ii + A_jj) ) { is_approx_sympd = false; } - } + if( (A_delta > tol) && (A_delta > (A_abs_max*tol)) ) { return false; } A_ji_ptr += N; - A_ii_ptr += Np1; } A_col += N; } + + return true; } template inline -typename enable_if2::yes, void>::result -analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) +typename enable_if2::yes, bool>::result +is_approx_sym_worker(const Mat& A) { arma_debug_sigprint(); - typedef typename get_pod_type::result T; + // NOTE: the function name is required for overloading, but is a misnomer: it processes complex hermitian matrices - is_approx_sym = true; - is_approx_sympd = true; + typedef typename get_pod_type::result T; const T tol = T(100) * std::numeric_limits::epsilon(); // allow some leeway @@ -339,37 +324,31 @@ analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& const eT* A_mem = A.memptr(); const eT* A_col = A_mem; - T max_diag = T(0); - + // ensure diagonal has approx real-only elements for(uword j=0; j < N; ++j) { const eT& A_jj = A_col[j]; - const T A_jj_real = std::real(A_jj); - const T A_jj_imag = std::imag(A_jj); + const T A_jj_r = std::real(A_jj ); + const T A_jj_i = std::imag(A_jj ); + const T A_jj_rabs = std::abs(A_jj_r); + const T A_jj_iabs = std::abs(A_jj_i); - if( (A_jj_real <= T(0)) || (std::abs(A_jj_imag) > tol) ) { is_approx_sympd = false; } + if(A_jj_iabs > tol ) { return false; } // imag should be approx zero + if(A_jj_iabs > A_jj_rabs) { return false; } // corner case: real and imag are close to zero, and imag is dominant - max_diag = (A_jj_real > max_diag) ? A_jj_real : max_diag; + if(arma_isfinite(A_jj_r) == false) { return false; } A_col += N; } - const T square_max_diag = max_diag * max_diag; - - if(arma_isfinite(square_max_diag) == false) { is_approx_sympd = false; } - A_col = A_mem; const uword Nm1 = N-1; - const uword Np1 = N+1; for(uword j=0; j < Nm1; ++j) { - const uword jp1 = j+1; - const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); - const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); - - const T A_jj_real = std::real(A_col[j]); + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); for(uword i=jp1; i < N; ++i) { @@ -390,63 +369,58 @@ analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& const T A_real_delta = (std::abs)(A_ij_real - A_ji_real); const T A_real_abs_max = (std::max)(A_ij_real_abs, A_ji_real_abs); - if( (A_real_delta > tol) && (A_real_delta > (A_real_abs_max*tol)) ) { is_approx_sym = false; return; } - + if( (A_real_delta > tol) && (A_real_delta > (A_real_abs_max*tol)) ) { return false; } + const T A_imag_delta = (std::abs)(A_ij_imag + A_ji_imag); // take into account complex conjugate const T A_imag_abs_max = (std::max)(A_ij_imag_abs, A_ji_imag_abs); - if( (A_imag_delta > tol) && (A_imag_delta > (A_imag_abs_max*tol)) ) { is_approx_sym = false; return; } - - if(is_approx_sympd) - { - // avoid using std::abs(), as that is time consuming due to division and std::sqrt() - const T square_A_ij_abs = (A_ij_real * A_ij_real) + (A_ij_imag * A_ij_imag); - - if(arma_isfinite(square_A_ij_abs) == false) - { - is_approx_sympd = false; - } - else - { - const T A_ii_real = std::real(*A_ii_ptr); - - if( (A_ij_real_abs + A_ij_real_abs) >= (A_ii_real + A_jj_real) ) { is_approx_sympd = false; } - - if(square_A_ij_abs >= square_max_diag) { is_approx_sympd = false; } - } - } + if( (A_imag_delta > tol) && (A_imag_delta > (A_imag_abs_max*tol)) ) { return false; } A_ji_ptr += N; - A_ii_ptr += Np1; } A_col += N; } + + return true; } template inline -void -analyse_matrix(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) +bool +is_approx_sym(const Mat& A) { arma_debug_sigprint(); - if((A.n_rows != A.n_cols) || (A.n_rows < uword(4))) - { - is_approx_sym = false; - is_approx_sympd = false; - return; - } + // analyse matrices with size >= 4x4 + + if((A.n_rows != A.n_cols) || (A.n_rows < uword(4))) { return false; } - analyse_matrix_worker(is_approx_sym, is_approx_sympd, A); + return is_approx_sym_worker(A); + } + + + +template +inline +bool +is_approx_sym(const Mat& A, const uword min_n_rows) + { + arma_debug_sigprint(); + + if((A.n_rows != A.n_cols) || (A.n_rows < min_n_rows)) { return false; } - if(is_approx_sym == false) { is_approx_sympd = false; } + return is_approx_sym_worker(A); } +// + + + template inline bool diff --git a/Include/armadillo/armadillo_bits/traits.hpp b/Include/armadillo/armadillo_bits/traits.hpp index 5318ef9fa..e3bf703e8 100644 --- a/Include/armadillo/armadillo_bits/traits.hpp +++ b/Include/armadillo/armadillo_bits/traits.hpp @@ -1343,5 +1343,29 @@ struct is_sym_expr< Glue< Op, op_htrans>, Mat, glue_times > > } }; +template +struct is_sym_expr< Op > + { + static + arma_inline + bool + eval(const Op&) + { + return true; + } + }; + +template +struct is_sym_expr< Op > + { + static + arma_inline + bool + eval(const Op&) + { + return true; + } + }; + //! @} diff --git a/Include/armadillo/armadillo_bits/translate_lapack.hpp b/Include/armadillo/armadillo_bits/translate_lapack.hpp index 7ed4c0ec9..cdb3aa650 100644 --- a/Include/armadillo/armadillo_bits/translate_lapack.hpp +++ b/Include/armadillo/armadillo_bits/translate_lapack.hpp @@ -1341,6 +1341,150 @@ namespace lapack } + + template + inline + void + sytrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* ipiv, eT* work, const blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssytrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssytrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + #endif + } + + + + template + inline + void + hetrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* ipiv, eT* work, const blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info, 1); } + #else + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetrf)(uplo, n, (T*)a, lda, ipiv, (T*)work, lwork, info); } + #endif + } + + + + template + inline + void + sytrs(const char* uplo, const blas_int* n, const blas_int* nrhs, const eT* a, const blas_int* lda, const blas_int* ipiv, eT* b, const blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + hetrs(const char* uplo, const blas_int* n, const blas_int* nrhs, const eT* a, const blas_int* lda, const blas_int* ipiv, eT* b, const blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + #else + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + sytri(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, const blas_int* ipiv, eT* work, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssytri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssytri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info); } + #endif + } + + + + template + inline + void + hetri(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, const blas_int* ipiv, eT* work, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info, 1); } + #else + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetri)(uplo, n, (T*)a, lda, ipiv, (T*)work, info); } + #endif + } + + + + + template + inline + void + sycon(const char* uplo, const blas_int* n, const eT* a, const blas_int* lda, const blas_int* ipiv, const eT* anorm, eT* rcond, eT* work, blas_int* iwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssycon)(uplo, n, (T*)a, lda, ipiv, (const T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsycon)(uplo, n, (T*)a, lda, ipiv, (const T*)anorm, (T*)rcond, (T*)work, iwork, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssycon)(uplo, n, (T*)a, lda, ipiv, (const T*)anorm, (T*)rcond, (T*)work, iwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsycon)(uplo, n, (T*)a, lda, ipiv, (const T*)anorm, (T*)rcond, (T*)work, iwork, info); } + #endif + } + + + + template + inline + void + hecon(const char* uplo, const blas_int* n, const std::complex* a, const blas_int* lda, const blas_int* ipiv, const T* anorm, T* rcond, std::complex* work, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_checon)(uplo, n, (cx_T*)a, lda, ipiv, (const pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, info, 1); } + else if(is_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zhecon)(uplo, n, (cx_T*)a, lda, ipiv, (const pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, info, 1); } + #else + if( is_float::value) { typedef float pod_T; typedef blas_cxf cx_T; arma_fortran(arma_checon)(uplo, n, (cx_T*)a, lda, ipiv, (const pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, info); } + else if(is_double::value) { typedef double pod_T; typedef blas_cxd cx_T; arma_fortran(arma_zhecon)(uplo, n, (cx_T*)a, lda, ipiv, (const pod_T*)anorm, (pod_T*)rcond, (cx_T*)work, info); } + #endif + } + } diff --git a/Include/armadillo/armadillo_bits/typedef_elem.hpp b/Include/armadillo/armadillo_bits/typedef_elem.hpp index be7c75ef9..c5f2dbac4 100644 --- a/Include/armadillo/armadillo_bits/typedef_elem.hpp +++ b/Include/armadillo/armadillo_bits/typedef_elem.hpp @@ -107,12 +107,9 @@ typedef void* void_ptr; // -#if defined(ARMA_BLAS_LONG_LONG) +#if defined(ARMA_BLAS_64BIT_INT) typedef long long blas_int; #define ARMA_MAX_BLAS_INT 0x7fffffffffffffffULL -#elif defined(ARMA_BLAS_LONG) - typedef long blas_int; - #define ARMA_MAX_BLAS_INT 0x7fffffffffffffffUL #else typedef int blas_int; #define ARMA_MAX_BLAS_INT 0x7fffffffU diff --git a/Include/armadillo/armadillo_bits/unwrap.hpp b/Include/armadillo/armadillo_bits/unwrap.hpp index 7e9363179..b6be60bc8 100644 --- a/Include/armadillo/armadillo_bits/unwrap.hpp +++ b/Include/armadillo/armadillo_bits/unwrap.hpp @@ -1075,8 +1075,8 @@ struct unwrap_check_mixed< Mat > template inline unwrap_check_mixed(const Mat& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Mat(A) : nullptr ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? new Mat(A) : nullptr ) + , M ( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? (*M_local) : A ) { arma_debug_sigprint(); } @@ -1112,8 +1112,8 @@ struct unwrap_check_mixed< Row > template inline unwrap_check_mixed(const Row& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Row(A) : nullptr ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? new Row(A) : nullptr ) + , M ( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? (*M_local) : A ) { arma_debug_sigprint(); } @@ -1150,8 +1150,8 @@ struct unwrap_check_mixed< Col > template inline unwrap_check_mixed(const Col& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Col(A) : nullptr ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? new Col(A) : nullptr ) + , M ( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? (*M_local) : A ) { arma_debug_sigprint(); } diff --git a/README.md b/README.md index 0a8d5c3bd..83fb0846b 100644 --- a/README.md +++ b/README.md @@ -295,7 +295,7 @@ Additional libraries used in **suanPan** are listed as follows. - [**VTK**](https://vtk.org/) version 9.2.6 - [**CUDA**](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/) version 12.5 - [**MAGMA**](https://icl.utk.edu/magma/) version 2.8.0 -- [**Armadillo**](http://arma.sourceforge.net/) version 14.0.2 +- [**Armadillo**](http://arma.sourceforge.net/) version 14.2.0 - [**ensmallen**](https://ensmallen.org/) version 2.21.1 - [**oneMKL**](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html) version 2024.2.1 - [**Catch2**](https://github.com/catchorg/Catch2) version 3.7.1