diff --git a/loki/expression/__init__.py b/loki/expression/__init__.py index 1c08791f4..495515f2c 100644 --- a/loki/expression/__init__.py +++ b/loki/expression/__init__.py @@ -10,3 +10,4 @@ from loki.expression.operations import * # noqa from loki.expression.mappers import * # noqa from loki.expression.symbolic import * # noqa +from loki.expression.parser import * # noqa diff --git a/loki/expression/parser.py b/loki/expression/parser.py new file mode 100644 index 000000000..3a6a05e80 --- /dev/null +++ b/loki/expression/parser.py @@ -0,0 +1,138 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from sys import intern +import pytools.lex +from pymbolic.parser import Parser as ParserBase +from pymbolic.mapper import Mapper +import pymbolic.primitives as pmbl + +from loki.expression import symbols as sym + +__all__ = ['LokiParser', 'loki_parse'] + + +class PymbolicMapper(Mapper): + """ + A visitor for expressions that returns the combined result of a specified callback function. + """ + # pylint: disable=abstract-method + + def __init__(self, scope=None): + self.scope = scope + super().__init__() + + def map_product(self, expr, *args, **kwargs): + return sym.Product(expr.children) + map_Mul = map_product + + def map_sum(self, expr, *args, **kwargs): + return sym.Sum(tuple(self.rec(child, *args, **kwargs) for child in expr.children)) + map_Add = map_sum + + def map_power(self, expr, *args, **kwargs): + return sym.Power(base=expr.base, exponent=expr.exponent) + + def map_quotient(self, expr, *args, **kwargs): + return sym.Quotient(numerator=expr.numerator, denominator=expr.denominator) + + def map_comparison(self, expr, *args, **kwargs): + # return sym.Comparison(tuple(self.rec(child, *args, **kwargs) for child in expr.children)) + return sym.Comparison(left=self.rec(expr.left), operator=expr.operator, right=self.rec(expr.right)) + + def map_logical_and(self, expr, *args, **kwargs): + return sym.LogicalAnd(expr.children) + + def map_logical_or(self, expr, *args, **kwargs): + return sym.LogicalOr(expr.children) + + def map_logical_not(self, expr, *args, **kwargs): + return sym.LogicalNot(expr.children) + + def map_constant(self, expr, *args, **kwargs): + return sym.Literal(expr) + map_logic_literal = map_constant + map_string_literal = map_constant + map_intrinsic_literal = map_constant + + map_int_literal = map_constant + map_float_literal = map_int_literal + map_variable_symbol = map_constant + map_deferred_type_symbol = map_constant + + def map_meta_symbol(self, expr, *args, **kwargs): + return sym.Variable(str(expr.name)) + map_Symbol = map_meta_symbol + map_scalar = map_meta_symbol + map_array = map_meta_symbol + + def map_slice(self, expr, *args, **kwargs): + return sym.RangeIndex(expr.children) + + map_range = map_slice + map_range_index = map_slice + map_loop_range = map_slice + + def map_variable(self, expr, *args, **kwargs): + return sym.Variable(name=expr.name, scope=self.scope) + + def map_algebraic_leaf(self, expr, *args, **kwargs): + if str(expr).isnumeric(): + return self.map_constant(expr) + if isinstance(expr, pmbl.Call): + if self.scope is not None: + if expr.function.name in self.scope.symbol_attrs: + return sym.Variable(name=expr.function.name, scope=self.scope, dimensions=self.rec(expr.parameters)) + return expr + else: + try: + return self.map_variable(expr) + except: + return expr + + def map_tuple(self, expr, *args, **kwargs): + return tuple(self.rec(elem) for elem in expr) + + +class LokiParser(ParserBase): + + _f_lessequal = intern('_f_lessequal') + _f_less = intern('_f_less') + _f_greaterequal = intern('_f_greaterequal') + _f_greater = intern('_f_greater') + _f_equal = intern('_f_equal') + _f_notequal = intern('_f_notequal') + _f_and = intern("and") + _f_or = intern("or") + # _f_not = intern("not") + + lex_table = [ + (_f_lessequal, pytools.lex.RE(r"\.le\.")), + (_f_less, pytools.lex.RE(r"\.lt\.")), + (_f_greaterequal, pytools.lex.RE(r"\.ge\.")), + (_f_greater, pytools.lex.RE(r"\.gt\.")), + (_f_equal, pytools.lex.RE(r"\.eq\.")), + (_f_notequal, pytools.lex.RE(r"\.ne\.")), + (_f_and, pytools.lex.RE(r"\.and\.")), + (_f_and, pytools.lex.RE(r"\.or\.")), + # (_f_and, pytools.lex.RE(r"\.not\.")), + ] + ParserBase.lex_table + + ParserBase._COMP_TABLE.update({ + _f_lessequal: "<=", + _f_less: "<", + _f_greaterequal: ">=", + _f_greater: ">", + _f_equal: "==", + _f_notequal: "!=" + }) + + def __call__(self, expr_str, scope=None, min_precedence=0): + result = super().__call__(expr_str, min_precedence) + return PymbolicMapper(scope=scope)(result) + +loki_parse = LokiParser() diff --git a/tests/test_expression.py b/tests/test_expression.py index f78d4c111..4c82850a9 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -25,7 +25,7 @@ DeferredTypeSymbol, Module, HAVE_FP, FindExpressions, LiteralList, FindInlineCalls, AttachScopesMapper, FindTypedSymbols, Reference, Dereference ) -from loki.expression import symbols +from loki.expression import symbols, loki_parse from loki.tools import gettempdir, filehash # pylint: disable=too-many-lines @@ -1551,3 +1551,65 @@ def test_expression_c_de_reference(frontend): c_str = cgen(routine).replace(' ', '') assert '(&renamed_var_reference)=1' in c_str assert '(*renamed_var_dereference)=2' in c_str + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_parser(frontend): + fcode = """ +subroutine some_routine() +implicit none +integer :: i1, i2, i3, len1, len2, len3 +real :: a, b +real :: arr(len1, len2, len3) +end subroutine some_routine + """.strip() + + routine = Subroutine.from_source(fcode, frontend=frontend) + + print("") + parsed = loki_parse('a + b') + print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}") + parsed = loki_parse('a + b', scope=routine) + print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}") + parsed = loki_parse('a + b + 2 + 10', scope=routine) + print(f"{parsed} | {type(parsed)} | {[type(child) for child in parsed.children]}") + parsed = loki_parse('a - b', scope=routine) + print(f"{parsed} | {type(parsed)}") + parsed = loki_parse('a * b', scope=routine) + print(f"{parsed} | {type(parsed)}") + parsed = loki_parse('a / b', scope=routine) + print(f"{parsed} | {type(parsed)}") + parsed = loki_parse('a ** b', scope=routine) + print(f"{parsed} | {type(parsed)}") + parsed = loki_parse('a:b', scope=routine) + print(f"{parsed} | {type(parsed)}") + parsed = loki_parse('a>b', scope=routine) + print(f"{parsed} | {type(parsed)}") + parsed = loki_parse('a.gt.b', scope=routine) + print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse('arr(i1, i2, i3)') + print(f"{parsed} | {type(parsed)}") #  | shape: {parsed.shape} | dimensions: {parsed.dimensions}") + parsed = loki_parse('arr(i1, i2, i3)', scope=routine) + print(f"{parsed} | {type(parsed)} | shape: {parsed.shape} | dimensions: {parsed.dimensions}") + + parsed = loki_parse('a') + print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}") + parsed = loki_parse('a', scope=routine) + print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}") + parsed = loki_parse('3.1415') + print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse('MODULO(A, B)') + print(f"{parsed} | {type(parsed)}") + + # parsed as range index ... + # parsed = loki_parse('integer :: some_integer') + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse('a .and. b') + print(f"{parsed} | {type(parsed)} | children: {parsed.children}") + parsed = loki_parse('a .or. b') + print(f"{parsed} | {type(parsed)} | children: {parsed.children}") + + parsed = loki_parse('((a + b)/(a - b))**3 + 3.1415', scope=routine) + print(f"{parsed} | {type(parsed)}")