Skip to content

Commit

Permalink
Continued 2: Loki string parser based on pymbolic parser
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Apr 8, 2024
1 parent c48db05 commit 013819d
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 64 deletions.
29 changes: 16 additions & 13 deletions loki/expression/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,24 @@ def __init__(self, scope=None):

def map_product(self, expr, *args, **kwargs):
return sym.Product(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
map_Mul = map_product
map_Mul = map_product

def map_sum(self, expr, *args, **kwargs):
return sym.Sum(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
map_Add = map_sum
# map_Add = map_sum

def map_power(self, expr, *args, **kwargs):
return sym.Power(base=self.rec(expr.base),
exponent=self.rec(expr.exponent))
return sym.Power(base=self.rec(expr.base, *args, **kwargs),
exponent=self.rec(expr.exponent, *args, **kwargs))

def map_quotient(self, expr, *args, **kwargs):
return sym.Quotient(numerator=expr.numerator, denominator=expr.denominator)
return sym.Quotient(numerator=self.rec(expr.numerator, *args, **kwargs),
denominator=self.rec(expr.denominator, *args, **kwargs))

def map_comparison(self, expr, *args, **kwargs):
return sym.Comparison(left=self.rec(expr.left),
return sym.Comparison(left=self.rec(expr.left, *args, **kwargs),
operator=expr.operator,
right=self.rec(expr.right))
right=self.rec(expr.right, *args, **kwargs))

def map_logical_and(self, expr, *args, **kwargs):
return sym.LogicalAnd(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
Expand All @@ -52,7 +53,7 @@ def map_logical_or(self, expr, *args, **kwargs):
return sym.LogicalOr(tuple(self.rec(child, *args, **kwargs) for child in expr.children))

def map_logical_not(self, expr, *args, **kwargs):
return sym.LogicalNot(self.rec(expr.child))
return sym.LogicalNot(self.rec(expr.child, *args, **kwargs))

def map_constant(self, expr, *args, **kwargs):
return sym.Literal(expr)
Expand All @@ -67,7 +68,7 @@ def map_constant(self, expr, *args, **kwargs):

def map_meta_symbol(self, expr, *args, **kwargs):
return sym.Variable(name=str(expr.name), scope=self.scope)
map_Symbol = map_meta_symbol
map_Symbol = map_meta_symbol
map_scalar = map_meta_symbol
map_array = map_meta_symbol

Expand All @@ -88,17 +89,19 @@ def map_algebraic_leaf(self, expr, *args, **kwargs):
if isinstance(expr, pmbl.Call):
if self.scope is not None:
if expr.function.name in self.scope.symbol_attrs:
return sym.Variable(name=expr.function.name, scope=self.scope, dimensions=self.rec(expr.parameters))
return expr
return sym.Variable(name=expr.function.name, scope=self.scope,
dimensions=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
return sym.InlineCall(function=sym.Variable(name=expr.function.name),
parameters=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
# else:
try:
return self.map_variable(expr)
return self.map_variable(expr, *args, **kwargs)
except Exception as e:
print(f"Exception: {e}")
return expr

def map_tuple(self, expr, *args, **kwargs):
return tuple(self.rec(elem) for elem in expr)
return tuple(self.rec(elem, *args, **kwargs) for elem in expr)


class LokiParser(ParserBase):
Expand Down
265 changes: 214 additions & 51 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,8 +1552,9 @@ def test_expression_c_de_reference(frontend):
assert '(&renamed_var_reference)=1' in c_str
assert '(*renamed_var_dereference)=2' in c_str

@pytest.mark.parametrize('case', ('upper', 'lower', 'random'))
@pytest.mark.parametrize('frontend', available_frontends())
def test_parser(frontend):
def test_parser(frontend, case):
fcode = """
subroutine some_routine()
implicit none
Expand All @@ -1563,61 +1564,223 @@ def test_parser(frontend):
end subroutine some_routine
""".strip()

def convert_to_case(_str, mode='upper'):
if mode == 'upper':
# print(f"{_str.upper()}")
return _str.upper()
if mode == 'lower':
# print(f"{_str.lower()}")
return _str.lower()
if mode == 'random':
# this is obviously not random, but fulfils its purpose ...
result = ''
for i, char in enumerate(_str):
result += char.upper() if i%2==0 and i<3 else char.lower()
# print(f"{result}")
return result
return convert_to_case(_str)


routine = Subroutine.from_source(fcode, frontend=frontend)

print("")
parsed = loki_parse('a + b')
print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}")
parsed = loki_parse('a + b', scope=routine)
print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}")
parsed = loki_parse('a + b + 2 + 10', scope=routine)
print(f"{parsed} | {type(parsed)} | {[type(child) for child in parsed.children]}")
parsed = loki_parse('a - b', scope=routine)
print(f"{parsed} | {type(parsed)}")
parsed = loki_parse('a * b', scope=routine)
print(f"{parsed} | {type(parsed)}")
parsed = loki_parse('a / b', scope=routine)
print(f"{parsed} | {type(parsed)}")
parsed = loki_parse('a ** b', scope=routine)
print(f"{parsed} | {type(parsed)}")
parsed = loki_parse('a:b', scope=routine)
print(f"{parsed} | {type(parsed)}")
parsed = loki_parse('a>b', scope=routine)
print(f"{parsed} | {type(parsed)}")
parsed = loki_parse('a.gt.b', scope=routine)
print(f"{parsed} | {type(parsed)}")

parsed = loki_parse('arr(i1, i2, i3)')
print(f"{parsed} | {type(parsed)}") #  | shape: {parsed.shape} | dimensions: {parsed.dimensions}")
parsed = loki_parse('arr(i1, i2, i3)', scope=routine)
print(f"{parsed} | {type(parsed)} | shape: {parsed.shape} | dimensions: {parsed.dimensions}")

parsed = loki_parse('a')
print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}")
parsed = loki_parse('a', scope=routine)
print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}")
parsed = loki_parse('3.1415')
print(f"{parsed} | {type(parsed)}")

parsed = loki_parse('MODULO(A, B)')
print(f"{parsed} | {type(parsed)}")
# print("")
parsed = loki_parse(convert_to_case('a + b', mode=case))
assert isinstance(parsed, symbols.Sum)
assert all(isinstance(_parsed, symbols.DeferredTypeSymbol) for _parsed in parsed.children)
# print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}")

parsed = loki_parse(convert_to_case('a + b', mode=case), scope=routine)
assert isinstance(parsed, symbols.Sum)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.children)
assert all(_parsed.scope == routine for _parsed in parsed.children)
# print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}")

