Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the TTree and RNTuple based backends write the GenericParameters the same way #625

Merged
merged 16 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/podio/GenericParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class RNTupleWriter;

namespace podio {

#if !defined(__CLING__)
// cling doesn't really deal well (i.e. at all in this case) with the forward
// declaration here and errors out, breaking e.g. python bindings.
class ROOTReader;
#endif

/// The types which are supported in the GenericParameters
using SupportedGenericDataTypes = std::tuple<int, float, std::string, double>;

Expand Down Expand Up @@ -134,6 +140,10 @@ class GenericParameters {
friend RNTupleWriter;
#endif

#if !defined(__CLING__)
friend ROOTReader;
#endif

/// Get a reference to the internal map for a given type
template <typename T>
const MapType<detail::GetVectorType<T>>& getMap() const {
Expand Down
4 changes: 4 additions & 0 deletions include/podio/ROOTReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ class ROOTReader {
/// Read the parameters for the entry specified in the passed CategoryInfo
GenericParameters readEntryParameters(CategoryInfo& catInfo, bool reloadBranches, unsigned int localEntry);

template <typename T>
static void readParams(CategoryInfo& catInfo, podio::GenericParameters& params, bool reloadBranches,
unsigned int localEntry);

/// Read the data entry specified in the passed CategoryInfo, and increase the
/// counter afterwards. In case the requested entry is larger than the
/// available number of entries, return a nullptr.
Expand Down
14 changes: 11 additions & 3 deletions include/podio/ROOTWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ class ROOTWriter {
std::vector<root_utils::CollectionWriteInfoT> collInfo{}; ///< Collection info for this category
podio::CollectionIDTable idTable{}; ///< The collection id table for this category
std::vector<std::string> collsToWrite{}; ///< The collections to write for this category

// Storage for the keys & values of all the parameters of this category
// (resp. at least the current entry)
root_utils::ParamStorage<int> intParams{};
root_utils::ParamStorage<float> floatParams{};
root_utils::ParamStorage<double> doubleParams{};
root_utils::ParamStorage<std::string> stringParams{};
};

/// Initialize the branches for this category
Expand All @@ -117,9 +124,10 @@ class ROOTWriter {
/// Get the (potentially uninitialized category information for this category)
CategoryInfo& getCategoryInfo(const std::string& category);

static void resetBranches(std::vector<root_utils::CollectionBranches>& branches,
const std::vector<root_utils::StoreCollection>& collections,
/*const*/ podio::GenericParameters* parameters);
static void resetBranches(CategoryInfo& categoryInfo, const std::vector<root_utils::StoreCollection>& collections);

/// Fill the parameter keys and values into the CategoryInfo storage
static void fillParams(CategoryInfo& catInfo, const GenericParameters& params);

std::unique_ptr<TFile> m_file{nullptr}; ///< The storage file
std::unordered_map<std::string, CategoryInfo> m_categories{}; ///< All categories
Expand Down
39 changes: 38 additions & 1 deletion include/podio/utilities/RootHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ namespace root_utils {
/// write a collection. Needed to cache the branch pointers and avoid having to
/// get them from a TTree/TChain for every event.
struct CollectionBranches {
CollectionBranches() = default;
~CollectionBranches() = default;
CollectionBranches(const CollectionBranches&) = delete;
CollectionBranches& operator=(const CollectionBranches&) = delete;
CollectionBranches(CollectionBranches&&) = default;
CollectionBranches& operator=(CollectionBranches&&) = default;

CollectionBranches(TBranch* dataBranch) : data(dataBranch) {
}

TBranch* data{nullptr};
std::vector<TBranch*> refs{};
std::vector<TBranch*> vecs{};
Expand All @@ -36,7 +46,34 @@ namespace root_utils {
/// Pair of keys and values for one type of the ones that can be stored in
/// GenericParameters
template <typename T>
using ParamStorage = std::tuple<std::vector<std::string>, std::vector<std::vector<T>>>;
struct ParamStorage {
ParamStorage() = default;
~ParamStorage() = default;
ParamStorage(const ParamStorage&) = delete;
ParamStorage& operator=(const ParamStorage&) = delete;
ParamStorage(ParamStorage&&) = default;
ParamStorage& operator=(ParamStorage&&) = default;

ParamStorage(const std::vector<std::string>& ks, const std::vector<std::vector<T>>& vs) : keys(ks), values(vs) {
}

auto keysPtr() {
m_keysPtr = &keys;
return &m_keysPtr;
}

auto valuesPtr() {
m_valuesPtr = &values;
return &m_valuesPtr;
}
tmadlener marked this conversation as resolved.
Show resolved Hide resolved

std::vector<std::string> keys{};
std::vector<std::vector<T>> values{};

private:
std::vector<std::string>* m_keysPtr{nullptr};
std::vector<std::vector<T>>* m_valuesPtr{nullptr};
};

} // namespace root_utils
} // namespace podio
Expand Down
4 changes: 3 additions & 1 deletion src/RNTupleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ root_utils::ParamStorage<T>& RNTupleWriter::getParamStorage(CategoryInfo& catInf
template <typename T>
void RNTupleWriter::fillParams(const GenericParameters& params, CategoryInfo& catInfo,
ROOT::Experimental::REntry* entry) {
auto& [keys, values] = getParamStorage<T>(catInfo);
auto& paramStorage = getParamStorage<T>(catInfo);
auto& keys = paramStorage.keys;
auto& values = paramStorage.values;
keys = params.getKeys<T>();
values = params.getValues<T>();
#if ROOT_VERSION_CODE >= ROOT_VERSION(6, 31, 0)
Expand Down
2 changes: 1 addition & 1 deletion src/ROOTLegacyReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void ROOTLegacyReader::createCollectionBranches(const std::vector<root_utils::Co

m_storedClasses.emplace_back(name, std::make_tuple(collType, isSubsetColl, collSchemaVersion, collectionIndex++));

m_collectionBranches.push_back(branches);
m_collectionBranches.emplace_back(std::move(branches));
}
}

Expand Down
112 changes: 93 additions & 19 deletions src/ROOTReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
#include "podio/CollectionIDTable.h"
#include "podio/DatamodelRegistry.h"
#include "podio/GenericParameters.h"
#include "podio/utilities/RootHelpers.h"
#include "rootUtils.h"

// ROOT specific includes
#include "TChain.h"
#include "TClass.h"
#include "TFile.h"
#include "TTree.h"
#include "TTreeCache.h"

#include <stdexcept>
Expand All @@ -27,22 +26,85 @@ std::tuple<std::vector<root_utils::CollectionBranches>, std::vector<std::pair<st
createCollectionBranchesIndexBased(TChain* chain, const podio::CollectionIDTable& idTable,
const std::vector<root_utils::CollectionWriteInfoT>& collInfo);

/// Helper struct to get the negative offsets from the end of the branches
/// vector for the different types of generic parameters.
template <typename T>
struct TypeToBranchIndexOffset;

template <>
struct TypeToBranchIndexOffset<int> {
constexpr static int keys = 8;
constexpr static int values = 7;
};

template <>
struct TypeToBranchIndexOffset<float> {
constexpr static int keys = 6;
constexpr static int values = 5;
};

template <>
struct TypeToBranchIndexOffset<double> {
constexpr static int keys = 4;
constexpr static int values = 3;
};

template <>
struct TypeToBranchIndexOffset<std::string> {
constexpr static int keys = 2;
constexpr static int values = 1;
};

template <typename T>
void ROOTReader::readParams(ROOTReader::CategoryInfo& catInfo, podio::GenericParameters& params, bool reloadBranches,
unsigned int localEntry) {
const auto nBranches = catInfo.branches.size();
if (reloadBranches) {
auto& keyBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::keys].data;
keyBranch = root_utils::getBranch(catInfo.chain.get(), root_utils::getGPKeyName<T>());
auto& valueBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::values].data;
valueBranch = root_utils::getBranch(catInfo.chain.get(), root_utils::getGPValueName<T>());
}

auto keyBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::keys].data;
auto valueBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::values].data;

root_utils::ParamStorage<T> storage;
keyBranch->SetAddress(storage.keysPtr());
keyBranch->GetEntry(localEntry);
valueBranch->SetAddress(storage.valuesPtr());
valueBranch->GetEntry(localEntry);

for (size_t i = 0; i < storage.keys.size(); ++i) {
params.getMap<T>().emplace(std::move(storage.keys[i]), std::move(storage.values[i]));
}
}

GenericParameters ROOTReader::readEntryParameters(ROOTReader::CategoryInfo& catInfo, bool reloadBranches,
unsigned int localEntry) {
// Parameter branch is always the last one
auto& paramBranches = catInfo.branches.back();
GenericParameters params;

// Make sure to have a valid branch pointer after switching trees in the chain
// as well as on the first event
if (reloadBranches) {
paramBranches.data = root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName);
if (m_fileVersion < podio::version::Version{0, 99, 99}) {
// Parameter branch is always the last one
auto& paramBranches = catInfo.branches.back();

// Make sure to have a valid branch pointer after switching trees in the chain
// as well as on the first event
if (reloadBranches) {
paramBranches.data = root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName);
}
auto* branch = paramBranches.data;

auto* emd = &params;
branch->SetAddress(&emd);
branch->GetEntry(localEntry);
} else {
readParams<int>(catInfo, params, reloadBranches, localEntry);
readParams<float>(catInfo, params, reloadBranches, localEntry);
readParams<double>(catInfo, params, reloadBranches, localEntry);
readParams<std::string>(catInfo, params, reloadBranches, localEntry);
}
auto* branch = paramBranches.data;

GenericParameters params;
auto* emd = &params;
branch->SetAddress(&emd);
branch->GetEntry(localEntry);
return params;
}

Expand Down Expand Up @@ -162,13 +224,25 @@ void ROOTReader::initCategory(CategoryInfo& catInfo, const std::string& category
std::tie(catInfo.branches, catInfo.storedClasses) =
createCollectionBranches(catInfo.chain.get(), *catInfo.table, *collInfo);
}

delete collInfo;

// Finally set up the branches for the parameters
root_utils::CollectionBranches paramBranches{};
paramBranches.data = root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName);
catInfo.branches.push_back(paramBranches);
if (m_fileVersion < podio::version::Version{0, 99, 99}) {
root_utils::CollectionBranches paramBranches{};
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName));
} else {
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::intKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::intValueName));

catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::floatKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::floatValueName));

catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::doubleKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::doubleValueName));

catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::stringKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::stringValueName));
}
}

std::vector<std::string> getAvailableCategories(TChain* metaChain) {
Expand Down Expand Up @@ -300,7 +374,7 @@ createCollectionBranchesIndexBased(TChain* chain, const podio::CollectionIDTable
collBranches.emplace_back(std::move(branches));
}

return {collBranches, storedClasses};
return {std::move(collBranches), storedClasses};
}

std::tuple<std::vector<root_utils::CollectionBranches>, std::vector<std::pair<std::string, detail::CollectionInfo>>>
Expand Down Expand Up @@ -346,7 +420,7 @@ createCollectionBranches(TChain* chain, const podio::CollectionIDTable& idTable,
collBranches.emplace_back(std::move(branches));
}

return {collBranches, storedClasses};
return {std::move(collBranches), storedClasses};
}

} // namespace podio
51 changes: 39 additions & 12 deletions src/ROOTWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "rootUtils.h"

#include "TTree.h"
#include <tuple>

namespace podio {

Expand Down Expand Up @@ -61,7 +62,8 @@ void ROOTWriter::writeFrame(const podio::Frame& frame, const std::string& catego
throw std::runtime_error("Trying to write category '" + category + "' with inconsistent collection content. " +
root_utils::getInconsistentCollsMsg(catInfo.collsToWrite, collsToWrite));
}
resetBranches(catInfo.branches, collections, &const_cast<podio::GenericParameters&>(frame.getParameters()));
fillParams(catInfo, frame.getParameters());
resetBranches(catInfo, collections);
}

