Skip to content

Commit

Permalink
Don't start the counter at 0 to avoid hitting empty defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
tmadlener committed Sep 16, 2024
1 parent e1988f5 commit 5ee96af
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 42 deletions.
44 changes: 23 additions & 21 deletions scripts/createEDM4hepFile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
FRAMES = 3 # How many frames or events will be written
VECTORSIZE = 5 # For vector members, each vector member will have this size

COUNT_START = 42 # Where to the counter from


def create_EventHeaderCollection(vectorsize):
"""Create an EventHeaderCollection"""
counter = count()
counter = count(COUNT_START)
header = edm4hep.EventHeaderCollection()
h = header.create()
h.setEventNumber(next(counter))
Expand All @@ -31,7 +33,7 @@ def create_EventHeaderCollection(vectorsize):

def create_MCParticleCollection():
"""Create an MCParticleCollection"""
counter = count()
counter = count(COUNT_START)
particles = edm4hep.MCParticleCollection()
p_list = []
for i in range(3):
Expand Down Expand Up @@ -65,7 +67,7 @@ def create_MCParticleCollection():

def create_SimTrackerHitCollection(particle):
"""Create a SimTrackerHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = edm4hep.SimTrackerHitCollection()
hit = hits.create()
hit.setCellID(next(counter))
Expand All @@ -81,7 +83,7 @@ def create_SimTrackerHitCollection(particle):

def create_CaloHitContributionCollection(particle):
"""Create a CaloHitContributionCollection"""
counter = count()
counter = count(COUNT_START)
hits = edm4hep.CaloHitContributionCollection()
hit = hits.create()
hit.setPDG(next(counter))
Expand All @@ -94,7 +96,7 @@ def create_CaloHitContributionCollection(particle):

def create_SimCalorimeterHitCollection(calo_contrib):
"""Create a SimCalorimeterHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = edm4hep.SimCalorimeterHitCollection()
hit = hits.create()
hit.setCellID(next(counter))
Expand All @@ -106,7 +108,7 @@ def create_SimCalorimeterHitCollection(calo_contrib):

def create_RawCalorimeterHitCollection():
"""Crate a RawCalorimeterHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = edm4hep.RawCalorimeterHitCollection()
hit = hits.create()
hit.setCellID(next(counter))
Expand All @@ -117,7 +119,7 @@ def create_RawCalorimeterHitCollection():

def create_CalorimeterHitCollection():
"""Create a CalorimeterHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = edm4hep.CalorimeterHitCollection()
hit = hits.create()
hit.setCellID(next(counter))
Expand All @@ -135,7 +137,7 @@ def create_CovMatrixNf(n_dim):
return ValueError(
f"{n_dim} is not a valid dimension for a CovMatrix in EDM4hep. Valid: (2, 3, 4, 6)"
)
counter = count()
counter = count(COUNT_START)
# With the current version of cppyy (from ROOT 6.30.06)
# it's not possible to initialize an std::array from a list
# In future versions (works with 6.32.02) it will be possible
Expand All @@ -147,7 +149,7 @@ def create_CovMatrixNf(n_dim):

def create_ParticleIDCollection(vectorsize):
"""Create a ParticleIDCollection"""
counter = count()
counter = count(COUNT_START)
pids = edm4hep.ParticleIDCollection()
pid = pids.create()
pid.setType(next(counter))
Expand All @@ -162,7 +164,7 @@ def create_ParticleIDCollection(vectorsize):

def create_ClusterCollection(vectorsize, calo_hit):
"""Create a ClusterCollection"""
counter = count()
counter = count(COUNT_START)
clusters = edm4hep.ClusterCollection()
cluster = clusters.create()
cluster.setType(next(counter))
Expand All @@ -186,7 +188,7 @@ def create_ClusterCollection(vectorsize, calo_hit):

def create_TrackerHit3DCollection():
"""Create a TrackerHit3DCollection"""
counter = count()
counter = count(COUNT_START)
hits = edm4hep.TrackerHit3DCollection()
hit = hits.create()
hit.setCellID(next(counter))
Expand All @@ -202,7 +204,7 @@ def create_TrackerHit3DCollection():

def create_TrackerHitPlaneCollection():
"""Create a TrackerHitPlaneCollection"""
counter = count()
counter = count(COUNT_START)
hits = edm4hep.TrackerHitPlaneCollection()
hit = hits.create()
hit.setCellID(next(counter))
Expand All @@ -222,7 +224,7 @@ def create_TrackerHitPlaneCollection():

def create_RawTimeSeriesCollection(vectorsize):
"""Create a RawTimeSeriesCollection"""
counter = count()
counter = count(COUNT_START)
series = edm4hep.RawTimeSeriesCollection()
serie = series.create()
serie.setCellID(next(counter))
Expand All @@ -237,7 +239,7 @@ def create_RawTimeSeriesCollection(vectorsize):

def create_TrackCollection(vectorsize, tracker_hit):
"""Create a TrackCollection"""
counter = count()
counter = count(COUNT_START)
tracks = edm4hep.TrackCollection()
track = tracks.create()
track.setType(next(counter))
Expand Down Expand Up @@ -268,7 +270,7 @@ def create_TrackCollection(vectorsize, tracker_hit):

