From 3f580b6e9813955f2968c8741ba1c0778b1998a8 Mon Sep 17 00:00:00 2001 From: tmadlener Date: Wed, 24 Jul 2024 11:39:59 +0200 Subject: [PATCH] Add AssociationNavigator utility class and tests --- include/podio/AssociationNavigator.h | 106 ++++++++++++++++++ .../podio/detail/AssociationCollectionImpl.h | 2 + tests/unittests/associations.cpp | 49 ++++++++ 3 files changed, 157 insertions(+) create mode 100644 include/podio/AssociationNavigator.h diff --git a/include/podio/AssociationNavigator.h b/include/podio/AssociationNavigator.h new file mode 100644 index 000000000..fccd2a328 --- /dev/null +++ b/include/podio/AssociationNavigator.h @@ -0,0 +1,106 @@ +#ifndef PODIO_ASSOCIATIONNAVIGATOR_H +#define PODIO_ASSOCIATIONNAVIGATOR_H + +#include "podio/detail/AssociationFwd.h" + +#include +#include +#include +#include + +namespace podio { + +namespace detail::associations { + /// A small struct that simply bundles an object and its weight for a more + /// convenient return value for the AssociationNavigator + /// + /// @note In most uses the names of the members should not really matter as it + /// is possible to us this via structured bindings + template + struct WeightedObject { + WeightedObject(T obj, float w) : o(obj), weight(w) { + } + T o; ///< The object + float weight; ///< The weight in the association + }; +} // namespace detail::associations + +/// A helper class to more easily handle one-to-many associations. +/// +/// Internally simply populates two maps in its constructor and then queries +/// them to retrieve objects that are associated with another. +/// +/// @note There are no guarantees on the order of the objects in these maps. +/// Hence, there are also no guarantees on the order of the returned objects, +/// even if there inherintly is an order to them in the underlying associations +/// collection. +template +class AssociationNavigator { + using FromT = AssociationCollT::from_type; + using ToT = AssociationCollT::to_type; + + template + using WeightedObject = detail::associations::WeightedObject; + +public: + /// Construct a navigator from an association collection + AssociationNavigator(const AssociationCollT& associations); + + /// We do only construct from a collection + AssociationNavigator() = delete; + AssociationNavigator(const AssociationNavigator&) = default; + AssociationNavigator& operator=(const AssociationNavigator&) = default; + AssociationNavigator(AssociationNavigator&&) = default; + AssociationNavigator& operator=(AssociationNavigator&&) = default; + ~AssociationNavigator() = default; + + /// Get all the objects and weights that are associated to the passed object + /// + /// @param object The object that is labeled *to* in the association + /// + /// @returns A vector of all objects and their weights that are associated to + /// the passed object + std::vector> getAssociated(const ToT& object) const { + const auto& [begin, end] = m_to2from.equal_range(object); + std::vector> result; + result.reserve(std::distance(begin, end)); + + for (auto it = begin; it != end; ++it) { + result.emplace_back(it->second); + } + return result; + } + + /// Get all the objects and weights that are associated to the passed object + /// + /// @param object The object that is labeled *from* in the association + /// + /// @returns A vector of all objects and their weights that are associated to + /// the passed object + std::vector> getAssociated(const FromT& object) const { + const auto& [begin, end] = m_from2to.equal_range(object); + std::vector> result; + result.reserve(std::distance(begin, end)); + + for (auto it = begin; it != end; ++it) { + result.emplace_back(it->second); + } + return result; + } + +private: + std::multimap> m_from2to; ///< Map the from to the to objects + std::multimap> m_to2from; ///< Map the to to the from objects +}; + +template +AssociationNavigator::AssociationNavigator(const AssociationCollT& associations) { + for (const auto& [from, to, weight] : associations) { + m_from2to.emplace(std::piecewise_construct, std::forward_as_tuple(from), std::forward_as_tuple(to, weight)); + m_to2from.emplace(std::piecewise_construct, std::forward_as_tuple(to), std::forward_as_tuple(from, weight)); + } +} + +} // namespace podio + +#endif // PODIO_ASSOCIATIONNAVIGATOR_H diff --git a/include/podio/detail/AssociationCollectionImpl.h b/include/podio/detail/AssociationCollectionImpl.h index 3123b9707..481083bf4 100644 --- a/include/podio/detail/AssociationCollectionImpl.h +++ b/include/podio/detail/AssociationCollectionImpl.h @@ -40,6 +40,8 @@ class AssociationCollection : public podio::CollectionBase { using CollectionDataT = podio::AssociationCollectionData; public: + using from_type = FromT; + using to_type = ToT; using value_type = Association; using mutable_type = MutableAssociation; using const_iterator = AssociationCollectionIterator; diff --git a/tests/unittests/associations.cpp b/tests/unittests/associations.cpp index 7ff03d9f8..b651edd47 100644 --- a/tests/unittests/associations.cpp +++ b/tests/unittests/associations.cpp @@ -1,6 +1,7 @@ #include "catch2/catch_test_macros.hpp" #include "podio/AssociationCollection.h" +#include "podio/AssociationNavigator.h" #include "datamodel/ExampleClusterCollection.h" #include "datamodel/ExampleHitCollection.h" @@ -396,3 +397,51 @@ TEST_CASE("AssociationCollection movability", "[associations][move-semantics][co REQUIRE(evenNewerAssocs.isSubsetCollection()); } } + +TEST_CASE("AssociationNavigator basics", "[asssociations]") { + TestAColl coll{}; + std::vector hits(11); + std::vector clusters(3); + + for (size_t i = 0; i < 10; ++i) { + auto a = coll.create(); + a.set(hits[i]); + a.set(clusters[i % 3]); + a.setWeight(i * 0.1f); + } + + auto a = coll.create(); + a.set(hits[10]); + + podio::AssociationNavigator nav{coll}; + + for (size_t i = 0; i < 10; ++i) { + const auto& hit = hits[i]; + const auto assocClusters = nav.getAssociated(hit); + REQUIRE(assocClusters.size() == 1); + const auto& [cluster, weight] = assocClusters[0]; + REQUIRE(cluster == clusters[i % 3]); + REQUIRE(weight == i * 0.1f); + } + + const auto& cluster1 = clusters[0]; + auto assocHits = nav.getAssociated(cluster1); + REQUIRE(assocHits.size() == 4); + for (size_t i = 0; i < 4; ++i) { + const auto& [hit, weight] = assocHits[i]; + REQUIRE(hit == hits[i * 3]); + REQUIRE(weight == i * 3 * 0.1f); + } + + const auto& cluster2 = clusters[1]; + assocHits = nav.getAssociated(cluster2); + REQUIRE(assocHits.size() == 3); + for (size_t i = 0; i < 3; ++i) { + const auto& [hit, weight] = assocHits[i]; + REQUIRE(hit == hits[i * 3 + 1]); + REQUIRE(weight == (i * 3 + 1) * 0.1f); + } + + const auto [noCluster, noWeight] = nav.getAssociated(hits[10])[0]; + REQUIRE_FALSE(noCluster.isAvailable()); +}