Skip to content

Commit

Permalink
Merge pull request #448 from ecmwf-ifs/naml-fix-module-enrichment
Browse files Browse the repository at this point in the history
Module: Fix enrichment of type info via `Module` imports
  • Loading branch information
reuterbal authored Nov 25, 2024
2 parents bcef931 + 83222f4 commit b93a01d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 39 deletions.
3 changes: 2 additions & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ def enrich(self, definitions, recurse=False):
"""
definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions))

for imprt in self.imports:
# Enrich type info from all known imports (including parent scopes)
for imprt in self.all_imports:
if not (module := definitions_map.get(imprt.module)):
# Skip modules that are not available in the definitions list
continue
Expand Down
129 changes: 91 additions & 38 deletions loki/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@

import pytest

from loki import (
Module, Subroutine, VariableDeclaration, TypeDef, fexprgen,
BasicType, Assignment, FindNodes, FindInlineCalls, FindTypedSymbols,
Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic,
Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal
)
from loki import Module, Subroutine, fexprgen, fgen
from loki.build import jit_compile, clean_test
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
nodes as ir, FindNodes, FindInlineCalls, FindTypedSymbols,
FindVariables, SubstituteExpressions, Transformer
)
from loki.sourcefile import Sourcefile
from loki.types import BasicType, DerivedType, SymbolAttributes


@pytest.mark.parametrize('frontend', available_frontends())
Expand All @@ -40,8 +41,8 @@ def test_module_from_source(frontend, tmp_path):
end module a_module
""".strip()
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
assert len([o for o in module.spec.body if isinstance(o, VariableDeclaration)]) == 2
assert len([o for o in module.spec.body if isinstance(o, TypeDef)]) == 1
assert len([o for o in module.spec.body if isinstance(o, ir.VariableDeclaration)]) == 2
assert len([o for o in module.spec.body if isinstance(o, ir.TypeDef)]) == 1
assert 'derived_type' in module.typedef_map
assert len(module.routines) == 1
assert module.routines[0].name == 'my_routine'
Expand Down Expand Up @@ -100,7 +101,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path):
assert fexprgen(a.shape) == exptected_array_shape

# Check the LHS of the assignment has correct meta-data
stmt = FindNodes(Assignment).visit(routine.body)[0]
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
pt_ext_arr = stmt.lhs
assert pt_ext_arr.type.dtype == BasicType.REAL
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
Expand Down Expand Up @@ -177,14 +178,14 @@ def test_module_external_typedefs_type(frontend, tmp_path):

# Verify correct attachment of type information
assert 'ext_type' in module.symbol_attrs
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, TypeDef)
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, TypeDef)
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, ir.TypeDef)
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, ir.TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, ir.TypeDef)
assert 'other_type' in module.symbol_attrs
assert 'other_type' not in module['other_routine'].symbol_attrs
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, TypeDef)
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, ir.TypeDef)
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)

# OMNI resolves explicit shape parameters in the frontend parser
exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'
Expand All @@ -206,7 +207,7 @@ def test_module_external_typedefs_type(frontend, tmp_path):
assert fexprgen(pt_ext_a.shape) == exptected_array_shape

# Check the LHS of the assignment has correct meta-data
stmt = FindNodes(Assignment).visit(routine.body)[0]
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
pt_ext_arr = stmt.lhs
assert pt_ext_arr.type.dtype == BasicType.REAL
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
Expand Down Expand Up @@ -412,9 +413,9 @@ def test_module_variables_add_remove(frontend, tmp_path):
x = module.variable_map['x'] # That's the symbol for variable 'x'
real_type = SymbolAttributes('real', kind=module.variable_map['jprb'])
int_type = SymbolAttributes('integer')
a = Variable(name='a', type=real_type, scope=module)
b = Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
c = Variable(name='c', type=int_type, scope=module)
a = sym.Variable(name='a', type=real_type, scope=module)
b = sym.Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
c = sym.Variable(name='c', type=int_type, scope=module)

# Add new variables and check that they are all in the module spec
module.variables += (a, b, c)
Expand Down Expand Up @@ -554,22 +555,22 @@ def test_module_deep_clone(frontend, tmp_path):
new_module = module.clone()

n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0]
n_decl = FindNodes(ir.VariableDeclaration).visit(new_module.spec)[0]

# Remove the declaration of `n` and replace it with `3`
new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: sym.Literal(3)}).visit(new_module.spec)

