Skip to content

Commit

Permalink
Add python bindings for the RNTuple backend (#488)
Browse files Browse the repository at this point in the history
* Add RNTupleReader to python bindings and file dispatch

* Complete ROOTNTupleReader interface and expose it in dictionaries

* Add python bindings for RNTuple writer
  • Loading branch information
tmadlener authored Dec 1, 2023
1 parent 7a85b3a commit 26d51c9
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 12 deletions.
14 changes: 14 additions & 0 deletions include/podio/ROOTNTupleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "podio/utilities/DatamodelRegistryIOHelpers.h"

#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -47,6 +48,9 @@ class ROOTNTupleReader {
*/
std::unique_ptr<podio::ROOTFrameData> readEntry(const std::string& name, const unsigned entry);

/// Get the names of all the available Frame categories in the current file(s)
std::vector<std::string_view> getAvailableCategories() const;

/// Returns number of entries for the given name
unsigned getEntries(const std::string& name);

Expand All @@ -55,6 +59,16 @@ class ROOTNTupleReader {
return m_fileVersion;
}

/// Get the datamodel definition for the given name
const std::string_view getDatamodelDefinition(const std::string& name) const {
return m_datamodelHolder.getDatamodelDefinition(name);
}

/// Get all names of the datamodels that ara available from this reader
std::vector<std::string> getAvailableDatamodels() const {
return m_datamodelHolder.getAvailableDatamodels();
}

void closeFile();

private:
Expand Down
31 changes: 25 additions & 6 deletions python/podio/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,26 @@ def _is_frame_sio_file(filename):
'or there is a version mismatch')


def _is_frame_root_file(filename):
"""Peek into the root file to determine whether this is a legacy file or not."""
class RootFileFormat:
"""Enum to specify the ROOT file format"""
TTREE = 0 # Non-legacy TTree based file
RNTUPLE = 1 # RNTuple based file
LEGACY = 2 # Legacy TTree based file


def _determine_root_format(filename):
"""Peek into the root file to determine which flavor we have at hand."""
file = TFile.Open(filename)
# The ROOT Frame writer puts a podio_metadata TTree into the file
return bool(file.Get('podio_metadata'))

metadata = file.Get("podio_metadata")
if not metadata:
return RootFileFormat.LEGACY

md_class = metadata.IsA().GetName()
if "TTree" in md_class:
return RootFileFormat.TTREE

return RootFileFormat.RNTUPLE


def get_reader(filename):
Expand All @@ -50,8 +65,12 @@ def get_reader(filename):
return sio_io.LegacyReader(filename)

if filename.endswith('.root'):
if _is_frame_root_file(filename):
root_flavor = _determine_root_format(filename)
if root_flavor == RootFileFormat.TTREE:
return root_io.Reader(filename)
return root_io.LegacyReader(filename)
if root_flavor == RootFileFormat.RNTUPLE:
return root_io.RNTupleReader(filename)
if root_flavor == RootFileFormat.LEGACY:
return root_io.LegacyReader(filename)

raise ValueError('file must end on .root or .sio')
29 changes: 29 additions & 0 deletions python/podio/root_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ def __init__(self, filenames):
super().__init__()


class RNTupleReader(BaseReaderMixin):
"""Reader class for reading podio RNTuple root files."""

def __init__(self, filenames):
"""Create an RNTuple reader that reads from the passed file(s).
Args:
filenames (str or list[str]): file(s) to open and read data from
"""
if isinstance(filenames, str):
filenames = (filenames,)

self._reader = podio.ROOTNTupleReader()
self._reader.openFiles(filenames)

super().__init__()


class LegacyReader(BaseReaderMixin):
"""Reader class for reading legacy podio root files.
Expand Down Expand Up @@ -59,3 +77,14 @@ def __init__(self, filename):
filename (str): The name of the output file
"""
self._writer = podio.ROOTFrameWriter(filename)


