From 903ed2acb95eff18d704aa8d09d637fd06fc2b37 Mon Sep 17 00:00:00 2001 From: MikePopoloski Date: Sat, 5 Oct 2024 15:25:23 -0400 Subject: [PATCH] Allow unpacked unions in equality and conditional expressions --- source/ast/expressions/OperatorExpressions.cpp | 14 ++++++++++++-- tests/unittests/ast/EvalTests.cpp | 3 +++ tests/unittests/ast/ExpressionTests.cpp | 11 ++++------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/source/ast/expressions/OperatorExpressions.cpp b/source/ast/expressions/OperatorExpressions.cpp index 4d18ff71c..d578ebadf 100644 --- a/source/ast/expressions/OperatorExpressions.cpp +++ b/source/ast/expressions/OperatorExpressions.cpp @@ -847,7 +847,7 @@ Expression& BinaryExpression::fromComponents(Expression& lhs, Expression& rhs, B contextDetermined(context, result->right_, result, compilation.getStringType(), opRange); } - else if (lt->isAggregate() && lt->isEquivalent(*rt) && !lt->isUnpackedUnion()) { + else if (lt->isAggregate() && lt->isEquivalent(*rt)) { good = !isWildcard; result->type = singleBitType(compilation, lt, rt); } @@ -1342,7 +1342,7 @@ Expression& ConditionalExpression::fromSyntax(Compilation& comp, else good = false; } - else if (lt->isEquivalent(*rt) && !lt->isUnpackedUnion()) { + else if (lt->isEquivalent(*rt)) { result->type = lt; } else if (left.isImplicitString() && right.isImplicitString()) { @@ -2615,6 +2615,16 @@ ConstantValue Expression::evalBinaryOperator(BinaryOperator op, const ConstantVa SLANG_UNREACHABLE; } } + else if (cvl.isUnion() && cvr.isUnion()) { + switch (op) { + OP(Equality, SVInt(cvl == cvr)); + OP(Inequality, SVInt(cvl != cvr)); + OP(CaseEquality, SVInt(cvl == cvr)); + OP(CaseInequality, SVInt(cvl != cvr)); + default: + SLANG_UNREACHABLE; + } + } #undef OP SLANG_UNREACHABLE; diff --git a/tests/unittests/ast/EvalTests.cpp b/tests/unittests/ast/EvalTests.cpp index 233a3e551..48833babb 100644 --- a/tests/unittests/ast/EvalTests.cpp +++ b/tests/unittests/ast/EvalTests.cpp @@ -841,6 +841,9 @@ union { session.eval("baz.c = 123;"); CHECK(session.eval("baz.a.s1").integer() == 123); + CHECK(session.eval("foo == bar").integer() == 1); + CHECK(session.eval("1 ? foo : bar").toString() == "(0) [3,4,42]"); + NO_SESSION_ERRORS; } diff --git a/tests/unittests/ast/ExpressionTests.cpp b/tests/unittests/ast/ExpressionTests.cpp index a98f19a97..204a8107a 100644 --- a/tests/unittests/ast/ExpressionTests.cpp +++ b/tests/unittests/ast/ExpressionTests.cpp @@ -336,12 +336,12 @@ TEST_CASE("Expression types") { // Unpacked unions declare("union { int i; real r; } uu1, uu2;"); - CHECK(typeof("uu1 == uu2") == ""); - CHECK(typeof("uu1 !== uu2") == ""); - CHECK(typeof("1 ? uu1 : uu2") == ""); + CHECK(typeof("uu1 == uu2") == "bit"); + CHECK(typeof("uu1 !== uu2") == "bit"); + CHECK(typeof("1 ? uu1 : uu2") == "union{int i;real r;}u$3"); auto diags = filterWarnings(compilation.getAllDiagnostics()); - REQUIRE(diags.size() == 11); + REQUIRE(diags.size() == 8); CHECK(diags[0].code == diag::BadUnaryExpression); CHECK(diags[1].code == diag::BadBinaryExpression); CHECK(diags[2].code == diag::BadBinaryExpression); @@ -350,9 +350,6 @@ TEST_CASE("Expression types") { CHECK(diags[5].code == diag::BadBinaryExpression); CHECK(diags[6].code == diag::BadConditionalExpression); CHECK(diags[7].code == diag::NotBooleanConvertible); - CHECK(diags[8].code == diag::BadBinaryExpression); - CHECK(diags[9].code == diag::BadBinaryExpression); - CHECK(diags[10].code == diag::BadConditionalExpression); } TEST_CASE("Expression - bad name references") {