Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module: Fix enrichment of type info via Module imports #448

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ def enrich(self, definitions, recurse=False):
"""
definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions))

for imprt in self.imports:
# Enrich type info from all known imports (including parent scopes)
for imprt in self.all_imports:
if not (module := definitions_map.get(imprt.module)):
# Skip modules that are not available in the definitions list
continue
Expand Down
129 changes: 91 additions & 38 deletions loki/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@

import pytest

from loki import (
Module, Subroutine, VariableDeclaration, TypeDef, fexprgen,
BasicType, Assignment, FindNodes, FindInlineCalls, FindTypedSymbols,
Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic,
Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal
)
from loki import Module, Subroutine, fexprgen, fgen
from loki.build import jit_compile, clean_test
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
nodes as ir, FindNodes, FindInlineCalls, FindTypedSymbols,
FindVariables, SubstituteExpressions, Transformer
)
from loki.sourcefile import Sourcefile
from loki.types import BasicType, DerivedType, SymbolAttributes


@pytest.mark.parametrize('frontend', available_frontends())
Expand All @@ -40,8 +41,8 @@ def test_module_from_source(frontend, tmp_path):
end module a_module
""".strip()
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
assert len([o for o in module.spec.body if isinstance(o, VariableDeclaration)]) == 2
assert len([o for o in module.spec.body if isinstance(o, TypeDef)]) == 1
assert len([o for o in module.spec.body if isinstance(o, ir.VariableDeclaration)]) == 2
assert len([o for o in module.spec.body if isinstance(o, ir.TypeDef)]) == 1
assert 'derived_type' in module.typedef_map
assert len(module.routines) == 1
assert module.routines[0].name == 'my_routine'
Expand Down Expand Up @@ -100,7 +101,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path):
assert fexprgen(a.shape) == exptected_array_shape

# Check the LHS of the assignment has correct meta-data
stmt = FindNodes(Assignment).visit(routine.body)[0]
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
pt_ext_arr = stmt.lhs
assert pt_ext_arr.type.dtype == BasicType.REAL
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
Expand Down Expand Up @@ -177,14 +178,14 @@ def test_module_external_typedefs_type(frontend, tmp_path):

# Verify correct attachment of type information
assert 'ext_type' in module.symbol_attrs
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, TypeDef)
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, TypeDef)
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, ir.TypeDef)
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, ir.TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, ir.TypeDef)
assert 'other_type' in module.symbol_attrs
assert 'other_type' not in module['other_routine'].symbol_attrs
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, TypeDef)
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, ir.TypeDef)
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)

# OMNI resolves explicit shape parameters in the frontend parser
exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'
Expand All @@ -206,7 +207,7 @@ def test_module_external_typedefs_type(frontend, tmp_path):
assert fexprgen(pt_ext_a.shape) == exptected_array_shape

# Check the LHS of the assignment has correct meta-data
stmt = FindNodes(Assignment).visit(routine.body)[0]
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
pt_ext_arr = stmt.lhs
assert pt_ext_arr.type.dtype == BasicType.REAL
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
Expand Down Expand Up @@ -412,9 +413,9 @@ def test_module_variables_add_remove(frontend, tmp_path):
x = module.variable_map['x'] # That's the symbol for variable 'x'
real_type = SymbolAttributes('real', kind=module.variable_map['jprb'])
int_type = SymbolAttributes('integer')
a = Variable(name='a', type=real_type, scope=module)
b = Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
c = Variable(name='c', type=int_type, scope=module)
a = sym.Variable(name='a', type=real_type, scope=module)
b = sym.Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
c = sym.Variable(name='c', type=int_type, scope=module)

# Add new variables and check that they are all in the module spec
module.variables += (a, b, c)
Expand Down Expand Up @@ -554,22 +555,22 @@ def test_module_deep_clone(frontend, tmp_path):
new_module = module.clone()

n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0]
n_decl = FindNodes(ir.VariableDeclaration).visit(new_module.spec)[0]

# Remove the declaration of `n` and replace it with `3`
new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: sym.Literal(3)}).visit(new_module.spec)

