From f350a8fbdd2e1f6797e46900fe905292ae55d7c9 Mon Sep 17 00:00:00 2001 From: hwse <43220762+hwse@users.noreply.github.com> Date: Wed, 21 Aug 2024 08:43:13 +0200 Subject: [PATCH] AVRO-3984 [C++] Improve code generated for unions (#3047) * AVRO-3984 [C++] Getters created by avrogencpp return a reference instead of a value to avoid calling copy constructor of large classes * AVRO-3984 [C++] Add getter for generated unions that returns a mutable reference. This allows the user to modify values in union branches after creation (#3047) * AVRO-3984 [C++] Add move setters for generated unions to provide a more efficient way to set a value (#3047) * AVRO-3984 [C++] Use std::move in decode implementation of codec_traits for unions to avoid a copy (#3047) * AVRO-3984 [C++] Generate an enum for each union type that maps the branch names to the corresponding index. This allows the user to avoid checks against "magic numbers" (#3047) * AVRO-3984 [C++] Add additional checks for the union branch in testUnionMethods test (#3047) * AVRO-3984 [C++] Add additional branch() method that returns the Branch enum directly, this avoids a manual static_cast (#3047) --------- Co-authored-by: hwse --- lang/c++/CMakeLists.txt | 3 +- lang/c++/impl/avrogencpp.cc | 60 ++++++++++++++--- lang/c++/jsonschemas/big_union | 101 ++++++++++++++++++++++++++++ lang/c++/test/AvrogencppTests.cc | 110 +++++++++++++++++++++++++++++++ 4 files changed, 265 insertions(+), 9 deletions(-) create mode 100644 lang/c++/jsonschemas/big_union diff --git a/lang/c++/CMakeLists.txt b/lang/c++/CMakeLists.txt index 1b234e0411c..d4c494cefa5 100644 --- a/lang/c++/CMakeLists.txt +++ b/lang/c++/CMakeLists.txt @@ -181,6 +181,7 @@ gen (crossref cr) gen (primitivetypes pt) gen (cpp_reserved_words cppres) gen (cpp_reserved_words_union_typedef cppres_union) +gen (big_union big_union) add_executable (avrogencpp impl/avrogencpp.cc) target_link_libraries (avrogencpp avrocpp_s) @@ -226,7 +227,7 @@ add_dependencies (AvrogencppTests bigrecord_hh bigrecord_r_hh bigrecord2_hh union_array_union_hh union_map_union_hh union_conflict_hh recursive_hh reuse_hh circulardep_hh tree1_hh tree2_hh crossref_hh primitivetypes_hh empty_record_hh cpp_reserved_words_union_typedef_hh - union_empty_record_hh) + union_empty_record_hh big_union_hh) include (InstallRequiredSystemLibraries) diff --git a/lang/c++/impl/avrogencpp.cc b/lang/c++/impl/avrogencpp.cc index 39da7af3539..c02ea554e39 100644 --- a/lang/c++/impl/avrogencpp.cc +++ b/lang/c++/impl/avrogencpp.cc @@ -313,12 +313,21 @@ static void generateGetterAndSetter(ostream &os, os << "inline\n"; - os << type << sn << "get_" << name << "() const {\n" + os << "const " << type << "&" << sn << "get_" << name << "() const {\n" << " if (idx_ != " << idx << ") {\n" << " throw avro::Exception(\"Invalid type for " << "union " << structName << "\");\n" << " }\n" - << " return std::any_cast<" << type << " >(value_);\n" + << " return *std::any_cast<" << type << " >(&value_);\n" + << "}\n\n"; + + os << "inline\n" + << type << "&" << sn << "get_" << name << "() {\n" + << " if (idx_ != " << idx << ") {\n" + << " throw avro::Exception(\"Invalid type for " + << "union " << structName << "\");\n" + << " }\n" + << " return *std::any_cast<" << type << " >(&value_);\n" << "}\n\n"; os << "inline\n" @@ -327,6 +336,13 @@ static void generateGetterAndSetter(ostream &os, << " idx_ = " << idx << ";\n" << " value_ = v;\n" << "}\n\n"; + + os << "inline\n" + << "void" << sn << "set_" << name + << "(" << type << "&& v) {\n" + << " idx_ = " << idx << ";\n" + << " value_ = std::move(v);\n" + << "}\n\n"; } static void generateConstructor(ostream &os, @@ -376,8 +392,33 @@ string CodeGen::generateUnionType(const NodePtr &n) { << "private:\n" << " size_t idx_;\n" << " std::any value_;\n" - << "public:\n" - << " size_t idx() const { return idx_; }\n"; + << "public:\n"; + + os_ << " /** enum representing union branches as returned by the idx() function */\n" + << " enum class Branch: size_t {\n"; + + // generate a enum that maps the branch name to the corresponding index (as returned by idx()) + std::set used_branch_names; + for (size_t i = 0; i < c; ++i) { + // escape reserved literals for c++ + auto branch_name = decorate(names[i]); + // avoid rare collisions, e.g. somone might name their struct int_ + if (used_branch_names.find(branch_name) != used_branch_names.end()) { + size_t postfix = 2; + std::string escaped_name = branch_name + "_" + std::to_string(postfix); + while (used_branch_names.find(escaped_name) != used_branch_names.end()) { + ++postfix; + escaped_name = branch_name + "_" + std::to_string(postfix); + } + branch_name = escaped_name; + } + os_ << " " << branch_name << " = " << i << ",\n"; + used_branch_names.insert(branch_name); + } + os_ << " };\n"; + + os_ << " size_t idx() const { return idx_; }\n"; + os_ << " Branch branch() const { return static_cast(idx_); }\n"; for (size_t i = 0; i < c; ++i) { const NodePtr &nn = n->leafAt(i); @@ -392,9 +433,11 @@ string CodeGen::generateUnionType(const NodePtr &n) { } else { const string &type = types[i]; const string &name = names[i]; - os_ << " " << type << " get_" << name << "() const;\n" - " void set_" - << name << "(const " << type << "& v);\n"; + os_ << " " + << "const " << type << "& get_" << name << "() const;\n" + << " " << type << "& get_" << name << "();\n" + << " void set_" << name << "(const " << type << "& v);\n" + << " void set_" << name << "(" << type << "&& v);\n"; pendingGettersAndSetters.emplace_back(result, type, name, i); } } @@ -645,7 +688,7 @@ void CodeGen::generateUnionTraits(const NodePtr &n) { os_ << " {\n" << " " << cppTypeOf(nn) << " vv;\n" << " avro::decode(d, vv);\n" - << " v.set_" << cppNameOf(nn) << "(vv);\n" + << " v.set_" << cppNameOf(nn) << "(std::move(vv));\n" << " }\n"; } os_ << " break;\n"; @@ -730,6 +773,7 @@ void CodeGen::generate(const ValidSchema &schema) { os_ << "#include \n" << "#include \n" + << "#include \n" << "#include \"" << includePrefix_ << "Specific.hh\"\n" << "#include \"" << includePrefix_ << "Encoder.hh\"\n" << "#include \"" << includePrefix_ << "Decoder.hh\"\n" diff --git a/lang/c++/jsonschemas/big_union b/lang/c++/jsonschemas/big_union new file mode 100644 index 00000000000..34cced4493b --- /dev/null +++ b/lang/c++/jsonschemas/big_union @@ -0,0 +1,101 @@ +{ + "type": "record", + "doc": "Top level Doc.", + "name": "RootRecord", + "fields": [ + { + "name": "big_union", + "doc": "A large union containing the primitive types, a array, a map and records.", + "type": [ + "null", + "boolean", + "int", + "long", + "float", + "double", + { + "type": "fixed", + "size": 16, + "name": "MD5" + }, + "string", + { + "type": "record", + "name": "Vec2", + "fields": [ + { + "name": "x", + "type": "long" + }, + { + "name": "y", + "type": "long" + } + ] + }, + { + "type": "record", + "name": "Vec3", + "fields": [ + { + "name": "x", + "type": "long" + }, + { + "name": "y", + "type": "long" + }, + { + "name": "z", + "type": "long" + } + ] + }, + { + "type": "enum", + "name": "Suit", + "symbols": [ + "SPADES", + "HEARTS", + "DIAMONDS", + "CLUBS" + ] + }, + { + "type": "array", + "items": "string", + "default": [] + }, + { + "type": "map", + "values": "long", + "default": {} + }, + { + "type": "record", + "name": "int_", + "doc": "try to force a collision with int", + "fields": [] + }, + { + "type": "record", + "name": "int__", + "doc": "try to force a collision with int", + "fields": [] + }, + { + "type": "record", + "name": "Int", + "doc": "name similar to primitive name", + "fields": [] + }, + { + "type": "record", + "name": "_Int", + "doc": "name with underscore as prefix", + "fields": [] + } + ] + } + ] +} diff --git a/lang/c++/test/AvrogencppTests.cc b/lang/c++/test/AvrogencppTests.cc index d393e373dc8..e7d5df92726 100644 --- a/lang/c++/test/AvrogencppTests.cc +++ b/lang/c++/test/AvrogencppTests.cc @@ -17,6 +17,7 @@ */ #include "Compiler.hh" +#include "big_union.hh" #include "bigrecord.hh" #include "bigrecord_r.hh" #include "tweet.hh" @@ -132,6 +133,14 @@ void checkDefaultValues(const testgen_r::RootRecord &r) { BOOST_CHECK_EQUAL(r.byteswithDefaultValue.get_bytes()[1], 0xaa); } +// enable use of BOOST_CHECK_EQUAL +template<> +struct boost::test_tools::tt_detail::print_log_value { + void operator()(std::ostream &stream, const big_union::RootRecord::big_union_t::Branch &branch) const { + stream << "big_union_t::Branch{" << static_cast(branch) << "}"; + } +}; + void testEncoding() { ValidSchema s; ifstream ifs("jsonschemas/bigrecord"); @@ -300,6 +309,105 @@ void testEmptyRecord() { BOOST_CHECK_EQUAL(calc2.stack[2].idx(), 2); } +void testUnionMethods() { + ValidSchema schema; + ifstream ifs_w("jsonschemas/bigrecord"); + compileJsonSchema(ifs_w, schema); + + testgen::RootRecord record; + // initialize the map and set values with getter + record.myunion.set_map({}); + record.myunion.get_map()["zero"] = 0; + record.myunion.get_map()["one"] = 1; + + std::vector bytes{1, 2, 3, 4}; + record.anotherunion.set_bytes(std::move(bytes)); + // after move assignment the local variable should be empty + BOOST_CHECK(bytes.empty()); + + unique_ptr out_stream = memoryOutputStream(); + EncoderPtr encoder = validatingEncoder(schema, binaryEncoder()); + encoder->init(*out_stream); + avro::encode(*encoder, record); + encoder->flush(); + + DecoderPtr decoder = validatingDecoder(schema, binaryDecoder()); + unique_ptr is = memoryInputStream(*out_stream); + decoder->init(*is); + testgen::RootRecord decoded_record; + avro::decode(*decoder, decoded_record); + + // check that a reference can be obtained from a union + BOOST_CHECK(decoded_record.myunion.branch() == testgen::RootRecord::myunion_t::Branch::map); + const std::map &read_map = decoded_record.myunion.get_map(); + BOOST_CHECK_EQUAL(read_map.size(), 2); + BOOST_CHECK_EQUAL(read_map.at("zero"), 0); + BOOST_CHECK_EQUAL(read_map.at("one"), 1); + + BOOST_CHECK(decoded_record.anotherunion.branch() == testgen::RootRecord::anotherunion_t::Branch::bytes); + const std::vector read_bytes = decoded_record.anotherunion.get_bytes(); + const std::vector expected_bytes{1, 2, 3, 4}; + BOOST_CHECK_EQUAL_COLLECTIONS(read_bytes.begin(), read_bytes.end(), expected_bytes.begin(), expected_bytes.end()); +} + +void testUnionBranchEnum() { + big_union::RootRecord record; + + using Branch = big_union::RootRecord::big_union_t::Branch; + + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::null); + record.big_union.set_null(); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::null); + + record.big_union.set_bool(false); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::bool_); + + record.big_union.set_int(123); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::int_); + + record.big_union.set_long(456); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::long_); + + record.big_union.set_float(555.555f); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::float_); + + record.big_union.set_double(777.777); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::double_); + + record.big_union.set_MD5({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::MD5); + + record.big_union.set_string("test"); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::string); + + record.big_union.set_Vec2({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Vec2); + + record.big_union.set_Vec3({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Vec3); + + record.big_union.set_Suit(big_union::Suit::CLUBS); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Suit); + + record.big_union.set_array({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::array); + + record.big_union.set_map({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::map); + + record.big_union.set_int_({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::int__2); + + record.big_union.set_int__({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::int__); + + record.big_union.set_Int({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Int); + + record.big_union.set__Int({}); + BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::_Int); +} + boost::unit_test::test_suite *init_unit_test_suite(int /*argc*/, char * /*argv*/[]) { auto *ts = BOOST_TEST_SUITE("Code generator tests"); ts->add(BOOST_TEST_CASE(testEncoding)); @@ -308,5 +416,7 @@ boost::unit_test::test_suite *init_unit_test_suite(int /*argc*/, char * /*argv*/ ts->add(BOOST_TEST_CASE(testEncoding2)); ts->add(BOOST_TEST_CASE(testNamespace)); ts->add(BOOST_TEST_CASE(testEmptyRecord)); + ts->add(BOOST_TEST_CASE(testUnionMethods)); + ts->add(BOOST_TEST_CASE(testUnionBranchEnum)); return ts; }