Skip to content

Commit

Permalink
SingleColumn: Demote arrays that are not used at all in the body
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 authored and reuterbal committed Mar 26, 2024
1 parent 20edf84 commit 77966f4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 2 additions & 0 deletions transformations/tests/test_single_column_coalesced.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def test_scc_demote_transformation(frontend, horizontal):
REAL :: t(nlon,nz)
REAL :: a(nlon)
REAL :: b(nlon,psize)
REAL :: unused(nlon)
INTEGER, PARAMETER :: psize = 3
INTEGER :: jl, jk
REAL :: c
Expand Down Expand Up @@ -280,6 +281,7 @@ def test_scc_demote_transformation(frontend, horizontal):
assert isinstance(kernel.variable_map['c'], Scalar)
assert isinstance(kernel.variable_map['t'], Array)
assert isinstance(kernel.variable_map['q'], Array)
assert isinstance(kernel.variable_map['unused'], Scalar)

# Ensure that parameter-sized array b got demoted only
assert kernel.variable_map['b'].shape == ((3,) if frontend is OMNI else ('psize',))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _get_local_arrays(section):
return arrays

# Create a list of all local horizontal temporary arrays
candidates = _get_local_arrays(routine.body)
candidates = _get_local_arrays(routine.spec)

# Create an index into all variable uses per vector-level section
vars_per_section = {s: set(v.name.lower() for v in _get_local_arrays(s)) for s in sections}
Expand All @@ -343,8 +343,8 @@ def _get_local_arrays(section):
for arr in candidates:
counts[arr] = sum(1 if arr.name.lower() in v else 0 for v in vars_per_section.values())

# Mark temporaries that are only used in one section for demotion
to_demote = [k for k, v in counts.items() if v == 1]
# Demote temporaries that are only used in one section or not at all
to_demote = [k for k, v in counts.items() if v <= 1]

# Filter out variables that we will pass down the call tree
calls = FindNodes(ir.CallStatement).visit(routine.body)
Expand Down

0 comments on commit 77966f4

Please sign in to comment.