Skip to content

Commit

Permalink
Merge pull request #303 from ecmwf-ifs/naan-block-index-inject
Browse files Browse the repository at this point in the history
Block-index injection transformations
  • Loading branch information
reuterbal authored May 28, 2024
2 parents f3e7d90 + 6707094 commit 77114a9
Show file tree
Hide file tree
Showing 11 changed files with 913 additions and 25 deletions.
2 changes: 1 addition & 1 deletion loki/backend/maxgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class <name> extends Kernel {

# Class signature
if is_manager:
if is_interface:
if is_interface: # pylint: disable=possibly-used-before-assignment
header += [self.format_line(
'public interface ', o.name, ' extends ManagerPCIe, ManagerKernel {')]
else:
Expand Down
18 changes: 17 additions & 1 deletion loki/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ class Dimension:
bounds_aliases : list or tuple of strings
String representations of alternative bounds variables that are
used to define loop ranges.
index_aliases : list or tuple of strings
String representations of alternative loop index variables associated
with this dimension.
"""

def __init__(self, name=None, index=None, bounds=None, size=None, aliases=None,
bounds_aliases=None):
bounds_aliases=None, index_aliases=None):
self.name = name
self._index = index
self._bounds = as_tuple(bounds)
self._size = size
self._aliases = as_tuple(aliases)
self._index_aliases = as_tuple(index_aliases)

if bounds_aliases:
if len(bounds_aliases) != 2:
Expand Down Expand Up @@ -118,3 +122,15 @@ def bounds_expressions(self):
exprs = [expr + (b,) for expr, b in zip(exprs, self._bounds_aliases)]

return as_tuple(exprs)

@property
def index_expressions(self):
"""
A list of all expression strings representing the index expression of an iteration space (loop).
"""

exprs = [self.index,]
if self._index_aliases:
exprs += list(self._index_aliases)

return as_tuple(exprs)
1 change: 1 addition & 0 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def visit_Char_Selector(self, o, **kwargs):
* some scalar expression for the kind
"""
length = None
kind = None
if o.children[0] is not None:
length = self.visit(o.children[0], **kwargs)
if o.children[1] is not None:
Expand Down
9 changes: 8 additions & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def from_source(cls, source, definitions=None, preprocess=False,
if frontend == Frontend.OFP:
ast = parse_ofp_source(source)
return cls.from_ofp(ast=ast, raw_source=source, definitions=definitions,
pp_info=pp_info, parent=parent)
pp_info=pp_info, parent=parent) # pylint: disable=possibly-used-before-assignment

if frontend == Frontend.FP:
ast = parse_fparser_source(source)
Expand Down Expand Up @@ -361,6 +361,13 @@ def enrich(self, definitions, recurse=False):
updated_symbol_attrs[local_name] = symbol.type.clone(
dtype=remote_node.dtype, imported=True, module=module
)
# Update dtype for local variables using this type
variables_with_this_type = {
name: type_.clone(dtype=remote_node.dtype)
for name, type_ in self.symbol_attrs.items()
if getattr(type_.dtype, 'name') == remote_node.dtype.name
}
updated_symbol_attrs.update(variables_with_this_type)
elif hasattr(remote_node, 'type'):
# This is a global variable or interface import
updated_symbol_attrs[local_name] = remote_node.type.clone(
Expand Down
99 changes: 79 additions & 20 deletions loki/tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@

from loki import (
Sourcefile, Module, Subroutine, FindVariables, FindNodes, Section,
CallStatement, BasicType, Array, Scalar, Variable,
Array, Scalar, Variable,
SymbolAttributes, StringLiteral, fgen, fexprgen,
VariableDeclaration, Transformer, FindTypedSymbols,
ProcedureSymbol, ProcedureType, StatementFunction,
normalize_range_indexing, DeferredTypeSymbol, Assignment,
Interface
ProcedureSymbol, StatementFunction,
normalize_range_indexing, DeferredTypeSymbol
)
from loki.build import jit_compile, jit_compile_lib, clean_test
from loki.frontend import available_frontends, OFP, OMNI, REGEX
from loki.types import BasicType, DerivedType, ProcedureType
from loki.ir import nodes as ir


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -767,7 +768,7 @@ def test_routine_call_arrays(header_path, frontend):
"""
header = Sourcefile.from_file(header_path, frontend=frontend)['header']
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=header)
call = FindNodes(CallStatement).visit(routine.body)[0]
call = FindNodes(ir.CallStatement).visit(routine.body)[0]

assert str(call.arguments[0]) == 'x'
assert str(call.arguments[1]) == 'y'
Expand Down Expand Up @@ -797,7 +798,7 @@ def test_call_no_arg(frontend):
call abort
end subroutine routine_call_no_arg
""")
calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments == ()
assert calls[0].kwarguments == ()
Expand All @@ -813,7 +814,7 @@ def test_call_kwargs(frontend):
call mpl_init(kprocs=kprocs, cdstring='routine_call_kwargs')
end subroutine routine_call_kwargs
""")
calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].name == 'mpl_init'

Expand All @@ -838,7 +839,7 @@ def test_call_args_kwargs(frontend):
call mpl_send(pbuf, ktag, kdest, cdstring='routine_call_args_kwargs')
end subroutine routine_call_args_kwargs
""")
calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].name == 'mpl_send'
assert len(calls[0].arguments) == 3
Expand Down Expand Up @@ -1520,7 +1521,7 @@ def test_subroutine_stmt_func(here, frontend):
routine.name += f'_{frontend!s}'

# Make sure the statement function injection doesn't invalidate source
for assignment in FindNodes(Assignment).visit(routine.body):
for assignment in FindNodes(ir.Assignment).visit(routine.body):
assert assignment.source is not None

