Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loop regions to the frontend's capabilities #1475

Merged
merged 77 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
b6e9000
Add loop regions to the frontend's capabilities
phschaad Dec 11, 2023
ced035b
Merge branch 'master' into loop_architecture_pt_3
phschaad Dec 11, 2023
0bb6223
Merge branch 'master' into loop_architecture_pt_3
phschaad Dec 12, 2023
d26c507
Bugfixes
phschaad Dec 12, 2023
b83d05d
Fix data dependent while loop generation
phschaad Dec 13, 2023
55acc41
Merge remote-tracking branch 'origin/master' into loop_architecture_pt_3
phschaad Dec 13, 2023
35e4272
Make state propagation test more robust to SDFG changes
phschaad Dec 13, 2023
6b4d1be
Fixes
phschaad Dec 13, 2023
83cf2d0
newast fixes
phschaad Dec 13, 2023
1393b19
Change property type and add better type hinting
phschaad Dec 14, 2023
864c607
Merge branch 'master' into loop_architecture_pt_3
phschaad Dec 20, 2023
b317a12
Allow orelse and break continue
phschaad Dec 20, 2023
240ff79
Fix free symbols for loops
phschaad Dec 20, 2023
016b21c
Merge branch 'master' into loop_architecture_pt_3
phschaad Jan 17, 2024
28177c7
Merge remote-tracking branch 'origin/master' into loop_architecture_pt_3
phschaad Jan 24, 2024
69eaf92
Provide pass compatibility check for passes and transformations
phschaad Jan 24, 2024
975d79e
Update passes and transformations
phschaad Jan 25, 2024
80e96bb
Make sure auto opt "works"
phschaad Jan 26, 2024
9b7f840
Refactor SDFG List to CFG List
phschaad Jan 29, 2024
bc8679f
Make sure no old style `sdfg_list` calls remain
phschaad Jan 29, 2024
68a6b62
Fix deserializataion for control flow regions
phschaad Jan 29, 2024
40cd861
Fix deserialization
phschaad Jan 29, 2024
482c30f
Remove legacy calls to sdfg_list
phschaad Jan 29, 2024
ae2e068
Merge branch 'refactor_sdfg_list_to_cfg_list' into loop_architecture_…
phschaad Jan 29, 2024
e596496
Fix transformation architecture
phschaad Jan 29, 2024
27a350e
Address review comments, update docs
phschaad Jan 29, 2024
5aeec96
Fix blunder
phschaad Jan 29, 2024
7e2aa45
Merge remote-tracking branch 'origin/refactor_sdfg_list_to_cfg_list' …
phschaad Jan 29, 2024
81b6972
Fix incorrect arg passing
phschaad Jan 29, 2024
4ca1fea
Fix control flow inlining
phschaad Jan 29, 2024
303c605
Fix control flow region traversal
phschaad Jan 29, 2024
c1ec438
Bugfixes
phschaad Jan 29, 2024
a966044
Fix test
phschaad Jan 30, 2024
01b3593
Fix missing reset of cfg list for inlining
phschaad Jan 30, 2024
8af34d5
Fix test
phschaad Jan 30, 2024
03e976c
Added loops to fortran frontend
phschaad Jan 30, 2024
2d3d77e
Ensure compatibility checks
phschaad Jan 30, 2024
a89c64d
Cleanup
phschaad Jan 30, 2024
9c06e06
Cleanup
phschaad Jan 30, 2024
a608779
Cleanup
phschaad Jan 30, 2024
926ad49
Fix codegen bug (for loops)
phschaad Jan 31, 2024
03f0d75
Fix SDFG references for complex loop condition tests
phschaad Jan 31, 2024
aad6c28
Make dreport file sorting based on version instead of state id
phschaad Jan 31, 2024
a04bf0b
Fix dinstr test
phschaad Jan 31, 2024
22b7456
Fix duplicate control flow block naming for while condition checks
phschaad Jan 31, 2024
abcd09e
Workflow debugging
phschaad Feb 2, 2024
3e779fd
pytest debugging
phschaad Feb 2, 2024
1c7a569
Fixes
phschaad Feb 5, 2024
24a3d8a
Revert two changes
phschaad Feb 6, 2024
ed3f36d
Merge branch 'master' into loop_architecture_pt_3
phschaad May 16, 2024
bb6159e
Merge addendum
phschaad May 16, 2024
7455eb2
More robustness to blocksafe wrapper
phschaad May 16, 2024
63df04b
Multistate inline fix
phschaad May 16, 2024
5869236
Cleanup
phschaad May 16, 2024
4bc7c69
Merge branch 'master' into loop_architecture_pt_3
phschaad May 29, 2024
6e77fd0
Temporarily disable tests that cause problems with CF detection
phschaad Jun 11, 2024
012e70a
Merge branch 'master' into loop_architecture_pt_3
phschaad Jun 11, 2024
be4ac47
Add tests and fixes
phschaad Jun 12, 2024
e783717
Update doc
phschaad Jun 12, 2024
8da2a33
Merge branch 'master' into loop_architecture_pt_3
phschaad Jun 14, 2024
2e9c16e
Update copyright year in newast.py
phschaad Jun 18, 2024
0238393
Address comments
phschaad Jun 18, 2024
31818f8
Address more comments
phschaad Jun 18, 2024
f81c156
Fix numpy version to < 2.0
phschaad Jun 18, 2024
12651b2
Merge branch 'numpy_version_fix' into loop_architecture_pt_3
phschaad Jun 18, 2024
67de006
Fix misplaced exception
phschaad Jun 18, 2024
4a24c9c
Address comments
phschaad Jun 19, 2024
be6fe26
Refactor
phschaad Jun 21, 2024
1b8c76d
Fix incompatible types with 3.7, again
phschaad Jun 21, 2024
75f4d64
Fixes
phschaad Jun 21, 2024
d45fd72
Added additional level of backwards compatibility safety for passes
phschaad Jun 24, 2024
f518019
Made map to for loop (legacy version) safer (renaming)
phschaad Jun 24, 2024
3c1a27c
Fix instanceof check
phschaad Jun 24, 2024
d4f79bb
Fix missing import
phschaad Jun 24, 2024
9053acd
More fixes
phschaad Jun 24, 2024
c7517b8
Remove erroneous import
phschaad Jun 24, 2024
cdd3bd8
Address minor comments:
phschaad Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]:
# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)

