Skip to content

Commit

Permalink
Add AssociationNavigator utility class and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tmadlener committed Jul 24, 2024
1 parent d7886b8 commit 3f580b6
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 0 deletions.
106 changes: 106 additions & 0 deletions include/podio/AssociationNavigator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#ifndef PODIO_ASSOCIATIONNAVIGATOR_H
#define PODIO_ASSOCIATIONNAVIGATOR_H

#include "podio/detail/AssociationFwd.h"

#include <map>
#include <tuple>
#include <utility>
#include <vector>

namespace podio {

namespace detail::associations {
/// A small struct that simply bundles an object and its weight for a more
/// convenient return value for the AssociationNavigator
///
/// @note In most uses the names of the members should not really matter as it
/// is possible to us this via structured bindings
template <typename T>
struct WeightedObject {
WeightedObject(T obj, float w) : o(obj), weight(w) {
}
T o; ///< The object
float weight; ///< The weight in the association
};
} // namespace detail::associations

/// A helper class to more easily handle one-to-many associations.
///
/// Internally simply populates two maps in its constructor and then queries
/// them to retrieve objects that are associated with another.
///
/// @note There are no guarantees on the order of the objects in these maps.
/// Hence, there are also no guarantees on the order of the returned objects,
/// even if there inherintly is an order to them in the underlying associations
/// collection.
template <typename AssociationCollT>
class AssociationNavigator {
using FromT = AssociationCollT::from_type;
using ToT = AssociationCollT::to_type;

template <typename T>
using WeightedObject = detail::associations::WeightedObject<T>;

public:
/// Construct a navigator from an association collection
AssociationNavigator(const AssociationCollT& associations);

/// We do only construct from a collection
AssociationNavigator() = delete;
AssociationNavigator(const AssociationNavigator&) = default;
AssociationNavigator& operator=(const AssociationNavigator&) = default;
AssociationNavigator(AssociationNavigator&&) = default;
AssociationNavigator& operator=(AssociationNavigator&&) = default;
~AssociationNavigator() = default;

/// Get all the objects and weights that are associated to the passed object
///
/// @param object The object that is labeled *to* in the association
///
/// @returns A vector of all objects and their weights that are associated to
/// the passed object
std::vector<WeightedObject<FromT>> getAssociated(const ToT& object) const {
const auto& [begin, end] = m_to2from.equal_range(object);
std::vector<WeightedObject<FromT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

/// Get all the objects and weights that are associated to the passed object
///
/// @param object The object that is labeled *from* in the association
///
/// @returns A vector of all objects and their weights that are associated to
/// the passed object
std::vector<WeightedObject<ToT>> getAssociated(const FromT& object) const {
const auto& [begin, end] = m_from2to.equal_range(object);
std::vector<WeightedObject<ToT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

private:
std::multimap<FromT, WeightedObject<ToT>> m_from2to; ///< Map the from to the to objects
std::multimap<ToT, WeightedObject<FromT>> m_to2from; ///< Map the to to the from objects
};

template <typename AssociationCollT>
AssociationNavigator<AssociationCollT>::AssociationNavigator(const AssociationCollT& associations) {
for (const auto& [from, to, weight] : associations) {
m_from2to.emplace(std::piecewise_construct, std::forward_as_tuple(from), std::forward_as_tuple(to, weight));
m_to2from.emplace(std::piecewise_construct, std::forward_as_tuple(to), std::forward_as_tuple(from, weight));
}
}

} // namespace podio

#endif // PODIO_ASSOCIATIONNAVIGATOR_H
2 changes: 2 additions & 0 deletions include/podio/detail/AssociationCollectionImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class AssociationCollection : public podio::CollectionBase {
using CollectionDataT = podio::AssociationCollectionData<FromT, ToT>;

public:
using from_type = FromT;
using to_type = ToT;
using value_type = Association<FromT, ToT>;
using mutable_type = MutableAssociation<FromT, ToT>;
using const_iterator = AssociationCollectionIterator<FromT, ToT>;
Expand Down
49 changes: 49 additions & 0 deletions tests/unittests/associations.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "catch2/catch_test_macros.hpp"

#include "podio/AssociationCollection.h"
#include "podio/AssociationNavigator.h"

#include "datamodel/ExampleClusterCollection.h"
#include "datamodel/ExampleHitCollection.h"
Expand Down Expand Up @@ -396,3 +397,51 @@ TEST_CASE("AssociationCollection movability", "[associations][move-semantics][co
REQUIRE(evenNewerAssocs.isSubsetCollection());
}
}

TEST_CASE("AssociationNavigator basics", "[asssociations]") {
TestAColl coll{};
std::vector<ExampleHit> hits(11);
std::vector<ExampleCluster> clusters(3);

for (size_t i = 0; i < 10; ++i) {
auto a = coll.create();
a.set(hits[i]);
a.set(clusters[i % 3]);
a.setWeight(i * 0.1f);
}

auto a = coll.create();
a.set(hits[10]);

podio::AssociationNavigator nav{coll};

for (size_t i = 0; i < 10; ++i) {
const auto& hit = hits[i];
const auto assocClusters = nav.getAssociated(hit);
REQUIRE(assocClusters.size() == 1);
const auto& [cluster, weight] = assocClusters[0];
REQUIRE(cluster == clusters[i % 3]);
REQUIRE(weight == i * 0.1f);
}

const auto& cluster1 = clusters[0];
auto assocHits = nav.getAssociated(cluster1);
REQUIRE(assocHits.size() == 4);
for (size_t i = 0; i < 4; ++i) {
const auto& [hit, weight] = assocHits[i];
REQUIRE(hit == hits[i * 3]);
REQUIRE(weight == i * 3 * 0.1f);
}

const auto& cluster2 = clusters[1];
assocHits = nav.getAssociated(cluster2);
REQUIRE(assocHits.size() == 3);
for (size_t i = 0; i < 3; ++i) {
const auto& [hit, weight] = assocHits[i];
REQUIRE(hit == hits[i * 3 + 1]);
REQUIRE(weight == (i * 3 + 1) * 0.1f);
}

const auto [noCluster, noWeight] = nav.getAssociated(hits[10])[0];
REQUIRE_FALSE(noCluster.isAvailable());
}

0 comments on commit 3f580b6

Please sign in to comment.