parsed = loki_parse(convert_to_case('a + b + 2 + 10', mode=case), scope=routine)
assert isinstance(parsed, symbols.Sum)
assert all(isinstance(_parsed, (symbols.Scalar, symbols.IntLiteral)) for _parsed in parsed.children)
# print(f"{parsed} | {type(parsed)} | {[type(child) for child in parsed.children]}")

parsed = loki_parse(convert_to_case('a - b', mode=case), scope=routine)
assert isinstance(parsed, symbols.Sum)
# assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.children)
assert isinstance(parsed.children[0], symbols.Scalar)
assert isinstance(parsed.children[1], symbols.Product)
assert isinstance(parsed.children[1].children[0], symbols.IntLiteral)
assert isinstance(parsed.children[1].children[1], symbols.Scalar)
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a * b', mode=case), scope=routine)
assert isinstance(parsed, symbols.Product)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.children)
assert all(_parsed.scope == routine for _parsed in parsed.children)
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a / b', mode=case), scope=routine)
assert isinstance(parsed, symbols.Quotient)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.numerator, parsed.denominator])
assert all(_parsed.scope == routine for _parsed in [parsed.numerator, parsed.denominator])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a ** b', mode=case), scope=routine)
assert isinstance(parsed, symbols.Power)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.base, parsed.exponent])
assert all(_parsed.scope == routine for _parsed in [parsed.base, parsed.exponent])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a:b', mode=case), scope=routine)
assert isinstance(parsed, symbols.RangeIndex)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.lower, parsed.upper])
assert all(_parsed.scope == routine for _parsed in [parsed.lower, parsed.upper])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a:b:5', mode=case), scope=routine)
assert isinstance(parsed, symbols.RangeIndex)
assert all(isinstance(_parsed, (symbols.Scalar, symbols.IntLiteral))
for _parsed in [parsed.lower, parsed.upper, parsed.step])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a == b', mode=case), scope=routine)
assert parsed.operator == '=='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")
parsed = loki_parse(convert_to_case('a.eq.b', mode=case), scope=routine)
assert parsed.operator == '=='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a!=b', mode=case), scope=routine)
assert parsed.operator == '!='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")
parsed = loki_parse(convert_to_case('a.ne.b', mode=case), scope=routine)
assert parsed.operator == '!='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a>b', mode=case), scope=routine)
assert parsed.operator == '>'
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")
parsed = loki_parse(convert_to_case('a.gt.b', mode=case), scope=routine)
assert parsed.operator == '>'
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a>=b', mode=case), scope=routine)
assert parsed.operator == '>='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")
parsed = loki_parse(convert_to_case('a.ge.b', mode=case), scope=routine)
assert parsed.operator == '>='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a<b', mode=case), scope=routine)
assert parsed.operator == '<'
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")
parsed = loki_parse(convert_to_case('a.lt.b', mode=case), scope=routine)
assert parsed.operator == '<'
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('a<=b', mode=case), scope=routine)
assert parsed.operator == '<='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")
parsed = loki_parse(convert_to_case('a.le.b', mode=case), scope=routine)
assert parsed.operator == '<='
assert isinstance(parsed, symbols.Comparison)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right])
assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right])
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('arr(i1, i2, i3)', mode=case))
assert isinstance(parsed, symbols.InlineCall)
assert all(isinstance(_parsed, symbols.DeferredTypeSymbol) for _parsed in parsed.parameters)
# assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.shape)
# print(f"{parsed} | {type(parsed)}") #  | shape: {parsed.shape} | dimensions: {parsed.dimensions}")
parsed = loki_parse(convert_to_case('arr(i1, i2, i3)', mode=case), scope=routine)
assert isinstance(parsed, symbols.Array)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.dimensions)
assert all(_parsed.scope == routine for _parsed in parsed.dimensions)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.shape)
assert all(_parsed.scope == routine for _parsed in parsed.shape)
# print(f"{parsed} | {type(parsed)} | shape: {parsed.shape} | dimensions: {parsed.dimensions}")