def create_VertexCollection(vectorsize):
"""Create a VertexCollection"""
counter = count()
counter = count(COUNT_START)
vertex = edm4hep.VertexCollection()
v = vertex.create()
v.setType(next(counter))
Expand All @@ -284,7 +286,7 @@ def create_VertexCollection(vectorsize):

def create_ReconstructedParticleCollection(vertex, cluster, track):
"""Create a ReconstructedParticleCollection"""
counter = count()
counter = count(COUNT_START)
rparticles = edm4hep.ReconstructedParticleCollection()
rparticle = rparticles.create()
rparticle.setPDG(next(counter))
Expand All @@ -306,7 +308,7 @@ def create_ReconstructedParticleCollection(vertex, cluster, track):

def create_LinkCollection(coll_type, from_el, to_el):
"""Create a LinkCollection of the given type and add one link to it"""
counter = count()
counter = count(COUNT_START)
links = coll_type()
link = links.create()
link.setWeight(next(counter))
Expand All @@ -317,7 +319,7 @@ def create_LinkCollection(coll_type, from_el, to_el):

def create_TimeSeriesCollection(vectorsize):
"""Create a TimeSeriesCollection"""
counter = count()
counter = count(COUNT_START)
timeseries = edm4hep.TimeSeriesCollection()
serie = timeseries.create()
serie.setCellID(next(counter))
Expand All @@ -330,7 +332,7 @@ def create_TimeSeriesCollection(vectorsize):

def create_RecDqdxCollection(track):
"""Create a RecDqdxCollection"""
counter = count()
counter = count(COUNT_START)
recdqdx = edm4hep.RecDqdxCollection()
dqdx = recdqdx.create()
q = edm4hep.Quantity()
Expand All @@ -344,7 +346,7 @@ def create_RecDqdxCollection(track):

def create_GeneratorEventParametersCollection(vectorsize, particle):
"""Create a GeneratorEventParametersCollection"""
counter = count()
counter = count(COUNT_START)
gep_coll = edm4hep.GeneratorEventParametersCollection()
gep = gep_coll.create()
gep.setEventScale(next(counter))
Expand All @@ -361,7 +363,7 @@ def create_GeneratorEventParametersCollection(vectorsize, particle):

def create_GeneratorPdfInfoCollection():
"""Create a GeneratorPdfInfoCollection"""
counter = count()
counter = count(COUNT_START)
gpi_coll = edm4hep.GeneratorPdfInfoCollection()
gpi = gpi_coll.create()
# Doesn't work with ROOT 6.30.06
Expand Down
43 changes: 22 additions & 21 deletions test/test_EDM4hepFile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# For now simply copy these from createEDM4hepFile.py
FRAMES = 3
VECTORSIZE = 5
COUNT_START = 42 # Where to the counter from


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -61,7 +62,7 @@ def track(event):

