Skip to content

Commit

Permalink
Merge pull request #273 from ecmwf-ifs/nams_de_reference_trafo_f2c
Browse files Browse the repository at this point in the history
F2C: `DeReferenceTrafo`
  • Loading branch information
reuterbal authored Apr 11, 2024
2 parents f4a0c33 + 92dcaf5 commit e5eae78
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 6 deletions.
2 changes: 1 addition & 1 deletion loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def visit_CallStatement(self, o, **kwargs):
"""
args = self.visit_all(o.arguments, **kwargs)
assert not o.kwarguments
return self.format_line(o.name, '(', self.join_items(args), ');')
return self.format_line(str(o.name), '(', self.join_items(args), ');')

def visit_SymbolAttributes(self, o, **kwargs): # pylint: disable=unused-argument
if isinstance(o.dtype, DerivedType):
Expand Down
72 changes: 67 additions & 5 deletions loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from loki.module import Module
from loki.expression import (
Variable, InlineCall, RangeIndex, Scalar, Array,
ProcedureSymbol, SubstituteExpressions, Dereference,
ProcedureSymbol, SubstituteExpressions, Dereference, Reference,
ExpressionRetriever, SubstituteExpressionsMapper,
)
from loki.expression import symbols as sym
from loki.tools import as_tuple, flatten
Expand All @@ -40,6 +41,54 @@
__all__ = ['FortranCTransformation']


class DeReferenceTrafo(Transformer):
"""
Transformation to apply/insert Dereference = `*` and
Reference/*address-of* = `&` operators.
Parameters
----------
vars2dereference : list
Variables to be dereferenced. Ususally the arguments except
for scalars with `intent=in`.
"""
# pylint: disable=unused-argument

def __init__(self, vars2dereference):
super().__init__()
self.retriever = ExpressionRetriever(self.is_dereference)
self.vars2dereference = vars2dereference

@staticmethod
def is_dereference(symbol):
return isinstance(symbol, (DerivedType, Array, Scalar)) and not (
isinstance(symbol, Array) and symbol.dimensions is not None
and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions)
)

def visit_Expression(self, o, **kwargs):
symbol_map = {
symbol: Dereference(symbol.clone()) for symbol in self.retriever.retrieve(o)
if symbol.name.lower() in self.vars2dereference
}
return SubstituteExpressionsMapper(symbol_map)(o)

def visit_CallStatement(self, o, **kwargs):
new_args = ()
call_arg_map = dict((v,k) for k,v in o.arg_map.items())
for arg in o.arguments:
if not self.is_dereference(arg) and (isinstance(call_arg_map[arg], Array)\
or call_arg_map[arg].type.intent.lower() != 'in'):
new_args += (Reference(arg.clone()),)
else:
if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in':
new_args += (Reference(arg.clone()),)
else:
new_args += (arg,)
o._update(arguments=new_args)
return o


class FortranCTransformation(Transformation):
"""
Fortran-to-C transformation that translates the given routine
Expand Down Expand Up @@ -401,6 +450,19 @@ def generate_c_header(self, module, **kwargs):
header_module.rescope_symbols()
return header_module

@staticmethod
def apply_de_reference(routine):
"""
Utility method to apply/insert Dereference = `*` and
Reference/*address-of* = `&` operators.
"""
to_be_dereferenced = []
for arg in routine.arguments:
if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)):
to_be_dereferenced.append(arg.name.lower())

routine.body = DeReferenceTrafo(to_be_dereferenced).visit(routine.body)

def generate_c_kernel(self, routine):
"""
Re-generate the C kernel and insert wrapper-specific peculiarities,
Expand Down Expand Up @@ -476,18 +538,18 @@ def generate_c_kernel(self, routine):
# Force all variables to lower-caps, as C/C++ is case-sensitive
convert_to_lower_case(kernel)

# Force pointer on reference-passed arguments
var_map = {}
# Force pointer on reference-passed arguments (and lower case type names for derived types)
for arg in kernel.arguments:
if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)):
_type = arg.type.clone(pointer=True)
if isinstance(arg.type.dtype, DerivedType):
# Lower case type names for derived types
typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower())
_type = _type.clone(dtype=typedef.dtype)
var_map[arg] = Dereference(arg)
kernel.symbol_attrs[arg.name] = _type
kernel.body = SubstituteExpressions(var_map).visit(kernel.body)

# apply dereference and reference where necessary
self.apply_de_reference(kernel)

symbol_map = {'epsilon': 'DBL_EPSILON'}
function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs',
Expand Down
53 changes: 53 additions & 0 deletions tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,59 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr):
f2c.c_path.unlink()


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_call(here, frontend, use_c_ptr):
fcode_module = """
module transpile_call_kernel_mod
implicit none
contains
subroutine transpile_call_kernel(a, b, c, arr1, len)
integer, intent(inout) :: a, c
integer, intent(in) :: b
integer, intent(in) :: len
integer, intent(inout) :: arr1(len, len)
a = b
c = b
end subroutine transpile_call_kernel
end module transpile_call_kernel_mod
"""

fcode = """
subroutine transpile_call_driver(a)
use transpile_call_kernel_mod, only: transpile_call_kernel
integer, intent(inout) :: a
integer, parameter :: len = 5
integer :: arr1(len, len)
integer :: arr2(len, len)
integer :: b
b = 2 * len
call transpile_call_kernel(a, b, arr2(1, 1), arr1, len)
end subroutine transpile_call_driver
"""
unlink_paths = []
module = Module.from_source(fcode_module, frontend=frontend)
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module)
f2c = FortranCTransformation(use_c_ptr=use_c_ptr, path=here)
f2c.apply(source=module.subroutine_map['transpile_call_kernel'], path=here, role='kernel')
unlink_paths.extend([f2c.wrapperpath, f2c.c_path])
ccode_kernel = f2c.c_path.read_text().replace(' ', '').replace('\n', '')
f2c.apply(source=routine, path=here, role='kernel')
unlink_paths.extend([f2c.wrapperpath, f2c.c_path])
ccode_driver = f2c.c_path.read_text().replace(' ', '').replace('\n', '')

assert "int*a,intb,int*c" in ccode_kernel
# check for applied Dereference
assert "(*a)=b;" in ccode_kernel
assert "(*c)=b;" in ccode_kernel
# check for applied Reference
assert "transpile_call_kernel((&a),b,(&arr2[" in ccode_driver

for path in unlink_paths:
path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('f_type', ['integer', 'real'])
def test_transpile_inline_functions(here, frontend, f_type):
Expand Down

0 comments on commit e5eae78

Please sign in to comment.