Skip to content

Commit

Permalink
Add python bindings for Frame::getName (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmadlener authored Jun 5, 2024
1 parent 7db028f commit f9ad017
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/podio/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,32 @@ def get(self, name):
raise KeyError(f"Collection '{name}' is not available")
return collection

def getName(self, token):
"""Get the name of the collection from the Frame
Args:
token (podio.CollectionBase | int): A collection or its ID
Returns:
str: The name of the collection inside the frame
Raises:
KeyError: If no collection can be found in the frame
"""
maybeName = self._frame.getName(token)
if maybeName.has_value():
return maybeName.value()

def _get_id(tok):
if isinstance(tok, int):
return f"{tok:0>8x}"
return _get_id(tok.getID())

raise KeyError(
f"No collection name can be found in Frame for collection id: {_get_id(token)}"
)

def put(self, collection, name):
"""Put the collection into the frame
Expand Down
17 changes: 17 additions & 0 deletions python/podio/test_Frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def test_frame_invalid_access(self):
collection = [1, 2, 4]
_ = frame.put(collection, "invalid_collection_type")

def test_frame_get_name_invalid_token(self):
"""Check that trying to get the collection name raises an exception if
the collection is not known to the frame"""
frame = Frame()
with self.assertRaises(KeyError):
_ = frame.getName(42)

with self.assertRaises(KeyError):
coll = ExampleHitCollection()
_ = frame.getName(coll)

def test_frame_put_collection(self):
"""Check that putting a collection works as expected"""
frame = Frame()
Expand Down Expand Up @@ -195,3 +206,9 @@ def test_frame_parameters(self):
)
# as_type='float' will also retrieve double values (if the name is unambiguous)
self.assertEqual(self.event.get_parameter("SomeVectorData", as_type="float"), [0.0, 0.0])

def test_frame_get_name(self):
"""Check that retrieving the name of a collection works as expected"""
mc_particles = self.event.get("mcparticles")
self.assertEqual(self.event.getName(mc_particles), "mcparticles")
self.assertEqual(self.event.getName(mc_particles.getID()), "mcparticles")

0 comments on commit f9ad017

Please sign in to comment.