Skip to content

Commit

Permalink
SanitiseTransform: Create "housekeeping" transformation for scheduler
Browse files Browse the repository at this point in the history
This so far includes resolving associates and sequence associations,
but could be extended into a general-purpose input unification pipeline.
  • Loading branch information
mlange05 committed Jan 16, 2024
1 parent 1ef01c5 commit 67b54a9
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 188 deletions.
3 changes: 1 addition & 2 deletions loki/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from loki.transform.transformation import * # noqa
from loki.transform.transform_utilities import * # noqa
from loki.transform.transform_array_indexing import * # noqa
from loki.transform.transform_associates import * # noqa
from loki.transform.transform_inline import * # noqa
from loki.transform.transform_loop import * # noqa
from loki.transform.transform_region import * # noqa
Expand All @@ -20,5 +19,5 @@
from loki.transform.transform_hoist_variables import * # noqa
from loki.transform.transform_parametrise import * # noqa
from loki.transform.transform_extract_contained_procedures import * # noqa
from loki.transform.transform_sequence_association import * # noqa
from loki.transform.transform_dead_code import * # noqa
from loki.transform.transform_sanitise import * # noqa
2 changes: 1 addition & 1 deletion loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
resolve_vector_notation, normalize_array_shape_and_access,
flatten_arrays
)
from loki.transform.transform_associates import resolve_associates
from loki.transform.transform_sanitise import resolve_associates
from loki.transform.transform_utilities import (
convert_to_lower_case, replace_intrinsics, sanitise_imports
)
Expand Down
2 changes: 1 addition & 1 deletion loki/transform/fortran_python_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from loki.transform.transform_array_indexing import (
shift_to_zero_indexing, invert_array_indices, normalize_range_indexing
)
from loki.transform.transform_associates import resolve_associates
from loki.transform.transform_sanitise import resolve_associates
from loki.transform.transform_utilities import (
convert_to_lower_case, replace_intrinsics
)
Expand Down
69 changes: 0 additions & 69 deletions loki/transform/transform_associates.py

This file was deleted.

19 changes: 0 additions & 19 deletions loki/transform/transform_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

