Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F2C: optional case-sensitivity for variables/symbols #277

Merged
merged 4 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def visit_Subroutine(self, o, **kwargs):
aptr += ['*']
else:
aptr += ['']
arguments = [f'{self.visit(a.type, **kwargs)} {p}{a.name.lower()}'
arguments = [f'{self.visit(a.type, **kwargs)} {p}{a.name}'
for a, p in zip(o.arguments, aptr)]

# check whether to return something and define function return type accordingly
Expand Down
17 changes: 14 additions & 3 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,21 @@ class TypedSymbol:
The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
parent : :any:`Scalar` or :any:`Array`, optional
The derived type variable this variable belongs to.
case_sensitive : bool, optional
Mark the name of this symbol as case-sensitive (default: `False`)
*args : optional
Any other positional arguments for other parent classes
**kwargs : optional
Any other keyword arguments for other parent classes
"""

init_arg_names = ('name', 'scope', 'parent', 'type', )
init_arg_names = ('name', 'scope', 'parent', 'type', 'case_sensitive', )

def __init__(self, *args, **kwargs):
self.name = kwargs['name']
self.parent = kwargs.pop('parent', None)
self.scope = kwargs.pop('scope', None)
self.case_sensitive = kwargs.pop('case_sensitive', config['case-sensitive'])

# Use provided type or try to determine from scope
self._type = None
Expand All @@ -154,7 +157,7 @@ def __getinitargs__(self):
symbol objects. We do not recurse here, since we own the
"name" attribute, which pymbolic will otherwise replicate.
"""
return (self.name, None, self._parent, self._type, )
return (self.name, None, self._parent, self._type, self.case_sensitive, )

@property
def scope(self):
Expand Down Expand Up @@ -339,6 +342,8 @@ def clone(self, **kwargs):
kwargs['type'] = self.type
if 'parent' not in kwargs and self.parent:
kwargs['parent'] = self.parent
if 'case_sensitive' not in kwargs and self.case_sensitive:
kwargs['case_sensitive'] = self.case_sensitive

return Variable(**kwargs)

Expand Down Expand Up @@ -639,6 +644,13 @@ def rescope(self, scope):
"""
return self.symbol.rescope(scope)

@property
def case_sensitive(self):
"""
Property to indicate that the name of this symbol is case-sensitive.
"""
return self.symbol.case_sensitive

def get_derived_type_member(self, name_str):
"""
Resolve type-bound variables of arbitrary nested depth.
Expand Down Expand Up @@ -698,7 +710,6 @@ class Array(MetaSymbol):
def __init__(self, name, scope=None, type=None, dimensions=None, **kwargs):
# Stop complaints about `type` in this function
# pylint: disable=redefined-builtin

symbol = VariableSymbol(name=name, scope=scope, type=type, **kwargs)
if dimensions:
symbol = ArraySubscript(symbol, dimensions)
Expand Down
45 changes: 45 additions & 0 deletions loki/tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

from loki import Subroutine, Module, FortranCTransformation, cgen
from loki.build import jit_compile, jit_compile_lib, clean_test, Builder
import loki.expression.symbols as sym
from loki.frontend import available_frontends, OFP
import loki.ir as ir
from loki.transform import normalize_range_indexing


Expand All @@ -25,6 +27,49 @@ def fixture_builder(here):
return Builder(source_dirs=here, build_dir=here/'build')


@pytest.mark.parametrize('case_sensitive', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_case_sensitivity(here, frontend, case_sensitive):
"""
A simple test for testing lowering the case and case-sensitivity
for specific symbols.
"""

fcode = """
subroutine transpile_case_sensitivity(a)
integer, intent(in) :: a

end subroutine transpile_case_sensitivity
"""
def convert_case(_str, case_sensitive):
return _str.lower() if not case_sensitive else _str

routine = Subroutine.from_source(fcode, frontend=frontend)

var_thread_idx = sym.Variable(name="threadIdx", case_sensitive=case_sensitive)
var_x = sym.Variable(name="x", parent=var_thread_idx, case_sensitive=case_sensitive)
assignment = ir.Assignment(lhs=routine.variable_map['a'], rhs=var_x)
routine.arguments=routine.arguments + (routine.arguments[0].clone(name='sOmE_vAr', case_sensitive=case_sensitive),
sym.Variable(name="oTher_VaR", case_sensitive=case_sensitive, type=routine.arguments[0].type.clone()))

call = ir.CallStatement(sym.Variable(name='somE_cALl', case_sensitive=case_sensitive),
arguments=(routine.variable_map['a'],))
inline_call = sym.InlineCall(function=sym.Variable(name='somE_InlINeCaLl', case_sensitive=case_sensitive),
parameters=(sym.IntLiteral(1),))
inline_call_assignment = ir.Assignment(lhs=routine.variable_map['a'], rhs=inline_call)
routine.body = (routine.body, assignment, call, inline_call_assignment)

f2c = FortranCTransformation()
f2c.apply(source=routine, path=here)
ccode = f2c.c_path.read_text().replace(' ', '').replace('\n', ' ').replace('\r', '').replace('\t', '')
assert convert_case('transpile_case_sensitivity_c(inta,intsOmE_vAr,intoTher_VaR)', case_sensitive) in ccode
assert convert_case('a=threadIdx%x;', case_sensitive) in ccode
assert convert_case('somE_cALl(a);', case_sensitive) in ccode
assert convert_case('a=somE_InlINeCaLl(1);', case_sensitive) in ccode

f2c.wrapperpath.unlink()
f2c.c_path.unlink()

@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_simple_loops(here, builder, frontend, use_c_ptr):
Expand Down
4 changes: 4 additions & 0 deletions loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from loki.sourcefile import Sourcefile
from loki.backend import cgen, fgen
from loki.logging import debug
from loki.ir import (
Section, Import, Intrinsic, Interface, CallStatement, VariableDeclaration,
TypeDef, Assignment, Transformer, FindNodes
Expand Down Expand Up @@ -75,6 +76,9 @@ def visit_Expression(self, o, **kwargs):

def visit_CallStatement(self, o, **kwargs):
new_args = ()
if o.routine is BasicType.DEFERRED:
debug(f'DeReferenceTrafo: Skipping call to {o.name!s} due to missing procedure enrichment')
return o
call_arg_map = dict((v,k) for k,v in o.arg_map.items())
for arg in o.arguments:
if not self.is_dereference(arg) and (isinstance(call_arg_map[arg], Array)\
Expand Down
5 changes: 3 additions & 2 deletions loki/transform/transform_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def convert_to_lower_case(routine):
variables = FindVariables(unique=False).visit(routine.ir)
vmap = {
v: v.clone(name=v.name.lower()) for v in variables
if isinstance(v, (sym.Scalar, sym.Array, sym.DeferredTypeSymbol)) and not v.name.islower()
if isinstance(v, (sym.Scalar, sym.Array, sym.DeferredTypeSymbol)) and not v.name.islower()\
and not v.case_sensitive
}

# Capture nesting by applying map to itself before applying to the routine
Expand All @@ -100,7 +101,7 @@ def convert_to_lower_case(routine):
# Downcase inline calls to, but only after the above has been propagated,
# so that we capture the updates from the variable update in the arguments
mapper = {
c: c.clone(function=c.function.clone(name=c.name.lower()))
c: c.clone(function=c.function.clone(name=c.name.lower() if not c.function.case_sensitive else c.name))
for c in FindInlineCalls().visit(routine.ir) if not c.name.islower()
}
mapper.update(
Expand Down
Loading