Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F2C: DeReferenceTrafo #273

Merged
merged 4 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def visit_CallStatement(self, o, **kwargs):
"""
args = self.visit_all(o.arguments, **kwargs)
assert not o.kwarguments
return self.format_line(o.name, '(', self.join_items(args), ');')
return self.format_line(str(o.name), '(', self.join_items(args), ');')

def visit_SymbolAttributes(self, o, **kwargs): # pylint: disable=unused-argument
if isinstance(o.dtype, DerivedType):
Expand Down
43 changes: 39 additions & 4 deletions loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from loki.module import Module
from loki.expression import (
Variable, InlineCall, RangeIndex, Scalar, Array,
ProcedureSymbol, SubstituteExpressions, Dereference,
ProcedureSymbol, SubstituteExpressions, Dereference, Reference,
ExpressionRetriever, SubstituteExpressionsMapper,
)
from loki.expression import symbols as sym
from loki.tools import as_tuple, flatten
Expand Down Expand Up @@ -477,17 +478,51 @@ def generate_c_kernel(self, routine):
convert_to_lower_case(kernel)

# Force pointer on reference-passed arguments
var_map = {}
to_be_dereferenced = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Readability of the method might benefit from putting the reference/dereference handling into a utility method maybe, and then simply calling self.convert_args_to_references(kernel)

for arg in kernel.arguments:
if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)):
_type = arg.type.clone(pointer=True)
if isinstance(arg.type.dtype, DerivedType):
# Lower case type names for derived types
typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower())
_type = _type.clone(dtype=typedef.dtype)
var_map[arg] = Dereference(arg)
to_be_dereferenced.append(arg.name.lower())
kernel.symbol_attrs[arg.name] = _type
kernel.body = SubstituteExpressions(var_map).visit(kernel.body)

class DeReferenceTrafo(Transformer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The criterion for a symbol that needs (de)referencing is the same throughout, so maybe it would make sense to capture this as a static method in the transformer:

@staticmethod
def is_reference(symbol):
    return (
        isinstance(symbol, Array) 
        and not (symbol.dimensions is None or all(dim == ':' for dim in symbol.dimensions))
    )


def __init__(self, vars2dereference):
super().__init__()
self.retriever = ExpressionRetriever(lambda e: isinstance(e, (DerivedType, Array, Scalar))\
and e.name.lower() in vars2dereference)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the above method, you could use this directly in the retriever:

self.retriever = ExpressionRetriever(self.is_reference)
self.vars2dereference = vars2dereference


def visit_Expression(self, o, **kwargs):
symbols = self.retriever.retrieve(o)
symbol_map = {}
for symbol in symbols:
if isinstance(symbol, Array) and symbol.dimensions is not None\
and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tthe creation of the symbol map is then simplified to

symbol_map = {
    symbol: Dereference(symbol.clone()) for symbol in self.retriever.retrieve()
    if symbol.name.lower() in self.vars2dereference
}

continue
symbol_map[symbol] = Dereference(symbol.clone())
return SubstituteExpressionsMapper(symbol_map)(o)

def visit_CallStatement(self, o, **kwargs):
new_args = ()
call_arg_map = dict((v,k) for k,v in o.arg_map.items())
for arg in o.arguments:
if isinstance(arg, Array) and arg.dimensions\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional would then also become more readable as

if self.is_reference(arg) and (isinstance(call_arg_map[arg], Array) or call_arg_map[arg].type.intent.lower() != 'in'):

and all(dim != sym.RangeIndex((None, None)) for dim in arg.dimensions) \
and (isinstance(call_arg_map[arg], Array) or call_arg_map[arg].type.intent.lower() != 'in'):
new_args += (Reference(arg.clone()),)
else:
if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in':
new_args += (Reference(arg.clone()),)
else:
new_args += (arg,)
o._update(arguments=new_args)
return o

kernel.body = DeReferenceTrafo(to_be_dereferenced).visit(kernel.body)

symbol_map = {'epsilon': 'DBL_EPSILON'}
function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs',
Expand Down
54 changes: 54 additions & 0 deletions tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,60 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr):
f2c.wrapperpath.unlink()
f2c.c_path.unlink()


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_call(here, frontend, use_c_ptr):
fcode_module = """
module transpile_call_kernel_mod
implicit none
contains

subroutine transpile_call_kernel(a, b, c, arr1, len)
integer, intent(inout) :: a, c
integer, intent(in) :: b
integer, intent(in) :: len
integer, intent(inout) :: arr1(len, len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education: What is the intended outcome if we had (another) argument arr(:,:) here? Would this pre-empt the F2C transformation until we figure out the actual shape, or would it simply not be passed as a reference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need the actual shape information for flattening (see flatten_arrays() - TypeError: Resolve shapes being of type RangeIndex, e.g., ":" before flattening!, raised if isinstance(shape[-2], sym.RangeIndex))arrays having more than one dimension, the DeReferenceTrafo wouldn't be executed.

However,

subroutine transpile_call_driver(a)
  use transpile_call_kernel_mod, only: transpile_call_kernel
    integer, parameter :: len = 5
    integer, intent(inout) :: arr1(len)
    integer, intent(inout) :: arr2(len)
    call transpile_call_kernel(arr1, arr2, len)
end subroutine transpile_call_driver

  subroutine transpile_call_kernel(arr1, arr2, len)
    integer, intent(in) :: len
    integer, intent(inout) :: arr1(len)
    integer, intent(inout) :: arr2(:)

    arr1(1) = 1
    arr2(1) = 1
  end subroutine transpile_call_kernel

is transformed/transpiled to:

int transpile_call_driver_c() {
  int len = 5;
  transpile_call_kernel(arr1, arr2, len);
  return 0;
}

int transpile_call_kernel_c(int * restrict arr1, int * restrict arr2, int len) {

  int arr1[len];
  int arr2[len];
  arr1[1 - 1] = 1;
  arr2[1 - 1] = 1;
  return 0;
}

if that answers your question?!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, many thanks!

a = b
c = b
end subroutine transpile_call_kernel
end module transpile_call_kernel_mod
"""

fcode = """
subroutine transpile_call_driver(a)
use transpile_call_kernel_mod, only: transpile_call_kernel
integer, intent(inout) :: a
integer, parameter :: len = 5
integer :: arr1(len, len)
integer :: arr2(len, len)
integer :: b
b = 2 * len
call transpile_call_kernel(a, b, arr2(1, 1), arr1, len)
end subroutine transpile_call_driver
"""
unlink_paths = []
module = Module.from_source(fcode_module, frontend=frontend)
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module)
f2c = FortranCTransformation(use_c_ptr=use_c_ptr, path=here)
f2c.apply(source=module.subroutine_map['transpile_call_kernel'], path=here, role='kernel')
unlink_paths.extend([f2c.wrapperpath, f2c.c_path])
ccode_kernel = f2c.c_path.read_text().replace(' ', '').replace('\n', '')
f2c.apply(source=routine, path=here, role='kernel')
unlink_paths.extend([f2c.wrapperpath, f2c.c_path])
ccode_driver = f2c.c_path.read_text().replace(' ', '').replace('\n', '')

assert "int*a,intb,int*c" in ccode_kernel
# check for applied Dereference
assert "(*a)=b;" in ccode_kernel
assert "(*c)=b;" in ccode_kernel
# check for applied Reference
assert "transpile_call_kernel((&a),b,(&arr2[" in ccode_driver

for path in unlink_paths:
path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multiconditional(here, builder, frontend):
"""
Expand Down
Loading