diff --git a/python/podio/frame.py b/python/podio/frame.py index e35ad6bac..b5651aa35 100644 --- a/python/podio/frame.py +++ b/python/podio/frame.py @@ -107,6 +107,32 @@ def get(self, name): raise KeyError(f"Collection '{name}' is not available") return collection + def getName(self, token): + """Get the name of the collection from the Frame + + Args: + token (podio.CollectionBase | int): A collection or its ID + + Returns: + str: The name of the collection inside the frame + + Raises: + KeyError: If no collection can be found in the frame + + """ + maybeName = self._frame.getName(token) + if maybeName.has_value(): + return maybeName.value() + + def _get_id(tok): + if isinstance(tok, int): + return f"{tok:0>8x}" + return _get_id(tok.getID()) + + raise KeyError( + f"No collection name can be found in Frame for collection id: {_get_id(token)}" + ) + def put(self, collection, name): """Put the collection into the frame diff --git a/python/podio/test_Frame.py b/python/podio/test_Frame.py index 08b98baea..aa798a8ae 100644 --- a/python/podio/test_Frame.py +++ b/python/podio/test_Frame.py @@ -70,6 +70,17 @@ def test_frame_invalid_access(self): collection = [1, 2, 4] _ = frame.put(collection, "invalid_collection_type") + def test_frame_get_name_invalid_token(self): + """Check that trying to get the collection name raises an exception if + the collection is not known to the frame""" + frame = Frame() + with self.assertRaises(KeyError): + _ = frame.getName(42) + + with self.assertRaises(KeyError): + coll = ExampleHitCollection() + _ = frame.getName(coll) + def test_frame_put_collection(self): """Check that putting a collection works as expected""" frame = Frame() @@ -195,3 +206,9 @@ def test_frame_parameters(self): ) # as_type='float' will also retrieve double values (if the name is unambiguous) self.assertEqual(self.event.get_parameter("SomeVectorData", as_type="float"), [0.0, 0.0]) + + def test_frame_get_name(self): + """Check that retrieving the name of a collection works as expected""" + mc_particles = self.event.get("mcparticles") + self.assertEqual(self.event.getName(mc_particles), "mcparticles") + self.assertEqual(self.event.getName(mc_particles.getID()), "mcparticles")