# OMNI inlines statement functions, so we can only check correct representation
Expand Down Expand Up @@ -1958,7 +1959,7 @@ def test_subroutine_clone_contained(frontend):
kernels = driver.subroutines

def _verify_call_enrichment(driver_, kernels_):
calls = FindNodes(CallStatement).visit(driver_.body)
calls = FindNodes(ir.CallStatement).visit(driver_.body)
assert len(calls) == 2

for call in calls:
Expand Down Expand Up @@ -2048,12 +2049,12 @@ def test_enrich_explicit_interface(frontend):
driver.enrich(kernel)

# check if call is enriched correctly
calls = FindNodes(CallStatement).visit(driver.body)
calls = FindNodes(ir.CallStatement).visit(driver.body)
assert calls[0].routine is kernel

# check if the procedure symbol in the interface block has been removed from
# driver's symbol table
intfs = FindNodes(Interface).visit(driver.spec)
intfs = FindNodes(ir.Interface).visit(driver.spec)
assert not intfs[0].body[0].parent

# check that call still points to correct subroutine
Expand All @@ -2065,6 +2066,64 @@ def test_enrich_explicit_interface(frontend):
assert calls[0].routine is kernel


@pytest.mark.parametrize('frontend', available_frontends())
def test_enrich_derived_types(tmp_path, frontend):
fcode = """
subroutine enrich_derived_types_routine(yda_array)
use field_array_module, only : field_3rb_array
implicit none
type(field_3rb_array), intent(inout) :: yda_array
yda_array%p = 0.
end subroutine enrich_derived_types_routine
""".strip()

fcode_module = """
module field_array_module
implicit none
type field_3rb_array
real, pointer :: p(:,:,:)
end type field_3rb_array
end module field_array_module
""".strip()

module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

# The derived type is a dangling import
field_3rb_symbol = routine.symbol_map['field_3rb_array']
assert field_3rb_symbol.type.imported
assert field_3rb_symbol.type.module is None
assert field_3rb_symbol.type.dtype is BasicType.DEFERRED

# The variable type is recognized as a derived type but without enrichment
yda_array = routine.variable_map['yda_array']
assert isinstance(yda_array.type.dtype, DerivedType)
assert routine.variable_map['yda_array'].type.dtype.typedef is BasicType.DEFERRED

# The pointer member has no type information
yda_array_p = routine.resolve_typebound_var('yda_array%p')
assert yda_array_p.type.dtype is BasicType.DEFERRED
assert yda_array_p.type.shape is None

# Pick out the typedef (before enrichment to validate object consistency)
field_3rb_tdef = module['field_3rb_array']
assert isinstance(field_3rb_tdef, ir.TypeDef)

# Enrich the routine with module definitions
routine.enrich(module)

# Ensure the imported type symbol is correctly enriched
assert field_3rb_symbol.type.imported
assert field_3rb_symbol.type.module is module
assert isinstance(field_3rb_symbol.type.dtype, DerivedType)

# Ensure the information has been propagated to other variables
assert isinstance(yda_array.type.dtype, DerivedType)
assert yda_array.type.dtype.typedef is field_3rb_tdef
assert yda_array_p.type.dtype is BasicType.REAL
assert yda_array_p.type.shape == (':', ':', ':')


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'OMNI cannot handle external type defs without source')]
))
Expand Down Expand Up @@ -2099,15 +2158,15 @@ def test_subroutine_deep_clone(frontend):

# Replace all assignments with dummy calls
map_nodes={}
for assign in FindNodes(Assignment).visit(new_routine.body):
map_nodes[assign] = CallStatement(
for assign in FindNodes(ir.Assignment).visit(new_routine.body):
map_nodes[assign] = ir.CallStatement(
name=DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine
)
new_routine.body = Transformer(map_nodes).visit(new_routine.body)

# Ensure that the original copy of the routine remains unaffected
assert len(FindNodes(Assignment).visit(routine.body)) == 3
assert len(FindNodes(Assignment).visit(new_routine.body)) == 0
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
assert len(FindNodes(ir.Assignment).visit(new_routine.body)) == 0

@pytest.mark.parametrize('frontend', available_frontends())
def test_call_args_kwargs_conversion(frontend):
Expand Down Expand Up @@ -2162,20 +2221,20 @@ def test_call_args_kwargs_conversion(frontend):
len_kwargs = (0, 7, 7, 2)

# sort kwargs
for i_call, call in enumerate(FindNodes(CallStatement).visit(driver.body)):
for i_call, call in enumerate(FindNodes(ir.CallStatement).visit(driver.body)):
assert call.check_kwarguments_order() == kwargs_in_order[i_call]
call.sort_kwarguments()

# check calls with sorted kwargs
for i_call, call in enumerate(FindNodes(CallStatement).visit(driver.body)):
for i_call, call in enumerate(FindNodes(ir.CallStatement).visit(driver.body)):
assert tuple(arg[1].name for arg in call.arg_iter()) == call_args
assert len(call.kwarguments) == len_kwargs[i_call]

# kwarg to arg conversion
for call in FindNodes(CallStatement).visit(driver.body):
for call in FindNodes(ir.CallStatement).visit(driver.body):
call.convert_kwargs_to_args()

# check calls with kwargs converted to args
for call in FindNodes(CallStatement).visit(driver.body):
for call in FindNodes(ir.CallStatement).visit(driver.body):
assert tuple(arg.name for arg in call.arguments) == call_args
assert call.kwarguments == ()
1 change: 1 addition & 0 deletions loki/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from loki.transformations.transform_region import * # noqa
from loki.transformations.pool_allocator import * # noqa
from loki.transformations.utilities import * # noqa
from loki.transformations.block_index_transformations import * # noqa
Loading

0 comments on commit 77114a9

Please sign in to comment.