parsed = loki_parse(convert_to_case('a', mode=case))
assert isinstance(parsed, symbols.DeferredTypeSymbol)
# print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}")
parsed = loki_parse(convert_to_case('a', mode=case), scope=routine)
assert isinstance(parsed, symbols.Scalar)
assert parsed.scope == routine
# print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}")
parsed = loki_parse(convert_to_case('3.1415', mode=case))
assert isinstance(parsed, symbols.FloatLiteral)
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse(convert_to_case('MODULO(A, B)', mode=case), scope=routine)
assert isinstance(parsed, symbols.InlineCall)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.parameters)
assert all(_parsed.scope == routine for _parsed in parsed.parameters)
# print(f"{parsed} | {type(parsed)}")

# parsed as range index ...
# parsed = loki_parse('integer :: some_integer')
# print(f"{parsed} | {type(parsed)}")

parsed = loki_parse('a .and. b')
print(f"{parsed} | {type(parsed)} | children: {parsed.children}")
parsed = loki_parse('a .or. b')
print(f"{parsed} | {type(parsed)} | children: {parsed.children}")
parsed = loki_parse(convert_to_case('a .and. b', mode=case))
assert isinstance(parsed, symbols.LogicalAnd)
assert all(isinstance(_parsed, symbols.DeferredTypeSymbol) for _parsed in parsed.children)
# print(f"{parsed} | {type(parsed)} | children: {parsed.children}")
parsed = loki_parse(convert_to_case('a .and. b', mode=case), scope=routine)
assert isinstance(parsed, symbols.LogicalAnd)
assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.children)
assert all(_parsed.scope == routine for _parsed in parsed.children)
# print(f"{parsed} | {type(parsed)} | children: {parsed.children}")
parsed = loki_parse(convert_to_case('a .or. b', mode=case))
assert isinstance(parsed, symbols.LogicalOr)
assert all(isinstance(_parsed, symbols.DeferredTypeSymbol) for _parsed in parsed.children)
# print(f"{parsed} | {type(parsed)} | children: {parsed.children}")
# parsed = loki_parse('.not. a')
parsed = loki_parse('a .or. .not. b')
print(f"{parsed} | {type(parsed)}") # | children: {parsed.child}")
print(f" {fgen(parsed)}")
parsed = loki_parse('a .OR. .NOT. b')
print(f"{parsed} | {type(parsed)}") # | children: {parsed.child}")
print(f" {fgen(parsed)}")

parsed = loki_parse('((a + b)/(a - b))**3 + 3.1415', scope=routine)
print(f"{parsed} | {type(parsed)}")
print(f" {fgen(parsed)}")
parsed = loki_parse(convert_to_case('a .or. .not. b', mode=case))
assert isinstance(parsed, symbols.LogicalOr)
assert isinstance(parsed.children[0], symbols.DeferredTypeSymbol)
assert isinstance(parsed.children[1], symbols.LogicalNot)
# print(f"{parsed} | {type(parsed)}") # | children: {parsed.child}")
# print(f" {fgen(parsed)}")

parsed = loki_parse(convert_to_case('((a + b)/(a - b))**3 + 3.1415', mode=case), scope=routine)
assert isinstance(parsed, symbols.Sum)
assert isinstance(parsed.children[0], symbols.Power)
assert isinstance(parsed.children[0].base, symbols.Quotient)
assert isinstance(parsed.children[0].base.numerator, symbols.Sum)
assert isinstance(parsed.children[0].base.denominator, symbols.Sum)
assert isinstance(parsed.children[1], symbols.FloatLiteral)
parsed_vars = FindVariables().visit(parsed)
assert parsed_vars == ('a', 'b', 'a', 'b')
assert all(parsed_var.scope == routine for parsed_var in parsed_vars)

# print(f"{parsed} | {type(parsed)}")
# print(f" {fgen(parsed)}")

0 comments on commit 013819d

Please sign in to comment.