-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
F2C: DeReferenceTrafo
#273
Changes from 2 commits
96a812b
512859e
b2fca8e
92dcaf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,8 @@ | |
from loki.module import Module | ||
from loki.expression import ( | ||
Variable, InlineCall, RangeIndex, Scalar, Array, | ||
ProcedureSymbol, SubstituteExpressions, Dereference, | ||
ProcedureSymbol, SubstituteExpressions, Dereference, Reference, | ||
ExpressionRetriever, SubstituteExpressionsMapper, | ||
) | ||
from loki.expression import symbols as sym | ||
from loki.tools import as_tuple, flatten | ||
|
@@ -477,17 +478,51 @@ def generate_c_kernel(self, routine): | |
convert_to_lower_case(kernel) | ||
|
||
# Force pointer on reference-passed arguments | ||
var_map = {} | ||
to_be_dereferenced = [] | ||
for arg in kernel.arguments: | ||
if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)): | ||
_type = arg.type.clone(pointer=True) | ||
if isinstance(arg.type.dtype, DerivedType): | ||
# Lower case type names for derived types | ||
typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower()) | ||
_type = _type.clone(dtype=typedef.dtype) | ||
var_map[arg] = Dereference(arg) | ||
to_be_dereferenced.append(arg.name.lower()) | ||
kernel.symbol_attrs[arg.name] = _type | ||
kernel.body = SubstituteExpressions(var_map).visit(kernel.body) | ||
|
||
class DeReferenceTrafo(Transformer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The criterion for a symbol that needs (de)referencing is the same throughout, so maybe it would make sense to capture this as a static method in the transformer: @staticmethod
def is_reference(symbol):
return (
isinstance(symbol, Array)
and not (symbol.dimensions is None or all(dim == ':' for dim in symbol.dimensions))
) |
||
|
||
def __init__(self, vars2dereference): | ||
super().__init__() | ||
self.retriever = ExpressionRetriever(lambda e: isinstance(e, (DerivedType, Array, Scalar))\ | ||
and e.name.lower() in vars2dereference) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the above method, you could use this directly in the retriever: self.retriever = ExpressionRetriever(self.is_reference)
self.vars2dereference = vars2dereference |
||
|
||
def visit_Expression(self, o, **kwargs): | ||
symbols = self.retriever.retrieve(o) | ||
symbol_map = {} | ||
for symbol in symbols: | ||
if isinstance(symbol, Array) and symbol.dimensions is not None\ | ||
and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tthe creation of the symbol map is then simplified to symbol_map = {
symbol: Dereference(symbol.clone()) for symbol in self.retriever.retrieve()
if symbol.name.lower() in self.vars2dereference
} |
||
continue | ||
symbol_map[symbol] = Dereference(symbol.clone()) | ||
return SubstituteExpressionsMapper(symbol_map)(o) | ||
|
||
def visit_CallStatement(self, o, **kwargs): | ||
new_args = () | ||
call_arg_map = dict((v,k) for k,v in o.arg_map.items()) | ||
for arg in o.arguments: | ||
if isinstance(arg, Array) and arg.dimensions\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This conditional would then also become more readable as if self.is_reference(arg) and (isinstance(call_arg_map[arg], Array) or call_arg_map[arg].type.intent.lower() != 'in'): |
||
and all(dim != sym.RangeIndex((None, None)) for dim in arg.dimensions) \ | ||
and (isinstance(call_arg_map[arg], Array) or call_arg_map[arg].type.intent.lower() != 'in'): | ||
new_args += (Reference(arg.clone()),) | ||
else: | ||
if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in': | ||
new_args += (Reference(arg.clone()),) | ||
else: | ||
new_args += (arg,) | ||
o._update(arguments=new_args) | ||
return o | ||
|
||
kernel.body = DeReferenceTrafo(to_be_dereferenced).visit(kernel.body) | ||
|
||
symbol_map = {'epsilon': 'DBL_EPSILON'} | ||
function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1004,6 +1004,60 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr): | |
f2c.wrapperpath.unlink() | ||
f2c.c_path.unlink() | ||
|
||
|
||
@pytest.mark.parametrize('use_c_ptr', (False, True)) | ||
@pytest.mark.parametrize('frontend', available_frontends()) | ||
def test_transpile_call(here, frontend, use_c_ptr): | ||
fcode_module = """ | ||
module transpile_call_kernel_mod | ||
implicit none | ||
contains | ||
|
||
subroutine transpile_call_kernel(a, b, c, arr1, len) | ||
integer, intent(inout) :: a, c | ||
integer, intent(in) :: b | ||
integer, intent(in) :: len | ||
integer, intent(inout) :: arr1(len, len) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my education: What is the intended outcome if we had (another) argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we need the actual However, subroutine transpile_call_driver(a)
use transpile_call_kernel_mod, only: transpile_call_kernel
integer, parameter :: len = 5
integer, intent(inout) :: arr1(len)
integer, intent(inout) :: arr2(len)
call transpile_call_kernel(arr1, arr2, len)
end subroutine transpile_call_driver
subroutine transpile_call_kernel(arr1, arr2, len)
integer, intent(in) :: len
integer, intent(inout) :: arr1(len)
integer, intent(inout) :: arr2(:)
arr1(1) = 1
arr2(1) = 1
end subroutine transpile_call_kernel is transformed/transpiled to: int transpile_call_driver_c() {
int len = 5;
transpile_call_kernel(arr1, arr2, len);
return 0;
}
int transpile_call_kernel_c(int * restrict arr1, int * restrict arr2, int len) {
int arr1[len];
int arr2[len];
arr1[1 - 1] = 1;
arr2[1 - 1] = 1;
return 0;
} if that answers your question?! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does, many thanks! |
||
a = b | ||
c = b | ||
end subroutine transpile_call_kernel | ||
end module transpile_call_kernel_mod | ||
""" | ||
|
||
fcode = """ | ||
subroutine transpile_call_driver(a) | ||
use transpile_call_kernel_mod, only: transpile_call_kernel | ||
integer, intent(inout) :: a | ||
integer, parameter :: len = 5 | ||
integer :: arr1(len, len) | ||
integer :: arr2(len, len) | ||
integer :: b | ||
b = 2 * len | ||
call transpile_call_kernel(a, b, arr2(1, 1), arr1, len) | ||
end subroutine transpile_call_driver | ||
""" | ||
unlink_paths = [] | ||
module = Module.from_source(fcode_module, frontend=frontend) | ||
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module) | ||
f2c = FortranCTransformation(use_c_ptr=use_c_ptr, path=here) | ||
f2c.apply(source=module.subroutine_map['transpile_call_kernel'], path=here, role='kernel') | ||
unlink_paths.extend([f2c.wrapperpath, f2c.c_path]) | ||
ccode_kernel = f2c.c_path.read_text().replace(' ', '').replace('\n', '') | ||
f2c.apply(source=routine, path=here, role='kernel') | ||
unlink_paths.extend([f2c.wrapperpath, f2c.c_path]) | ||
ccode_driver = f2c.c_path.read_text().replace(' ', '').replace('\n', '') | ||
|
||
assert "int*a,intb,int*c" in ccode_kernel | ||
# check for applied Dereference | ||
assert "(*a)=b;" in ccode_kernel | ||
assert "(*c)=b;" in ccode_kernel | ||
# check for applied Reference | ||
assert "transpile_call_kernel((&a),b,(&arr2[" in ccode_driver | ||
|
||
for path in unlink_paths: | ||
path.unlink() | ||
|
||
|
||
@pytest.mark.parametrize('frontend', available_frontends()) | ||
def test_transpile_multiconditional(here, builder, frontend): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Readability of the method might benefit from putting the reference/dereference handling into a utility method maybe, and then simply calling
self.convert_args_to_references(kernel)