diff --git a/tests/unit/test_eval_node.cpp b/tests/unit/test_eval_node.cpp index c475afed2..36cb6df7c 100644 --- a/tests/unit/test_eval_node.cpp +++ b/tests/unit/test_eval_node.cpp @@ -3,13 +3,9 @@ #include #include -namespace { +#include "utils.hpp" -// validates if x is constructible from tspec using parse_expr -auto validate_tensor = [](const auto& x, std::wstring_view tspec) -> bool { - return x.to_latex() == - sequant::parse_expr(tspec, sequant::Symmetry::antisymm)->to_latex(); -}; +namespace { auto eval_node(sequant::ExprPtr const& expr) { return sequant::eval_node(expr); @@ -60,24 +56,21 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto node1 = eval_node(p1); - REQUIRE(validate_tensor(node(node1, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_TENSOR_EQUAL(node(node1, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}"); REQUIRE(node(node1, {R}).as_constant() == Constant{rational{1, 16}}); - REQUIRE( - validate_tensor(node(node1, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_TENSOR_EQUAL(node(node1, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}"); - REQUIRE( - validate_tensor(node(node1, {L, L}).as_tensor(), L"I_{a1,a2}^{a3,a4}")); + REQUIRE_TENSOR_EQUAL(node(node1, {L, L}).as_tensor(), L"I_{a1,a2}^{a3,a4}"); - REQUIRE( - validate_tensor(node(node1, {L, R}).as_tensor(), L"t_{a3,a4}^{i1,i2}")); + REQUIRE_TENSOR_EQUAL(node(node1, {L, R}).as_tensor(), L"t_{a3,a4}^{i1,i2}"); - REQUIRE(validate_tensor(node(node1, {L, L, L}).as_tensor(), - L"g_{i3,i4}^{a3,a4}")); + REQUIRE_TENSOR_EQUAL(node(node1, {L, L, L}).as_tensor(), + L"g_{i3,i4}^{a3,a4}"); - REQUIRE(validate_tensor(node(node1, {L, L, R}).as_tensor(), - L"t_{a1,a2}^{i3,i4}")); + REQUIRE_TENSOR_EQUAL(node(node1, {L, L, R}).as_tensor(), + L"t_{a1,a2}^{i3,i4}"); // 1/16 * A * (B * C) auto node2p = Product{p1->as().scalar(), {}}; @@ -87,24 +80,24 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto const node2 = eval_node(ex(node2p)); - REQUIRE(validate_tensor(node(node2, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_TENSOR_EQUAL(node(node2, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}"); - REQUIRE( - validate_tensor(node(node2, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_TENSOR_EQUAL( + node(node2, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}"); REQUIRE(node(node2, {R}).as_constant() == Constant{rational{1, 16}}); - REQUIRE( - validate_tensor(node(node2, {L, L}).as_tensor(), L"g{i3,i4; a3,a4}")); + REQUIRE_TENSOR_EQUAL( + node(node2, {L, L}).as_tensor(), L"g{i3,i4; a3,a4}"); - REQUIRE(validate_tensor(node(node2, {L, R}).as_tensor(), - L"I{a1,a2,a3,a4;i3,i4,i1,i2}")); + REQUIRE_TENSOR_EQUAL(node(node2, {L, R}).as_tensor(), + L"I{a1,a2,a3,a4;i3,i4,i1,i2}"); - REQUIRE( - validate_tensor(node(node2, {L, R, L}).as_tensor(), L"t{a1,a2;i3,i4}")); + REQUIRE_TENSOR_EQUAL( + node(node2, {L, R, L}).as_tensor(), L"t{a1,a2;i3,i4}"); - REQUIRE( - validate_tensor(node(node2, {L, R, R}).as_tensor(), L"t{a3,a4;i1,i2}")); + REQUIRE_TENSOR_EQUAL( + node(node2, {L, R, R}).as_tensor(), L"t{a3,a4;i1,i2}"); } SECTION("sum") { @@ -116,20 +109,22 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto const node1 = eval_node(sum1); REQUIRE(node1->op_type() == EvalOp::Sum); REQUIRE(node1.left()->op_type() == EvalOp::Sum); - REQUIRE(validate_tensor(node1.left()->as_tensor(), L"I^{i1,i2}_{a1,a2}")); - REQUIRE(validate_tensor(node1.left().left()->as_tensor(), - L"X^{i1,i2}_{a1,a2}")); - REQUIRE(validate_tensor(node1.left().right()->as_tensor(), - L"Y^{i1,i2}_{a1,a2}")); + REQUIRE_TENSOR_EQUAL(node1.left()->as_tensor(), L"I^{i1,i2}_{a1,a2}"); + REQUIRE_TENSOR_EQUAL(node1.left().left()->as_tensor(), + L"X^{i1,i2}_{a1,a2}"); + REQUIRE_TENSOR_EQUAL(node1.left().right()->as_tensor(), + L"Y^{i1,i2}_{a1,a2}"); REQUIRE(node1.right()->op_type() == EvalOp::Prod); - REQUIRE( - (validate_tensor(node1.right()->as_tensor(), L"I_{a2,a1}^{i1,i2}") || - validate_tensor(node1.right()->as_tensor(), L"I_{a1,a2}^{i2,i1}"))); - REQUIRE(validate_tensor(node1.right().left()->as_tensor(), - L"g_{i3,a1}^{i1,i2}")); - REQUIRE( - validate_tensor(node1.right().right()->as_tensor(), L"t_{a2}^{i3}")); + if constexpr (hash_version() == hash::Impl::BoostPre181) { + REQUIRE_TENSOR_EQUAL(node1.right()->as_tensor(), L"I_{a2,a1}^{i1,i2}"); + } else { + REQUIRE_TENSOR_EQUAL(node1.right()->as_tensor(), L"I_{a1,a2}^{i1,i2}"); + } + REQUIRE_TENSOR_EQUAL(node1.right().left()->as_tensor(), + L"g_{i3,a1}^{i1,i2}"); + REQUIRE_TENSOR_EQUAL( + node1.right().right()->as_tensor(), L"t_{a2}^{i3}"); } SECTION("variable") { @@ -161,8 +156,8 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto prod2 = parse_expr(L"a * t{i1;a1}"); auto node3 = eval_node(prod2); - REQUIRE(validate_tensor(node(node3, {}), L"I{i1;a1}")); - REQUIRE(validate_tensor(node(node3, {R}), L"t{i1;a1}")); + REQUIRE_TENSOR_EQUAL(node(node3, {}), L"I{i1;a1}"); + REQUIRE_TENSOR_EQUAL(node(node3, {R}), L"t{i1;a1}"); REQUIRE(node(node3, {L}).as_variable() == Variable{L"a"}); } diff --git a/tests/unit/utils.hpp b/tests/unit/utils.hpp index 8b1aa4483..27c2e4f38 100644 --- a/tests/unit/utils.hpp +++ b/tests/unit/utils.hpp @@ -24,4 +24,9 @@ } \ REQUIRE(to_latex(sum) == std::wstring(str)) +#define REQUIRE_TENSOR_EQUAL(tensor, spec) \ + REQUIRE(sequant::to_latex(tensor) == \ + sequant::to_latex( \ + sequant::parse_expr(spec, sequant::Symmetry::antisymm))); + #endif