def check_cov_matrix(cov_matrix, n_dim):
"""Check the contents of the passed covariance matrix"""
counter = count()
counter = count(COUNT_START)
# for val in cov_matrix: # Doesn't work as expected with root
for i in range(n_dim * (n_dim + 1) // 2):
assert cov_matrix[i] == next(counter)
Expand All @@ -74,7 +75,7 @@ def test_basic_file_contents(events):

def test_EventHeaderCollection(event):
"""Check an EventHeaderCollection"""
counter = count()
counter = count(COUNT_START)
header = event.get("EventHeader")
assert len(header) == 1
h = header[0]
Expand All @@ -89,7 +90,7 @@ def test_EventHeaderCollection(event):

def test_MCParticleCollection(event):
"""Check the MCParticleCollection"""
counter = count()
counter = count(COUNT_START)
particles = event.get("MCParticleCollection")
assert len(particles) == 3
for i in range(3):
Expand Down Expand Up @@ -123,7 +124,7 @@ def test_MCParticleCollection(event):

def test_SimTrackerHitCollection(event, particle):
"""Check the SimTrackerHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = event.get("SimTrackerHitCollection")
assert len(hits) == 1
hit = hits[0]
Expand All @@ -144,7 +145,7 @@ def test_SimTrackerHitCollection(event, particle):

def test_CaloHitContributionCollection(event, particle):
"""Check the CaloHitContributionCollection"""
counter = count()
counter = count(COUNT_START)
hits = event.get("CaloHitContributionCollection")
assert len(hits) == 1
hit = hits[0]
Expand All @@ -160,7 +161,7 @@ def test_CaloHitContributionCollection(event, particle):

def test_SimCalorimeterHitCollection(event):
"""Check the SimCalorimeterHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = event.get("SimCalorimeterHitCollection")
assert len(hits) == 1
hit = hits[0]
Expand All @@ -177,7 +178,7 @@ def test_SimCalorimeterHitCollection(event):

def test_RawCalorimeterHitCollection(event):
"""Check the RawCalorimeterHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = event.get("RawCalorimeterHitCollection")
assert len(hits) == 1
hit = hits[0]
Expand All @@ -188,7 +189,7 @@ def test_RawCalorimeterHitCollection(event):

def test_CalorimeterHitCollection(event):
"""Check the CalorimeterHitCollection"""
counter = count()
counter = count(COUNT_START)
hits = event.get("CalorimeterHitCollection")
assert len(hits) == 1
hit = hits[0]
Expand All @@ -204,7 +205,7 @@ def test_CalorimeterHitCollection(event):

def test_ParticleIDCollection(event, reco_particle):
"""Check the ParticleIDCollection"""
counter = count()
counter = count(COUNT_START)
pids = event.get("ParticleIDCollection")
assert len(pids) == 1
pid = pids[0]
Expand All @@ -221,7 +222,7 @@ def test_ParticleIDCollection(event, reco_particle):

def test_ClusterCollection(event):
"""Check the ClusterCollection"""
counter = count()
counter = count(COUNT_START)
clusters = event.get("ClusterCollection")
assert len(clusters) == 1
cluster = clusters[0]
Expand Down Expand Up @@ -255,7 +256,7 @@ def test_ClusterCollection(event):

def test_TrackerHit3DCollection(event):
"""Check the TrackerHit3DCollection"""
counter = count()
counter = count(COUNT_START)
hits = event.get("TrackerHit3DCollection")
assert len(hits) == 1
hit = hits[0]
Expand All @@ -273,7 +274,7 @@ def test_TrackerHit3DCollection(event):

def test_TrackerHitPlaneCollection(event):
"""Check the TrackerHitPlaneCollection"""
counter = count()
counter = count(COUNT_START)
hits = event.get("TrackerHitPlaneCollection")
assert len(hits) == 1
hit = hits[0]
Expand All @@ -295,7 +296,7 @@ def test_TrackerHitPlaneCollection(event):

def test_RawTimeSeriesCollection(event):
"""Check a RawTimeSeriesCollection"""
counter = count()
counter = count(COUNT_START)
series = event.get("RawTimeSeriesCollection")
assert len(series) == 1
serie = series[0]
Expand All @@ -311,7 +312,7 @@ def test_RawTimeSeriesCollection(event):

def test_TrackCollection(event):
"""Check the TrackCollection"""
counter = count()
counter = count(COUNT_START)
tracks = event.get("TrackCollection")
assert len(tracks) == 1
track = tracks[0]
Expand Down Expand Up @@ -355,7 +356,7 @@ def test_TrackCollection(event):

def test_VertexCollection(event, reco_particle):
"""Check the VertexCollection"""
counter = count()
counter = count(COUNT_START)
vertex = event.get("VertexCollection")
assert len(vertex) == 1
v = vertex[0]
Expand All @@ -377,7 +378,7 @@ def test_VertexCollection(event, reco_particle):

def test_ReconstructedParticleCollection(event, track):
"""Check the ReconstructedParticleCollection"""
counter = count()
counter = count(COUNT_START)
rparticles = event.get("ReconstructedParticleCollection")
assert len(rparticles) == 1
rparticle = rparticles[0]
Expand Down Expand Up @@ -410,7 +411,7 @@ def test_ReconstructedParticleCollection(event, track):

def test_TimeSeriesCollection(event):
"""Check the TimeSeriesCollection"""
counter = count()
counter = count(COUNT_START)
timeseries = event.get("TimeSeriesCollection")
assert len(timeseries) == 1
serie = timeseries[0]
Expand All @@ -424,7 +425,7 @@ def test_TimeSeriesCollection(event):

def test_RecDqdxCollection(event, track):
"""Check the RecDqdxCollection"""
counter = count()
counter = count(COUNT_START)
recdqdx = event.get("RecDqdxCollection")
assert len(recdqdx) == 1
dqdx = recdqdx[0]
Expand All @@ -437,7 +438,7 @@ def test_RecDqdxCollection(event, track):

def test_GeneratorEventParametersCollection(event, particle):
"""Check the GeneratorEventParametersCollection"""
counter = count()
counter = count(COUNT_START)
gep_coll = event.get("GeneratorEventParametersCollection")
assert len(gep_coll) == 1
gep = gep_coll[0]
Expand All @@ -460,7 +461,7 @@ def test_GeneratorEventParametersCollection(event, particle):

def test_GeneratorPdfInfoCollection(event):
"""Check the GeneratorPdfInfoCollection"""
counter = count()
counter = count(COUNT_START)
gpi_coll = event.get("GeneratorPdfInfoCollection")
assert len(gpi_coll) == 1
gpi = gpi_coll[0]
Expand All @@ -478,7 +479,7 @@ def test_GeneratorPdfInfoCollection(event):

def check_LinkCollection(event, coll_type, from_el, to_el):
"""Check a single link collection of a given type"""
counter = count()
counter = count(COUNT_START)
coll_name = f"{coll_type}Collection"
link_coll = event.get(coll_name)
assert len(link_coll) == 1
Expand Down

0 comments on commit 5ee96af

Please sign in to comment.