diff --git a/SeQuant/core/optimize/fusion.cpp b/SeQuant/core/optimize/fusion.cpp index 64629828b..23b62892e 100644 --- a/SeQuant/core/optimize/fusion.cpp +++ b/SeQuant/core/optimize/fusion.cpp @@ -49,12 +49,13 @@ ExprPtr Fusion::fuse_left(Product const& lhs, Product const& rhs) { auto lsmand_prod = Product{lsmand.begin(), lsmand.end()}; auto rsmand_prod = Product{rsmand.begin(), rsmand.end()}; - if (lhs.scalar() == rhs.scalar()) - fac_prod.scale(lhs.scalar()); - else { - lsmand_prod.scale(lhs.scalar()); - rsmand_prod.scale(rhs.scalar()); - } + assert(lhs.scalar().imag().is_zero() && rhs.scalar().imag().is_zero() && + "Complex valued gcd not supported"); + auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real()); + + fac_prod.scale(scalars_fused.at(0)); + lsmand_prod.scale(scalars_fused.at(1)); + rsmand_prod.scale(scalars_fused.at(2)); // f (a + b) @@ -86,12 +87,13 @@ ExprPtr Fusion::fuse_right(Product const& lhs, Product const& rhs) { auto lsmand_prod = Product{lsmand.begin(), lsmand.end()}; auto rsmand_prod = Product{rsmand.begin(), rsmand.end()}; - if (lhs.scalar() == rhs.scalar()) - fac_prod.scale(lhs.scalar()); - else { - lsmand_prod.scale(lhs.scalar()); - rsmand_prod.scale(rhs.scalar()); - } + assert(lhs.scalar().imag().is_zero() && rhs.scalar().imag().is_zero() && + "Complex valued gcd not supported"); + auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real()); + + fac_prod.scale(scalars_fused.at(0)); + lsmand_prod.scale(scalars_fused.at(1)); + rsmand_prod.scale(scalars_fused.at(2)); // (a + b) f @@ -102,4 +104,29 @@ ExprPtr Fusion::fuse_right(Product const& lhs, Product const& rhs) { return ex(ExprPtrList{ex(ExprPtrList{a, b}), f}); } +rational Fusion::gcd_rational(rational const& left, rational const& right) { + auto&& r1 = left.real(); + auto&& r2 = right.real(); + auto&& n1 = numerator(r1); + auto&& d1 = denominator(r1); + auto&& n2 = numerator(r2); + auto&& d2 = denominator(r2); + + auto num = gcd(n1 * d2, n2 * d1); + return {num, d1 * d2}; +} + +std::array Fusion::fuse_scalar(rational const& left, + rational const& right) { + auto fused = gcd_rational(left, right); + rational left_fused = left / fused; + rational right_fused = right / fused; + if (left < 0 && right < 0) { + fused *= -1; + left_fused *= -1; + right_fused *= -1; + } + return {fused, left_fused, right_fused}; +} + } // namespace sequant::opt diff --git a/SeQuant/core/optimize/fusion.hpp b/SeQuant/core/optimize/fusion.hpp index 97c4cf869..210a9b120 100644 --- a/SeQuant/core/optimize/fusion.hpp +++ b/SeQuant/core/optimize/fusion.hpp @@ -36,6 +36,20 @@ class Fusion { static ExprPtr fuse_right(Product const& lhs, Product const& rhs); + /// + /// Get the greatest common divisor of two rational numbers. + /// + static rational gcd_rational(rational const& left, rational const& right); + + /// + /// Fuse scalars @param left and @param right and the return the result + /// as an array of three elements: first is the greatest common factor, + /// second the fused sub-factor of @param left and the third is that + /// of @param right. + /// + static std::array fuse_scalar(rational const& left, + rational const& right); + private: ExprPtr left_; diff --git a/tests/unit/test_fusion.cpp b/tests/unit/test_fusion.cpp index 6a80c836d..99ee9e823 100644 --- a/tests/unit/test_fusion.cpp +++ b/tests/unit/test_fusion.cpp @@ -1,17 +1,15 @@ #include -#include #include #include #include -#include #include #include #include #include -TEST_CASE("TEST_FUSION", "[Fusion]") { +TEST_CASE("TEST_FUSION", "[optimize]") { using sequant::opt::Fusion; using namespace sequant; std::vector> fused_terms{ @@ -23,7 +21,7 @@ TEST_CASE("TEST_FUSION", "[Fusion]") { {L"1/8 g{a1,a2;a3,a4} t{a3,a4;i1,i2}", L"1/4 g{a1,a2;a3,a4} t{a3;i1} t{a4;i2}", - L"g{a1,a2;a3,a4}(1/8 t{a3,a4;i1,i2} + 1/4 t{a3;i1} t{a4;i2})"}, + L"1/8 g{a1,a2;a3,a4}(t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"}, {L"1/4 g{a1,a2;a3,a4} t{a3;i1} t{a4;i2}", L"1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}", @@ -32,13 +30,20 @@ TEST_CASE("TEST_FUSION", "[Fusion]") { {L"1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3,a4;i1,i2}", L"1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}", - L"g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} " - L" (1/8 t{a3,a4;i1,i2} + 1/4 t{a3;i1} t{a4;i2})"}}; + L"1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} " + L" (t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"}, + + {L"-1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3,a4;i1,i2}", + L"-1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}", + L"-1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} " + L" (t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"} + + }; for (auto&& [l, r, f] : fused_terms) { - auto const le = parse_expr(l, Symmetry::nonsymm); - auto const re = parse_expr(r, Symmetry::nonsymm); - auto const fe = parse_expr(f, Symmetry::nonsymm); + auto const le = parse_expr(l); + auto const re = parse_expr(r); + auto const fe = parse_expr(f); auto fu = Fusion{le->as(), re->as()}; REQUIRE((fu.left() || fu.right()));