From 26d51c91d19d626bd6d96776a35a7af774c49ced Mon Sep 17 00:00:00 2001 From: Thomas Madlener Date: Fri, 1 Dec 2023 19:55:05 +0100 Subject: [PATCH] Add python bindings for the RNTuple backend (#488) * Add RNTupleReader to python bindings and file dispatch * Complete ROOTNTupleReader interface and expose it in dictionaries * Add python bindings for RNTuple writer --- include/podio/ROOTNTupleReader.h | 14 +++++++++ python/podio/reading.py | 31 +++++++++++++++---- python/podio/root_io.py | 29 ++++++++++++++++++ src/ROOTNTupleReader.cc | 10 +++++++ src/ROOTNTupleWriter.cc | 6 +++- src/root_selection.xml | 2 ++ tests/dumpmodel/CMakeLists.txt | 33 +++++++++++++++++++++ tests/root_io/CMakeLists.txt | 10 ++++++- tests/root_io/read_python_frame_rntuple.cpp | 7 +++++ tests/sio_io/CMakeLists.txt | 2 +- tests/write_frame.py | 8 +++-- tools/CMakeLists.txt | 5 ++++ 12 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 tests/root_io/read_python_frame_rntuple.cpp diff --git a/include/podio/ROOTNTupleReader.h b/include/podio/ROOTNTupleReader.h index 672840587..30783c901 100644 --- a/include/podio/ROOTNTupleReader.h +++ b/include/podio/ROOTNTupleReader.h @@ -9,6 +9,7 @@ #include "podio/utilities/DatamodelRegistryIOHelpers.h" #include +#include #include #include @@ -47,6 +48,9 @@ class ROOTNTupleReader { */ std::unique_ptr readEntry(const std::string& name, const unsigned entry); + /// Get the names of all the available Frame categories in the current file(s) + std::vector getAvailableCategories() const; + /// Returns number of entries for the given name unsigned getEntries(const std::string& name); @@ -55,6 +59,16 @@ class ROOTNTupleReader { return m_fileVersion; } + /// Get the datamodel definition for the given name + const std::string_view getDatamodelDefinition(const std::string& name) const { + return m_datamodelHolder.getDatamodelDefinition(name); + } + + /// Get all names of the datamodels that ara available from this reader + std::vector getAvailableDatamodels() const { + return m_datamodelHolder.getAvailableDatamodels(); + } + void closeFile(); private: diff --git a/python/podio/reading.py b/python/podio/reading.py index b4f6401f5..357e37151 100644 --- a/python/podio/reading.py +++ b/python/podio/reading.py @@ -23,11 +23,26 @@ def _is_frame_sio_file(filename): 'or there is a version mismatch') -def _is_frame_root_file(filename): - """Peek into the root file to determine whether this is a legacy file or not.""" +class RootFileFormat: + """Enum to specify the ROOT file format""" + TTREE = 0 # Non-legacy TTree based file + RNTUPLE = 1 # RNTuple based file + LEGACY = 2 # Legacy TTree based file + + +def _determine_root_format(filename): + """Peek into the root file to determine which flavor we have at hand.""" file = TFile.Open(filename) - # The ROOT Frame writer puts a podio_metadata TTree into the file - return bool(file.Get('podio_metadata')) + + metadata = file.Get("podio_metadata") + if not metadata: + return RootFileFormat.LEGACY + + md_class = metadata.IsA().GetName() + if "TTree" in md_class: + return RootFileFormat.TTREE + + return RootFileFormat.RNTUPLE def get_reader(filename): @@ -50,8 +65,12 @@ def get_reader(filename): return sio_io.LegacyReader(filename) if filename.endswith('.root'): - if _is_frame_root_file(filename): + root_flavor = _determine_root_format(filename) + if root_flavor == RootFileFormat.TTREE: return root_io.Reader(filename) - return root_io.LegacyReader(filename) + if root_flavor == RootFileFormat.RNTUPLE: + return root_io.RNTupleReader(filename) + if root_flavor == RootFileFormat.LEGACY: + return root_io.LegacyReader(filename) raise ValueError('file must end on .root or .sio') diff --git a/python/podio/root_io.py b/python/podio/root_io.py index 9623ee24d..6ebfbdac7 100644 --- a/python/podio/root_io.py +++ b/python/podio/root_io.py @@ -27,6 +27,24 @@ def __init__(self, filenames): super().__init__() +class RNTupleReader(BaseReaderMixin): + """Reader class for reading podio RNTuple root files.""" + + def __init__(self, filenames): + """Create an RNTuple reader that reads from the passed file(s). + + Args: + filenames (str or list[str]): file(s) to open and read data from + """ + if isinstance(filenames, str): + filenames = (filenames,) + + self._reader = podio.ROOTNTupleReader() + self._reader.openFiles(filenames) + + super().__init__() + + class LegacyReader(BaseReaderMixin): """Reader class for reading legacy podio root files. @@ -59,3 +77,14 @@ def __init__(self, filename): filename (str): The name of the output file """ self._writer = podio.ROOTFrameWriter(filename) + + +class RNTupleWriter(BaseWriterMixin): + """Writer class for writing podio root files""" + def __init__(self, filename): + """Create a writer for writing files + + Args: + filename (str): The name of the output file + """ + self._writer = podio.ROOTNTupleWriter(filename) diff --git a/src/ROOTNTupleReader.cc b/src/ROOTNTupleReader.cc index 4b020b314..b4db7f991 100644 --- a/src/ROOTNTupleReader.cc +++ b/src/ROOTNTupleReader.cc @@ -87,6 +87,7 @@ void ROOTNTupleReader::openFiles(const std::vector& filenames) { auto edmView = m_metadata->GetView>>(root_utils::edmDefBranchName); auto edm = edmView(0); + m_datamodelHolder = DatamodelDefinitionHolder(std::move(edm)); auto availableCategoriesField = m_metadata->GetView>(root_utils::availableCategories); m_availableCategories = availableCategoriesField(0); @@ -107,6 +108,15 @@ unsigned ROOTNTupleReader::getEntries(const std::string& name) { return m_totalEntries[name]; } +std::vector ROOTNTupleReader::getAvailableCategories() const { + std::vector cats; + cats.reserve(m_availableCategories.size()); + for (const auto& cat : m_availableCategories) { + cats.emplace_back(cat); + } + return cats; +} + std::unique_ptr ROOTNTupleReader::readNextEntry(const std::string& name) { return readEntry(name, m_entries[name]); } diff --git a/src/ROOTNTupleWriter.cc b/src/ROOTNTupleWriter.cc index 9cf3b9d8d..79fc982c9 100644 --- a/src/ROOTNTupleWriter.cc +++ b/src/ROOTNTupleWriter.cc @@ -164,6 +164,10 @@ std::unique_ptr ROOTNTupleWriter::createModels(const std::vector& collections) { auto model = ROOT::Experimental::RNTupleModel::CreateBare(); for (auto& [name, coll] : collections) { + // For the first entry in each category we also record the datamodel + // definition + m_datamodelCollector.registerDatamodelDefinition(coll, name); + const auto collBuffers = coll->getBuffers(); if (collBuffers.vecPtr) { @@ -252,7 +256,7 @@ void ROOTNTupleWriter::finish() { auto edmDefinitions = m_datamodelCollector.getDatamodelDefinitionsToWrite(); auto edmField = m_metadata->MakeField>>(root_utils::edmDefBranchName); - *edmField = edmDefinitions; + *edmField = std::move(edmDefinitions); auto availableCategoriesField = m_metadata->MakeField>(root_utils::availableCategories); for (auto& [c, _] : m_categories) { diff --git a/src/root_selection.xml b/src/root_selection.xml index 886a69e68..afda66935 100644 --- a/src/root_selection.xml +++ b/src/root_selection.xml @@ -3,5 +3,7 @@ + + diff --git a/tests/dumpmodel/CMakeLists.txt b/tests/dumpmodel/CMakeLists.txt index 11c7f5524..6e9dfef1e 100644 --- a/tests/dumpmodel/CMakeLists.txt +++ b/tests/dumpmodel/CMakeLists.txt @@ -57,10 +57,43 @@ if (ENABLE_SIO) ) endif() +set(rntuple_roundtrip_tests "") +if (ENABLE_RNTUPLE) + add_test(NAME datamodel_def_store_roundtrip_rntuple COMMAND + ${PROJECT_SOURCE_DIR}/tests/scripts/dumpModelRoundTrip.sh + ${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root + datamodel + ${PROJECT_SOURCE_DIR}/tests + ) + PODIO_SET_TEST_ENV(datamodel_def_store_roundtrip_rntuple) + + add_test(NAME datamodel_def_store_roundtrip_rntuple_extension COMMAND + ${PROJECT_SOURCE_DIR}/tests/scripts/dumpModelRoundTrip.sh + ${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root + extension_model + ${PROJECT_SOURCE_DIR}/tests/extension_model + --upstream-edm=datamodel:${PROJECT_SOURCE_DIR}/tests/datalayout.yaml + ) + PODIO_SET_TEST_ENV(datamodel_def_store_roundtrip_rntuple_extension) + + set(rntuple_roundtrip_tests + datamodel_def_store_roundtrip_rntuple + datamodel_def_store_roundtrip_rntuple_extension + ) + + set_tests_properties( + ${rntuple_roundtrip_tests} + PROPERTIES + DEPENDS write_rntuple + ) + +endif() + set_tests_properties( datamodel_def_store_roundtrip_root datamodel_def_store_roundtrip_root_extension ${sio_roundtrip_tests} + ${rntuple_roundtrip_tests} PROPERTIES WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) diff --git a/tests/root_io/CMakeLists.txt b/tests/root_io/CMakeLists.txt index 17435bedc..099815022 100644 --- a/tests/root_io/CMakeLists.txt +++ b/tests/root_io/CMakeLists.txt @@ -19,6 +19,7 @@ if(ENABLE_RNTUPLE) ${root_dependent_tests} write_rntuple.cpp read_rntuple.cpp + read_python_frame_rntuple.cpp ) endif() set(root_libs TestDataModelDict ExtensionDataModelDict podio::podioRootIO) @@ -80,6 +81,13 @@ endforeach() #--- Write via python and the ROOT backend and see if we can read it back in in #--- c++ -add_test(NAME write_python_frame_root COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.root) +add_test(NAME write_python_frame_root COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.root root_io.Writer) PODIO_SET_TEST_ENV(write_python_frame_root) set_property(TEST read_python_frame_root PROPERTY DEPENDS write_python_frame_root) + +if (ENABLE_RNTUPLE) + add_test(NAME write_python_frame_rntuple COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py_rntuple.root root_io.RNTupleWriter) + PODIO_SET_TEST_ENV(write_python_frame_rntuple) + + set_property(TEST read_python_frame_rntuple PROPERTY DEPENDS write_python_frame_rntuple) +endif() diff --git a/tests/root_io/read_python_frame_rntuple.cpp b/tests/root_io/read_python_frame_rntuple.cpp new file mode 100644 index 000000000..52d7576d8 --- /dev/null +++ b/tests/root_io/read_python_frame_rntuple.cpp @@ -0,0 +1,7 @@ +#include "read_python_frame.h" + +#include "podio/ROOTNTupleReader.h" + +int main() { + return read_frame("example_frame_with_py_rntuple.root"); +} diff --git a/tests/sio_io/CMakeLists.txt b/tests/sio_io/CMakeLists.txt index d48ad1372..2d24ce2f9 100644 --- a/tests/sio_io/CMakeLists.txt +++ b/tests/sio_io/CMakeLists.txt @@ -40,6 +40,6 @@ set_property(TEST check_benchmark_outputs_sio PROPERTY DEPENDS read_timed_sio wr #--- Write via python and the SIO backend and see if we can read it back in in #--- c++ -add_test(NAME write_python_frame_sio COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.sio) +add_test(NAME write_python_frame_sio COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.sio sio_io.Writer) PODIO_SET_TEST_ENV(write_python_frame_sio) set_property(TEST read_python_frame_sio PROPERTY DEPENDS write_python_frame_sio) diff --git a/tests/write_frame.py b/tests/write_frame.py index 05b901f5e..313c72a9e 100644 --- a/tests/write_frame.py +++ b/tests/write_frame.py @@ -54,13 +54,14 @@ def create_frame(): return frame -def write_file(io_backend, filename): +def write_file(writer_type, filename): """Write a file using the given Writer type and put one Frame into it under the events category """ + io_backend, writer_name = writer_type.split(".") io_module = importlib.import_module(f"podio.{io_backend}") - writer = io_module.Writer(filename) + writer = getattr(io_module, writer_name)(filename) event = create_frame() writer.write_frame(event, "events") @@ -70,9 +71,10 @@ def write_file(io_backend, filename): parser = argparse.ArgumentParser() parser.add_argument("outputfile", help="Output file name") + parser.add_argument("writer", help="The writer type to use") args = parser.parse_args() io_format = args.outputfile.split(".")[-1] - write_file(f"{io_format}_io", args.outputfile) + write_file(args.writer, args.outputfile) diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 1fe5aaf63..06c5118fe 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -37,4 +37,9 @@ if(BUILD_TESTING) CREATE_DUMP_TEST(podio-dump-detailed-sio-legacy "write_sio" --detailed --entries 9 ${PROJECT_BINARY_DIR}/tests/sio_io/example.sio) endif() + if (ENABLE_RNTUPLE) + CREATE_DUMP_TEST(podio-dump-rntuple "write_rntuple" ${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root) + CREATE_DUMP_TEST(podio-dump-rntuple-detailed "write_rntuple" --detailed --category events --entries 1:3 ${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root) + endif() + endif()