catInfo.tree->Fill();
Expand All @@ -78,7 +80,8 @@ ROOTWriter::CategoryInfo& ROOTWriter::getCategoryInfo(const std::string& categor

void ROOTWriter::initBranches(CategoryInfo& catInfo, const std::vector<root_utils::StoreCollection>& collections,
/*const*/ podio::GenericParameters& parameters) {
catInfo.branches.reserve(collections.size() + 1); // collections + parameters
catInfo.branches.reserve(collections.size() +
std::tuple_size_v<podio::SupportedGenericDataTypes> * 2); // collections + parameters

// First collections
for (auto& [name, coll] : collections) {
Expand Down Expand Up @@ -117,28 +120,45 @@ void ROOTWriter::initBranches(CategoryInfo& catInfo, const std::vector<root_util
}
}

catInfo.branches.push_back(branches);
catInfo.branches.emplace_back(std::move(branches));
catInfo.collInfo.emplace_back(catInfo.idTable.collectionID(name).value(), std::string(coll->getTypeName()),
coll->isSubsetCollection(), coll->getSchemaVersion());
}

// Also make branches for the parameters
root_utils::CollectionBranches branches;
branches.data = catInfo.tree->Branch(root_utils::paramBranchName, &parameters);
catInfo.branches.push_back(branches);
fillParams(catInfo, parameters);
catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::intKeyName, &catInfo.intParams.keys));
catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::intValueName, &catInfo.intParams.values));

catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::floatKeyName, &catInfo.floatParams.keys));
catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::floatValueName, &catInfo.floatParams.values));

catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::doubleKeyName, &catInfo.doubleParams.keys));
catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::doubleValueName, &catInfo.doubleParams.values));

catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::stringKeyName, &catInfo.stringParams.keys));
catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::stringValueName, &catInfo.stringParams.values));
}

void ROOTWriter::resetBranches(std::vector<root_utils::CollectionBranches>& branches,
const std::vector<root_utils::StoreCollection>& collections,
/*const*/ podio::GenericParameters* parameters) {
void ROOTWriter::resetBranches(CategoryInfo& categoryInfo,
const std::vector<root_utils::StoreCollection>& collections) {
size_t iColl = 0;
for (auto& [_, coll] : collections) {
const auto& collBranches = branches[iColl];
const auto& collBranches = categoryInfo.branches[iColl];
root_utils::setCollectionAddresses(coll->getBuffers(), collBranches);
iColl++;
}

branches.back().data->SetAddress(&parameters);
categoryInfo.branches[iColl].data->SetAddress(categoryInfo.intParams.keysPtr());
categoryInfo.branches[iColl + 1].data->SetAddress(categoryInfo.intParams.valuesPtr());

categoryInfo.branches[iColl + 2].data->SetAddress(categoryInfo.floatParams.keysPtr());
categoryInfo.branches[iColl + 3].data->SetAddress(categoryInfo.floatParams.valuesPtr());

categoryInfo.branches[iColl + 4].data->SetAddress(categoryInfo.doubleParams.keysPtr());
categoryInfo.branches[iColl + 5].data->SetAddress(categoryInfo.doubleParams.valuesPtr());

categoryInfo.branches[iColl + 6].data->SetAddress(categoryInfo.stringParams.keysPtr());
categoryInfo.branches[iColl + 7].data->SetAddress(categoryInfo.stringParams.valuesPtr());
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
}

void ROOTWriter::finish() {
Expand Down Expand Up @@ -175,4 +195,11 @@ ROOTWriter::checkConsistency(const std::vector<std::string>& collsToWrite, const
return {std::vector<std::string>{}, collsToWrite};
}

void ROOTWriter::fillParams(CategoryInfo& catInfo, const GenericParameters& params) {
catInfo.intParams = {params.getKeys<int>(), params.getValues<int>()};
catInfo.floatParams = {params.getKeys<float>(), params.getValues<float>()};
catInfo.doubleParams = {params.getKeys<double>(), params.getValues<double>()};
catInfo.stringParams = {params.getKeys<std::string>(), params.getValues<std::string>()};
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace podio
Loading