# Check the new module has been changed
assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body)
assert len(FindNodes(ir.VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(ir.VariableDeclaration).visit(new_module['my_type'].body)
assert len(new_type_decls) == 2
assert new_type_decls[0].symbols[0] == 'vector(3)'
assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'

# Check the old one has not changed
assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body)
assert len(FindNodes(ir.VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(ir.VariableDeclaration).visit(module['my_type'].body)
assert len(type_decls) == 2
assert type_decls[0].symbols[0] == 'vector(n)'
assert type_decls[1].symbols[0] == 'matrix(n, n)'
Expand Down Expand Up @@ -831,7 +832,7 @@ def test_module_rename_imports_with_definitions(frontend, tmp_path):
assert mod3.symbol_attrs[s].compare(mod2.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name'))

# Verify Import IR node
for imprt in FindNodes(Import).visit(mod3.spec):
for imprt in FindNodes(ir.Import).visit(mod3.spec):
if imprt.module == 'test_rename_mod':
assert imprt.rename_list
assert not imprt.symbols
Expand Down Expand Up @@ -915,7 +916,7 @@ def test_module_rename_imports_no_definitions(frontend, tmp_path):
assert mod3.symbol_attrs[s].use_name == use_name

# Verify Import IR node
for imprt in FindNodes(Import).visit(mod3.spec):
for imprt in FindNodes(ir.Import).visit(mod3.spec):
if imprt.module == 'test_rename_mod':
assert imprt.rename_list
assert not imprt.symbols
Expand Down Expand Up @@ -969,7 +970,7 @@ def test_module_use_module_nature(frontend, tmp_path):

# Check properties on the Import IR node in the external module
assert ext_mod.imported_symbols == ('int16',)
imprt = FindNodes(Import).visit(ext_mod.spec)[0]
imprt = FindNodes(ir.Import).visit(ext_mod.spec)[0]
assert imprt.nature.lower() == 'intrinsic'
assert imprt.module.lower() == 'iso_c_binding'
assert ext_mod.imported_symbol_map['int16'].type.imported is True
Expand All @@ -988,8 +989,8 @@ def test_module_use_module_nature(frontend, tmp_path):
assert set(my_kinds.imported_symbols) == {'int8', 'int16'}
assert set(kinds.imported_symbols) == {'int8', 'int16'}

my_import_map = {s.name: imprt for imprt in FindNodes(Import).visit(my_kinds.spec) for s in imprt.symbols}
import_map = {s.name: imprt for imprt in FindNodes(Import).visit(kinds.spec) for s in imprt.symbols}
my_import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(my_kinds.spec) for s in imprt.symbols}
import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(kinds.spec) for s in imprt.symbols}

assert my_import_map['int8'] is my_import_map['int16']
assert import_map['int8'] is import_map['int16']
Expand Down Expand Up @@ -1194,13 +1195,13 @@ def test_module_contains_auto_insert(frontend, tmp_path):
assert routine1.contains is None

routine1 = routine1.clone(contains=routine2)
assert isinstance(routine1.contains, Section)
assert isinstance(routine1.contains.body[0], Intrinsic)
assert isinstance(routine1.contains, ir.Section)
assert isinstance(routine1.contains.body[0], ir.Intrinsic)
assert routine1.contains.body[0].text == 'CONTAINS'

module = module.clone(contains=routine1)
assert isinstance(module.contains, Section)
assert isinstance(module.contains.body[0], Intrinsic)
assert isinstance(module.contains, ir.Section)
assert isinstance(module.contains.body[0], ir.Intrinsic)
assert module.contains.body[0].text == 'CONTAINS'


Expand Down Expand Up @@ -1243,14 +1244,14 @@ def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_
b = driver.symbol_map['b']

if complete_tree:
assert isinstance(a, Scalar)
assert isinstance(a, sym.Scalar)
assert a.type.dtype is BasicType.INTEGER
assert isinstance(b, Scalar)
assert isinstance(b, sym.Scalar)
assert b.type.dtype is BasicType.INTEGER
else:
assert isinstance(a, DeferredTypeSymbol)
assert isinstance(a, sym.DeferredTypeSymbol)
assert a.type.dtype is BasicType.DEFERRED
assert isinstance(b, DeferredTypeSymbol)
assert isinstance(b, sym.DeferredTypeSymbol)
assert b.type.dtype is BasicType.DEFERRED

assert a.type.imported
Expand Down Expand Up @@ -1371,3 +1372,55 @@ def test_module_enrichment_within_file(frontend, tmp_path):
assert calls[0].arguments[0].type.parameter
assert calls[0].arguments[0].type.initial == 16
assert calls[0].arguments[0].type.module is source['foo']


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_enrichment_typdefs(frontend, tmp_path):
""" Test that module-level enrihcment is propagated correctly """

fcode_state_mod = """
module state_type_mod
implicit none
type state_type
real, pointer, dimension(:,:) :: a
end type state_type
end module state_type_mod
"""

fcode_driver_mod = """
module driver_mod
use state_type_mod, only: state_type
implicit none
contains
subroutine driver_routine(state)
type(state_type), intent(inout) :: state
state%a = 1
end subroutine driver_routine
end module driver_mod
"""
state_mod = Sourcefile.from_source(fcode_state_mod, frontend=frontend, xmods=[tmp_path])['state_type_mod']
driver_mod = Sourcefile.from_source(fcode_driver_mod, frontend=frontend, xmods=[tmp_path])['driver_mod']
driver = driver_mod['driver_routine']

state = driver.variable_map['state']
assert isinstance(state.type.dtype, DerivedType)
assert state.type.dtype.typedef == BasicType.DEFERRED

# Enrich typedef on the outer module Import
driver_mod.enrich([state_mod], recurse=True)

state = driver.variable_map['state']

# Ensure type info has been propagated to inner subroutine
assert isinstance(state.type.dtype, DerivedType)
assert isinstance(state.type.dtype.typedef, ir.TypeDef)

assigns = FindNodes(ir.Assignment).visit(driver.body)
assert len(assigns) == 1
assert assigns[0].lhs.type.dtype == BasicType.REAL
assert assigns[0].lhs.type.shape == (':', ':')

0 comments on commit b93a01d

Please sign in to comment.