Skip to content

Commit

Permalink
Add a working example of reading several collections at the same time
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarcell committed Sep 25, 2023
1 parent 7dbda71 commit 3725a3f
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 11 deletions.
163 changes: 161 additions & 2 deletions k4FWCore/components/PodioInput.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,154 @@ void PodioInput::fillReaders() {
[&](std::string_view collName) {
maybeRead<podio::UserDataCollection<uint64_t>>(collName);
};
m_readers["std::vector<edm4hep::MCParticleCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::MCParticleCollection*>>(collName);
};
m_readers["std::vector<edm4hep::SimTrackerHitCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::SimTrackerHitCollection*>>(collName);
};
m_readers["std::vector<edm4hep::CaloHitContributionCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::CaloHitContributionCollection*>>(collName);
};
m_readers["std::vector<edm4hep::SimCalorimeterHitCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::SimCalorimeterHitCollection*>>(collName);
};
m_readers["std::vector<edm4hep::RawCalorimeterHitCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::RawCalorimeterHitCollection*>>(collName);
};
m_readers["std::vector<edm4hep::CalorimeterHitCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::CalorimeterHitCollection*>>(collName);
};
m_readers["std::vector<edm4hep::ParticleIDCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::ParticleIDCollection*>>(collName);
};
m_readers["std::vector<edm4hep::ClusterCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::ClusterCollection*>>(collName);
};
m_readers["std::vector<edm4hep::TrackerHitCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::TrackerHitCollection*>>(collName);
};
m_readers["std::vector<edm4hep::TrackerHitPlaneCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::TrackerHitPlaneCollection*>>(collName);
};
m_readers["std::vector<edm4hep::RawTimeSeriesCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::RawTimeSeriesCollection*>>(collName);
};
m_readers["std::vector<edm4hep::TrackCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::TrackCollection*>>(collName);
};
m_readers["std::vector<edm4hep::VertexCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::VertexCollection*>>(collName);
};
m_readers["std::vector<edm4hep::ReconstructedParticleCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::ReconstructedParticleCollection*>>(collName);
};
m_readers["std::vector<edm4hep::MCRecoParticleAssociationCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::MCRecoParticleAssociationCollection*>>(collName);
};
m_readers["std::vector<edm4hep::MCRecoCaloAssociationCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::MCRecoCaloAssociationCollection*>>(collName);
};
m_readers["std::vector<edm4hep::MCRecoTrackerAssociationCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::MCRecoTrackerAssociationCollection*>>(collName);
};
m_readers["std::vector<edm4hep::MCRecoTrackerHitPlaneAssociationCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::MCRecoTrackerHitPlaneAssociationCollection*>>(collName);
};
m_readers["std::vector<edm4hep::MCRecoClusterParticleAssociationCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::MCRecoClusterParticleAssociationCollection*>>(collName);
};
m_readers["std::vector<edm4hep::MCRecoTrackParticleAssociationCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::MCRecoTrackParticleAssociationCollection*>>(collName);
};
m_readers["std::vector<edm4hep::RecoParticleVertexAssociationCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::RecoParticleVertexAssociationCollection*>>(collName);
};
m_readers["std::vector<edm4hep::SimPrimaryIonizationClusterCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::SimPrimaryIonizationClusterCollection*>>(collName);
};
m_readers["std::vector<edm4hep::TrackerPulseCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::TrackerPulseCollection*>>(collName);
};
m_readers["std::vector<edm4hep::RecIonizationClusterCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::RecIonizationClusterCollection*>>(collName);
};
m_readers["std::vector<edm4hep::TimeSeriesCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::TimeSeriesCollection*>>(collName);
};
m_readers["std::vector<edm4hep::RecDqdxCollection>"] =
[&](std::string_view collName) {
maybeRead<std::vector<edm4hep::RecDqdxCollection*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<int>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<int>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<float>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<float>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<double>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<double>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<int8_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<int8_t>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<int16_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<int16_t>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<int32_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<int32_t>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<int64_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<int64_t>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<uint8_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<uint8_t>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<uint16_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<uint16_t>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<uint32_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<uint32_t>*>>(collName);
};
m_readers["std::vector<podio::UserDataCollection<uint64_t>>"] =
[&](std::string_view collName) {
maybeRead<std::vector<podio::UserDataCollection<uint64_t>*>>(collName);
};
}

