Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a LinkNavigator utility #646

Merged
merged 15 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions doc/links.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,44 @@ and have that compiled into the library. This is necessary if you want to use
the python bindings, since they rely on dynamically loading the datamodel
libraries.

## The `LinkNavigator` utility

`podio::LinkCollection`s store each link separately even if a given object is
present in several links. Additionally, they don't offer any really easy way to
look up objects that are linked (apart from manually looping and comparing
elements). To alleviate these issues, we provide the `podio::LinkNavigator`
utility class that facilitates navigating links and lookups. It can be
constructed from any `podio::LinkCollection` and can then be used to retrieve
linked objects. E.g.

```cpp
const auto& recoMcLinks = event.get<edm4hep::RecoMCParticleLinkCollection>("RecoMCLinks");
const auto linkNavigator = podio::LinkNavigator(recoMcLinks);

// For podio::LinkCollections with disparate types just use getLinked
const auto linkedRecs = linkNavigator.getLinked(mcParticle);
```

If you want to be explicit about the lookup direction, e.g. in case you have a
link that has the same `From` and `To` type, you can use the overloads that take
a second *tag argument*:
```cpp
const auto linkedMCs = linkNavigator.getLinked(recoParticle, podio::ReturnTo);
```

The return type of all methods is a `std::vector<WeightedObject>`, where the
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The return type of all methods is a `std::vector<WeightedObject>`, where the
The return type of all methods is a `std::vector<podio::detail::links::WeightedObject>`, where the

With the implicit question whether we should lift this object out of the (implicitly private) detail namespace, since it is in principle user facing.

`WeightedObject` is a simple template class that wraps the object and its
weight. It supports structured bindings, so you can e.g. do the following

```cpp
for (const auto& [reco, weight] : linkedRecs) {
// do something with the reco particle and its weight
}
```

Alternatively, you can access the object via the `o` member and the weight via
the `weight` member.

## Implementation details

In order to give a slightly easier entry to the details of the implementation
Expand Down
171 changes: 171 additions & 0 deletions include/podio/LinkNavigator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#ifndef PODIO_LINKNAVIGATOR_H
#define PODIO_LINKNAVIGATOR_H

#include <map>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