# Before generating the code, run type inference on the SDFG connectors
infer_types.infer_connector_types(sdfg)
Expand Down
3 changes: 2 additions & 1 deletion dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ def as_cpp(self, codegen, symbols) -> str:

update = ''
if self.update is not None:
update = f'{self.itervar} = {self.update}'
cppupdate = unparse_interstate_edge(self.update, sdfg, codegen=codegen)
update = f'{self.itervar} = {cppupdate}'

expr = f'{preinit}\nfor ({init}; {cond}; {update}) {{\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
Expand Down
262 changes: 151 additions & 111 deletions dace/frontend/fortran/fortran_parser.py

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion dace/frontend/python/nested_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace.sdfg import SDFG, SDFGState
from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from dace.frontend.python.newast import ProgramVisitor
else:
ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor'


class NestedCall():
Expand All @@ -18,7 +24,13 @@ def _cos_then_max(pv, sdfg, state, a: str):
# return a tuple of the nest object and the result
return nest, result
"""
def __init__(self, pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState):
state: SDFGState
last_state: Optional[SDFGState]
pv: ProgramVisitor
sdfg: SDFG
count: int

def __init__(self, pv: ProgramVisitor, sdfg: SDFG, state: SDFGState):
self.pv = pv
self.sdfg = sdfg
self.state = state
Expand Down
433 changes: 240 additions & 193 deletions dace/frontend/python/newast.py

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dace import data, dtypes, hooks, symbolic
from dace.config import Config
from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing)
from dace.sdfg import SDFG
from dace.sdfg import SDFG, utils as sdutils
from dace.data import create_datadescriptor, Data

try:
Expand Down Expand Up @@ -152,7 +152,8 @@ def __init__(self,
regenerate_code: bool = True,
recompile: bool = True,
distributed_compilation: bool = False,
method: bool = False):
method: bool = False,
use_experimental_cfg_blocks: bool = False):
from dace.codegen import compiled_sdfg # Avoid import loops

self.f = f
Expand All @@ -172,6 +173,7 @@ def __init__(self,
self.recreate_sdfg = recreate_sdfg
self.regenerate_code = regenerate_code
self.recompile = recompile
self.use_experimental_cfg_blocks = use_experimental_cfg_blocks
self.distributed_compilation = distributed_compilation

self.global_vars = _get_locals_and_globals(f)
Expand Down Expand Up @@ -491,6 +493,11 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF
# Obtain DaCe program as SDFG
sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify)

if not self.use_experimental_cfg_blocks:
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)
sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks
phschaad marked this conversation as resolved.
Show resolved Hide resolved

# Apply simplification pass automatically
if not cached and (simplify == True or
(simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))):
Expand Down Expand Up @@ -801,7 +808,8 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey:
_, key = self._load_sdfg(None, *args, **kwargs)
return key

def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], simplify: Optional[bool] = None) -> SDFG:
def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any],
simplify: Optional[bool] = None) -> Tuple[SDFG, bool]:
""" Generates the parsed AST representation of a DaCe program.

