Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the TTree and RNTuple based backends write the GenericParameters the same way #625

Merged
merged 16 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions include/podio/GenericParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class RNTupleWriter;

namespace podio {

#if !defined(__CLING__)
// cling doesn't really deal well (i.e. at all in this case) with the forward
// declaration here and errors out, breaking e.g. python bindings.
class ROOTReader;
#endif

/// The types which are supported in the GenericParameters
using SupportedGenericDataTypes = std::tuple<int, float, std::string, double>;

Expand Down Expand Up @@ -98,6 +104,10 @@ class GenericParameters {
set<std::vector<T>>(key, std::move(values));
}

/// Load multiple key value pairs simultaneously
template <typename T, template <typename...> typename VecLike>
void loadFrom(VecLike<std::string> keys, VecLike<std::vector<T>> values);

/// Get the number of elements stored under the given key for a type
template <typename T, typename = EnableIfValidGenericDataType<T>>
size_t getN(const std::string& key) const;
Expand All @@ -108,7 +118,7 @@ class GenericParameters {

/// Get all the available values for a given type
template <typename T, typename = EnableIfValidGenericDataType<T>>
std::vector<std::vector<T>> getValues() const;
std::tuple<std::vector<std::string>, std::vector<std::vector<T>>> getKeysAndValues() const;

/// erase all elements
void clear() {
Expand All @@ -134,6 +144,10 @@ class GenericParameters {
friend RNTupleWriter;
#endif

#if !defined(__CLING__)
friend ROOTReader;
#endif

/// Get a reference to the internal map for a given type
template <typename T>
const MapType<detail::GetVectorType<T>>& getMap() const {
Expand Down Expand Up @@ -249,18 +263,36 @@ std::vector<std::string> GenericParameters::getKeys() const {
}

template <typename T, typename>
std::vector<std::vector<T>> GenericParameters::getValues() const {
std::tuple<std::vector<std::string>, std::vector<std::vector<T>>> GenericParameters::getKeysAndValues() const {
std::vector<std::vector<T>> values;
std::vector<std::string> keys;
auto& mtx = getMutex<T>();
const auto& map = getMap<T>();
{
// Lock to avoid concurrent changes to the map while we get the stored
// values
std::lock_guard lock{mtx};
values.reserve(map.size());
std::transform(map.begin(), map.end(), std::back_inserter(values), [](const auto& pair) { return pair.second; });
keys.reserve(map.size());

for (const auto& [k, v] : map) {
keys.emplace_back(k);
values.emplace_back(v);
}
}
return values;
return {keys, values};
}

template <typename T, template <typename...> typename VecLike>
void GenericParameters::loadFrom(VecLike<std::string> keys, VecLike<std::vector<T>> values) {
auto& map = getMap<T>();
auto& mtx = getMutex<T>();

std::lock_guard lock{mtx};
for (size_t i = 0; i < keys.size(); ++i) {
map.emplace(std::move(keys[i]), std::move(values[i]));
}
}

} // namespace podio
#endif
4 changes: 4 additions & 0 deletions include/podio/ROOTReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ class ROOTReader {
/// Read the parameters for the entry specified in the passed CategoryInfo
GenericParameters readEntryParameters(CategoryInfo& catInfo, bool reloadBranches, unsigned int localEntry);

template <typename T>
static void readParams(CategoryInfo& catInfo, podio::GenericParameters& params, bool reloadBranches,
unsigned int localEntry);

/// Read the data entry specified in the passed CategoryInfo, and increase the
/// counter afterwards. In case the requested entry is larger than the
/// available number of entries, return a nullptr.
Expand Down
14 changes: 11 additions & 3 deletions include/podio/ROOTWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ class ROOTWriter {
std::vector<root_utils::CollectionWriteInfoT> collInfo{}; ///< Collection info for this category
podio::CollectionIDTable idTable{}; ///< The collection id table for this category
std::vector<std::string> collsToWrite{}; ///< The collections to write for this category

// Storage for the keys & values of all the parameters of this category
// (resp. at least the current entry)
root_utils::ParamStorage<int> intParams{};
root_utils::ParamStorage<float> floatParams{};
root_utils::ParamStorage<double> doubleParams{};
root_utils::ParamStorage<std::string> stringParams{};
};

/// Initialize the branches for this category
Expand All @@ -117,9 +124,10 @@ class ROOTWriter {
/// Get the (potentially uninitialized category information for this category)
CategoryInfo& getCategoryInfo(const std::string& category);

static void resetBranches(std::vector<root_utils::CollectionBranches>& branches,
const std::vector<root_utils::StoreCollection>& collections,
/*const*/ podio::GenericParameters* parameters);
static void resetBranches(CategoryInfo& categoryInfo, const std::vector<root_utils::StoreCollection>& collections);

/// Fill the parameter keys and values into the CategoryInfo storage
static void fillParams(CategoryInfo& catInfo, const GenericParameters& params);

std::unique_ptr<TFile> m_file{nullptr}; ///< The storage file
std::unordered_map<std::string, CategoryInfo> m_categories{}; ///< All categories
Expand Down
51 changes: 50 additions & 1 deletion include/podio/utilities/RootHelpers.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef PODIO_UTILITIES_ROOTHELPERS_H
#define PODIO_UTILITIES_ROOTHELPERS_H

#include "podio/GenericParameters.h"

#include "ROOT/RVec.hxx"
#include "TBranch.h"

#include <string>
Expand All @@ -26,6 +29,16 @@ namespace root_utils {
/// write a collection. Needed to cache the branch pointers and avoid having to
/// get them from a TTree/TChain for every event.
struct CollectionBranches {
CollectionBranches() = default;
~CollectionBranches() = default;
CollectionBranches(const CollectionBranches&) = delete;
CollectionBranches& operator=(const CollectionBranches&) = delete;
CollectionBranches(CollectionBranches&&) = default;
CollectionBranches& operator=(CollectionBranches&&) = default;

CollectionBranches(TBranch* dataBranch) : data(dataBranch) {
}

TBranch* data{nullptr};
std::vector<TBranch*> refs{};
std::vector<TBranch*> vecs{};
Expand All @@ -36,7 +49,43 @@ namespace root_utils {
/// Pair of keys and values for one type of the ones that can be stored in
/// GenericParameters
template <typename T>
using ParamStorage = std::tuple<std::vector<std::string>, std::vector<std::vector<T>>>;
struct ParamStorage {
ParamStorage() = default;
~ParamStorage() = default;
ParamStorage(const ParamStorage&) = delete;
ParamStorage& operator=(const ParamStorage&) = delete;
ParamStorage(ParamStorage&&) = default;
ParamStorage& operator=(ParamStorage&&) = default;

ParamStorage(std::tuple<std::vector<std::string>, std::vector<std::vector<T>>> keysValues) :
keys(std::move(std::get<0>(keysValues))), values(std::move(std::get<1>(keysValues))) {
}

/// Get a pointer to the stored keys for binding it to a TBranch
auto keysPtr() {
m_keysPtr = &keys;
return &m_keysPtr;
}

/// Get a pointer to the stored vectors for binding it to a TBranch
auto valuesPtr() {
m_valuesPtr = &values;
return &m_valuesPtr;
}

std::vector<std::string> keys{}; ///< The keys for this type
std::vector<std::vector<T>> values{}; ///< The values for this type

private:
std::vector<std::string>* m_keysPtr{nullptr};
std::vector<std::vector<T>>* m_valuesPtr{nullptr};
};

GenericParameters
loadParamsFrom(ROOT::VecOps::RVec<std::string> intKeys, ROOT::VecOps::RVec<std::vector<int>> intValues,
ROOT::VecOps::RVec<std::string> floatKeys, ROOT::VecOps::RVec<std::vector<float>> floatValues,
ROOT::VecOps::RVec<std::string> doubleKeys, ROOT::VecOps::RVec<std::vector<double>> doubleValues,
ROOT::VecOps::RVec<std::string> stringKeys, ROOT::VecOps::RVec<std::vector<std::string>> stringValues);

} // namespace root_utils
} // namespace podio
Expand Down
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ SET(root_sources
ROOTReader.cc
ROOTLegacyReader.cc
ROOTFrameData.cc
RootHelpers.cc
)
if(ENABLE_RNTUPLE)
list(APPEND root_sources
Expand All @@ -104,7 +105,7 @@ if(ENABLE_RNTUPLE)
endif()

PODIO_ADD_LIB_AND_DICT(podioRootIO "${root_headers}" "${root_sources}" root_selection.xml)
target_link_libraries(podioRootIO PUBLIC podio::podio ROOT::Core ROOT::RIO ROOT::Tree)
target_link_libraries(podioRootIO PUBLIC podio::podio ROOT::Core ROOT::RIO ROOT::Tree ROOT::ROOTVecOps)
if(ENABLE_RNTUPLE)
target_link_libraries(podioRootIO PUBLIC ROOT::ROOTNTuple)
target_compile_definitions(podioRootIO PUBLIC PODIO_ENABLE_RNTUPLE=1)
Expand Down
4 changes: 1 addition & 3 deletions src/RNTupleReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ void RNTupleReader::readParams(const std::string& name, unsigned entNum, Generic
auto keyView = m_readers[name][0]->GetView<std::vector<std::string>>(root_utils::getGPKeyName<T>());
auto valueView = m_readers[name][0]->GetView<std::vector<std::vector<T>>>(root_utils::getGPValueName<T>());

for (size_t i = 0; i < keyView(entNum).size(); ++i) {
params.getMap<T>().emplace(std::move(keyView(entNum)[i]), std::move(valueView(entNum)[i]));
}
params.loadFrom(keyView(entNum), valueView(entNum));
}

GenericParameters RNTupleReader::readEventMetaData(const std::string& name, unsigned entNum) {
Expand Down
13 changes: 6 additions & 7 deletions src/RNTupleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@ root_utils::ParamStorage<T>& RNTupleWriter::getParamStorage(CategoryInfo& catInf
template <typename T>
void RNTupleWriter::fillParams(const GenericParameters& params, CategoryInfo& catInfo,
ROOT::Experimental::REntry* entry) {
auto& [keys, values] = getParamStorage<T>(catInfo);
keys = params.getKeys<T>();
values = params.getValues<T>();
auto& paramStorage = getParamStorage<T>(catInfo);
paramStorage = params.getKeysAndValues<T>();
#if ROOT_VERSION_CODE >= ROOT_VERSION(6, 31, 0)
entry->BindRawPtr(root_utils::getGPKeyName<T>(), &keys);
entry->BindRawPtr(root_utils::getGPValueName<T>(), &values);
entry->BindRawPtr(root_utils::getGPKeyName<T>(), &paramStorage.keys);
entry->BindRawPtr(root_utils::getGPValueName<T>(), &paramStorage.values);
#else
entry->CaptureValueUnsafe(root_utils::getGPKeyName<T>(), &keys);
entry->CaptureValueUnsafe(root_utils::getGPValueName<T>(), &values);
entry->CaptureValueUnsafe(root_utils::getGPKeyName<T>(), &paramStorage.keys);
entry->CaptureValueUnsafe(root_utils::getGPValueName<T>(), &paramStorage.values);
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion src/ROOTLegacyReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void ROOTLegacyReader::createCollectionBranches(const std::vector<root_utils::Co

m_storedClasses.emplace_back(name, std::make_tuple(collType, isSubsetColl, collSchemaVersion, collectionIndex++));

m_collectionBranches.push_back(branches);
m_collectionBranches.emplace_back(std::move(branches));
}
}

Expand Down
83 changes: 64 additions & 19 deletions src/ROOTReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
#include "podio/CollectionIDTable.h"
#include "podio/DatamodelRegistry.h"
#include "podio/GenericParameters.h"
#include "podio/utilities/RootHelpers.h"
#include "rootUtils.h"

// ROOT specific includes
#include "TChain.h"
#include "TClass.h"
#include "TFile.h"
#include "TTree.h"
#include "TTreeCache.h"

#include <stdexcept>
Expand All @@ -27,22 +26,56 @@ std::tuple<std::vector<root_utils::CollectionBranches>, std::vector<std::pair<st
createCollectionBranchesIndexBased(TChain* chain, const podio::CollectionIDTable& idTable,
const std::vector<root_utils::CollectionWriteInfoT>& collInfo);

GenericParameters ROOTReader::readEntryParameters(ROOTReader::CategoryInfo& catInfo, bool reloadBranches,
unsigned int localEntry) {
// Parameter branch is always the last one
auto& paramBranches = catInfo.branches.back();
template <typename T>
void ROOTReader::readParams(ROOTReader::CategoryInfo& catInfo, podio::GenericParameters& params, bool reloadBranches,
unsigned int localEntry) {
const auto collBranchIdx = catInfo.branches.size() - root_utils::nParamBranches - 1;
constexpr auto brOffset = root_utils::getGPBranchOffsets<T>();

// Make sure to have a valid branch pointer after switching trees in the chain
// as well as on the first event
if (reloadBranches) {
paramBranches.data = root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName);
auto& keyBranch = catInfo.branches[collBranchIdx + brOffset.keys].data;
keyBranch = root_utils::getBranch(catInfo.chain.get(), root_utils::getGPKeyName<T>());
auto& valueBranch = catInfo.branches[collBranchIdx + brOffset.values].data;
valueBranch = root_utils::getBranch(catInfo.chain.get(), root_utils::getGPValueName<T>());
}
auto* branch = paramBranches.data;

auto keyBranch = catInfo.branches[collBranchIdx + brOffset.keys].data;
auto valueBranch = catInfo.branches[collBranchIdx + brOffset.values].data;

root_utils::ParamStorage<T> storage;
keyBranch->SetAddress(storage.keysPtr());
keyBranch->GetEntry(localEntry);
valueBranch->SetAddress(storage.valuesPtr());
valueBranch->GetEntry(localEntry);

params.loadFrom(std::move(storage.keys), std::move(storage.values));
}

GenericParameters ROOTReader::readEntryParameters(ROOTReader::CategoryInfo& catInfo, bool reloadBranches,
unsigned int localEntry) {
GenericParameters params;
auto* emd = &params;
branch->SetAddress(&emd);
branch->GetEntry(localEntry);

if (m_fileVersion < podio::version::Version{0, 99, 99}) {
// Parameter branch is always the last one
auto& paramBranches = catInfo.branches.back();

// Make sure to have a valid branch pointer after switching trees in the chain
// as well as on the first event
if (reloadBranches) {
paramBranches.data = root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName);
}
auto* branch = paramBranches.data;

auto* emd = &params;
branch->SetAddress(&emd);
branch->GetEntry(localEntry);
} else {
readParams<int>(catInfo, params, reloadBranches, localEntry);
readParams<float>(catInfo, params, reloadBranches, localEntry);
readParams<double>(catInfo, params, reloadBranches, localEntry);
readParams<std::string>(catInfo, params, reloadBranches, localEntry);
}

return params;
}

Expand Down Expand Up @@ -162,13 +195,25 @@ void ROOTReader::initCategory(CategoryInfo& catInfo, const std::string& category
std::tie(catInfo.branches, catInfo.storedClasses) =
createCollectionBranches(catInfo.chain.get(), *catInfo.table, *collInfo);
}

delete collInfo;

// Finally set up the branches for the parameters
root_utils::CollectionBranches paramBranches{};
paramBranches.data = root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName);
catInfo.branches.push_back(paramBranches);
if (m_fileVersion < podio::version::Version{0, 99, 99}) {
root_utils::CollectionBranches paramBranches{};
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::paramBranchName));
} else {
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::intKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::intValueName));

catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::floatKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::floatValueName));

catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::doubleKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::doubleValueName));

catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::stringKeyName));
catInfo.branches.emplace_back(root_utils::getBranch(catInfo.chain.get(), root_utils::stringValueName));
}
}

std::vector<std::string> getAvailableCategories(TChain* metaChain) {
Expand Down Expand Up @@ -300,7 +345,7 @@ createCollectionBranchesIndexBased(TChain* chain, const podio::CollectionIDTable
collBranches.emplace_back(std::move(branches));
}

return {collBranches, storedClasses};
return {std::move(collBranches), storedClasses};
}

std::tuple<std::vector<root_utils::CollectionBranches>, std::vector<std::pair<std::string, detail::CollectionInfo>>>
Expand Down Expand Up @@ -346,7 +391,7 @@ createCollectionBranches(TChain* chain, const podio::CollectionIDTable& idTable,
collBranches.emplace_back(std::move(branches));
}

return {collBranches, storedClasses};
return {std::move(collBranches), storedClasses};
}

} // namespace podio
Loading