namespace podio {

namespace detail::links {
/// A small struct that simply bundles an object and its weight for a more
/// convenient return value for the LinkNavigator
///
/// @note In most uses the names of the members should not really matter as it
/// is possible to us this via structured bindings
template <typename T>
struct WeightedObject {
WeightedObject(T obj, float w) : o(obj), weight(w) {
}
T o; ///< The object
float weight; ///< The weight in the link

bool operator==(const WeightedObject<T>& other) const {
return other.o == o && other.weight == weight;
}
};

/// Simple struct tag for overload selection in LinkNavigator below
struct ReturnFromTag {};
/// Simple struct tag for overload selection in LinkNavigator below
struct ReturnToTag {};
} // namespace detail::links

/// Tag variable to select the lookup of *From* objects have links with a *To*
/// object in podio::LinkNavigator::getLinked
static constexpr detail::links::ReturnFromTag ReturnFrom;
/// Tag variable to select the lookup of *To* objects that have links with a
/// *From* object in podio::LinkNavigator::getLinked
static constexpr detail::links::ReturnToTag ReturnTo;

/// A helper class to more easily handle one-to-many links.
///
/// Internally simply populates two maps in its constructor and then queries
/// them to retrieve objects that are linked with another.
///
/// @note There are no guarantees on the order of the objects in these maps.
/// Hence, there are also no guarantees on the order of the returned objects,
/// even if there inherintly is an order to them in the underlying links
/// collection.
template <typename LinkCollT>
class LinkNavigator {
using FromT = typename LinkCollT::from_type;
using ToT = typename LinkCollT::to_type;

template <typename T>
using WeightedObject = detail::links::WeightedObject<T>;

public:
/// Construct a navigator from an link collection
LinkNavigator(const LinkCollT& links);

/// We do only construct from a collection
LinkNavigator() = delete;
LinkNavigator(const LinkNavigator&) = default;
LinkNavigator& operator=(const LinkNavigator&) = default;
LinkNavigator(LinkNavigator&&) = default;
LinkNavigator& operator=(LinkNavigator&&) = default;
~LinkNavigator() = default;

/// Get all the *From* objects and weights that have links with the passed
/// object
///
/// You will get this overload if you pass the podio::LookupFrom tag as second
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
/// argument
///
/// @note This overload works always, even if the LinkCollection that was used
/// to construct this instance of the LinkNavigator has the same From and To
/// types.
///
/// @param object The object that is labeled *To* in the link
/// @param . tag variable for selecting this overload
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
std::vector<WeightedObject<FromT>> getLinked(const ToT& object, podio::detail::links::ReturnFromTag) const {
const auto& [begin, end] = m_to2from.equal_range(object);
std::vector<WeightedObject<FromT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

/// Get all the *From* objects and weights that have links with the passed
/// object
///
/// @note This overload will automatically do the right thing (TM) in case the
/// LinkCollection that has been passed to construct this LinkNavigator has
/// different From and To types.
///
/// @param object The object that is labeled *To* in the link
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
template <typename ToU = ToT>
std::enable_if_t<!std::is_same_v<FromT, ToU>, std::vector<WeightedObject<FromT>>> getLinked(const ToT& object) const {
return getLinked(object, podio::ReturnFrom);
}

/// Get all the *To* objects and weights that have links with the passed
/// object
///
/// You will get this overload if you pass the podio::LookupTo tag as second
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
/// argument
///
/// @note This overload works always, even if the LinkCollection that was used
/// to construct this instance of the LinkNavigator has the same From and To
/// types.
///
/// @param object The object that is labeled *From* in the link
/// @param . tag variable for selecting this overload
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
std::vector<WeightedObject<ToT>> getLinked(const FromT& object, podio::detail::links::ReturnToTag) const {
const auto& [begin, end] = m_from2to.equal_range(object);
std::vector<WeightedObject<ToT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

/// Get all the *To* objects and weights that have links with the passed
/// object
///
/// @note This overload will automatically do the right thing (TM) in case the
/// LinkCollection that has been passed to construct this LinkNavigator has
/// different From and To types.
///
/// @param object The object that is labeled *From* in the link
///
/// @returns A vector of all objects and their weights that have links with
/// the passed object
template <typename FromU = FromT>
std::enable_if_t<!std::is_same_v<FromU, ToT>, std::vector<WeightedObject<ToT>>> getLinked(const FromT& object) const {
return getLinked(object, podio::ReturnTo);
}

private:
std::multimap<FromT, WeightedObject<ToT>> m_from2to{}; ///< Map the from to the to objects
std::multimap<ToT, WeightedObject<FromT>> m_to2from{}; ///< Map the to to the from objects
};

template <typename LinkCollT>
LinkNavigator<LinkCollT>::LinkNavigator(const LinkCollT& links) {
for (const auto& [from, to, weight] : links) {
m_from2to.emplace(std::piecewise_construct, std::forward_as_tuple(from), std::forward_as_tuple(to, weight));
m_to2from.emplace(std::piecewise_construct, std::forward_as_tuple(to), std::forward_as_tuple(from, weight));
}
}

} // namespace podio

#endif // PODIO_LINKNAVIGATOR_H
2 changes: 2 additions & 0 deletions include/podio/detail/LinkCollectionImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class LinkCollection : public podio::CollectionBase {
using CollectionDataT = podio::LinkCollectionData<FromT, ToT>;

public:
using from_type = FromT;
using to_type = ToT;
using value_type = Link<FromT, ToT>;
using mutable_type = MutableLink<FromT, ToT>;
using const_iterator = LinkCollectionIterator<FromT, ToT>;
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
if(CMAKE_CXX_STANDARD GREATER_EQUAL 20)
set(CATCH2_MIN_VERSION 3.4)
else()
set(CATCH2_MIN_VERSION 3.1)
set(CATCH2_MIN_VERSION 3.3)
endif()
if(USE_EXTERNAL_CATCH2)
if (USE_EXTERNAL_CATCH2 STREQUAL AUTO)
Expand Down
87 changes: 87 additions & 0 deletions tests/unittests/links.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "catch2/catch_test_macros.hpp"
#include "catch2/matchers/catch_matchers_vector.hpp"

#include "podio/LinkCollection.h"
#include "podio/LinkNavigator.h"

#include "datamodel/ExampleClusterCollection.h"
#include "datamodel/ExampleHitCollection.h"
Expand Down Expand Up @@ -473,3 +475,88 @@ TEST_CASE("Link JSON conversion", "[links][json]") {
}

#endif

TEST_CASE("LinkNavigator basics", "[links]") {
TestLColl coll{};
std::vector<ExampleHit> hits(11);
std::vector<ExampleCluster> clusters(3);

for (size_t i = 0; i < 10; ++i) {
auto a = coll.create();
a.set(hits[i]);
a.set(clusters[i % 3]);
a.setWeight(i * 0.1f);
}

auto a = coll.create();
a.set(hits[10]);

podio::LinkNavigator nav{coll};

for (size_t i = 0; i < 10; ++i) {
const auto& hit = hits[i];
const auto linkedClusters = nav.getLinked(hit);
REQUIRE(linkedClusters.size() == 1);
const auto& [cluster, weight] = linkedClusters[0];
REQUIRE(cluster == clusters[i % 3]);
REQUIRE(weight == i * 0.1f);
}

using Catch::Matchers::UnorderedEquals;
using podio::detail::links::WeightedObject;
using WeightedHits = std::vector<WeightedObject<ExampleHit>>;

auto linkedHits = nav.getLinked(clusters[0]);
REQUIRE_THAT(linkedHits,
UnorderedEquals(WeightedHits{WeightedObject{hits[0], 0.f}, WeightedObject{hits[3], 3 * 0.1f},
WeightedObject{hits[6], 6 * 0.1f}, WeightedObject{hits[9], 9 * 0.1f}}));

linkedHits = nav.getLinked(clusters[1]);
REQUIRE_THAT(linkedHits,
UnorderedEquals(WeightedHits{WeightedObject{hits[1], 0.1f}, WeightedObject{hits[4], 0.4f},
WeightedObject{hits[7], 0.7f}}));

const auto [noCluster, noWeight] = nav.getLinked(hits[10])[0];
REQUIRE_FALSE(noCluster.isAvailable());
}

TEST_CASE("LinkNavigator same types", "[links]") {
std::vector<ExampleCluster> clusters(3);
auto linkColl = podio::LinkCollection<ExampleCluster, ExampleCluster>{};
auto link = linkColl.create();
link.setFrom(clusters[0]);
link.setTo(clusters[1]);
link.setWeight(0.5f);

link = linkColl.create();
link.setFrom(clusters[0]);
link.setTo(clusters[2]);
link.setWeight(0.25f);

link = linkColl.create();
link.setFrom(clusters[1]);
link.setTo(clusters[2]);
link.setWeight(0.66f);

auto navigator = podio::LinkNavigator{linkColl};
auto linkedClusters = navigator.getLinked(clusters[1], podio::ReturnTo);
REQUIRE(linkedClusters.size() == 1);
REQUIRE(linkedClusters[0].o == clusters[2]);
REQUIRE(linkedClusters[0].weight == 0.66f);

linkedClusters = navigator.getLinked(clusters[1], podio::ReturnFrom);
REQUIRE(linkedClusters.size() == 1);
REQUIRE(linkedClusters[0].o == clusters[0]);
REQUIRE(linkedClusters[0].weight == 0.5f);

using Catch::Matchers::UnorderedEquals;
using podio::detail::links::WeightedObject;
using WeightedObjVec = std::vector<WeightedObject<ExampleCluster>>;
linkedClusters = navigator.getLinked(clusters[0], podio::ReturnTo);
REQUIRE_THAT(linkedClusters,
UnorderedEquals(WeightedObjVec{WeightedObject(clusters[1], 0.5f), WeightedObject{clusters[2], 0.25f}}));

linkedClusters = navigator.getLinked(clusters[2], podio::ReturnFrom);
REQUIRE_THAT(linkedClusters,
UnorderedEquals(WeightedObjVec{WeightedObject{clusters[0], 0.25f}, WeightedObject{clusters[1], 0.66f}}));
}
Loading