Skip to content

Commit

Permalink
#2716 WIP exploring options
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Sep 27, 2024
1 parent dd29ca3 commit 4383c7c
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,17 @@ def validate(self, node, options=None):

# Check that the PSyIR of the routine/kernel can be retrieved.
try:
_, kernel_schedule = (
_, kernels, _ = (
KernelModuleInlineTrans._get_psyir_to_inline(node))
except Exception as error:
raise TransformationError(
f"{self.name} failed to retrieve PSyIR for {kern_or_call} "
f"'{kname}' due to: {error}"
) from error

# TODO ARPDBG - need to examine every kernel implementation, not just
# the first one.
kernel_schedule = kernels[0]
# We do not support kernels that use symbols representing data
# declared in their own parent module (we would need to new imports
# from this module for those, and we don't do this yet).
Expand Down Expand Up @@ -218,7 +221,7 @@ def validate(self, node, options=None):
f"subroutines is not supported yet.")

@staticmethod
def _prepare_code_to_inline(code_to_inline):
def _prepare_code_to_inline(routines_to_inline, symbol):
'''Prepare the PSyIR tree to inline by bringing in to the subroutine
all referenced symbols so that the implementation is self contained.
Expand All @@ -234,65 +237,66 @@ def _prepare_code_to_inline(code_to_inline):
'''
# pylint: disable=too-many-branches
source_container = code_to_inline.ancestor(Container)

# First make a set with all symbols used inside the subroutine
all_symbols = set()
for scope in code_to_inline.walk(ScopingNode):
for symbol in scope.symbol_table.symbols:
all_symbols.add(symbol)
for reference in code_to_inline.walk(Reference):
all_symbols.add(reference.symbol)
for literal in code_to_inline.walk(Literal):
# Literals may reference symbols in their precision
if isinstance(literal.datatype.precision, Symbol):
all_symbols.add(literal.datatype.precision)
for caller in code_to_inline.walk(Call):
all_symbols.add(caller.routine.symbol)
for cblock in code_to_inline.walk(CodeBlock):
for name in cblock.get_symbol_names():
all_symbols.add(cblock.scope.symbol_table.lookup(name))

# Then decide which symbols need to be brought inside the subroutine
symbols_to_bring_in = set()
for symbol in all_symbols:
if symbol.is_unresolved or symbol.is_import:
# This symbol is already in the symbol table, but adding it
# to the 'symbols_to_bring_in' will make the next step bring
# into the subroutine all modules that it could come from.
symbols_to_bring_in.add(symbol)
if isinstance(symbol, DataSymbol):
# DataTypes can reference other symbols
if isinstance(symbol.datatype, DataTypeSymbol):
symbols_to_bring_in.add(symbol.datatype)
elif hasattr(symbol.datatype, 'precision'):
if isinstance(symbol.datatype.precision, Symbol):
symbols_to_bring_in.add(symbol.datatype.precision)

# Bring the selected symbols inside the subroutine
for symbol in symbols_to_bring_in:
if symbol.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(symbol)
# And when necessary the modules where they come from
if symbol.is_unresolved:
# We don't know where this comes from, we need to bring
# in all top-level imports with wildcard imports
for mod in source_container.symbol_table.containersymbols:
if mod.wildcard_import:
if mod.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(mod)
else:
code_to_inline.symbol_table.lookup(mod.name).\
wildcard_import = True
elif symbol.is_import:
module_symbol = symbol.interface.container_symbol
if module_symbol.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(module_symbol)
else:
# If it already exists, we know its a container (from the
# validation) so we just need to point to it
symbol.interface.container_symbol = \
code_to_inline.symbol_table.lookup(module_symbol.name)
source_container = routines_to_inline[0].ancestor(Container)

for code_to_inline in routines_to_inline:
# First make a set with all symbols used inside the subroutine
all_symbols = set()
for scope in code_to_inline.walk(ScopingNode):
for symbol in scope.symbol_table.symbols:
all_symbols.add(symbol)
for reference in code_to_inline.walk(Reference):
all_symbols.add(reference.symbol)
for literal in code_to_inline.walk(Literal):
# Literals may reference symbols in their precision
if isinstance(literal.datatype.precision, Symbol):
all_symbols.add(literal.datatype.precision)
for caller in code_to_inline.walk(Call):
all_symbols.add(caller.routine.symbol)
for cblock in code_to_inline.walk(CodeBlock):
for name in cblock.get_symbol_names():
all_symbols.add(cblock.scope.symbol_table.lookup(name))

# Then decide which symbols need to be brought inside the subroutine
symbols_to_bring_in = set()
for symbol in all_symbols:
if symbol.is_unresolved or symbol.is_import:
# This symbol is already in the symbol table, but adding it
# to the 'symbols_to_bring_in' will make the next step bring
# into the subroutine all modules that it could come from.
symbols_to_bring_in.add(symbol)
if isinstance(symbol, DataSymbol):
# DataTypes can reference other symbols
if isinstance(symbol.datatype, DataTypeSymbol):
symbols_to_bring_in.add(symbol.datatype)
elif hasattr(symbol.datatype, 'precision'):
if isinstance(symbol.datatype.precision, Symbol):
symbols_to_bring_in.add(symbol.datatype.precision)