class RNTupleWriter(BaseWriterMixin):
"""Writer class for writing podio root files"""
def __init__(self, filename):
"""Create a writer for writing files
Args:
filename (str): The name of the output file
"""
self._writer = podio.ROOTNTupleWriter(filename)
10 changes: 10 additions & 0 deletions src/ROOTNTupleReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ void ROOTNTupleReader::openFiles(const std::vector<std::string>& filenames) {

auto edmView = m_metadata->GetView<std::vector<std::tuple<std::string, std::string>>>(root_utils::edmDefBranchName);
auto edm = edmView(0);
m_datamodelHolder = DatamodelDefinitionHolder(std::move(edm));

auto availableCategoriesField = m_metadata->GetView<std::vector<std::string>>(root_utils::availableCategories);
m_availableCategories = availableCategoriesField(0);
Expand All @@ -107,6 +108,15 @@ unsigned ROOTNTupleReader::getEntries(const std::string& name) {
return m_totalEntries[name];
}

std::vector<std::string_view> ROOTNTupleReader::getAvailableCategories() const {
std::vector<std::string_view> cats;
cats.reserve(m_availableCategories.size());
for (const auto& cat : m_availableCategories) {
cats.emplace_back(cat);
}
return cats;
}

std::unique_ptr<ROOTFrameData> ROOTNTupleReader::readNextEntry(const std::string& name) {
return readEntry(name, m_entries[name]);
}
Expand Down
6 changes: 5 additions & 1 deletion src/ROOTNTupleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ std::unique_ptr<ROOT::Experimental::RNTupleModel>
ROOTNTupleWriter::createModels(const std::vector<StoreCollection>& collections) {
auto model = ROOT::Experimental::RNTupleModel::CreateBare();
for (auto& [name, coll] : collections) {
// For the first entry in each category we also record the datamodel
// definition
m_datamodelCollector.registerDatamodelDefinition(coll, name);

const auto collBuffers = coll->getBuffers();

if (collBuffers.vecPtr) {
Expand Down Expand Up @@ -252,7 +256,7 @@ void ROOTNTupleWriter::finish() {
auto edmDefinitions = m_datamodelCollector.getDatamodelDefinitionsToWrite();
auto edmField =
m_metadata->MakeField<std::vector<std::tuple<std::string, std::string>>>(root_utils::edmDefBranchName);
*edmField = edmDefinitions;
*edmField = std::move(edmDefinitions);

auto availableCategoriesField = m_metadata->MakeField<std::vector<std::string>>(root_utils::availableCategories);
for (auto& [c, _] : m_categories) {
Expand Down
2 changes: 2 additions & 0 deletions src/root_selection.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
<class name="podio::ROOTFrameReader"/>
<class name="podio::ROOTLegacyReader"/>
<class name="podio::ROOTFrameWriter"/>
<class name="podio::ROOTNTupleReader"/>
<class name="podio::ROOTNTupleWriter"/>
</selection>
</lcgdict>
33 changes: 33 additions & 0 deletions tests/dumpmodel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,43 @@ if (ENABLE_SIO)
)
endif()

set(rntuple_roundtrip_tests "")
if (ENABLE_RNTUPLE)
add_test(NAME datamodel_def_store_roundtrip_rntuple COMMAND
${PROJECT_SOURCE_DIR}/tests/scripts/dumpModelRoundTrip.sh
${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root
datamodel
${PROJECT_SOURCE_DIR}/tests
)
PODIO_SET_TEST_ENV(datamodel_def_store_roundtrip_rntuple)

add_test(NAME datamodel_def_store_roundtrip_rntuple_extension COMMAND
${PROJECT_SOURCE_DIR}/tests/scripts/dumpModelRoundTrip.sh
${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root
extension_model
${PROJECT_SOURCE_DIR}/tests/extension_model
--upstream-edm=datamodel:${PROJECT_SOURCE_DIR}/tests/datalayout.yaml
)
PODIO_SET_TEST_ENV(datamodel_def_store_roundtrip_rntuple_extension)

set(rntuple_roundtrip_tests
datamodel_def_store_roundtrip_rntuple
datamodel_def_store_roundtrip_rntuple_extension
)

set_tests_properties(
${rntuple_roundtrip_tests}
PROPERTIES
DEPENDS write_rntuple
)

endif()

set_tests_properties(
datamodel_def_store_roundtrip_root
datamodel_def_store_roundtrip_root_extension
${sio_roundtrip_tests}
${rntuple_roundtrip_tests}
PROPERTIES
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
)
10 changes: 9 additions & 1 deletion tests/root_io/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ if(ENABLE_RNTUPLE)
${root_dependent_tests}
write_rntuple.cpp
read_rntuple.cpp
read_python_frame_rntuple.cpp
)
endif()
set(root_libs TestDataModelDict ExtensionDataModelDict podio::podioRootIO)
Expand Down Expand Up @@ -80,6 +81,13 @@ endforeach()

#--- Write via python and the ROOT backend and see if we can read it back in in
#--- c++
add_test(NAME write_python_frame_root COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.root)
add_test(NAME write_python_frame_root COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.root root_io.Writer)
PODIO_SET_TEST_ENV(write_python_frame_root)
set_property(TEST read_python_frame_root PROPERTY DEPENDS write_python_frame_root)

if (ENABLE_RNTUPLE)
add_test(NAME write_python_frame_rntuple COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py_rntuple.root root_io.RNTupleWriter)
PODIO_SET_TEST_ENV(write_python_frame_rntuple)

set_property(TEST read_python_frame_rntuple PROPERTY DEPENDS write_python_frame_rntuple)
endif()
7 changes: 7 additions & 0 deletions tests/root_io/read_python_frame_rntuple.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include "read_python_frame.h"

#include "podio/ROOTNTupleReader.h"

int main() {
return read_frame<podio::ROOTNTupleReader>("example_frame_with_py_rntuple.root");
}
2 changes: 1 addition & 1 deletion tests/sio_io/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ set_property(TEST check_benchmark_outputs_sio PROPERTY DEPENDS read_timed_sio wr

#--- Write via python and the SIO backend and see if we can read it back in in
#--- c++
add_test(NAME write_python_frame_sio COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.sio)
add_test(NAME write_python_frame_sio COMMAND python3 ${PROJECT_SOURCE_DIR}/tests/write_frame.py example_frame_with_py.sio sio_io.Writer)
PODIO_SET_TEST_ENV(write_python_frame_sio)
set_property(TEST read_python_frame_sio PROPERTY DEPENDS write_python_frame_sio)
8 changes: 5 additions & 3 deletions tests/write_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ def create_frame():
return frame


def write_file(io_backend, filename):
def write_file(writer_type, filename):
"""Write a file using the given Writer type and put one Frame into it under
the events category
"""
io_backend, writer_name = writer_type.split(".")
io_module = importlib.import_module(f"podio.{io_backend}")

writer = io_module.Writer(filename)
writer = getattr(io_module, writer_name)(filename)
event = create_frame()
writer.write_frame(event, "events")

Expand All @@ -70,9 +71,10 @@ def write_file(io_backend, filename):

parser = argparse.ArgumentParser()
parser.add_argument("outputfile", help="Output file name")
parser.add_argument("writer", help="The writer type to use")

args = parser.parse_args()

io_format = args.outputfile.split(".")[-1]

write_file(f"{io_format}_io", args.outputfile)
write_file(args.writer, args.outputfile)
5 changes: 5 additions & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,9 @@ if(BUILD_TESTING)
CREATE_DUMP_TEST(podio-dump-detailed-sio-legacy "write_sio" --detailed --entries 9 ${PROJECT_BINARY_DIR}/tests/sio_io/example.sio)
endif()

if (ENABLE_RNTUPLE)
CREATE_DUMP_TEST(podio-dump-rntuple "write_rntuple" ${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root)
CREATE_DUMP_TEST(podio-dump-rntuple-detailed "write_rntuple" --detailed --category events --entries 1:3 ${PROJECT_BINARY_DIR}/tests/root_io/example_rntuple.root)
endif()

endif()

0 comments on commit 26d51c9

Please sign in to comment.