PodioInput::PodioInput(const std::string& name, ISvcLocator* svcLoc) : Consumer(name, svcLoc) {
Expand All @@ -222,13 +370,24 @@ PodioInput::PodioInput(const std::string& name, ISvcLocator* svcLoc) : Consumer(

void PodioInput::operator()() const {
for (auto& collName : m_collectionNames) {
debug() << "Registering collection to read " << collName << endmsg;
auto type = m_podioDataSvc->getCollectionType(collName);
info() << "Registering collection to read " << collName << endmsg;

// We use the space as a separator for when reading multiple collections
std::string type;
std::string name;
if (collName.find(" ") != std::string::npos) {
auto first = collName.substr(0, collName.find(" "));
type = "std::vector<" + std::string(m_podioDataSvc->getCollectionType(first)) + ">";
}
else {
type = m_podioDataSvc->getCollectionType(collName);
}
if (m_readers.find(type) != m_readers.end()) {
m_readers[type](collName);
} else {
maybeRead<podio::CollectionBase>(collName);
}

}

// Tell data service that we are done with requested collections
Expand Down
51 changes: 42 additions & 9 deletions k4FWCore/include/k4FWCore/PodioDataSvc.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,50 @@ class PodioDataSvc : public DataSvc {

const std::string_view getCollectionType(const std::string& collName);

template<typename T>
struct is_vector
{
static constexpr bool value = false;
};

template<template<typename...> class C, typename U>
struct is_vector<C<U>>
{
static constexpr bool value =
std::is_same<C<U>,std::vector<U>>::value;
};

template <typename T>
StatusCode readCollection(const std::string& collName) {
const T* collection(nullptr);
collection = static_cast<const T*>(m_eventframe.get(collName));
if (collection == nullptr) {
error() << "Collection " << collName << " does not exist." << endmsg;
std::enable_if_t<!is_vector<T>::value,StatusCode>
readCollection(const std::string& collName) {
const T* collection = static_cast<const T*>(m_eventframe.get(collName));
if (!collection) {
error() << "Collection " << collName << " does not exist." << endmsg;
}
auto wrapper = new DataWrapper<T>;
wrapper->setData(collection);
m_podio_datawrappers.push_back(wrapper);
return DataSvc::registerObject("/Event", "/" + collName, wrapper);
}
auto wrapper = new DataWrapper<T>;
wrapper->setData(collection);
m_podio_datawrappers.push_back(wrapper);
return DataSvc::registerObject("/Event", "/" + collName, wrapper);

template <typename T>
std::enable_if_t<is_vector<T>::value,StatusCode>
readCollection(const std::string& collName) {
std::istringstream iss(collName);
std::string token;
auto vec = new std::vector<typename T::value_type>();
// Assume collName is a space-separated list of collection names
while (iss >> token) {
auto collection = dynamic_cast<typename T::value_type>(const_cast<podio::CollectionBase*>(m_eventframe.get(token)));
if (!collection) {
error() << "Collection " << token << " does not exist." << endmsg;
}
vec->push_back(collection);
}
auto wrapper = new DataWrapper<T>;
wrapper->setData(vec);
m_podio_datawrappers.push_back(wrapper);
return DataSvc::registerObject("/Event", "/" + collName, wrapper);
}

const podio::Frame& getEventFrame() const { return m_eventframe; }
Expand Down
8 changes: 8 additions & 0 deletions test/k4FWCoreTest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(k4fwcoretest_plugin_sources
src/components/ExampleFunctionalProducerMultiple.cpp
src/components/ExampleFunctionalConsumerMultiple.cpp
src/components/ExampleFunctionalTransformerMultiple.cpp
src/components/ExampleFunctionalConsumerSeveralColls.cpp
)

gaudi_add_module(k4FWCoreTestPlugins
Expand Down Expand Up @@ -243,3 +244,10 @@ add_test(NAME ExampleFunctionalTransformerMultiple
set_test_env(ExampleFunctionalTransformerMultiple)
set_tests_properties(ExampleFunctionalTransformerMultiple PROPERTIES
DEPENDS "ExampleFunctionalProducerMultiple;ExampleFunctionalConsumerMultiple")

add_test(NAME ExampleFunctionalConsumerSeveralColls
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
COMMAND ${K4RUN} options/runExampleFunctionalConsumerSeveralColls.py)
set_test_env(ExampleFunctionalConsumerSeveralColls)
set_tests_properties(ExampleFunctionalConsumerSeveralColls PROPERTIES
DEPENDS "ExampleFunctionalProducerMultiple")
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from Gaudi.Configuration import INFO
from Configurables import ExampleFunctionalConsumerSeveralColls
from Configurables import ApplicationMgr
from Configurables import k4DataSvc
from Configurables import PodioInput

podioevent = k4DataSvc("EventDataSvc")
podioevent.input = "output_k4test_exampledata_producer_multiple.root"

inp = PodioInput()
# We pass a space-separated list of collections to PodioInput to make sure
# they are pushed to the TES
inp.collections = [
"MCParticles1 MCParticles2",
]

consumer = ExampleFunctionalConsumerSeveralColls("ExampleFunctionalConsumerSeveralColls",
InputCollection="MCParticles1 MCParticles2",
)

ApplicationMgr(TopAlg=[inp, consumer],
EvtSel="NONE",
EvtMax=10,
ExtSvc=[podioevent],
OutputLevel=INFO,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "Gaudi/Property.h"
#include "GaudiAlg/Consumer.h"

#include "edm4hep/MCParticleCollection.h"

// Define BaseClass_t
#include "k4FWCore/BaseClass.h"

#include <string>

// Which type of collection we are reading
// When reading multiple collections
using colltype = std::vector<edm4hep::MCParticleCollection*>;

struct ExampleFunctionalConsumerSeveralColls final : Gaudi::Functional::Consumer<void(const colltype& input), BaseClass_t> {
// The pair in KeyValue can be changed from python and it corresponds
// to the name of the input collection
ExampleFunctionalConsumerSeveralColls(const std::string& name, ISvcLocator* svcLoc)
: Consumer(name, svcLoc,
KeyValue("InputCollection", "MCParticles1 MCParticles 2")
) {}

// This is the function that will be called to transform the data
// Note that the function has to be const, as well as the collections
// we get from the input
void operator()(const colltype& input) const override {
int i = 0;
for (auto ptr : input) {
if (i == 0) {
for (const auto& particle : *ptr) {
if ((particle.getPDG() != 1 + i) || (particle.getGeneratorStatus() != 2 + i) ||
(particle.getSimulatorStatus() != 3 + i) || (particle.getCharge() != 4 + i) ||
(particle.getTime() != 5 + i) || (particle.getMass() != 6 + i)) {
fatal() << "Wrong data in MCParticle collection";
}
i++;
}
}
else {
if (ptr->size() != 0) {
fatal() << "Wrong data in MCParticle collection";
}
}
}
}
};

DECLARE_COMPONENT(ExampleFunctionalConsumerSeveralColls)

0 comments on commit 3725a3f

Please sign in to comment.