# Check the new module has been changed
assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body)
assert len(FindNodes(ir.VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(ir.VariableDeclaration).visit(new_module['my_type'].body)
assert len(new_type_decls) == 2
assert new_type_decls[0].symbols[0] == 'vector(3)'
assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'

# Check the old one has not changed
assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body)
assert len(FindNodes(ir.VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(ir.VariableDeclaration).visit(module['my_type'].body)
assert len(type_decls) == 2
assert type_decls[0].symbols[0] == 'vector(n)'
assert type_decls[1].symbols[0] == 'matrix(n, n)'
Expand Down Expand Up @@ -831,7 +832,7 @@ def test_module_rename_imports_with_definitions(frontend, tmp_path):
assert mod3.symbol_attrs[s].compare(mod2.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name'))

# Verify Import IR node
for imprt in FindNodes(Import).visit(mod3.spec):
for imprt in FindNodes(ir.Import).visit(mod3.spec):
if imprt.module == 'test_rename_mod':
assert imprt.rename_list
assert not imprt.symbols
Expand Down Expand Up @@ -915,7 +916,7 @@ def test_module_rename_imports_no_definitions(frontend, tmp_path):
assert mod3.symbol_attrs[s].use_name == use_name

# Verify Import IR node
for imprt in FindNodes(Import).visit(mod3.spec):
for imprt in FindNodes(ir.Import).visit(mod3.spec):
if imprt.module == 'test_rename_mod':
assert imprt.rename_list
assert not imprt.symbols
Expand Down Expand Up @@ -969,7 +970,7 @@ def test_module_use_module_nature(frontend, tmp_path):

# Check properties on the Import IR node in the external module
assert ext_mod.imported_symbols == ('int16',)
imprt = FindNodes(Import).visit(ext_mod.spec)[0]
imprt = FindNodes(ir.Import).visit(ext_mod.spec)[0]
assert imprt.nature.lower() == 'intrinsic'
assert imprt.module.lower() == 'iso_c_binding'
assert ext_mod.imported_symbol_map['int16'].type.imported is True
Expand All @@ -988,8 +989,8 @@ def test_module_use_module_nature(frontend, tmp_path):
assert set(my_kinds.imported_symbols) == {'int8', 'int16'}
assert set(kinds.imported_symbols) == {'int8', 'int16'}

my_import_map = {s.name: imprt for imprt in FindNodes(Import).visit(my_kinds.spec) for s in imprt.symbols}
import_map = {s.name: imprt for imprt in FindNodes(Import).visit(kinds.spec) for s in imprt.symbols}
my_import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(my_kinds.spec) for s in imprt.symbols}
import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(kinds.spec) for s in imprt.symbols}

assert my_import_map['int8'] is my_import_map['int16']
assert import_map['int8'] is import_map['int16']
Expand Down Expand Up @@ -1194,13 +1195,13 @@ def test_module_contains_auto_insert(frontend, tmp_path):
assert routine1.contains is None

routine1 = routine1.clone(contains=routine2)
assert isinstance(routine1.contains, Section)
assert isinstance(routine1.contains.body[0], Intrinsic)
assert isinstance(routine1.contains, ir.Section)
assert isinstance(routine1.contains.body[0], ir.Intrinsic)
assert routine1.contains.body[0].text == 'CONTAINS'

module = module.clone(contains=routine1)
assert isinstance(module.contains, Section)
assert isinstance(module.contains.body[0], Intrinsic)
assert isinstance(module.contains, ir.Section)
assert isinstance(module.contains.body[0], ir.Intrinsic)
assert module.contains.body[0].text == 'CONTAINS'


Expand Down Expand Up @@ -1243,14 +1244,14 @@ def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_
b = driver.symbol_map['b']

if complete_tree:
assert isinstance(a, Scalar)
assert isinstance(a, sym.Scalar)
assert a.type.dtype is BasicType.INTEGER
assert isinstance(b, Scalar)
assert isinstance(b, sym.Scalar)
assert b.type.dtype is BasicType.INTEGER
else:
assert isinstance(a, DeferredTypeSymbol)
assert isinstance(a, sym.DeferredTypeSymbol)
assert a.type.dtype is BasicType.DEFERRED
assert isinstance(b, DeferredTypeSymbol)
assert isinstance(b, sym.DeferredTypeSymbol)
assert b.type.dtype is BasicType.DEFERRED

assert a.type.imported
Expand Down Expand Up @@ -1371,3 +1372,55 @@ def test_module_enrichment_within_file(frontend, tmp_path):
assert calls[0].arguments[0].type.parameter
assert calls[0].arguments[0].type.initial == 16
assert calls[0].arguments[0].type.module is source['foo']


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_enrichment_typdefs(frontend, tmp_path):
""" Test that module-level enrihcment is propagated correctly """

fcode_state_mod = """
module state_type_mod
implicit none

type state_type
real, pointer, dimension(:,:) :: a
end type state_type

end module state_type_mod
"""

fcode_driver_mod = """
module driver_mod
use state_type_mod, only: state_type
implicit none

contains
subroutine driver_routine(state)
type(state_type), intent(inout) :: state

state%a = 1

end subroutine driver_routine
end module driver_mod
"""
state_mod = Sourcefile.from_source(fcode_state_mod, frontend=frontend, xmods=[tmp_path])['state_type_mod']
driver_mod = Sourcefile.from_source(fcode_driver_mod, frontend=frontend, xmods=[tmp_path])['driver_mod']
driver = driver_mod['driver_routine']

state = driver.variable_map['state']
assert isinstance(state.type.dtype, DerivedType)
assert state.type.dtype.typedef == BasicType.DEFERRED

# Enrich typedef on the outer module Import
driver_mod.enrich([state_mod], recurse=True)

state = driver.variable_map['state']

# Ensure type info has been propagated to inner subroutine
assert isinstance(state.type.dtype, DerivedType)
assert isinstance(state.type.dtype.typedef, ir.TypeDef)

assigns = FindNodes(ir.Assignment).visit(driver.body)
assert len(assigns) == 1
assert assigns[0].lhs.type.dtype == BasicType.REAL
assert assigns[0].lhs.type.shape == (':', ':')
Loading