diff --git a/loki/analyse/analyse_dataflow.py b/loki/analyse/analyse_dataflow.py index 7de456f15..13021b5f8 100644 --- a/loki/analyse/analyse_dataflow.py +++ b/loki/analyse/analyse_dataflow.py @@ -189,8 +189,14 @@ def visit_MultiConditional(self, o, **kwargs): def visit_MaskedStatement(self, o, **kwargs): live = kwargs.pop('live_symbols', set()) conditions = self._symbols_from_expr(o.conditions) - body, defines, uses = self._visit_body(o.bodies, live=live, uses=conditions, **kwargs) - body = tuple(as_tuple(b,) for b in body) + + body = () + defines = set() + uses = set(conditions) + for b in o.bodies: + _b, defines, uses = self._visit_body(b, live=live, uses=uses, defines=defines, **kwargs) + body += (_b,) + default, default_defs, uses = self._visit_body(o.default, live=live, uses=uses, **kwargs) o._update(bodies=body, default=default) return self.visit_Node(o, live_symbols=live, defines_symbols=defines|default_defs, uses_symbols=uses, **kwargs) diff --git a/loki/analyse/tests/test_analyse_dataflow.py b/loki/analyse/tests/test_analyse_dataflow.py index cb7975659..a8e77bfb5 100644 --- a/loki/analyse/tests/test_analyse_dataflow.py +++ b/loki/analyse/tests/test_analyse_dataflow.py @@ -497,6 +497,7 @@ def test_analyse_maskedstatement(frontend): where (mask(:) < -5) vec1(:) = -5.0 + vec1(:) = vec1(:) -5.0 elsewhere (mask(:) > 5) vec1(:) = 5.0 elsewhere @@ -508,12 +509,15 @@ def test_analyse_maskedstatement(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) mask = FindNodes(MaskedStatement).visit(routine.body)[0] + num_bodies = len(mask.bodies) with dataflow_analysis_attached(routine): assert len(mask.uses_symbols) == 1 assert len(mask.defines_symbols) == 1 assert 'mask' in mask.uses_symbols assert 'vec1' in mask.defines_symbols + assert len(mask.bodies) == num_bodies + @pytest.mark.parametrize('frontend', available_frontends()) def test_analyse_whileloop(frontend):