:param args: The given arguments to the program.
Expand Down
3 changes: 3 additions & 0 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,9 @@ def _add_exits(self, until_loop_end: bool, only_one: bool = False) -> List[ast.A
for stmt in reversed(self.with_statements):
if until_loop_end and not isinstance(stmt, (ast.With, ast.AsyncWith)):
break
elif not until_loop_end and isinstance(stmt, (ast.For, ast.While)):
break

for mgrname, mgr in reversed(self.context_managers[stmt]):
# Call __exit__ (without exception management all three arguments are set to None)
exit_call = ast.copy_location(ast.parse(f'{mgrname}.__exit__(None, None, None)').body[0], stmt)
Expand Down
19 changes: 11 additions & 8 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from functools import reduce
from numbers import Number, Integral
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, TYPE_CHECKING

import dace
from dace.codegen.tools import type_inference
Expand All @@ -28,7 +28,10 @@

Size = Union[int, dace.symbolic.symbol]
Shape = Sequence[Size]
ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor'
if TYPE_CHECKING:
from dace.frontend.python.newast import ProgramVisitor
else:
ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor'


def normalize_axes(axes: Tuple[int], max_dim: int) -> List[int]:
Expand Down Expand Up @@ -971,8 +974,8 @@ def _pymax(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: Union[str, Numbe
for i, b in enumerate(args):
if i > 0:
pv._add_state('__min2_%d' % i)
pv.last_state.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_state
pv.last_block.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_block
left_arg = _minmax2(pv, sdfg, current_state, left_arg, b, ismin=False)
return left_arg

Expand All @@ -986,8 +989,8 @@ def _pymin(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: Union[str, Numbe
for i, b in enumerate(args):
if i > 0:
pv._add_state('__min2_%d' % i)
pv.last_state.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_state
pv.last_block.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_block
left_arg = _minmax2(pv, sdfg, current_state, left_arg, b)
return left_arg

Expand Down Expand Up @@ -3355,7 +3358,7 @@ def _create_subgraph(visitor: ProgramVisitor,
cond_state.add_nedge(r, w, dace.Memlet("{}[0]".format(r)))
true_state = sdfg.add_state(label=cond_state.label + '_true')
state = true_state
visitor.last_state = state
visitor.last_block = state
cond = name
cond_else = 'not ({})'.format(cond)
sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(cond))
Expand All @@ -3374,7 +3377,7 @@ def _create_subgraph(visitor: ProgramVisitor,
dace.Memlet.from_array(arg, sdfg.arrays[arg]))
if has_where and isinstance(where, str) and where in sdfg.arrays.keys():
visitor._add_state(label=cond_state.label + '_true')
sdfg.add_edge(cond_state, visitor.last_state, dace.InterstateEdge(cond_else))
sdfg.add_edge(cond_state, visitor.last_block, dace.InterstateEdge(cond_else))
else:
# Map needed
if has_where:
Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def infer_connector_types(sdfg: SDFG):
:param sdfg: The SDFG to infer.
"""
# Loop over states, and in a topological sort over each state's nodes
for state in sdfg.nodes():
for state in sdfg.states():
for node in dfs_topological_sort(state):
# Try to infer input connector type from node type or previous edges
for e in state.in_edges(node):
Expand Down Expand Up @@ -168,7 +168,7 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E

if isinstance(scope, SDFG):
# Set device for default top-level schedules and storages
for state in scope.nodes():
for state in scope.states():
set_default_schedule_and_storage_types(state,
parent_schedules,
use_parent_schedule=use_parent_schedule,
Expand Down
22 changes: 20 additions & 2 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dace.frontend.python import astutils, wrappers
from dace.sdfg import nodes as nd
from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView
from dace.sdfg.state import SDFGState, ControlFlowRegion
from dace.sdfg.state import ControlFlowBlock, SDFGState, ControlFlowRegion
from dace.sdfg.propagation import propagate_memlets_sdfg
from dace.distr_types import ProcessGrid, SubArray, RedistrArray
from dace.dtypes import validate_name
Expand Down Expand Up @@ -183,7 +183,7 @@ class InterstateEdge(object):
desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')")
condition = CodeProperty(desc="Transition condition", default=CodeBlock("1"))

def __init__(self, condition: CodeBlock = None, assignments=None):
def __init__(self, condition: Optional[Union[CodeBlock, str, ast.AST, list]] = None, assignments=None):
if condition is None:
condition = CodeBlock("1")

Expand Down Expand Up @@ -452,6 +452,9 @@ class SDFG(ControlFlowRegion):
desc='Mapping between callback name and its original callback '
'(for when the same callback is used with a different signature)')

using_experimental_blocks = Property(dtype=bool, default=False,
desc="Whether the SDFG contains experimental control flow blocks")

def __init__(self,
name: str,
constants: Dict[str, Tuple[dt.Data, Any]] = None,
Expand Down Expand Up @@ -509,6 +512,8 @@ def __init__(self,
self._orig_name = name
self._num = 0

self._sdfg = self

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
Expand Down Expand Up @@ -2220,6 +2225,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)

# Rename SDFG to avoid runtime issues with clashing names
index = 0
Expand Down Expand Up @@ -2680,3 +2686,15 @@ def make_array_memlet(self, array: str):
:return: a Memlet that fully transfers array
"""
return dace.Memlet.from_array(array, self.data(array))

def recheck_using_experimental_blocks(self) -> bool:
found_experimental_block = False
for node, graph in self.root_sdfg.all_nodes_recursive():
if isinstance(graph, ControlFlowRegion) and not isinstance(graph, SDFG):
found_experimental_block = True
break
if isinstance(node, ControlFlowBlock) and not isinstance(node, SDFGState):
found_experimental_block = True
break
self.root_sdfg.using_experimental_blocks = found_experimental_block
return found_experimental_block
Loading
Loading