Skip to content

Commit

Permalink
#1823 Imporve module-inline transformation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sergisiso committed Oct 7, 2022
1 parent 4732036 commit ab3811f
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,66 +161,66 @@ def validate(self, node, options=None):
f"subroutines is not supported yet.")

@staticmethod
def _prepare_code_to_inline(node):
def _prepare_code_to_inline(code_to_inline):
''' Prepare the PSyIR tree to inline by brining in the subroutine all
referenced symbols so that the implementation is self contained.
:param node: the kernel to module-inline.
:type node: :py:class:`psyclone.psyGen.CodedKern`
:returns: a self contained version of the subroutine to inline.
:rtype: :py:class:`psyclone.psyir.nodes.Routine`
:raise InternalError: unexpected PSyIR.
:param code_to_inline: the subroutine to module-inline.
:type code_to_inline: :py:class:`psyclone.psyir.node.Routine`
'''
code_to_inline = node.get_kernel_schedule()
source_container = code_to_inline.ancestor(Container)
symbols_to_bring_in = set()

# Find all symbols that have to be brought inside the subroutine
# First make a set with all symbols used inside the subroutine
all_symbols = set()
for scope in code_to_inline.walk(ScopingNode):
for symbol in scope.symbol_table.symbols:
if symbol.is_unresolved:
# We don't know where this comes from, we need to bring
# in all top-level imports
for mod in source_container.symbol_table.containersymbols:
symbols_to_bring_in.add(mod)
elif symbol.is_import:
# Add to symbols_to_bring_in
symbols_to_bring_in.add(symbol)
if isinstance(symbol, DataSymbol):
# DataTypes can reference other symbols
if isinstance(symbol.datatype, DataTypeSymbol):
symbols_to_bring_in.add(symbol.datatype)
elif hasattr(symbol.datatype, 'precision'):
if isinstance(symbol.datatype.precision, Symbol):
symbols_to_bring_in.add(symbol.datatype.precision)

# Literals can also reference symbols in they precision
all_symbols.add(symbol)
for reference in code_to_inline.walk(Reference):
all_symbols.add(reference.symbol)
for literal in code_to_inline.walk(Literal):
# Literals may reference symbols in they precision
if isinstance(literal.datatype.precision, Symbol):
symbols_to_bring_in.add(literal.datatype.precision)

# Calls also refer to symbols
all_symbols.add(literal.datatype.precision)
for caller in code_to_inline.walk(Call):
# TODO #1366: We still need a solution for intrinsics that
# currently are parsed into Calls/RoutineSymbols, for the
# moment here we skip the ones causing issues.
if caller.routine.name not in ("random_number", ):
symbols_to_bring_in.add(caller.routine)
all_symbols.add(caller.routine)

# Then decide which symbols need to be brought inside the subroutine
symbols_to_bring_in = set()
for symbol in all_symbols:
if symbol.is_unresolved:
# We don't know where this comes from, we need to bring
# in all top-level imports
for mod in source_container.symbol_table.containersymbols:
symbols_to_bring_in.add(mod)
elif symbol.is_import:
# Add to symbols_to_bring_in
symbols_to_bring_in.add(symbol)
if isinstance(symbol, DataSymbol):
# DataTypes can reference other symbols
if isinstance(symbol.datatype, DataTypeSymbol):
symbols_to_bring_in.add(symbol.datatype)
elif hasattr(symbol.datatype, 'precision'):
if isinstance(symbol.datatype.precision, Symbol):
symbols_to_bring_in.add(symbol.datatype.precision)

# Bring the selected symbols inside the subroutine
for symbol in symbols_to_bring_in:
if symbol.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(symbol)
# And when necessary the modules where they come from
if symbol.is_import:
module_symbol = symbol.interface.container_symbol
if module_symbol.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(module_symbol)

return code_to_inline
# And when necessary the modules where they come from
if symbol.is_import:
module_symbol = symbol.interface.container_symbol
if module_symbol.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(module_symbol)
elif symbol.is_unresolved:
for mod in source_container.symbol_table.containersymbols:
if mod.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(mod)

def apply(self, node, options=None):
''' Bring the kernel subroutine in this Container.
Expand All @@ -242,7 +242,8 @@ def apply(self, node, options=None):
except KeyError:
existing_symbol = None

code_to_inline = self._prepare_code_to_inline(node)
code_to_inline = node.get_kernel_schedule()
self._prepare_code_to_inline(code_to_inline)

if not existing_symbol:
# If it doesn't exist already, module-inline the subroutine by:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,95 @@ def test_module_inline_apply_same_kernel(tmpdir):
assert GOceanBuild(tmpdir).code_compiles(psy)


def test_module_inline_apply_bring_in_non_local_symbols(
fortran_reader, fortran_writer):
''' Test that when the inlined routine uses non local symbols, it brings
them inside the subroutine when feasible. '''

inline_trans = KernelModuleInlineTrans()

# Bring all imports when we can't guarantee where symbols come from
psyir = fortran_reader.psyir_from_source('''
module my_mod
use external_mod1
use external_mod2
implicit none
contains
subroutine code()
a = b + c
end subroutine code
end module my_mod
''')

routine = psyir.walk(Routine)[0]
inline_trans._prepare_code_to_inline(routine)
result = fortran_writer(routine)
assert "use external_mod1" in result
assert "use external_mod2" in result

# Also, if they are in datatype precision expressions
psyir = fortran_reader.psyir_from_source('''
module my_mod
use external_mod1, only: r_def
use external_mod2, only: my_user_type
use not_needed
implicit none
contains
subroutine code()
real(kind=r_def) :: a,b
type(my_user_type) :: x
a = b + x%data
end subroutine code
end module my_mod
''')

routine = psyir.walk(Routine)[0]
inline_trans._prepare_code_to_inline(routine)
result = fortran_writer(routine)
assert "use external_mod1, only : r_def" in result
assert "use external_mod2, only : my_user_type" in result
assert "use not_needed" not in result

# Also, if they are literal precision expressions
psyir = fortran_reader.psyir_from_source('''
module my_mod
use external_mod1, only: r_def
use not_needed
implicit none
contains
subroutine code()
real :: a,b
a = b + 1.0_r_def
end subroutine code
end module my_mod
''')

routine = psyir.walk(Routine)[0]
inline_trans._prepare_code_to_inline(routine)
result = fortran_writer(routine)
assert "use external_mod1, only : r_def" in result
assert "use not_needed" not in result

# Also, if they are routine names
psyir = fortran_reader.psyir_from_source('''
module my_mod
use external_mod1, only: my_sub
implicit none
contains
subroutine code()
real :: a
call random_number(a) !intrinsic
call my_sub(a)
end subroutine code
end module my_mod
''')

routine = psyir.walk(Routine)[0]
inline_trans._prepare_code_to_inline(routine)
result = fortran_writer(routine)
assert "use external_mod1, only : my_sub" in result


def test_module_inline_dynamo(monkeypatch, annexed, dist_mem):
'''Tests that correct results are obtained when a kernel is inlined
into the psy-layer in the dynamo0.3 API. All previous tests use GOcean
Expand Down

0 comments on commit ab3811f

Please sign in to comment.