from loki.transform.transformation import Transformation
from loki.transform.transform_dead_code import dead_code_elimination
from loki.transform.transform_sequence_association import transform_sequence_association
from loki.transform.transform_associates import resolve_associates
from loki.transform.transform_utilities import (
single_variable_declaration,
recursive_expression_map_update
Expand Down Expand Up @@ -60,11 +58,6 @@ class InlineTransformation(Transformation):
inline_marked : bool
Inline :any:`Subroutine` objects marked by pragma annotations;
default: True.
resolve_associate_mappings : bool
Resolve ASSOCIATE mappings in body of processed subroutines; default: True.
resolve_sequence_association : bool
Replace scalars that are passed to array arguments with array
ranges; default: False.
eliminate_dead_code : bool
Perform dead code elimination, where unreachable branches are
trimmed from the code; default@ True
Expand All @@ -84,7 +77,6 @@ class InlineTransformation(Transformation):
def __init__(
self, inline_constants=False, inline_elementals=True,
inline_internals=False, inline_marked=True,
resolve_associate_mappings=True, resolve_sequence_association=False,
eliminate_dead_code=True, allowed_aliases=None,
remove_imports=True, external_only=True
):
Expand All @@ -93,8 +85,6 @@ def __init__(
self.inline_internals = inline_internals
self.inline_marked = inline_marked

self.resolve_associate_mappings = resolve_associate_mappings
self.resolve_sequence_association = resolve_sequence_association
self.eliminate_dead_code = eliminate_dead_code

self.allowed_aliases = allowed_aliases
Expand All @@ -103,15 +93,6 @@ def __init__(

def transform_subroutine(self, routine, **kwargs):

# Associates at the highest level, so they don't interfere
# with the sections we need to do for detecting subroutine calls
if self.resolve_associate_mappings:
resolve_associates(routine)

# Transform arrays passed with scalar syntax to array syntax
if self.resolve_sequence_association:
transform_sequence_association(routine)

# Replace constant parameter variables with explicit values
if self.inline_constants:
inline_constant_parameters(routine, external_only=self.external_only)
Expand Down
193 changes: 193 additions & 0 deletions loki/transform/transform_sanitise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


"""
A small selection of utility transformations that resolve certain code
constructs to unify code structure and make reasoning about Fortran
code easier.
"""

from loki.expression import FindVariables, SubstituteExpressions, Array, RangeIndex
from loki.ir import CallStatement
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import BasicType
from loki.visitors import FindNodes, Transformer, NestedTransformer

from loki.transform.transform_utilities import recursive_expression_map_update
from loki.transform.transformation import Transformation


__all__ = [
'SanitiseTransformation', 'resolve_associates',
'ResolveAssociatesTransformer', 'transform_sequence_association'
]


class SanitiseTransformation(Transformation):
"""
:any:`Transformation` object to apply several code sanitisation
steps when batch-processing large source trees via the :any:`Scheduler`.
Parameters
----------
resolve_associate_mappings : bool
Resolve ASSOCIATE mappings in body of processed subroutines; default: True.
resolve_sequence_association : bool
Replace scalars that are passed to array arguments with array
ranges; default: False.
"""

def __init__(
self, resolve_associate_mappings=True, resolve_sequence_association=False
):
self.resolve_associate_mappings = resolve_associate_mappings
self.resolve_sequence_association = resolve_sequence_association

def transform_subroutine(self, routine, **kwargs):

# Associates at the highest level, so they don't interfere
# with the sections we need to do for detecting subroutine calls
if self.resolve_associate_mappings:
resolve_associates(routine)

# Transform arrays passed with scalar syntax to array syntax
if self.resolve_sequence_association:
transform_sequence_association(routine)


def resolve_associates(routine):
"""
Resolve :any:`Associate` mappings in the body of a given routine.
Parameters
----------
routine : :any:`Subroutine`
The subroutine for which to resolve all associate blocks.
"""
routine.body = ResolveAssociatesTransformer().visit(routine.body)

# Ensure that all symbols have the appropriate scope attached.
# This is needed, as the parent of a symbol might have changed,
# which affects the symbol's type-defining scope.
routine.rescope_symbols()


class ResolveAssociatesTransformer(NestedTransformer):
"""
:any:`Transformer` class to resolve :any:`Associate` nodes in IR trees
This will replace each :any:`Associate` node with its own body,
where all `identifier` symbols have been replaced with the
corresponding `selector` expression defined in ``associations``.
"""

def visit_Associate(self, o, **kwargs):
# First head-recurse, so that all associate blocks beneath are resolved
body = self.visit(o.body, **kwargs)

# Create an inverse association map to look up replacements
invert_assoc = CaseInsensitiveDict({v.name: k for k, v in o.associations})

# Build the expression substitution map
vmap = {}
for v in FindVariables().visit(body):
if v.name in invert_assoc:
# Clone the expression to update its parentage and scoping
inv = invert_assoc[v.name]
if hasattr(v, 'dimensions'):
vmap[v] = inv.clone(dimensions=v.dimensions)
else:
vmap[v] = inv

# Apply the expression substitution map to itself to handle nested expressions
vmap = recursive_expression_map_update(vmap)

# Mark the associate block for replacement with its body, with all expressions replaced
self.mapper[o] = SubstituteExpressions(vmap).visit(body)

# Return the original object unchanged and let the tuple injection mechanism take care
# of replacing it by its body - otherwise we would end up with nested tuples
return o


def check_if_scalar_syntax(arg, dummy):
"""
Check if an array argument, arg,
is passed to an array dummy argument, dummy,
using scalar syntax. i.e. arg(1,1) -> d(m,n)
Parameters
----------
arg: variable
dummy: variable
"""
if isinstance(arg, Array) and isinstance(dummy, Array):
if arg.dimensions:
if not any(isinstance(d, RangeIndex) for d in arg.dimensions):
return True
return False


def transform_sequence_association(routine):
"""
Housekeeping routine to replace scalar syntax when passing arrays as arguments
For example, a call like
real :: a(m,n)
call myroutine(a(i,j))
where myroutine looks like
subroutine myroutine(a)
real :: a(5)
end subroutine myroutine
should be changed to
call myroutine(a(i:m,j)
Parameters
----------
routine : :any:`Subroutine`
The subroutine where calls will be changed
"""

#List calls in routine, but make sure we have the called routine definition
calls = (c for c in FindNodes(CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED)
call_map = {}

for call in calls:

new_args = []

found_scalar = False
for dummy, arg in call.arg_map.items():
if check_if_scalar_syntax(arg, dummy):
found_scalar = True

n_dims = len(dummy.shape)
new_dims = []
for s, lower in zip(arg.shape[:n_dims], arg.dimensions[:n_dims]):

if isinstance(s, RangeIndex):
new_dims += [RangeIndex((lower, s.stop))]
else:
new_dims += [RangeIndex((lower, s))]

if len(arg.dimensions) > n_dims:
new_dims += arg.dimensions[len(dummy.shape):]
new_args += [arg.clone(dimensions=as_tuple(new_dims)),]
else:
new_args += [arg,]

if found_scalar:
call_map[call] = call.clone(arguments = as_tuple(new_args))

if call_map:
routine.body = Transformer(call_map).visit(routine.body)
Loading

0 comments on commit 67b54a9

Please sign in to comment.