# Bring the selected symbols inside the subroutine
for symbol in symbols_to_bring_in:
if symbol.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(symbol)
# And when necessary the modules where they come from
if symbol.is_unresolved:
# We don't know where this comes from, we need to bring
# in all top-level imports with wildcard imports
for mod in source_container.symbol_table.containersymbols:
if mod.wildcard_import:
if mod.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(mod)
else:
code_to_inline.symbol_table.lookup(mod.name).\
wildcard_import = True
elif symbol.is_import:
module_symbol = symbol.interface.container_symbol
if module_symbol.name not in code_to_inline.symbol_table:
code_to_inline.symbol_table.add(module_symbol)
else:
# If it already exists, we know its a container (from the
# validation) so we just need to point to it
symbol.interface.container_symbol = \
code_to_inline.symbol_table.lookup(module_symbol.name)

@staticmethod
def _get_psyir_to_inline(node):
Expand All @@ -306,7 +310,8 @@ def _get_psyir_to_inline(node):
:returns: the name of the routine as seen by the caller and the
PSyIR of the routine implementation.
:rtype: Tuple(str, :py:class:`psyclone.psyir.nodes.Call`)
:rtype: Tuple(str, list[:py:class:`psyclone.psyir.nodes.Routine`],
:py:class:`psyclone.psyir.symbols.Symbol`)
:raises TransformationError: if we have a call to a language-level
Routine that maps to an Interface block as this is not yet
Expand All @@ -321,6 +326,7 @@ def _get_psyir_to_inline(node):
# interface matching the arguments in the call.
routines = [node.get_kernel_schedule()]
caller_name = node.name.lower()
interface_sym = None
else:
# We have a generic routine call.
routines = node.get_callees()
Expand All @@ -334,7 +340,7 @@ def _get_psyir_to_inline(node):
f"The target of the call to '{caller_name}' cannot be "
f"inserted because multiple implementations were found: "
f"{[rout.name for rout in routines]}. TODO #924")
return (caller_name, routines[0])
return (caller_name, routines, interface_sym)

def apply(self, node, options=None):
''' Bring the kernel subroutine into this Container.
Expand Down Expand Up @@ -364,16 +370,19 @@ def apply(self, node, options=None):
# may already be in use, but the equality check below guarantees
# that if it exists it is only valid when it references the exact same
# implementation.
caller_name, code_to_inline = (
caller_name, code_to_inline, interface_sym = (
KernelModuleInlineTrans._get_psyir_to_inline(node))
callee_name = code_to_inline.name
if interface_sym:
callee_name = interface_sym.name
else:
callee_name = code_to_inline[0].name

try:
existing_symbol = node.scope.symbol_table.lookup(callee_name)
except KeyError:
existing_symbol = None

self._prepare_code_to_inline(code_to_inline)
self._prepare_code_to_inline(code_to_inline, interface_sym)

container = node.ancestor(Container)
if not existing_symbol:
Expand All @@ -384,7 +393,8 @@ def apply(self, node, options=None):
visibility=Symbol.Visibility.PRIVATE)
container.symbol_table.add(routine_symbol)
# 2) Insert the relevant code into the tree.
container.addchild(code_to_inline.detach())
for routine in code_to_inline:
container.addchild(routine.detach())
else:
if existing_symbol.is_import:
# The RoutineSymbol is in the table but that is because it is
Expand All @@ -400,7 +410,8 @@ def apply(self, node, options=None):
existing_symbol.visibility = Symbol.Visibility.PRIVATE
if remove_csym:
ctable.remove(csym)
container.addchild(code_to_inline.detach())
for routine in code_to_inline:
container.addchild(routine.detach())
else:
# The routine symbol already exists, and we know from the
# validation that it's a Routine. Now check if they are
Expand Down
4 changes: 4 additions & 0 deletions src/psyclone/domain/lfric/lfric_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def _create_generic_scalars():
'''
GenericScalar = namedtuple('GenericScalar', ["name", "intrinsic",
"precision"])
api_config = Config.get().api_conf("lfric")

lfric_kinds = list(api_config.precision_map.keys())

generic_scalar_datatypes = [
GenericScalar("LFRicIntegerScalar", ScalarType.Intrinsic.INTEGER,
LFRicTypes("I_DEF")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,12 @@ def test_module_inline_with_interfaces(tmpdir):
kern_call = invoke.schedule.walk(CodedKern)[0]
inline_trans = KernelModuleInlineTrans()
inline_trans.apply(kern_call)
gen = str(psy.gen)
# Both the caller and the callee are in the file and use the specialized
# implementation name.
assert "CALL mixed_code_64(" in gen
assert "SUBROUTINE mixed_code_64(" in gen
gen = str(psy.gen).lower()
# Both the caller and the callee are in the file and use the interface
# name.
assert "call mixed_code(" in gen
assert "subroutine mixed_code_64(" in gen
assert "subroutine mixed_code_32(" in gen

# And it is valid code
assert LFRicBuild(tmpdir).code_compiles(psy)
Expand Down

0 comments on commit 4383c7c

Please sign in to comment.