Skip to content

Commit

Permalink
fixed cms-l1t-offline#176: read file-by-file
Browse files Browse the repository at this point in the history
  • Loading branch information
kreczko committed Aug 2, 2019
1 parent 8091e21 commit ac03dce
Showing 1 changed file with 44 additions and 27 deletions.
71 changes: 44 additions & 27 deletions cmsl1t/io/eventreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,41 +65,40 @@ def __init__(self, input_files, ntuple_map, nevents=-1, vectorized=False, batch_
self.nevents = nevents
self._trees = {}

self._used_arrays = False
self._used_trees = False
self._used_arrays = vectorized
self._used_trees = not vectorized
self._passed_events = 0
self._batch_size = batch_size

if vectorized:
self._load_arrays()
else:
self._load_trees()
self._numentries = uproot.numentries(
self.input_files,
list(self._treeNames)[0],
total=False,
)

def _load_trees(self):
def _load_trees(self, input_file):
for treeName in self._treeNames:
try:
self._trees[treeName] = TreeChain(
treeName,
self.input_files,
[input_file],
cache=True,
events=self.nevents,
)
logger.debug("Successfully loaded {0}".format(treeName))
except RuntimeError as e:
logger.warning(
"Cannot find tree: {0} in input file".format(treeName))
logger.warning("Cannot find tree: {0} in input file".format(treeName))
logger.error(e)
if treeName in self._trees:
logger.warning('DELETING TREE')
del self._trees[treeName]
continue
self._used_trees = True

def _load_arrays(self):
def _load_arrays(self, input_file):
for treeName in self._treeNames:
try:
self._trees[treeName] = uproot.iterate(
self.input_files,
[input_file],
treeName,
entrysteps=self._batch_size,
)
Expand All @@ -112,27 +111,45 @@ def _load_arrays(self):
logger.warning('DELETING TREE')
del self._trees[treeName]
continue
self._used_arrays = True

def __contains__(self, name):
return name in self._aliasMap.keys()

def __iter__(self):
# event loop
try:
if self._used_trees:
for trees in six.moves.zip(*six.itervalues(self._trees)):
yield Event(self._trees, self._aliasMap)
for input_file in self.input_files:
nevents = self._numentries[input_file]
logger.info('Opening file {} ({} events)'.format(input_file, nevents))
if self._used_arrays:
for treeGen in six.moves.zip(*six.itervalues(self._trees)):
data = dict(six.moves.zip(self._trees, treeGen))
yield UprootEvent(data, self._aliasMap, batch_size=self._batch_size)
self._passed_events += self._batch_size
if self.nevents > 0 and self._passed_events >= self.nevents:
break
except Exception as e:
logger.critical("Error when reading data from ROOT file: {}".format(e))
sys.exit(-1)
self._load_arrays(input_file)
else:
self._load_trees(input_file)
try:
eventGenerator = self.new_event
if self._used_arrays:
eventGenerator = self.new_uproot_event
for event in eventGenerator():
yield event
except Exception as e:
logger.critical("Error when reading data from ROOT file: {}".format(e))
sys.exit(-1)
logger.info('Closing file {}'.format(input_file))

def new_event(self):
for trees in six.moves.zip(*six.itervalues(self._trees)):
yield Event(self._trees, self._aliasMap)

def new_uproot_event(self):
for treeGen in six.moves.zip(*six.itervalues(self._trees)):
data = dict(six.moves.zip(self._trees, treeGen))
yield UprootEvent(data, self._aliasMap, batch_size=self._batch_size)
self._passed_events += self._batch_size
if self.nevents > 0 and self._passed_events >= self.nevents:
break

@property
def numentries(self):
return sum([x for x in self._numentries.values()])


class UprootEvent(object):
Expand Down

0 comments on commit ac03dce

Please sign in to comment.