Skip to content

Commit

Permalink
Ensure failing eval node tests print something useful
Browse files Browse the repository at this point in the history
  • Loading branch information
Krzmbrzl committed Feb 6, 2024
1 parent 113ef4b commit 5dc009d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 42 deletions.
79 changes: 37 additions & 42 deletions tests/unit/test_eval_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
#include <SeQuant/core/eval_node.hpp>
#include <SeQuant/core/parse_expr.hpp>

namespace {
#include "utils.hpp"

// validates if x is constructible from tspec using parse_expr
auto validate_tensor = [](const auto& x, std::wstring_view tspec) -> bool {
return x.to_latex() ==
sequant::parse_expr(tspec, sequant::Symmetry::antisymm)->to_latex();
};
namespace {

auto eval_node(sequant::ExprPtr const& expr) {
return sequant::eval_node<sequant::EvalExpr>(expr);
Expand Down Expand Up @@ -60,24 +56,21 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") {

auto node1 = eval_node(p1);

REQUIRE(validate_tensor(node(node1, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}"));
REQUIRE_TENSOR_EQUAL(node(node1, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}");

REQUIRE(node(node1, {R}).as_constant() == Constant{rational{1, 16}});

REQUIRE(
validate_tensor(node(node1, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}"));
REQUIRE_TENSOR_EQUAL(node(node1, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}");

REQUIRE(
validate_tensor(node(node1, {L, L}).as_tensor(), L"I_{a1,a2}^{a3,a4}"));
REQUIRE_TENSOR_EQUAL(node(node1, {L, L}).as_tensor(), L"I_{a1,a2}^{a3,a4}");

REQUIRE(
validate_tensor(node(node1, {L, R}).as_tensor(), L"t_{a3,a4}^{i1,i2}"));
REQUIRE_TENSOR_EQUAL(node(node1, {L, R}).as_tensor(), L"t_{a3,a4}^{i1,i2}");

REQUIRE(validate_tensor(node(node1, {L, L, L}).as_tensor(),
L"g_{i3,i4}^{a3,a4}"));
REQUIRE_TENSOR_EQUAL(node(node1, {L, L, L}).as_tensor(),
L"g_{i3,i4}^{a3,a4}");

REQUIRE(validate_tensor(node(node1, {L, L, R}).as_tensor(),
L"t_{a1,a2}^{i3,i4}"));
REQUIRE_TENSOR_EQUAL(node(node1, {L, L, R}).as_tensor(),
L"t_{a1,a2}^{i3,i4}");

// 1/16 * A * (B * C)
auto node2p = Product{p1->as<Product>().scalar(), {}};
Expand All @@ -87,24 +80,24 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") {

auto const node2 = eval_node(ex<Product>(node2p));

REQUIRE(validate_tensor(node(node2, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}"));
REQUIRE_TENSOR_EQUAL(node(node2, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}");

REQUIRE(
validate_tensor(node(node2, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}"));
REQUIRE_TENSOR_EQUAL(
node(node2, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}");

REQUIRE(node(node2, {R}).as_constant() == Constant{rational{1, 16}});

REQUIRE(
validate_tensor(node(node2, {L, L}).as_tensor(), L"g{i3,i4; a3,a4}"));
REQUIRE_TENSOR_EQUAL(
node(node2, {L, L}).as_tensor(), L"g{i3,i4; a3,a4}");

REQUIRE(validate_tensor(node(node2, {L, R}).as_tensor(),
L"I{a1,a2,a3,a4;i3,i4,i1,i2}"));
REQUIRE_TENSOR_EQUAL(node(node2, {L, R}).as_tensor(),
L"I{a1,a2,a3,a4;i3,i4,i1,i2}");

REQUIRE(
validate_tensor(node(node2, {L, R, L}).as_tensor(), L"t{a1,a2;i3,i4}"));
REQUIRE_TENSOR_EQUAL(
node(node2, {L, R, L}).as_tensor(), L"t{a1,a2;i3,i4}");

REQUIRE(
validate_tensor(node(node2, {L, R, R}).as_tensor(), L"t{a3,a4;i1,i2}"));
REQUIRE_TENSOR_EQUAL(
node(node2, {L, R, R}).as_tensor(), L"t{a3,a4;i1,i2}");
}

SECTION("sum") {
Expand All @@ -116,20 +109,22 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") {
auto const node1 = eval_node(sum1);
REQUIRE(node1->op_type() == EvalOp::Sum);
REQUIRE(node1.left()->op_type() == EvalOp::Sum);
REQUIRE(validate_tensor(node1.left()->as_tensor(), L"I^{i1,i2}_{a1,a2}"));
REQUIRE(validate_tensor(node1.left().left()->as_tensor(),
L"X^{i1,i2}_{a1,a2}"));
REQUIRE(validate_tensor(node1.left().right()->as_tensor(),
L"Y^{i1,i2}_{a1,a2}"));
REQUIRE_TENSOR_EQUAL(node1.left()->as_tensor(), L"I^{i1,i2}_{a1,a2}");
REQUIRE_TENSOR_EQUAL(node1.left().left()->as_tensor(),
L"X^{i1,i2}_{a1,a2}");
REQUIRE_TENSOR_EQUAL(node1.left().right()->as_tensor(),
L"Y^{i1,i2}_{a1,a2}");

REQUIRE(node1.right()->op_type() == EvalOp::Prod);
REQUIRE(
(validate_tensor(node1.right()->as_tensor(), L"I_{a2,a1}^{i1,i2}") ||
validate_tensor(node1.right()->as_tensor(), L"I_{a1,a2}^{i2,i1}")));
REQUIRE(validate_tensor(node1.right().left()->as_tensor(),
L"g_{i3,a1}^{i1,i2}"));
REQUIRE(
validate_tensor(node1.right().right()->as_tensor(), L"t_{a2}^{i3}"));
if constexpr (hash_version() == hash::Impl::BoostPre181) {
REQUIRE_TENSOR_EQUAL(node1.right()->as_tensor(), L"I_{a2,a1}^{i1,i2}");
} else {
REQUIRE_TENSOR_EQUAL(node1.right()->as_tensor(), L"I_{a1,a2}^{i1,i2}");
}
REQUIRE_TENSOR_EQUAL(node1.right().left()->as_tensor(),
L"g_{i3,a1}^{i1,i2}");
REQUIRE_TENSOR_EQUAL(
node1.right().right()->as_tensor(), L"t_{a2}^{i3}");
}

SECTION("variable") {
Expand Down Expand Up @@ -161,8 +156,8 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") {

auto prod2 = parse_expr(L"a * t{i1;a1}");
auto node3 = eval_node(prod2);
REQUIRE(validate_tensor(node(node3, {}), L"I{i1;a1}"));
REQUIRE(validate_tensor(node(node3, {R}), L"t{i1;a1}"));
REQUIRE_TENSOR_EQUAL(node(node3, {}), L"I{i1;a1}");
REQUIRE_TENSOR_EQUAL(node(node3, {R}), L"t{i1;a1}");
REQUIRE(node(node3, {L}).as_variable() == Variable{L"a"});
}

Expand Down
5 changes: 5 additions & 0 deletions tests/unit/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@
} \
REQUIRE(to_latex(sum) == std::wstring(str))

#define REQUIRE_TENSOR_EQUAL(tensor, spec) \
REQUIRE(sequant::to_latex(tensor) == \
sequant::to_latex( \
sequant::parse_expr(spec, sequant::Symmetry::antisymm)));

#endif

0 comments on commit 5dc009d

Please sign in to comment.