From b06201fadaa7102242e0c76ec0b577d4f1ad4f85 Mon Sep 17 00:00:00 2001 From: tmadlener Date: Fri, 24 May 2024 15:04:15 +0200 Subject: [PATCH 1/6] Add python bindings and more overloads for getName --- include/podio/Frame.h | 22 ++++++++++++++++++++++ python/podio/frame.py | 32 ++++++++++++++++++++++++++++++++ python/podio/test_Frame.py | 23 +++++++++++++++++++++++ tests/unittests/frame.cpp | 1 + 4 files changed, 78 insertions(+) diff --git a/include/podio/Frame.h b/include/podio/Frame.h index ce459e119..28d18145b 100644 --- a/include/podio/Frame.h +++ b/include/podio/Frame.h @@ -338,6 +338,28 @@ class Frame { return m_self->getIDTable().name(collectionID); } + /// Get the name of the collection for the passed object ID + /// + /// @param objId The objectID of an element of a collection for which the name + /// should be obtained + /// @returns The name of the collection or an empty optional if this + /// objectID is not known to the Frame + inline std::optional getName(const podio::ObjectID& objId) const { + return getName(objId.collectionID); + } + + /// Get the name of the collection to which this element belongs + /// + /// @tparam ElemT A datatype of a podio generated datamodel + /// @param elem The element of a collection for which the name should be + /// obtained + /// @returns The name of the collection or an empty optional if this + /// element is not known to the Frame + template >> + inline std::optional getName(const ElemT& elem) const { + return getName(elem.id().collectionID); + } + // Interfaces for writing below /// Get a collection for writing. diff --git a/python/podio/frame.py b/python/podio/frame.py index e35ad6bac..c2ce68575 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -2,6 +2,7 @@ """Module for the python bindings of the podio::Frame""" import cppyy +from copy import deepcopy import ROOT @@ -107,6 +108,37 @@ def get(self, name): raise KeyError(f"Collection '{name}' is not available") return collection + def getName(self, token): + """Get the name of the collection from the Frame + + Args: + token (podio.CollectionBase | int | podio.ObjectID | podio generated + datatype): Something that that can be used to get the name of a + collection. Can either be the collection itself, a collectionID + or an object ID of an element of a collection + + Returns: + str: The name of the collection inside the frame + + Raises: + KeyError: If no collection can be found in the frame + + """ + maybeName = self._frame.getName(token) + if maybeName.has_value(): + return deepcopy(maybeName.value()) + + def _get_id(tok): + if isinstance(tok, int): + return f"{tok:0>8x}" + if _is_collection_base(token): + return _get_id(tok.getID()) + if isinstance(tok, podio.ObjectID): + return f"({tok.collectionID:0>8x}: {tok.index})" + return _get_id(tok.id()) + + raise KeyError(f"No collection name can be found in Frame for {_get_id(token)}") + def put(self, collection, name): """Put the collection into the frame diff --git a/python/podio/test_Frame.py b/python/podio/test_Frame.py index 08b98baea..07f65bb8b 100644 --- a/python/podio/test_Frame.py +++ b/python/podio/test_Frame.py @@ -70,6 +70,22 @@ def test_frame_invalid_access(self): collection = [1, 2, 4] _ = frame.put(collection, "invalid_collection_type") + def test_frame_get_name_invalid_token(self): + """Check that trying to get the collection name raises an exception if + the collection is not known to the frame""" + frame = Frame() + with self.assertRaises(KeyError): + _ = frame.getName(42) + + with self.assertRaises(KeyError): + coll = ExampleHitCollection() + _ = frame.getName(coll) + + with self.assertRaises(KeyError): + coll = ExampleHitCollection() + hit = coll.create() + _ = frame.getName(hit) + def test_frame_put_collection(self): """Check that putting a collection works as expected""" frame = Frame() @@ -195,3 +211,10 @@ def test_frame_parameters(self): ) # as_type='float' will also retrieve double values (if the name is unambiguous) self.assertEqual(self.event.get_parameter("SomeVectorData", as_type="float"), [0.0, 0.0]) + + def test_frame_get_name(self): + """Check that retrieving the name of a collection works as expected""" + mc_particles = self.event.get("mcparticles") + self.assertEqual(self.event.getName(mc_particles), "mcparticles") + self.assertEqual(self.event.getName(mc_particles.getID()), "mcparticles") + self.assertEqual(self.event.getName(mc_particles[0]), "mcparticles") diff --git a/tests/unittests/frame.cpp b/tests/unittests/frame.cpp index a5159de72..48774e869 100644 --- a/tests/unittests/frame.cpp +++ b/tests/unittests/frame.cpp @@ -418,6 +418,7 @@ TEST_CASE("Frame getName", "[frame][basics]") { const auto& hits = frame.get("hits"); REQUIRE(frame.getName(hits).value() == "hits"); + REQUIRE(frame.getName(hits[0]).value() == "hits"); REQUIRE_FALSE(frame.getName(0xfffffff).has_value()); } From 787edf0c7a30535a47b438c92af16f1802527893 Mon Sep 17 00:00:00 2001 From: tmadlener Date: Fri, 24 May 2024 15:32:18 +0200 Subject: [PATCH 2/6] Remove unnecessary deepcopy --- python/podio/frame.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/podio/frame.py b/python/podio/frame.py index c2ce68575..a44154c69 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -2,7 +2,6 @@ """Module for the python bindings of the podio::Frame""" import cppyy -from copy import deepcopy import ROOT @@ -126,7 +125,7 @@ def getName(self, token): """ maybeName = self._frame.getName(token) if maybeName.has_value(): - return deepcopy(maybeName.value()) + return maybeName.value() def _get_id(tok): if isinstance(tok, int): From 4c07994885ffdca8b766d3c845cf9605cd3e097b Mon Sep 17 00:00:00 2001 From: tmadlener Date: Wed, 5 Jun 2024 15:58:37 +0200 Subject: [PATCH 3/6] Remove unnecessary overloads again --- include/podio/Frame.h | 22 ---------------------- tests/unittests/frame.cpp | 1 - 2 files changed, 23 deletions(-) diff --git a/include/podio/Frame.h b/include/podio/Frame.h index 28d18145b..ce459e119 100644 --- a/include/podio/Frame.h +++ b/include/podio/Frame.h @@ -338,28 +338,6 @@ class Frame { return m_self->getIDTable().name(collectionID); } - /// Get the name of the collection for the passed object ID - /// - /// @param objId The objectID of an element of a collection for which the name - /// should be obtained - /// @returns The name of the collection or an empty optional if this - /// objectID is not known to the Frame - inline std::optional getName(const podio::ObjectID& objId) const { - return getName(objId.collectionID); - } - - /// Get the name of the collection to which this element belongs - /// - /// @tparam ElemT A datatype of a podio generated datamodel - /// @param elem The element of a collection for which the name should be - /// obtained - /// @returns The name of the collection or an empty optional if this - /// element is not known to the Frame - template >> - inline std::optional getName(const ElemT& elem) const { - return getName(elem.id().collectionID); - } - // Interfaces for writing below /// Get a collection for writing. diff --git a/tests/unittests/frame.cpp b/tests/unittests/frame.cpp index 48774e869..a5159de72 100644 --- a/tests/unittests/frame.cpp +++ b/tests/unittests/frame.cpp @@ -418,7 +418,6 @@ TEST_CASE("Frame getName", "[frame][basics]") { const auto& hits = frame.get("hits"); REQUIRE(frame.getName(hits).value() == "hits"); - REQUIRE(frame.getName(hits[0]).value() == "hits"); REQUIRE_FALSE(frame.getName(0xfffffff).has_value()); } From c25cd89351a13262bcbda0f1614cf7ce9e065abc Mon Sep 17 00:00:00 2001 From: tmadlener Date: Wed, 5 Jun 2024 16:07:05 +0200 Subject: [PATCH 4/6] Remove the handling of unnecessary overloads from python --- python/podio/frame.py | 11 ++--------- python/podio/test_Frame.py | 6 ------ 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/python/podio/frame.py b/python/podio/frame.py index a44154c69..10a216f65 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -111,10 +111,7 @@ def getName(self, token): """Get the name of the collection from the Frame Args: - token (podio.CollectionBase | int | podio.ObjectID | podio generated - datatype): Something that that can be used to get the name of a - collection. Can either be the collection itself, a collectionID - or an object ID of an element of a collection + token (podio.CollectionBase | int): A collection or its ID Returns: str: The name of the collection inside the frame @@ -130,11 +127,7 @@ def getName(self, token): def _get_id(tok): if isinstance(tok, int): return f"{tok:0>8x}" - if _is_collection_base(token): - return _get_id(tok.getID()) - if isinstance(tok, podio.ObjectID): - return f"({tok.collectionID:0>8x}: {tok.index})" - return _get_id(tok.id()) + return _get_id(tok.getID()) raise KeyError(f"No collection name can be found in Frame for {_get_id(token)}") diff --git a/python/podio/test_Frame.py b/python/podio/test_Frame.py index 07f65bb8b..aa798a8ae 100644 --- a/python/podio/test_Frame.py +++ b/python/podio/test_Frame.py @@ -81,11 +81,6 @@ def test_frame_get_name_invalid_token(self): coll = ExampleHitCollection() _ = frame.getName(coll) - with self.assertRaises(KeyError): - coll = ExampleHitCollection() - hit = coll.create() - _ = frame.getName(hit) - def test_frame_put_collection(self): """Check that putting a collection works as expected""" frame = Frame() @@ -217,4 +212,3 @@ def test_frame_get_name(self): mc_particles = self.event.get("mcparticles") self.assertEqual(self.event.getName(mc_particles), "mcparticles") self.assertEqual(self.event.getName(mc_particles.getID()), "mcparticles") - self.assertEqual(self.event.getName(mc_particles[0]), "mcparticles") From c0290e7f31c52466d42dcbafff4150672f3e29fe Mon Sep 17 00:00:00 2001 From: Thomas Madlener Date: Wed, 5 Jun 2024 16:10:10 +0200 Subject: [PATCH 5/6] Improve error message --- python/podio/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/podio/frame.py b/python/podio/frame.py index 10a216f65..fac12be33 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -129,7 +129,7 @@ def _get_id(tok): return f"{tok:0>8x}" return _get_id(tok.getID()) - raise KeyError(f"No collection name can be found in Frame for {_get_id(token)}") + raise KeyError(f"No collection name can be found in Frame for collection id: {_get_id(token)}") def put(self, collection, name): """Put the collection into the frame From 1fe61520f82171ea9e542b9eedce88acf8f0cdca Mon Sep 17 00:00:00 2001 From: Thomas Madlener Date: Wed, 5 Jun 2024 17:00:29 +0200 Subject: [PATCH 6/6] Fix style --- python/podio/frame.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/podio/frame.py b/python/podio/frame.py index fac12be33..b5651aa35 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -129,7 +129,9 @@ def _get_id(tok): return f"{tok:0>8x}" return _get_id(tok.getID()) - raise KeyError(f"No collection name can be found in Frame for collection id: {_get_id(token)}") + raise KeyError( + f"No collection name can be found in Frame for collection id: {_get_id(token)}" + ) def put(self, collection, name): """Put the collection into the frame