Skip to content

Commit

Permalink
fix error in tracking of available outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
harrypuuter committed Jan 19, 2024
1 parent c7aa063 commit 850beb9
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 24 deletions.
5 changes: 5 additions & 0 deletions code_generation/analysis_template.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ int main(int argc, char *argv[]) {
cutflow.SetTitle("cutflow");
// iterate through the cutflow vector and fill the histogram with the
// .GetPass() values
if (scope_counter >= cutReports.size()) {
Logger::get("main")->critical(
"Cutflow vector is too small, this should not happen");
return 1;
}
for (auto cut = cutReports[scope_counter].begin();
cut != cutReports[scope_counter].end(); cut++) {
cutflow.SetBinContent(
Expand Down
2 changes: 1 addition & 1 deletion code_generation/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def set_quantities_shift_map(self) -> str:
global_commands = []
outputset = list(set(self.output_commands[scope] + global_commands))
# now split by __ and get a set of all the shifts per variable
for i, output in enumerate(outputset):
for output in outputset:
try:
quantity, shift = output.split("__")
except ValueError:
Expand Down
42 changes: 23 additions & 19 deletions code_generation/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.available_sample_types = set(available_sample_types)
self.available_eras = set(available_eras)
self.available_scopes = set(available_scopes)
self.available_outputs: QuantitiesStore = {}
self.available_outputs: Dict[str, QuantitiesStore] = {}
self.available_shifts: Dict[str, Set[str]] = {}
self.global_scope = "global"

Expand Down Expand Up @@ -184,7 +184,9 @@ def setup_defaults(self) -> None:
self.unpacked_producers[scope] = {}
self.outputs[scope] = set()
self.shifts[scope] = {}
self.available_outputs[scope] = set()
self.available_outputs[scope] = {}
for sampletype in self.available_sample_types:
self.available_outputs[scope][sampletype] = set()
self.config_parameters[scope] = {}
self.available_shifts[scope] = set()
self._set_sample_parameters()
Expand Down Expand Up @@ -229,9 +231,10 @@ def add_producers(
producers = [producers]
for scope in scopes:
self.producers[scope].extend(producers)
self.available_outputs[scope].update(
CollectProducersOutput(producers, scope)
)
for sampletype in self.available_sample_types:
self.available_outputs[scope][sampletype].update(
CollectProducersOutput(producers, scope)
)
self.unpack_producergroups(scope, producers)

def unpack_producergroups(
Expand Down Expand Up @@ -409,7 +412,6 @@ def _is_valid_shift(
shift.shiftname, scope, self.available_scopes
)
)
return False
if len(self.selected_shifts) == 1 and "all" in self.selected_shifts:
return True
elif len(self.selected_shifts) == 1 and "none" in self.selected_shifts:
Expand Down Expand Up @@ -520,7 +522,8 @@ def _remove_empty_scopes(self) -> None:
del self.outputs[scope]
del self.shifts[scope]
del self.config_parameters[scope]
del self.available_outputs[scope]
for sampletype in self.available_sample_types:
del self.available_outputs[scope][sampletype]

def _apply_rules(self) -> None:
"""
Expand All @@ -532,16 +535,17 @@ def _apply_rules(self) -> None:
rule.apply(
self.sample, self.producers, self.unpacked_producers, self.outputs
)
# also update the set of available outputs
for scope in rule.affected_scopes():
if isinstance(rule, RemoveProducer):
self.available_outputs[scope] - CollectProducersOutput(
rule.affected_producers(), scope
)
else:
self.available_outputs[scope].update(
CollectProducersOutput(rule.affected_producers(), scope)
)
# also update the set of available outputs if the affected sample is the current sample
if self.sample in rule.samples:
for scope in rule.affected_scopes():
if isinstance(rule, RemoveProducer):
self.available_outputs[scope][
self.sample
] -= CollectProducersOutput(rule.affected_producers(), scope)
else:
self.available_outputs[scope][self.sample].update(
CollectProducersOutput(rule.affected_producers(), scope)
)

def optimize(self) -> None:
"""
Expand Down Expand Up @@ -588,8 +592,8 @@ def _validate_outputs(self) -> None:
)
# merge the two sets of outputs
provided_outputs = (
self.available_outputs[scope]
| self.available_outputs[self.global_scope]
self.available_outputs[scope][self.sample]
| self.available_outputs[self.global_scope][self.sample]
)
missing_outputs = required_outputs - provided_outputs
if len(missing_outputs) > 0:
Expand Down
15 changes: 11 additions & 4 deletions code_generation/friend_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,12 @@ def _shift_producer_inputs(
# only shift if necessary
if shift in self.input_quantities_mapping[scope].keys():
inputs_to_shift = []
for input in inputs:
if input.name in self.input_quantities_mapping[scope][shift]:
inputs_to_shift.append(input)
for input_quantity in inputs:
if (
input_quantity.name
in self.input_quantities_mapping[scope][shift]
):
inputs_to_shift.append(input_quantity)
if len(inputs_to_shift) > 0:
log.debug("Adding shift %s to producer %s", shift, producer)
producer.shift(shiftname, scope)
Expand Down Expand Up @@ -335,7 +338,7 @@ def _validate_outputs(self) -> None:
for scope in [scope for scope in self.scopes]:
required_outputs = set(output for output in self.outputs[scope])
# merge the two sets of outputs
provided_outputs = self.available_outputs[scope]
provided_outputs = self.available_outputs[scope][self.sample]
missing_outputs = required_outputs - provided_outputs
if len(missing_outputs) > 0:
raise InvalidOutputError(scope, missing_outputs)
Expand Down Expand Up @@ -415,6 +418,10 @@ def expanded_configuration(self) -> Configuration:
expanded_configuration[scope] = {}
if self.run_nominal:
log.debug("Adding nominal in scope {}".format(scope))
if scope not in self.config_parameters.keys():

Check failure on line 421 in code_generation/friend_trees.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

code_generation/friend_trees.py#L421

Access to member 'config_parameters' before its definition line 448
raise ConfigurationError(
"Scope {} not found in configuration parameters".format(scope)
)
expanded_configuration[scope]["nominal"] = self.config_parameters[scope]
if len(self.shifts[scope]) > 0:
for shift in self.shifts[scope]:
Expand Down
1 change: 1 addition & 0 deletions code_generation/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def apply(
outputs_to_be_updated: QuantitiesStore,
) -> None:
if self.is_applicable(sample):
log.critical(f"Applying rule {self} for sample {sample}")
log.debug("For sample {}, applying >> {} ".format(sample, self))
self.update_producers(producers_to_be_updated, unpacked_producers)
self.update_outputs(outputs_to_be_updated)
Expand Down

0 comments on commit 850beb9

Please sign in to comment.