diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index 6e2786660f..f73e3f8d11 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -189,6 +189,7 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]: # Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops. # TODO (later): Adapt codegen to deal with hierarchical CFGs instead. sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg) # Before generating the code, run type inference on the SDFG connectors infer_types.infer_connector_types(sdfg) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 2460816793..9f7e19ea9a 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -390,7 +390,8 @@ def as_cpp(self, codegen, symbols) -> str: update = '' if self.update is not None: - update = f'{self.itervar} = {self.update}' + cppupdate = unparse_interstate_edge(self.update, sdfg, codegen=codegen) + update = f'{self.itervar} = {cppupdate}' expr = f'{preinit}\nfor ({init}; {cond}; {update}) {{\n' expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 6870b29b07..28143f715a 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -9,12 +9,13 @@ import dace.frontend.fortran.ast_transforms as ast_transforms import dace.frontend.fortran.ast_utils as ast_utils import dace.frontend.fortran.ast_internal_classes as ast_internal_classes -from typing import List, Tuple, Set +from typing import List, Optional, Tuple, Set from dace import dtypes from dace import Language as lang from dace import data as dat from dace import SDFG, InterstateEdge, Memlet, pointer, nodes from dace import symbolic as sym +from dace.sdfg.state import ControlFlowRegion, LoopRegion from copy import deepcopy as dpcp from dace.properties import CodeBlock @@ -28,7 +29,7 @@ class AST_translator: """ This class is responsible for translating the internal AST into a SDFG. """ - def __init__(self, ast: ast_components.InternalFortranAst, source: str): + def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_experimental_cfg_blocks: bool = False): """ :ast: The internal fortran AST to be used for translation :source: The source file name from which the AST was generated @@ -68,6 +69,7 @@ def __init__(self, ast: ast_components.InternalFortranAst, source: str): ast_internal_classes.Allocate_Stmt_Node: self.allocate2sdfg, ast_internal_classes.Break_Node: self.break2sdfg, } + self.use_experimental_cfg_blocks = use_experimental_cfg_blocks def get_dace_type(self, type): """ @@ -119,7 +121,7 @@ def get_memlet_range(self, sdfg: SDFG, variables: List[ast_internal_classes.FNod if o_v.name == var_name_tasklet: return ast_utils.generate_memlet(o_v, sdfg, self) - def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG): + def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG, cfg: Optional[ControlFlowRegion] = None): """ This function is responsible for translating the AST into a SDFG. :param node: The node to be translated @@ -128,15 +130,17 @@ def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG): :note: This function will call the appropriate function for the node type :note: The dictionary ast_elements, part of the class itself contains all functions that are called for the different node types """ + if not cfg: + cfg = sdfg if node.__class__ in self.ast_elements: - self.ast_elements[node.__class__](node, sdfg) + self.ast_elements[node.__class__](node, sdfg, cfg) elif isinstance(node, list): for i in node: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) else: warnings.warn(f"WARNING: {node.__class__.__name__}") - def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): + def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating the Fortran AST into a SDFG. :param node: The node to be translated @@ -148,27 +152,27 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): self.globalsdfg = sdfg for i in node.modules: for j in i.specification_part.typedecls: - self.translate(j, sdfg) + self.translate(j, sdfg, cfg) for k in j.vardecl: self.module_vars.append((k.name, i.name)) for j in i.specification_part.symbols: - self.translate(j, sdfg) + self.translate(j, sdfg, cfg) for k in j.vardecl: self.module_vars.append((k.name, i.name)) for j in i.specification_part.specifications: - self.translate(j, sdfg) + self.translate(j, sdfg, cfg) for k in j.vardecl: self.module_vars.append((k.name, i.name)) for i in node.main_program.specification_part.typedecls: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) for i in node.main_program.specification_part.symbols: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) for i in node.main_program.specification_part.specifications: - self.translate(i, sdfg) - self.translate(node.main_program.execution_part.execution, sdfg) + self.translate(i, sdfg, cfg) + self.translate(node.main_program.execution_part.execution, sdfg, cfg) - def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: SDFG): + def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran basic blocks into a SDFG. :param node: The node to be translated @@ -176,9 +180,9 @@ def basicblock2sdfg(self, node: ast_internal_classes.Execution_Part_Node, sdfg: """ for i in node.execution: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) - def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDFG): + def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran allocate statements into a SDFG. :param node: The node to be translated @@ -215,11 +219,11 @@ def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDF transient=transient) - def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG): + def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): #TODO implement raise NotImplementedError("Fortran write statements are not implemented yet") - def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): + def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran if statements into a SDFG. :param node: The node to be translated @@ -227,85 +231,117 @@ def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): """ name = f"If_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"Begin{name}") - guard_substate = sdfg.add_state(f"Guard{name}") - sdfg.add_edge(begin_state, guard_substate, InterstateEdge()) + begin_state = ast_utils.add_simple_state_to_sdfg(self, cfg, f"Begin{name}") + guard_substate = cfg.add_state(f"Guard{name}") + cfg.add_edge(begin_state, guard_substate, InterstateEdge()) condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.cond) - body_ifstart_state = sdfg.add_state(f"BodyIfStart{name}") - self.last_sdfg_states[sdfg] = body_ifstart_state - self.translate(node.body, sdfg) - final_substate = sdfg.add_state(f"MergeState{name}") + body_ifstart_state = cfg.add_state(f"BodyIfStart{name}") + self.last_sdfg_states[cfg] = body_ifstart_state + self.translate(node.body, sdfg, cfg) + final_substate = cfg.add_state(f"MergeState{name}") - sdfg.add_edge(guard_substate, body_ifstart_state, InterstateEdge(condition)) + cfg.add_edge(guard_substate, body_ifstart_state, InterstateEdge(condition)) - if self.last_sdfg_states[sdfg] not in [ - self.last_loop_breaks.get(sdfg), - self.last_loop_continues.get(sdfg), - self.last_returns.get(sdfg) + if self.last_sdfg_states[cfg] not in [ + self.last_loop_breaks.get(cfg), + self.last_loop_continues.get(cfg), + self.last_returns.get(cfg) ]: - body_ifend_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"BodyIfEnd{name}") - sdfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) + body_ifend_state = ast_utils.add_simple_state_to_sdfg(self, cfg, f"BodyIfEnd{name}") + cfg.add_edge(body_ifend_state, final_substate, InterstateEdge()) if len(node.body_else.execution) > 0: name_else = f"Else_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}" - body_elsestart_state = sdfg.add_state("BodyElseStart" + name_else) - self.last_sdfg_states[sdfg] = body_elsestart_state - self.translate(node.body_else, sdfg) - body_elseend_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, f"BodyElseEnd{name_else}") - sdfg.add_edge(guard_substate, body_elsestart_state, InterstateEdge("not (" + condition + ")")) - sdfg.add_edge(body_elseend_state, final_substate, InterstateEdge()) + body_elsestart_state = cfg.add_state("BodyElseStart" + name_else) + self.last_sdfg_states[cfg] = body_elsestart_state + self.translate(node.body_else, sdfg, cfg) + body_elseend_state = ast_utils.add_simple_state_to_sdfg(self, cfg, f"BodyElseEnd{name_else}") + cfg.add_edge(guard_substate, body_elsestart_state, InterstateEdge("not (" + condition + ")")) + cfg.add_edge(body_elseend_state, final_substate, InterstateEdge()) else: - sdfg.add_edge(guard_substate, final_substate, InterstateEdge("not (" + condition + ")")) - self.last_sdfg_states[sdfg] = final_substate + cfg.add_edge(guard_substate, final_substate, InterstateEdge("not (" + condition + ")")) + self.last_sdfg_states[cfg] = final_substate - def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG): + def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran for statements into a SDFG. :param node: The node to be translated :param sdfg: The SDFG to which the node should be translated """ - declloop = False - name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) - begin_state = ast_utils.add_simple_state_to_sdfg(self, sdfg, "Begin" + name) - guard_substate = sdfg.add_state("Guard" + name) - final_substate = sdfg.add_state("Merge" + name) - self.last_sdfg_states[sdfg] = final_substate - decl_node = node.init - entry = {} - if isinstance(decl_node, ast_internal_classes.BinOp_Node): - if sdfg.symbols.get(decl_node.lval.name) is not None: - iter_name = decl_node.lval.name - elif self.name_mapping[sdfg].get(decl_node.lval.name) is not None: - iter_name = self.name_mapping[sdfg][decl_node.lval.name] - else: - raise ValueError("Unknown variable " + decl_node.lval.name) - entry[iter_name] = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(decl_node.rval) - - sdfg.add_edge(begin_state, guard_substate, InterstateEdge(assignments=entry)) - - condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.cond) - - increment = "i+0+1" - if isinstance(node.iter, ast_internal_classes.BinOp_Node): - increment = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.iter.rval) - entry = {iter_name: increment} - - begin_loop_state = sdfg.add_state("BeginLoop" + name) - end_loop_state = sdfg.add_state("EndLoop" + name) - self.last_sdfg_states[sdfg] = begin_loop_state - self.last_loop_continues[sdfg] = final_substate - self.translate(node.body, sdfg) - - sdfg.add_edge(self.last_sdfg_states[sdfg], end_loop_state, InterstateEdge()) - sdfg.add_edge(guard_substate, begin_loop_state, InterstateEdge(condition)) - sdfg.add_edge(end_loop_state, guard_substate, InterstateEdge(assignments=entry)) - sdfg.add_edge(guard_substate, final_substate, InterstateEdge(f"not ({condition})")) - self.last_sdfg_states[sdfg] = final_substate - - def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): + if not self.use_experimental_cfg_blocks: + declloop = False + name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) + begin_state = ast_utils.add_simple_state_to_sdfg(self, cfg, "Begin" + name) + guard_substate = cfg.add_state("Guard" + name) + final_substate = cfg.add_state("Merge" + name) + self.last_sdfg_states[cfg] = final_substate + decl_node = node.init + entry = {} + if isinstance(decl_node, ast_internal_classes.BinOp_Node): + if sdfg.symbols.get(decl_node.lval.name) is not None: + iter_name = decl_node.lval.name + elif self.name_mapping[sdfg].get(decl_node.lval.name) is not None: + iter_name = self.name_mapping[sdfg][decl_node.lval.name] + else: + raise ValueError("Unknown variable " + decl_node.lval.name) + entry[iter_name] = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(decl_node.rval) + + cfg.add_edge(begin_state, guard_substate, InterstateEdge(assignments=entry)) + + condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.cond) + + increment = "i+0+1" + if isinstance(node.iter, ast_internal_classes.BinOp_Node): + increment = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.iter.rval) + entry = {iter_name: increment} + + begin_loop_state = cfg.add_state("BeginLoop" + name) + end_loop_state = cfg.add_state("EndLoop" + name) + self.last_sdfg_states[cfg] = begin_loop_state + self.last_loop_continues[cfg] = final_substate + self.translate(node.body, sdfg, cfg) + + cfg.add_edge(self.last_sdfg_states[cfg], end_loop_state, InterstateEdge()) + cfg.add_edge(guard_substate, begin_loop_state, InterstateEdge(condition)) + cfg.add_edge(end_loop_state, guard_substate, InterstateEdge(assignments=entry)) + cfg.add_edge(guard_substate, final_substate, InterstateEdge(f"not ({condition})")) + self.last_sdfg_states[cfg] = final_substate + else: + name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) + decl_node = node.init + entry = {} + if isinstance(decl_node, ast_internal_classes.BinOp_Node): + if sdfg.symbols.get(decl_node.lval.name) is not None: + iter_name = decl_node.lval.name + elif self.name_mapping[sdfg].get(decl_node.lval.name) is not None: + iter_name = self.name_mapping[sdfg][decl_node.lval.name] + else: + raise ValueError("Unknown variable " + decl_node.lval.name) + entry[iter_name] = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(decl_node.rval) + + condition = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.cond) + + increment = "i+0+1" + if isinstance(node.iter, ast_internal_classes.BinOp_Node): + increment = ast_utils.ProcessedWriter(sdfg, self.name_mapping).write_code(node.iter.rval) + + loop_region = LoopRegion(name, condition, iter_name, f"{iter_name} = {entry[iter_name]}", + f"{iter_name} = {increment}") + is_start = self.last_sdfg_states.get(cfg) is None + cfg.add_node(loop_region, is_start_block=is_start) + if not is_start: + cfg.add_edge(self.last_sdfg_states[cfg], loop_region, InterstateEdge()) + self.last_sdfg_states[cfg] = loop_region + + begin_loop_state = loop_region.add_state("BeginLoop" + name, is_start_block=True) + self.last_sdfg_states[loop_region] = begin_loop_state + + self.translate(node.body, sdfg, loop_region) + + def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function is responsible for translating Fortran symbol declarations into a SDFG. :param node: The node to be translated @@ -323,24 +359,25 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG): datatype = self.get_dace_type(node.type) if node.name not in sdfg.symbols: sdfg.add_symbol(node.name, datatype) - if self.last_sdfg_states.get(sdfg) is None: - bstate = sdfg.add_state("SDFGbegin", is_start_state=True) - self.last_sdfg_states[sdfg] = bstate + if self.last_sdfg_states.get(cfg) is None: + bstate = cfg.add_state("SDFGbegin", is_start_state=True) + self.last_sdfg_states[cfg] = bstate if node.init is not None: - substate = sdfg.add_state(f"Dummystate_{node.name}") + substate = cfg.add_state(f"Dummystate_{node.name}") increment = ast_utils.TaskletWriter([], [], sdfg, self.name_mapping).write_code(node.init) entry = {node.name: increment} - sdfg.add_edge(self.last_sdfg_states[sdfg], substate, InterstateEdge(assignments=entry)) - self.last_sdfg_states[sdfg] = substate + cfg.add_edge(self.last_sdfg_states[cfg], substate, InterstateEdge(assignments=entry)) + self.last_sdfg_states[cfg] = substate - def symbolarray2sdfg(self, node: ast_internal_classes.Symbol_Array_Decl_Node, sdfg: SDFG): + def symbolarray2sdfg(self, node: ast_internal_classes.Symbol_Array_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): return NotImplementedError( "Symbol_Decl_Node not implemented. This should be done via a transformation that itemizes the constant array." ) - def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, sdfg: SDFG): + def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, sdfg: SDFG, + cfg: ControlFlowRegion): """ This function is responsible for translating Fortran subroutine declarations into a SDFG. :param node: The node to be translated @@ -364,7 +401,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, parameters = node.args.copy() new_sdfg = SDFG(node.name.name) - substate = ast_utils.add_simple_state_to_sdfg(self, sdfg, "state" + node.name.name) + substate = ast_utils.add_simple_state_to_sdfg(self, cfg, "state" + node.name.name) variables_in_call = [] if self.last_call_expression.get(sdfg) is not None: variables_in_call = self.last_call_expression[sdfg] @@ -763,12 +800,12 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, pass for j in node.specification_part.specifications: - self.declstmt2sdfg(j, new_sdfg) + self.declstmt2sdfg(j, new_sdfg, new_sdfg) for i in assigns: - self.translate(i, new_sdfg) - self.translate(node.execution_part, new_sdfg) + self.translate(i, new_sdfg, new_sdfg) + self.translate(node.execution_part, new_sdfg, new_sdfg) - def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): + def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This parses binary operations to tasklets in a new state or creates a function call with a nested SDFG if the operation is a function @@ -784,7 +821,7 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): if augmented_call.name.name not in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh", "__dace_epsilon"]: augmented_call.args.append(node.lval) augmented_call.hasret = True - self.call2sdfg(augmented_call, sdfg) + self.call2sdfg(augmented_call, sdfg, cfg) return outputnodefinder = ast_transforms.FindOutputs() @@ -818,7 +855,7 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): input_names_tasklet.append(i.name + "_" + str(count) + "_in") substate = ast_utils.add_simple_state_to_sdfg( - self, sdfg, "_state_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1])) + self, cfg, "_state_l" + str(node.line_number[0]) + "_c" + str(node.line_number[1])) output_names_changed = [o_t + "_out" for o_t in output_names] @@ -840,7 +877,7 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG): text = tw.write_code(node) tasklet.code = CodeBlock(text, lang.Python) - def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): + def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This parses function calls to a nested SDFG or creates a tasklet with an external library call. @@ -855,20 +892,20 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): if node.name in self.functions_and_subroutines: for i in self.top_level.function_definitions: if i.name == node.name: - self.function2sdfg(i, sdfg) + self.function2sdfg(i, sdfg, cfg) return for i in self.top_level.subroutine_definitions: if i.name == node.name: - self.subroutine2sdfg(i, sdfg) + self.subroutine2sdfg(i, sdfg, cfg) return for j in self.top_level.modules: for i in j.function_definitions: if i.name == node.name: - self.function2sdfg(i, sdfg) + self.function2sdfg(i, sdfg, cfg) return for i in j.subroutine_definitions: if i.name == node.name: - self.subroutine2sdfg(i, sdfg) + self.subroutine2sdfg(i, sdfg, cfg) return else: # This part handles the case that it's an external library call @@ -923,7 +960,7 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): else: text = tw.write_code(node) - substate = ast_utils.add_simple_state_to_sdfg(self, sdfg, "_state" + str(node.line_number[0])) + substate = ast_utils.add_simple_state_to_sdfg(self, cfg, "_state" + str(node.line_number[0])) tasklet = ast_utils.add_tasklet(substate, str(node.line_number[0]), { **input_names_tasklet, @@ -952,7 +989,7 @@ def call2sdfg(self, node: ast_internal_classes.Call_Expr_Node, sdfg: SDFG): setattr(tasklet, "code", CodeBlock(text, lang.Python)) - def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG): + def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function translates a variable declaration statement to an access node on the sdfg :param node: The node to translate @@ -960,9 +997,9 @@ def declstmt2sdfg(self, node: ast_internal_classes.Decl_Stmt_Node, sdfg: SDFG): :note This function is the top level of the declaration, most implementation is in vardecl2sdfg """ for i in node.vardecl: - self.translate(i, sdfg) + self.translate(i, sdfg, cfg) - def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): + def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, cfg: ControlFlowRegion): """ This function translates a variable declaration to an access node on the sdfg :param node: The node to translate @@ -1016,10 +1053,10 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): if node.name not in self.contexts[sdfg.name].containers: self.contexts[sdfg.name].containers.append(node.name) - def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG): + def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG, cfg: ControlFlowRegion): - self.last_loop_breaks[sdfg] = self.last_sdfg_states[sdfg] - sdfg.add_edge(self.last_sdfg_states[sdfg], self.last_loop_continues.get(sdfg), InterstateEdge()) + self.last_loop_breaks[cfg] = self.last_sdfg_states[cfg] + cfg.add_edge(self.last_sdfg_states[cfg], self.last_loop_continues.get(cfg), InterstateEdge()) def create_ast_from_string( source_string: str, @@ -1063,7 +1100,8 @@ def create_ast_from_string( def create_sdfg_from_string( source_string: str, sdfg_name: str, - normalize_offsets: bool = False + normalize_offsets: bool = False, + use_experimental_cfg_blocks: bool = False ): """ Creates an SDFG from a fortran file in a string @@ -1092,7 +1130,7 @@ def create_sdfg_from_string( program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) - ast2sdfg = AST_translator(own_ast, __file__) + ast2sdfg = AST_translator(own_ast, __file__, use_experimental_cfg_blocks) sdfg = SDFG(sdfg_name) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg @@ -1107,10 +1145,11 @@ def create_sdfg_from_string( sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None sdfg.reset_cfg_list() + sdfg.using_experimental_blocks = use_experimental_cfg_blocks return sdfg -def create_sdfg_from_fortran_file(source_string: str): +def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_blocks: bool = False): """ Creates an SDFG from a fortran file :param source_string: The fortran file name @@ -1137,10 +1176,11 @@ def create_sdfg_from_fortran_file(source_string: str): program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program).visit(program) - ast2sdfg = AST_translator(own_ast, __file__) + ast2sdfg = AST_translator(own_ast, __file__, use_experimental_cfg_blocks) sdfg = SDFG(source_string) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg ast2sdfg.translate(program, sdfg) + sdfg.using_experimental_blocks = use_experimental_cfg_blocks return sdfg diff --git a/dace/frontend/python/nested_call.py b/dace/frontend/python/nested_call.py index c5691dc75d..2495a20dce 100644 --- a/dace/frontend/python/nested_call.py +++ b/dace/frontend/python/nested_call.py @@ -1,6 +1,12 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace from dace.sdfg import SDFG, SDFGState +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from dace.frontend.python.newast import ProgramVisitor +else: + ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor' class NestedCall(): @@ -18,7 +24,13 @@ def _cos_then_max(pv, sdfg, state, a: str): # return a tuple of the nest object and the result return nest, result """ - def __init__(self, pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState): + state: SDFGState + last_state: Optional[SDFGState] + pv: ProgramVisitor + sdfg: SDFG + count: int + + def __init__(self, pv: ProgramVisitor, sdfg: SDFG, state: SDFGState): self.pv = pv self.sdfg = sdfg self.state = state diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index fda2bd2e23..5269f1cf83 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from collections import OrderedDict import copy @@ -32,6 +32,7 @@ from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock from dace.sdfg import SDFG, SDFGState +from dace.sdfg.state import BreakBlock, ContinueBlock, ControlFlowBlock, LoopRegion, ControlFlowRegion from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -1072,6 +1073,12 @@ class ProgramVisitor(ExtNodeVisitor): progress_bar = None start_time: float = 0 + sdfg: SDFG + last_block: ControlFlowBlock + cfg_target: ControlFlowRegion + last_cfg_target: ControlFlowRegion + current_state: SDFGState + def __init__(self, name: str, filename: str, @@ -1147,7 +1154,10 @@ def __init__(self, if sym.name not in self.sdfg.symbols: self.sdfg.add_symbol(sym.name, sym.dtype) self.sdfg._temp_transients = tmp_idx - self.last_state = self.sdfg.add_state('init', is_start_state=True) + self.cfg_target = self.sdfg + self.current_state = self.sdfg.add_state('init', is_start_state=True) + self.last_block = self.current_state + self.last_cfg_target = self.sdfg self.inputs: DependencyType = {} self.outputs: DependencyType = {} @@ -1167,11 +1177,6 @@ def __init__(self, for stmt in _DISALLOWED_STMTS: setattr(self, 'visit_' + stmt, lambda n: _disallow_stmt(self, n)) - # Loop status - self.loop_idx = -1 - self.continue_states = [] - self.break_states = [] - # Tmp fix for missing state symbol propagation self.symbols = dict() @@ -1296,7 +1301,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List return new_nodes # Map view access nodes to their respective data - for state in self.sdfg.nodes(): + for state in self.sdfg.states(): # NOTE: We need to support views of views nodes = list(state.data_nodes()) while nodes: @@ -1349,13 +1354,34 @@ def defined(self): return result - def _add_state(self, label=None): - state = self.sdfg.add_state(label) - if self.last_state is not None: - self.sdfg.add_edge(self.last_state, state, dace.InterstateEdge()) - self.last_state = state + def _on_block_added(self, block: ControlFlowBlock): + if self.last_block is not None and self.last_cfg_target == self.cfg_target: + self.cfg_target.add_edge(self.last_block, block, dace.InterstateEdge()) + self.last_block = block + + self.last_cfg_target = self.cfg_target + if not isinstance(block, SDFGState): + self.current_state = None + else: + self.current_state = block + + def _add_state(self, label=None, is_start=False) -> SDFGState: + state = self.cfg_target.add_state(label, is_start_block=is_start) + self._on_block_added(state) return state + def _add_loop_region(self, + condition_expr: str, + label: str = 'loop', + loop_var: Optional[str] = None, + init_expr: Optional[str] = None, + update_expr: Optional[str] = None, + inverted: bool = False) -> LoopRegion: + loop_region = LoopRegion(label, condition_expr, loop_var, init_expr, update_expr, inverted) + self.cfg_target.add_node(loop_region) + self._on_block_added(loop_region) + return loop_region + def _parse_arg(self, arg: Any, as_list=True): """ Parse possible values to slices or objects that can be used in the SDFG API. """ @@ -2023,7 +2049,7 @@ def _add_dependencies(self, else: name = memlet.data vname = "{c}_in_from_{s}{n}".format(c=conn, - s=self.sdfg.nodes().index(state), + s=self.sdfg.states().index(state), n=('_%s' % state.node_id(entry_node) if entry_node else '')) self.accesses[(name, scope_memlet.subset, 'r')] = (vname, orng) orig_shape = orng.size() @@ -2113,7 +2139,7 @@ def _add_dependencies(self, else: name = memlet.data vname = "{c}_out_of_{s}{n}".format(c=conn, - s=self.sdfg.nodes().index(state), + s=self.sdfg.states().index(state), n=('_%s' % state.node_id(exit_node) if exit_node else '')) self.accesses[(name, scope_memlet.subset, 'w')] = (vname, orng) orig_shape = orng.size() @@ -2170,15 +2196,21 @@ def _recursive_visit(self, body: List[ast.AST], name: str, lineno: int, - last_state=True, + parent: ControlFlowRegion, + unconnected_last_block=True, extra_symbols=None) -> Tuple[SDFGState, SDFGState, SDFGState, bool]: """ Visits a subtree of the AST, creating special states before and after the visit. Returns the previous state, and the first and last internal states of the recursive visit. Also returns a boolean value indicating whether a return statement was met or not. This value can be used by other visitor methods, e.g., visit_If, to generate correct control flow. """ - before_state = self.last_state - self.last_state = None - first_internal_state = self._add_state('%s_%d' % (name, lineno)) + previous_last_cfg_target = self.last_cfg_target + previous_last_block = self.last_block + previous_target = self.cfg_target + + self.last_block = None + self.cfg_target = parent + + first_inner_block = self._add_state('%s_%d' % (name, lineno)) # Add iteration variables to recursive visit if extra_symbols: @@ -2190,20 +2222,26 @@ def _recursive_visit(self, return_stmt = False for stmt in body: self.visit_TopLevel(stmt) - if isinstance(stmt, ast.Return): + if isinstance(stmt, ast.Return) or isinstance(stmt, ast.Break) or isinstance(stmt, ast.Continue): return_stmt = True # Create the next state - last_internal_state = self.last_state - if last_state: - self.last_state = None + last_inner_block = self.last_block + if unconnected_last_block: + self.last_block = None self._add_state('end%s_%d' % (name, lineno)) # Revert new symbols if extra_symbols: self.globals = old_globals - return before_state, first_internal_state, last_internal_state, return_stmt + # Restore previous target + self.cfg_target = previous_target + self.last_cfg_target = previous_last_cfg_target + if not unconnected_last_block: + self.last_block = previous_last_block + + return previous_last_block, first_inner_block, last_inner_block, return_stmt def _replace_with_global_symbols(self, expr: sympy.Expr) -> sympy.Expr: repldict = dict() @@ -2319,24 +2357,20 @@ def visit_For(self, node: ast.For): if (astr not in self.sdfg.symbols and not (astr in self.variables or astr in self.sdfg.arrays)): self.sdfg.add_symbol(astr, atom.dtype) - # Add an initial loop state with a None last_state (so as to not - # create an interstate edge) - self.loop_idx += 1 - self.continue_states.append([]) - self.break_states.append([]) - laststate, first_loop_state, last_loop_state, _ = self._recursive_visit(node.body, - 'for', - node.lineno, - extra_symbols=extra_syms) - end_loop_state = self.last_state - # Add loop to SDFG loop_cond = '>' if ((pystr_to_symbolic(ranges[0][2]) < 0) == True) else '<' - incr = {indices[0]: '%s + %s' % (indices[0], astutils.unparse(ast_ranges[0][2]))} - _, loop_guard, loop_end = self.sdfg.add_loop( - laststate, first_loop_state, end_loop_state, indices[0], astutils.unparse(ast_ranges[0][0]), - '%s %s %s' % (indices[0], loop_cond, astutils.unparse(ast_ranges[0][1])), incr[indices[0]], - last_loop_state) + loop_cond_expr = '%s %s %s' % (indices[0], loop_cond, astutils.unparse(ast_ranges[0][1])) + incr = {indices[0]: '%s = %s + %s' % (indices[0], indices[0], astutils.unparse(ast_ranges[0][2]))} + loop_region = self._add_loop_region(loop_cond_expr, + label=f'for_{node.lineno}', + loop_var=indices[0], + init_expr='%s = %s' % (indices[0], astutils.unparse(ast_ranges[0][0])), + update_expr=incr[indices[0]], + inverted=False) + _, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}', node.lineno, + extra_symbols=extra_syms, parent=loop_region, + unconnected_last_block=False) + loop_region.start_block = loop_region.node_id(first_subblock) # Handle else clause if node.orelse: @@ -2345,32 +2379,16 @@ def visit_For(self, node: ast.For): self.visit(stmt) # The state that all "break" edges go to - loop_end = self._add_state(f'postloop_{node.lineno}') - - body_states = list( - sdutil.dfs_conditional(self.sdfg, - sources=[first_loop_state], - condition=lambda p, c: c is not loop_guard)) - - continue_states = self.continue_states.pop() - while continue_states: - next_state = continue_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, loop_guard, dace.InterstateEdge(assignments=incr)) - break_states = self.break_states.pop() - while break_states: - next_state = break_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, loop_end, dace.InterstateEdge()) - self.loop_idx -= 1 - - for state in body_states: - if not nx.has_path(self.sdfg.nx, loop_guard, state): - self.sdfg.remove_node(state) + state = self.cfg_target.add_state(f'postloop_{node.lineno}') + if self.last_block is not None: + self.cfg_target.add_edge(self.last_block, state, dace.InterstateEdge()) + self.last_block = state + + self._generate_orelse(loop_region, state) + + return state + + self.last_block = loop_region else: raise DaceSyntaxError(self, node, 'Unsupported for-loop iterator "%s"' % iterator) @@ -2389,42 +2407,81 @@ def _is_test_simple(self, node: ast.AST): return all(self._is_test_simple(value) for value in node.values) return is_test_simple - def _visit_test(self, node: ast.Expr): + def _visit_complex_test(self, node: ast.Expr): + test_region = ControlFlowRegion('%s_%s' % ('cond_prep', node.lineno), self.sdfg) + inner_start = test_region.add_state('%s_start_%s' % ('cond_prep', node.lineno)) + + p_last_cfg_target, p_last_block, p_target = self.last_cfg_target, self.last_block, self.cfg_target + self.cfg_target, self.last_block, self.last_cfg_target = test_region, inner_start, test_region + + parsed_node = self.visit(node) + if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1: + parsed_node = parsed_node[0] + if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays: + datadesc = self.sdfg.arrays[parsed_node] + if isinstance(datadesc, data.Array): + parsed_node += '[0]' + + self.last_cfg_target, self.last_block, self.cfg_target = p_last_cfg_target, p_last_block, p_target + + return parsed_node, test_region + + def _visit_test(self, node: ast.Expr) -> Tuple[str, str, Optional[ControlFlowRegion]]: is_test_simple = self._is_test_simple(node) # Visit test-condition if not is_test_simple: - parsed_node = self.visit(node) - if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1: - parsed_node = parsed_node[0] - if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays: - datadesc = self.sdfg.arrays[parsed_node] - if isinstance(datadesc, data.Array): - parsed_node += '[0]' + parsed_node, test_region = self._visit_complex_test(node) + self.cfg_target.add_node(test_region) + self._on_block_added(test_region) else: parsed_node = astutils.unparse(node) + test_region = None # Generate conditions cond = astutils.unparse(parsed_node) cond_else = astutils.unparse(astutils.negate_expr(parsed_node)) - return cond, cond_else + return cond, cond_else, test_region def visit_While(self, node: ast.While): - # Get loop condition expression - begin_guard = self._add_state("while_guard") - loop_cond, _ = self._visit_test(node.test) - end_guard = self.last_state + # Get loop condition expression and create the necessary states for it. + loop_cond, _, test_region = self._visit_test(node.test) + loop_region = self._add_loop_region(loop_cond, label=f'while_{node.lineno}', inverted=False) # Parse body - self.loop_idx += 1 - self.continue_states.append([]) - self.break_states.append([]) - laststate, first_loop_state, last_loop_state, _ = \ - self._recursive_visit(node.body, 'while', node.lineno) - end_loop_state = self.last_state - - assert (laststate == end_guard) + self._recursive_visit(node.body, f'while_{node.lineno}', node.lineno, parent=loop_region, + unconnected_last_block=False) + + if test_region is not None: + iter_end_blocks = set() + for n in loop_region.nodes(): + if isinstance(n, ContinueBlock): + # If it needs to be connected back to the test region, it does no longer need to be handled + # specially and thus is no longer a special continue state. Add an empty state and redirect the + # edges leading into the continue into it. + replacer_state = loop_region.add_state() + iter_end_blocks.add(replacer_state) + for ie in loop_region.in_edges(n): + loop_region.add_edge(ie.src, replacer_state, ie.data) + loop_region.remove_edge(ie) + loop_region.remove_node(n) + for inner_node in loop_region.nodes(): + if loop_region.out_degree(inner_node) == 0: + iter_end_blocks.add(inner_node) + + test_region_copy = copy.deepcopy(test_region) + loop_region.add_node(test_region_copy) + + # Make sure the entire sub-graph of the test_region copy has proper sdfg references and that each block has + # a unique name in the SDFG. + loop_region.sdfg._labels = set(s.label for s in loop_region.sdfg.all_control_flow_blocks()) + for block in test_region_copy.all_control_flow_blocks(): + block.sdfg = loop_region.sdfg + block.label = data.find_new_name(block.label, loop_region.sdfg._labels) + + for block in iter_end_blocks: + loop_region.add_edge(block, test_region_copy, dace.InterstateEdge()) # Add symbols from test as necessary symcond = pystr_to_symbolic(loop_cond) @@ -2439,24 +2496,6 @@ def visit_While(self, node: ast.While): if (astr not in self.sdfg.symbols and astr not in self.variables): self.sdfg.add_symbol(astr, atom.dtype) - # Add loop to SDFG - _, loop_guard, loop_end = self.sdfg.add_loop(laststate, first_loop_state, end_loop_state, None, None, loop_cond, - None, last_loop_state) - - # Connect the correct while-guard state - # Current state: - # begin_guard -> ... -> end_guard/laststate -> loop_guard -> first_loop - # Desired state: - # begin_guard -> ... -> end_guard/laststate -> first_loop - for e in list(self.sdfg.in_edges(loop_guard)): - if e.src != laststate: - self.sdfg.add_edge(e.src, begin_guard, e.data) - self.sdfg.remove_edge(e) - for e in list(self.sdfg.out_edges(loop_guard)): - self.sdfg.add_edge(end_guard, e.dst, e.data) - self.sdfg.remove_edge(e) - self.sdfg.remove_node(loop_guard) - # Handle else clause if node.orelse: # Continue visiting body @@ -2464,80 +2503,83 @@ def visit_While(self, node: ast.While): self.visit(stmt) # The state that all "break" edges go to - loop_end = self._add_state(f'postwhile_{node.lineno}') - - body_states = list( - sdutil.dfs_conditional(self.sdfg, sources=[first_loop_state], condition=lambda p, c: c is not loop_guard)) - - continue_states = self.continue_states.pop() - while continue_states: - next_state = continue_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, begin_guard, dace.InterstateEdge()) - break_states = self.break_states.pop() - while break_states: - next_state = break_states.pop() - out_edges = self.sdfg.out_edges(next_state) - for e in out_edges: - self.sdfg.remove_edge(e) - self.sdfg.add_edge(next_state, loop_end, dace.InterstateEdge()) - self.loop_idx -= 1 - - for state in body_states: - if not nx.has_path(self.sdfg.nx, end_guard, state): - self.sdfg.remove_node(state) + self._add_state(f'postwhile_{node.lineno}') + + postloop_block = self.last_block + self._generate_orelse(loop_region, postloop_block) + + self.last_block = loop_region + + def _generate_orelse(self, loop_region: LoopRegion, postloop_block: ControlFlowBlock): + did_break_symbol = 'did_break_' + loop_region.label + self.sdfg.add_symbol(did_break_symbol, dace.int32) + for n in loop_region.nodes(): + if isinstance(n, BreakBlock): + for iedge in loop_region.in_edges(n): + iedge.data.assignments[did_break_symbol] = '1' + for iedge in self.cfg_target.in_edges(loop_region): + iedge.data.assignments[did_break_symbol] = '0' + oedges = self.cfg_target.out_edges(loop_region) + if len(oedges) > 1: + raise DaceSyntaxError('Multiple exits to a loop with for-else syntax') + + intermediate = self.cfg_target.add_state(f'{loop_region.label}_normal_exit') + self.cfg_target.add_edge(loop_region, intermediate, + dace.InterstateEdge(condition=f"(not {did_break_symbol} == 1)")) + oedge = oedges[0] + self.cfg_target.add_edge(intermediate, oedge.dst, copy.deepcopy(oedge.data)) + self.cfg_target.remove_edge(oedge) + self.cfg_target.add_edge(loop_region, postloop_block, dace.InterstateEdge(condition=f"{did_break_symbol} == 1")) def visit_Break(self, node: ast.Break): - if self.loop_idx < 0: - error_msg = "'break' is only supported inside for and while loops " + if isinstance(self.cfg_target, LoopRegion): + self._on_block_added(self.cfg_target.add_break(f'break_{self.cfg_target.label}_{node.lineno}')) + else: + error_msg = "'break' is only supported inside loops " if self.nested: - error_msg += ("('break' is not supported in Maps and cannot be " - " used in nested DaCe program calls to break out " - " of loops of outer scopes)") + error_msg += ("('break' is not supported in Maps and cannot be used in nested DaCe program calls to " + " break out of loops of outer scopes)") raise DaceSyntaxError(self, node, error_msg) - self.break_states[self.loop_idx].append(self.last_state) def visit_Continue(self, node: ast.Continue): - if self.loop_idx < 0: - error_msg = ("'continue' is only supported inside for and while loops ") + if isinstance(self.cfg_target, LoopRegion): + self._on_block_added(self.cfg_target.add_continue(f'continue_{self.cfg_target.label}_{node.lineno}')) + else: + error_msg = ("'continue' is only supported inside loops ") if self.nested: - error_msg += ("('continue' is not supported in Maps and cannot " - " be used in nested DaCe program calls to " + error_msg += ("('continue' is not supported in Maps and cannot be used in nested DaCe program calls to " " continue loops of outer scopes)") raise DaceSyntaxError(self, node, error_msg) - self.continue_states[self.loop_idx].append(self.last_state) def visit_If(self, node: ast.If): # Add a guard state self._add_state('if_guard') - self.last_state.debuginfo = self.current_lineinfo + self.last_block.debuginfo = self.current_lineinfo # Generate conditions - cond, cond_else = self._visit_test(node.test) + cond, cond_else, _ = self._visit_test(node.test) # Visit recursively laststate, first_if_state, last_if_state, return_stmt = \ - self._recursive_visit(node.body, 'if', node.lineno) - end_if_state = self.last_state + self._recursive_visit(node.body, 'if', node.lineno, self.cfg_target, True) + end_if_state = self.last_block # Connect the states - self.sdfg.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) - self.sdfg.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) + self.cfg_target.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) + self.cfg_target.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) # Process 'else'/'elif' statements if len(node.orelse) > 0: # Visit recursively _, first_else_state, last_else_state, return_stmt = \ - self._recursive_visit(node.orelse, 'else', node.lineno, False) + self._recursive_visit(node.orelse, 'else', node.lineno, self.cfg_target, False) # Connect the states - self.sdfg.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) - self.sdfg.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) - self.last_state = end_if_state + self.cfg_target.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) + self.cfg_target.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) else: - self.sdfg.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) + self.cfg_target.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) + self.last_block = end_if_state def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): @@ -3133,7 +3175,7 @@ def _add_access( inner_indices = set(non_squeezed) - state = self.last_state + state = self.current_state new_memlet = None if has_indirection: @@ -3443,9 +3485,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): view = self.sdfg.arrays[result] cname, carr = self.sdfg.add_transient(result, view.shape, view.dtype, find_new_name=True) self._add_state(f'copy_from_view_{node.lineno}') - rnode = self.last_state.add_read(result, debuginfo=self.current_lineinfo) - wnode = self.last_state.add_read(cname, debuginfo=self.current_lineinfo) - self.last_state.add_nedge(rnode, wnode, Memlet.from_array(cname, carr)) + rnode = self.current_state.add_read(result, debuginfo=self.current_lineinfo) + wnode = self.current_state.add_read(cname, debuginfo=self.current_lineinfo) + self.current_state.add_nedge(rnode, wnode, Memlet.from_array(cname, carr)) result = cname # Strict independent access check for augmented assignments @@ -3466,7 +3508,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): # Handle output indirection output_indirection = None if _subset_has_indirection(rng, self): - output_indirection = self.sdfg.add_state('wslice_%s_%d' % (new_name, node.lineno)) + output_indirection = self.cfg_target.add_state('wslice_%s_%d' % (new_name, node.lineno)) wnode = output_indirection.add_write(new_name, debuginfo=self.current_lineinfo) memlet = Memlet.simple(new_name, str(rng)) # Dependent augmented assignments need WCR in the @@ -3496,10 +3538,10 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): if op and independent: if _subset_has_indirection(rng, self): self._add_state('rslice_%s_%d' % (new_name, node.lineno)) - rnode = self.last_state.add_read(new_name, debuginfo=self.current_lineinfo) + rnode = self.current_state.add_read(new_name, debuginfo=self.current_lineinfo) memlet = Memlet.simple(new_name, str(rng)) tmp = self.sdfg.temp_data_name() - ind_name = add_indirection_subgraph(self.sdfg, self.last_state, rnode, None, memlet, tmp, self) + ind_name = add_indirection_subgraph(self.sdfg, self.current_state, rnode, None, memlet, tmp, self) rtarget = ind_name else: rtarget = (new_name, new_rng) @@ -3512,8 +3554,8 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): # Connect states properly when there is output indirection if output_indirection: - self.sdfg.add_edge(self.last_state, output_indirection, dace.sdfg.InterstateEdge()) - self.last_state = output_indirection + self.cfg_target.add_edge(self.last_block, output_indirection, dace.sdfg.InterstateEdge()) + self.last_block = output_indirection def visit_AugAssign(self, node: ast.AugAssign): self._visit_assign(node, node.target, augassign_ops[type(node.op).__name__]) @@ -3929,7 +3971,7 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no output_slices = set() for arg in itertools.chain(node.args, [kw.value for kw in node.keywords]): if isinstance(arg, ast.Subscript): - slice_state = self.last_state + slice_state = self.current_state break # Make sure that any scope vars in the arguments are substituted @@ -3956,8 +3998,8 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no for sym, local in mapping.items(): if isinstance(local, str) and local in self.sdfg.arrays: # Add assignment state and inter-state edge - symassign_state = self.sdfg.add_state_before(state) - isedge = self.sdfg.edges_between(symassign_state, state)[0] + symassign_state = self.cfg_target.add_state_before(state) + isedge = self.cfg_target.edges_between(symassign_state, state)[0] newsym = self.sdfg.find_new_symbol(f'sym_{local}') desc = self.sdfg.arrays[local] self.sdfg.add_symbol(newsym, desc.dtype) @@ -4021,7 +4063,7 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no # Delete the old read descriptor if not isinput: conn_used = False - for s in self.sdfg.nodes(): + for s in self.sdfg.states(): for n in s.data_nodes(): if n.data == aname: conn_used = True @@ -4335,11 +4377,11 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): # Create a state with a tasklet and the right arguments self._add_state('callback_%d' % node.lineno) - self.last_state.set_default_lineinfo(self.current_lineinfo) + self.last_block.set_default_lineinfo(self.current_lineinfo) if callback_type.is_scalar_function() and len(callback_type.return_types) > 0: call_args = ', '.join(str(s) for s in allargs[:-1]) - tasklet = self.last_state.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' + tasklet = self.last_block.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' for name in args} | {'__istate'}, {f'__out_{name}' for name in outargs} | {'__ostate'}, @@ -4347,7 +4389,7 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): side_effects=True) else: call_args = ', '.join(str(s) for s in allargs) - tasklet = self.last_state.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' + tasklet = self.last_block.add_tasklet(f'callback_{node.lineno}', {f'__in_{name}' for name in args} | {'__istate'}, {f'__out_{name}' for name in outargs} | {'__ostate'}, @@ -4361,15 +4403,15 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): # Setup arguments in graph for arg in dtypes.deduplicate(args): - r = self.last_state.add_read(arg) - self.last_state.add_edge(r, None, tasklet, f'__in_{arg}', Memlet(arg)) + r = self.current_state.add_read(arg) + self.current_state.add_edge(r, None, tasklet, f'__in_{arg}', Memlet(arg)) for arg in dtypes.deduplicate(outargs): - w = self.last_state.add_write(arg) - self.last_state.add_edge(tasklet, f'__out_{arg}', w, None, Memlet(arg)) + w = self.current_state.add_write(arg) + self.current_state.add_edge(tasklet, f'__out_{arg}', w, None, Memlet(arg)) # Connect Python state - self._connect_pystate(tasklet, self.last_state, '__istate', '__ostate') + self._connect_pystate(tasklet, self.current_state, '__istate', '__ostate') if return_type is None: return [] @@ -4555,17 +4597,18 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): keywords = {arg.arg: self._parse_function_arg(arg.value) for arg in node.keywords} self._add_state('call_%d' % node.lineno) - self.last_state.set_default_lineinfo(self.current_lineinfo) + self.last_block.set_default_lineinfo(self.current_lineinfo) if found_ufunc: - result = func(self, node, self.sdfg, self.last_state, ufunc_name, args, keywords) + result = func(self, node, self.sdfg, self.last_block, ufunc_name, args, keywords) else: - result = func(self, self.sdfg, self.last_state, *args, **keywords) + result = func(self, self.sdfg, self.last_block, *args, **keywords) - self.last_state.set_default_lineinfo(None) + self.last_block.set_default_lineinfo(None) if isinstance(result, tuple) and type(result[0]) is nested_call.NestedCall: - self.last_state = result[0].last_state + nc: nested_call.NestedCall = result[0] + self.last_block = nc.last_state result = result[1] if not isinstance(result, (tuple, list)): @@ -4645,6 +4688,10 @@ def visit_Return(self, node: ast.Return): ast_name = ast.copy_location(ast.Name(id='__return'), node) self._visit_assign(new_node, ast_name, None, is_return=True) + if not isinstance(self.cfg_target, SDFG): + # In a nested control flow region, a return needs to be explicitly marked with a return block. + self._on_block_added(self.cfg_target.add_return(f'return_{self.cfg_target.label}_{node.lineno}')) + def visit_With(self, node, is_async=False): # "with dace.tasklet" syntax if len(node.items) == 1: @@ -4768,9 +4815,9 @@ def visit_Attribute(self, node: ast.Attribute): if func is not None: # A new state is likely needed here, e.g., for transposition (ndarray.T) self._add_state('%s_%d' % (type(node).__name__, node.lineno)) - self.last_state.set_default_lineinfo(self.current_lineinfo) - result = func(self, self.sdfg, self.last_state, result) - self.last_state.set_default_lineinfo(None) + self.last_block.set_default_lineinfo(self.current_lineinfo) + result = func(self, self.sdfg, self.last_block, result) + self.last_block.set_default_lineinfo(None) return result # Otherwise, try to find compile-time attribute (such as shape) @@ -4879,9 +4926,9 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS raise DaceSyntaxError(self, node, f'Operator {opname} is not defined for types {op1name} and {op2name}') self._add_state('%s_%d' % (type(node).__name__, node.lineno)) - self.last_state.set_default_lineinfo(self.current_lineinfo) + self.last_block.set_default_lineinfo(self.current_lineinfo) try: - result = func(self, self.sdfg, self.last_state, operand1, operand2) + result = func(self, self.sdfg, self.last_block, operand1, operand2) except SyntaxError as ex: raise DaceSyntaxError(self, node, str(ex)) if not isinstance(result, (list, tuple)): @@ -4894,7 +4941,7 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS raise DaceSyntaxError(self, node, "Variable {v} has been already defined".format(v=r)) self.variables[r] = r - self.last_state.set_default_lineinfo(None) + self.last_block.set_default_lineinfo(None) return result @@ -4938,7 +4985,7 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): self._add_state('slice_%s_%d' % (array.replace('.', '_'), node.lineno)) if has_array_indirection: # Make copy slicing state - rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) + rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) return self._array_indirection_subgraph(rnode, expr) else: is_index = False @@ -4982,11 +5029,11 @@ def _add_read_slice(self, array: str, node: ast.Subscript, expr: MemletExpr): wcr=expr.wcr)) self.variables[tmp] = tmp if not isinstance(tmparr, data.View): - rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) - wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) + rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) + wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo) # NOTE: We convert the subsets to string because keeping the original symbolic information causes # equality check failures, e.g., in LoopToMap. - self.last_state.add_nedge( + self.current_state.add_nedge( rnode, wnode, Memlet(data=array, subset=str(expr.subset), @@ -5024,7 +5071,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]: # `not sym` returns True. This exception is benign. pass state = self._add_state(f'promote_{scalar}_to_{str(sym)}') - edge = self.sdfg.in_edges(state)[0] + edge = state.parent_graph.in_edges(state)[0] edge.data.assignments = {str(sym): scalar} return sym return scalar @@ -5213,17 +5260,17 @@ def make_slice(self, arrname: str, rng: subsets.Range): # Add slicing state # TODO: naming issue, we don't have the linenumber here self._add_state('slice_%s' % (array)) - rnode = self.last_state.add_read(array, debuginfo=self.current_lineinfo) + rnode = self.current_state.add_read(array, debuginfo=self.current_lineinfo) other_subset = copy.deepcopy(rng) other_subset.squeeze() if _subset_has_indirection(rng, self): memlet = Memlet.simple(array, rng) tmp = self.sdfg.temp_data_name() - tmp = add_indirection_subgraph(self.sdfg, self.last_state, rnode, None, memlet, tmp, self) + tmp = add_indirection_subgraph(self.sdfg, self.current_state, rnode, None, memlet, tmp, self) else: tmp, tmparr = self.sdfg.add_temp_transient(other_subset.size(), arrobj.dtype, arrobj.storage) - wnode = self.last_state.add_write(tmp, debuginfo=self.current_lineinfo) - self.last_state.add_nedge( + wnode = self.current_state.add_write(tmp, debuginfo=self.current_lineinfo) + self.current_state.add_nedge( rnode, wnode, Memlet.simple(array, rng, num_accesses=rng.num_elements(), other_subset_str=other_subset)) return tmp, other_subset @@ -5292,7 +5339,7 @@ def _array_indirection_subgraph(self, rnode: nodes.AccessNode, expr: MemletExpr) # output shape dimensions are len(output_shape) # Make map with output shape - state: SDFGState = self.last_state + state = self.current_state wnode = state.add_write(outname) maprange = [(f'__i{i}', f'0:{s}') for i, s in enumerate(output_shape)] me, mx = state.add_map('indirect_slice', maprange, debuginfo=self.current_lineinfo) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 34cb8fb4ad..e55829933c 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -13,7 +13,7 @@ from dace import data, dtypes, hooks, symbolic from dace.config import Config from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing) -from dace.sdfg import SDFG +from dace.sdfg import SDFG, utils as sdutils from dace.data import create_datadescriptor, Data try: @@ -152,7 +152,8 @@ def __init__(self, regenerate_code: bool = True, recompile: bool = True, distributed_compilation: bool = False, - method: bool = False): + method: bool = False, + use_experimental_cfg_blocks: bool = False): from dace.codegen import compiled_sdfg # Avoid import loops self.f = f @@ -172,6 +173,7 @@ def __init__(self, self.recreate_sdfg = recreate_sdfg self.regenerate_code = regenerate_code self.recompile = recompile + self.use_experimental_cfg_blocks = use_experimental_cfg_blocks self.distributed_compilation = distributed_compilation self.global_vars = _get_locals_and_globals(f) @@ -491,6 +493,11 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF # Obtain DaCe program as SDFG sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) + if not self.use_experimental_cfg_blocks: + sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg) + sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks + # Apply simplification pass automatically if not cached and (simplify == True or (simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))): @@ -801,7 +808,8 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey: _, key = self._load_sdfg(None, *args, **kwargs) return key - def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], simplify: Optional[bool] = None) -> SDFG: + def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], + simplify: Optional[bool] = None) -> Tuple[SDFG, bool]: """ Generates the parsed AST representation of a DaCe program. :param args: The given arguments to the program. diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 420346ca88..bb2c70f6c0 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -935,6 +935,9 @@ def _add_exits(self, until_loop_end: bool, only_one: bool = False) -> List[ast.A for stmt in reversed(self.with_statements): if until_loop_end and not isinstance(stmt, (ast.With, ast.AsyncWith)): break + elif not until_loop_end and isinstance(stmt, (ast.For, ast.While)): + break + for mgrname, mgr in reversed(self.context_managers[stmt]): # Call __exit__ (without exception management all three arguments are set to None) exit_call = ast.copy_location(ast.parse(f'{mgrname}.__exit__(None, None, None)').body[0], stmt) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 8bca373b02..8c123f6bfe 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -8,7 +8,7 @@ import warnings from functools import reduce from numbers import Number, Integral -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, TYPE_CHECKING import dace from dace.codegen.tools import type_inference @@ -28,7 +28,10 @@ Size = Union[int, dace.symbolic.symbol] Shape = Sequence[Size] -ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor' +if TYPE_CHECKING: + from dace.frontend.python.newast import ProgramVisitor +else: + ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor' def normalize_axes(axes: Tuple[int], max_dim: int) -> List[int]: @@ -971,8 +974,8 @@ def _pymax(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: Union[str, Numbe for i, b in enumerate(args): if i > 0: pv._add_state('__min2_%d' % i) - pv.last_state.set_default_lineinfo(pv.current_lineinfo) - current_state = pv.last_state + pv.last_block.set_default_lineinfo(pv.current_lineinfo) + current_state = pv.last_block left_arg = _minmax2(pv, sdfg, current_state, left_arg, b, ismin=False) return left_arg @@ -986,8 +989,8 @@ def _pymin(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: Union[str, Numbe for i, b in enumerate(args): if i > 0: pv._add_state('__min2_%d' % i) - pv.last_state.set_default_lineinfo(pv.current_lineinfo) - current_state = pv.last_state + pv.last_block.set_default_lineinfo(pv.current_lineinfo) + current_state = pv.last_block left_arg = _minmax2(pv, sdfg, current_state, left_arg, b) return left_arg @@ -3355,7 +3358,7 @@ def _create_subgraph(visitor: ProgramVisitor, cond_state.add_nedge(r, w, dace.Memlet("{}[0]".format(r))) true_state = sdfg.add_state(label=cond_state.label + '_true') state = true_state - visitor.last_state = state + visitor.last_block = state cond = name cond_else = 'not ({})'.format(cond) sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(cond)) @@ -3374,7 +3377,7 @@ def _create_subgraph(visitor: ProgramVisitor, dace.Memlet.from_array(arg, sdfg.arrays[arg])) if has_where and isinstance(where, str) and where in sdfg.arrays.keys(): visitor._add_state(label=cond_state.label + '_true') - sdfg.add_edge(cond_state, visitor.last_state, dace.InterstateEdge(cond_else)) + sdfg.add_edge(cond_state, visitor.last_block, dace.InterstateEdge(cond_else)) else: # Map needed if has_where: diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index 9a42203eed..cf58cf76cc 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -61,7 +61,7 @@ def infer_connector_types(sdfg: SDFG): :param sdfg: The SDFG to infer. """ # Loop over states, and in a topological sort over each state's nodes - for state in sdfg.nodes(): + for state in sdfg.states(): for node in dfs_topological_sort(state): # Try to infer input connector type from node type or previous edges for e in state.in_edges(node): @@ -168,7 +168,7 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E if isinstance(scope, SDFG): # Set device for default top-level schedules and storages - for state in scope.nodes(): + for state in scope.states(): set_default_schedule_and_storage_types(state, parent_schedules, use_parent_schedule=use_parent_schedule, diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index b43ff2a7bf..82d98c1e18 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -30,7 +30,7 @@ from dace.frontend.python import astutils, wrappers from dace.sdfg import nodes as nd from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView -from dace.sdfg.state import SDFGState, ControlFlowRegion +from dace.sdfg.state import ControlFlowBlock, SDFGState, ControlFlowRegion from dace.sdfg.propagation import propagate_memlets_sdfg from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name @@ -183,7 +183,7 @@ class InterstateEdge(object): desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')") condition = CodeProperty(desc="Transition condition", default=CodeBlock("1")) - def __init__(self, condition: CodeBlock = None, assignments=None): + def __init__(self, condition: Optional[Union[CodeBlock, str, ast.AST, list]] = None, assignments=None): if condition is None: condition = CodeBlock("1") @@ -452,6 +452,9 @@ class SDFG(ControlFlowRegion): desc='Mapping between callback name and its original callback ' '(for when the same callback is used with a different signature)') + using_experimental_blocks = Property(dtype=bool, default=False, + desc="Whether the SDFG contains experimental control flow blocks") + def __init__(self, name: str, constants: Dict[str, Tuple[dt.Data, Any]] = None, @@ -509,6 +512,8 @@ def __init__(self, self._orig_name = name self._num = 0 + self._sdfg = self + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) @@ -2220,6 +2225,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops. # TODO (later): Adapt codegen to deal with hierarchical CFGs instead. sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg) # Rename SDFG to avoid runtime issues with clashing names index = 0 @@ -2680,3 +2686,15 @@ def make_array_memlet(self, array: str): :return: a Memlet that fully transfers array """ return dace.Memlet.from_array(array, self.data(array)) + + def recheck_using_experimental_blocks(self) -> bool: + found_experimental_block = False + for node, graph in self.root_sdfg.all_nodes_recursive(): + if isinstance(graph, ControlFlowRegion) and not isinstance(graph, SDFG): + found_experimental_block = True + break + if isinstance(node, ControlFlowBlock) and not isinstance(node, SDFGState): + found_experimental_block = True + break + self.root_sdfg.using_experimental_blocks = found_experimental_block + return found_experimental_block diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 429fbbd690..736a4799df 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains classes of a single SDFG state and dataflow subgraphs. """ import ast @@ -8,7 +8,8 @@ import inspect import itertools import warnings -from typing import TYPE_CHECKING, Any, AnyStr, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload +from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, + overload) import dace import dace.serialize @@ -19,7 +20,7 @@ from dace import subsets as sbs from dace import symbolic from dace.properties import (CodeBlock, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, - CodeProperty, make_properties, SetProperty) + CodeProperty, make_properties) from dace.sdfg import nodes as nd from dace.sdfg.graph import MultiConnectorEdge, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge from dace.sdfg.propagation import propagate_memlet @@ -30,7 +31,6 @@ import dace.sdfg.scope from dace.sdfg import SDFG - NodeT = Union[nd.Node, 'ControlFlowBlock'] EdgeT = Union[MultiConnectorEdge[mm.Memlet], Edge['dace.sdfg.InterstateEdge']] GraphT = Union['ControlFlowRegion', 'SDFGState'] @@ -80,7 +80,6 @@ class BlockGraphView(object): creation, queries, and replacements. ``ControlFlowBlock`` and ``StateSubgraphView`` inherit from this class to share methods. """ - ################################################################### # Typing overrides @@ -109,15 +108,21 @@ def sdfg(self) -> 'SDFG': # Traversal methods @abc.abstractmethod - def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + def all_nodes_recursive( + self, + predicate: Optional[Callable[[NodeT, GraphT], bool]] = None) -> Iterator[Tuple[NodeT, GraphT]]: """ Iterate over all nodes in this graph or subgraph. This includes control flow blocks, nodes in those blocks, and recursive control flow blocks and nodes within nested SDFGs. It returns tuples of the form (node, parent), where the node is either a dataflow node, in which case the parent is an SDFG state, or a control flow block, in which case the parent is a control flow graph (i.e., an SDFG or a scope block). + + :param predicate: An optional predicate function that decides on whether the traversal should recurse or not. + If the predicate returns False, traversal is not recursed any further into the graph found under NodeT for + a given [NodeT, GraphT] pair. """ - raise NotImplementedError() + return [] @abc.abstractmethod def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: @@ -127,7 +132,7 @@ def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: the form (edge, parent), where the edge is either a dataflow edge, in which case the parent is an SDFG state, or an inter-stte edge, in which case the parent is a control flow graph (i.e., an SDFG or a scope block). """ - raise NotImplementedError() + return [] @abc.abstractmethod def data_nodes(self) -> List[nd.AccessNode]: @@ -135,17 +140,17 @@ def data_nodes(self) -> List[nd.AccessNode]: Returns all data nodes (i.e., AccessNodes, arrays) present in this graph or subgraph. Note: This does not recurse into nested SDFGs. """ - raise NotImplementedError() + return [] @abc.abstractmethod - def entry_node(self, node: nd.Node) -> nd.EntryNode: + def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: """ Returns the entry node that wraps the current node, or None if it is top-level in a state. """ - raise NotImplementedError() + return None @abc.abstractmethod - def exit_node(self, entry_node: nd.EntryNode) -> nd.ExitNode: + def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: """ Returns the exit node leaving the context opened by the given entry node. """ - raise NotImplementedError() + raise None ################################################################### # Memlet-tracking methods @@ -208,7 +213,7 @@ def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[Multi # Query, subgraph, and replacement methods @abc.abstractmethod - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: """ Returns a set of symbol names that are used in the graph. @@ -216,8 +221,8 @@ def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) - :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping will be removed from the set of defined symbols. """ - raise NotImplementedError() - + return set() + @property def free_symbols(self) -> Set[str]: """ @@ -237,13 +242,13 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: :return: A two-tuple of sets of things denoting ({data read}, {data written}). """ - raise NotImplementedError() + return set(), set() @abc.abstractmethod def unordered_arglist(self, defined_syms=None, shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: - raise NotImplementedError() + return {}, {} def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Data]: """ @@ -288,12 +293,12 @@ def signature_arglist(self, with_types=True, for_call=False): @abc.abstractmethod def top_level_transients(self) -> Set[str]: """Iterate over top-level transients of this graph.""" - raise NotImplementedError() + return set() @abc.abstractmethod def all_transients(self) -> List[str]: """Iterate over all transients in this graph.""" - raise NotImplementedError() + return [] @abc.abstractmethod def replace(self, name: str, new_name: str): @@ -303,7 +308,7 @@ def replace(self, name: str, new_name: str): :param name: Name to find. :param new_name: Name to replace. """ - raise NotImplementedError() + pass @abc.abstractmethod def replace_dict(self, @@ -315,7 +320,7 @@ def replace_dict(self, :param repl: Mapping from names to replacements. :param symrepl: Optional symbolic version of ``repl``. """ - raise NotImplementedError() + pass @make_properties @@ -338,11 +343,12 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]: ################################################################### # Traversal methods - def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + def all_nodes_recursive(self, predicate = None) -> Iterator[Tuple[NodeT, GraphT]]: for node in self.nodes(): yield node, self if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_nodes_recursive() + if predicate is None or predicate(node, self): + yield from node.sdfg.all_nodes_recursive() def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: for e in self.edges(): @@ -637,7 +643,7 @@ def is_leaf_memlet(self, e): return False return True - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: state = self.graph if isinstance(self, SubgraphView) else self sdfg = state.sdfg new_symbols = set() @@ -955,10 +961,11 @@ def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: ################################################################### # Traversal methods - def all_nodes_recursive(self) -> Iterator[Tuple[NodeT, GraphT]]: + def all_nodes_recursive(self, predicate = None) -> Iterator[Tuple[NodeT, GraphT]]: for node in self.nodes(): yield node, self - yield from node.all_nodes_recursive() + if predicate is None or predicate(node, self): + yield from node.all_nodes_recursive() def all_edges_recursive(self) -> Iterator[Tuple[EdgeT, GraphT]]: for e in self.edges(): @@ -1028,7 +1035,7 @@ def _used_symbols_internal(self, keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: raise NotImplementedError() - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping)[0] def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: @@ -1072,7 +1079,8 @@ def replace(self, name: str, new_name: str): def replace_dict(self, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, - replace_in_graph: bool = True, replace_keys: bool = False): + replace_in_graph: bool = True, + replace_keys: bool = False): symrepl = symrepl or { symbolic.symbol(k): symbolic.pystr_to_symbolic(v) if isinstance(k, str) else v for k, v in repl.items() @@ -1087,6 +1095,7 @@ def replace_dict(self, for state in self.nodes(): state.replace_dict(repl, symrepl) + @make_properties class ControlFlowBlock(BlockGraphView, abc.ABC): @@ -1098,10 +1107,7 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): _label: str - def __init__(self, - label: str='', - sdfg: Optional['SDFG'] = None, - parent: Optional['ControlFlowRegion'] = None): + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optional['ControlFlowRegion'] = None): super(ControlFlowBlock, self).__init__() self._label = label self._default_lineinfo = None @@ -1112,6 +1118,12 @@ def __init__(self, self.post_conditions = {} self.invariant_conditions = {} + def nodes(self): + return [] + + def edges(self): + return [] + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): """ Sets the default source line information to be lineinfo, or None to @@ -1134,6 +1146,23 @@ def __str__(self): def __repr__(self) -> str: return f'ControlFlowBlock ({self.label})' + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k in ('_parent_graph', '_sdfg'): # Skip derivative attributes + continue + setattr(result, k, copy.deepcopy(v, memo)) + + for k in ('_parent_graph', '_sdfg'): + if id(getattr(self, k)) in memo: + setattr(result, k, memo[id(getattr(self, k))]) + else: + setattr(result, k, None) + + return result + @property def label(self) -> str: return self._label @@ -1209,7 +1238,6 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): :param sdfg: A reference to the parent SDFG. :param debuginfo: Source code locator for debugging. """ - from dace.sdfg.sdfg import SDFG # Avoid import loop OrderedMultiDiConnectorGraph.__init__(self) ControlFlowBlock.__init__(self, label, sdfg) super(SDFGState, self).__init__() @@ -1221,31 +1249,6 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): self.location = location if location is not None else {} self._default_lineinfo = None - def __deepcopy__(self, memo): - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k in ('_parent_graph', '_sdfg'): # Skip derivative attributes - continue - setattr(result, k, copy.deepcopy(v, memo)) - - for k in ('_parent_graph', '_sdfg'): - if id(getattr(self, k)) in memo: - setattr(result, k, memo[id(getattr(self, k))]) - else: - setattr(result, k, None) - - for node in result.nodes(): - if isinstance(node, nd.NestedSDFG): - try: - node.sdfg.parent = result - except AttributeError: - # NOTE: There are cases where a NestedSDFG does not have `sdfg` attribute. - # TODO: Investigate why this happens. - pass - return result - @property def parent(self): """ Returns the parent SDFG of this state. """ @@ -1410,6 +1413,19 @@ def _repr_html_(self): return sdfg._repr_html_() + def __deepcopy__(self, memo): + result: SDFGState = ControlFlowBlock.__deepcopy__(self, memo) + + for node in result.nodes(): + if isinstance(node, nd.NestedSDFG): + try: + node.sdfg.parent = result + except AttributeError: + # NOTE: There are cases where a NestedSDFG does not have `sdfg` attribute. + # TODO: Investigate why this happens. + pass + return result + def symbols_defined_at(self, node: nd.Node) -> Dict[str, dtypes.typeclass]: """ Returns all symbols available to a given node. @@ -2378,6 +2394,27 @@ def fill_scope_connectors(self): node.add_in_connector(edge.dst_conn) +class ContinueBlock(ControlFlowBlock): + """ Special control flow block to represent a continue inside of loops. """ + + def __repr__(self): + return f'ContinueBlock ({self.label})' + + +class BreakBlock(ControlFlowBlock): + """ Special control flow block to represent a continue inside of loops or switch / select blocks. """ + + def __repr__(self): + return f'BreakBlock ({self.label})' + + +class ReturnBlock(ControlFlowBlock): + """ Special control flow block to represent an early return out of the SDFG or a nested procedure / SDFG. """ + + def __repr__(self): + return f'ReturnBlock ({self.label})' + + class StateSubgraphView(SubgraphView, DataflowGraphView): """ A read-only subgraph view of an SDFG state. """ @@ -2394,7 +2431,7 @@ def sdfg(self) -> 'SDFG': class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, ControlFlowBlock): - def __init__(self, label: str='', sdfg: Optional['SDFG'] = None): + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None): OrderedDiGraph.__init__(self) ControlGraphView.__init__(self) ControlFlowBlock.__init__(self, label, sdfg) @@ -2404,6 +2441,13 @@ def __init__(self, label: str='', sdfg: Optional['SDFG'] = None): self._cached_start_block: Optional[ControlFlowBlock] = None self._cfg_list: List['ControlFlowRegion'] = [self] + @property + def root_sdfg(self) -> 'SDFG': + from dace.sdfg.sdfg import SDFG # Avoid import loop + if not isinstance(self.cfg_list[0], SDFG): + raise RuntimeError('Root CFG is not of type SDFG') + return self.cfg_list[0] + def reset_cfg_list(self) -> List['ControlFlowRegion']: """ Reset the CFG list when changes have been made to the SDFG's CFG tree. @@ -2448,6 +2492,65 @@ def update_cfg_list(self, cfg_list): else: self._cfg_list = sub_cfg_list + def inline(self) -> Tuple[bool, Any]: + """ + Inlines the control flow region into its parent control flow region (if it exists). + + :return: True if the inlining succeeded, false otherwise. + """ + parent = self.parent_graph + if parent: + end_state = parent.add_state(self.label + '_end') + + # Add all region states and make sure to keep track of all the ones that need to be connected in the end. + to_connect: Set[SDFGState] = set() + block_to_state_map: Dict[ControlFlowBlock, SDFGState] = dict() + for node in self.nodes(): + node.label = self.label + '_' + node.label + parent.add_node(node, ensure_unique_name=True) + if isinstance(node, ReturnBlock) and isinstance(parent, dace.SDFG): + # If a return block is being inlined into an SDFG, convert it into a regular state. Otherwise it + # remains as-is. + newnode = parent.add_state(node.label) + block_to_state_map[node] = newnode + elif self.out_degree(node) == 0: + to_connect.add(node) + + # Add all region edges. + for edge in self.edges(): + src = block_to_state_map[edge.src] if edge.src in block_to_state_map else edge.src + dst = block_to_state_map[edge.dst] if edge.dst in block_to_state_map else edge.dst + parent.add_edge(src, dst, edge.data) + + # Redirect all edges to the region to the internal start state. + for b_edge in parent.in_edges(self): + parent.add_edge(b_edge.src, self.start_block, b_edge.data) + parent.remove_edge(b_edge) + # Redirect all edges exiting the region to instead exit the end state. + for a_edge in parent.out_edges(self): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + for node in to_connect: + parent.add_edge(node, end_state, dace.InterstateEdge()) + + # Remove the original control flow region (self) from the parent graph. + parent.remove_node(self) + + sdfg = parent if isinstance(parent, dace.SDFG) else parent.sdfg + sdfg.reset_cfg_list() + + return True, end_state + + return False, None + + def add_return(self, label=None) -> ReturnBlock: + label = self._ensure_unique_block_name(label) + block = ReturnBlock(label) + self._labels.add(label) + self.add_node(block) + return block + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. @@ -2465,9 +2568,23 @@ def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdf self._cached_start_block = None return super().add_edge(src, dst, data) - def add_node(self, node, is_start_block=False, *, is_start_state: bool=None): + def _ensure_unique_block_name(self, proposed: Optional[str] = None) -> str: + if self._labels is None or len(self._labels) != self.number_of_nodes(): + self._labels = set(s.label for s in self.nodes()) + return dt.find_new_name(proposed or 'block', self._labels) + + def add_node(self, + node, + is_start_block: bool = False, + ensure_unique_name: bool = False, + *, + is_start_state: bool = None): if not isinstance(node, ControlFlowBlock): raise TypeError('Expected ControlFlowBlock, got ' + str(type(node))) + + if ensure_unique_name: + node.label = self._ensure_unique_block_name(node.label) + super().add_node(node) self._cached_start_block = None node.parent_graph = self @@ -2484,12 +2601,8 @@ def add_node(self, node, is_start_block=False, *, is_start_state: bool=None): self.start_block = len(self.nodes()) - 1 self._cached_start_block = node - def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=None) -> SDFGState: - if self._labels is None or len(self._labels) != self.number_of_nodes(): - self._labels = set(s.label for s in self.nodes()) - label = label or 'state' - existing_labels = self._labels - label = dt.find_new_name(label, existing_labels) + def add_state(self, label=None, is_start_block=False, *, is_start_state: bool = None) -> SDFGState: + label = self._ensure_unique_block_name(label) state = SDFGState(label) self._labels.add(label) start_block = is_start_block @@ -2506,7 +2619,7 @@ def add_state_before(self, condition: CodeBlock = None, assignments=None, *, - is_start_state: bool=None) -> SDFGState: + is_start_state: bool = None) -> SDFGState: """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. :param state: The state to prepend the new state before. @@ -2532,7 +2645,7 @@ def add_state_after(self, condition: CodeBlock = None, assignments=None, *, - is_start_state: bool=None) -> SDFGState: + is_start_state: bool = None) -> SDFGState: """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. :param state: The state to append the new state after. @@ -2551,7 +2664,6 @@ def add_state_after(self, self.add_edge(state, new_state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments)) return new_state - @abc.abstractmethod def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, @@ -2586,9 +2698,9 @@ def _used_symbols_internal(self, # compute the symbols that are used before being assigned. efsyms = e.data.used_symbols(all_symbols) # collect symbols representing data containers - dsyms = {sym for sym in efsyms if sym in self.arrays} + dsyms = {sym for sym in efsyms if sym in self.sdfg.arrays} for d in dsyms: - efsyms |= {str(sym) for sym in self.arrays[d].used_symbols(all_symbols)} + efsyms |= {str(sym) for sym in self.sdfg.arrays[d].used_symbols(all_symbols)} defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) used_before_assignment.update(efsyms - defined_syms) free_syms |= efsyms @@ -2767,16 +2879,19 @@ class LoopRegion(ControlFlowRegion): present). """ - update_statement = CodeProperty(optional=True, allow_none=True, default=None, + update_statement = CodeProperty(optional=True, + allow_none=True, + default=None, desc='The loop update statement. May be None if the update happens elsewhere.') - init_statement = CodeProperty(optional=True, allow_none=True, default=None, + init_statement = CodeProperty(optional=True, + allow_none=True, + default=None, desc='The loop init statement. May be None if the initialization happens elsewhere.') loop_condition = CodeProperty(allow_none=True, default=None, desc='The loop condition') - inverted = Property(dtype=bool, default=False, + inverted = Property(dtype=bool, + default=False, desc='If True, the loop condition is checked after the first iteration.') loop_variable = Property(dtype=str, default='', desc='The loop variable, if given') - break_states = SetProperty(element_type=int, desc='States that when reached break out of the loop') - continue_states = SetProperty(element_type=int, desc='States that when reached directly execute the next iteration') def __init__(self, label: str, @@ -2805,12 +2920,132 @@ def __init__(self, self.loop_variable = loop_var or '' self.inverted = inverted + def inline(self) -> Tuple[bool, Any]: + """ + Inlines the loop region into its parent control flow region. + + :return: True if the inlining succeeded, false otherwise. + """ + parent = self.parent_graph + if not parent: + raise RuntimeError('No top-level SDFG present to inline into') + + # Avoid circular imports + from dace.frontend.python import astutils + + # Check that the loop initialization and update statements each only contain assignments, if the loop has any. + if self.init_statement is not None: + if isinstance(self.init_statement.code, list): + for stmt in self.init_statement.code: + if not isinstance(stmt, astutils.ast.Assign): + return False, None + if self.update_statement is not None: + if isinstance(self.update_statement.code, list): + for stmt in self.update_statement.code: + if not isinstance(stmt, astutils.ast.Assign): + return False, None + + # First recursively inline any other contained control flow regions other than loops to ensure break, continue, + # and return are inlined correctly. + def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: + for block in region.nodes(): + if isinstance(block, ControlFlowRegion) and not isinstance(block, LoopRegion): + recursive_inline_cf_regions(block) + block.inline() + recursive_inline_cf_regions(self) + + # Add all boilerplate loop states necessary for the structure. + init_state = parent.add_state(self.label + '_init') + guard_state = parent.add_state(self.label + '_guard') + end_state = parent.add_state(self.label + '_end') + loop_latch_state = parent.add_state(self.label + '_latch') + + # Add all loop states and make sure to keep track of all the ones that need to be connected in the end. + # Return blocks are inlined as-is. If the parent graph is an SDFG, they are converted to states, otherwise + # they are left as explicit exit blocks. + connect_to_latch: Set[SDFGState] = set() + connect_to_end: Set[SDFGState] = set() + block_to_state_map: Dict[ControlFlowBlock, SDFGState] = dict() + for node in self.nodes(): + node.label = self.label + '_' + node.label + if isinstance(node, BreakBlock): + newnode = parent.add_state(node.label) + connect_to_end.add(newnode) + block_to_state_map[node] = newnode + elif isinstance(node, ContinueBlock): + newnode = parent.add_state(node.label) + connect_to_latch.add(newnode) + block_to_state_map[node] = newnode + elif isinstance(node, ReturnBlock) and isinstance(parent, dace.SDFG): + newnode = parent.add_state(node.label) + block_to_state_map[node] = newnode + else: + if self.out_degree(node) == 0: + connect_to_latch.add(node) + parent.add_node(node, ensure_unique_name=True) + + # Add all internal loop edges. + for edge in self.edges(): + src = block_to_state_map[edge.src] if edge.src in block_to_state_map else edge.src + dst = block_to_state_map[edge.dst] if edge.dst in block_to_state_map else edge.dst + parent.add_edge(src, dst, edge.data) + + # Redirect all edges to the loop to the init state. + for b_edge in parent.in_edges(self): + parent.add_edge(b_edge.src, init_state, b_edge.data) + parent.remove_edge(b_edge) + # Redirect all edges exiting the loop to instead exit the end state. + for a_edge in parent.out_edges(self): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + # Add an initialization edge that initializes the loop variable if applicable. + init_edge = dace.InterstateEdge() + if self.init_statement is not None: + init_edge.assignments = {} + for stmt in self.init_statement.code: + assign: astutils.ast.Assign = stmt + init_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) + if self.inverted: + parent.add_edge(init_state, self.start_block, init_edge) + else: + parent.add_edge(init_state, guard_state, init_edge) + + # Connect the loop tail. + update_edge = dace.InterstateEdge() + if self.update_statement is not None: + update_edge.assignments = {} + for stmt in self.update_statement.code: + assign: astutils.ast.Assign = stmt + update_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) + parent.add_edge(loop_latch_state, guard_state, update_edge) + + # Add condition checking edges and connect the guard state. + cond_expr = self.loop_condition.code + parent.add_edge(guard_state, end_state, + dace.InterstateEdge(CodeBlock(astutils.negate_expr(cond_expr)).code)) + parent.add_edge(guard_state, self.start_block, dace.InterstateEdge(CodeBlock(cond_expr).code)) + + # Connect any end states from the loop's internal state machine to the tail state so they end a + # loop iteration. Do the same for any continue states, and connect any break states to the end of the loop. + for node in connect_to_latch: + parent.add_edge(node, loop_latch_state, dace.InterstateEdge()) + for node in connect_to_end: + parent.add_edge(node, end_state, dace.InterstateEdge()) + + parent.remove_node(self) + + sdfg = parent if isinstance(parent, dace.SDFG) else parent.sdfg + sdfg.reset_cfg_list() + + return True, (init_state, guard_state, end_state) + def _used_symbols_internal(self, all_symbols: bool, - defined_syms: Optional[Set]=None, - free_syms: Optional[Set]=None, - used_before_assignment: Optional[Set]=None, - keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -2823,20 +3058,21 @@ def _used_symbols_internal(self, free_syms |= self.loop_condition.get_free_symbols() b_free_symbols, b_defined_symbols, b_used_before_assignment = super()._used_symbols_internal( - all_symbols, keep_defined_in_mapping=keep_defined_in_mapping - ) + all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) free_syms |= b_free_symbols defined_syms |= b_defined_symbols - used_before_assignment |= b_used_before_assignment + used_before_assignment |= (b_used_before_assignment - {self.loop_variable}) defined_syms -= used_before_assignment free_syms -= defined_syms return free_syms, defined_syms, used_before_assignment - def replace_dict(self, repl: Dict[str, str], + def replace_dict(self, + repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, - replace_in_graph: bool = True, replace_keys: bool = True): + replace_in_graph: bool = True, + replace_keys: bool = True): if replace_keys: from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) @@ -2849,22 +3085,37 @@ def replace_dict(self, repl: Dict[str, str], def to_json(self, parent=None): return super().to_json(parent) - def _add_node_internal(self, node, is_continue=False, is_break=False): - if is_continue: - if is_break: - raise ValueError('Cannot set both is_continue and is_break') - self.continue_states.add(self.node_id(node)) - if is_break: - if is_continue: - raise ValueError('Cannot set both is_continue and is_break') - self.break_states.add(self.node_id(node)) - - def add_node(self, node, is_start_block=False, is_continue=False, is_break=False, *, is_start_state: bool = None): - super().add_node(node, is_start_block, is_start_state=is_start_state) - self._add_node_internal(node, is_continue, is_break) - - def add_state(self, label=None, is_start_block=False, is_continue=False, is_break=False, *, - is_start_state: bool = None) -> SDFGState: - state = super().add_state(label, is_start_block, is_start_state=is_start_state) - self._add_node_internal(state, is_continue, is_break) - return state + def add_break(self, label=None) -> BreakBlock: + label = self._ensure_unique_block_name(label) + block = BreakBlock(label) + self._labels.add(label) + self.add_node(block) + return block + + def add_continue(self, label=None) -> ContinueBlock: + label = self._ensure_unique_block_name(label) + block = ContinueBlock(label) + self._labels.add(label) + self.add_node(block) + return block + + @property + def has_continue(self) -> bool: + for node, _ in self.all_nodes_recursive(lambda n, _: not isinstance(n, (LoopRegion, SDFGState))): + if isinstance(node, ContinueBlock): + return True + return False + + @property + def has_break(self) -> bool: + for node, _ in self.all_nodes_recursive(lambda n, _: not isinstance(n, (LoopRegion, SDFGState))): + if isinstance(node, BreakBlock): + return True + return False + + @property + def has_return(self) -> bool: + for node, _ in self.all_nodes_recursive(lambda n, _: not isinstance(n, (LoopRegion, SDFGState))): + if isinstance(node, ReturnBlock): + return True + return False diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 7311f4f028..12f66db85f 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,12 +13,11 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import SDFGState, StateSubgraphView, LoopRegion, ControlFlowBlock, GraphT +from dace.sdfg.state import SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation -from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs, symbolic +from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs from dace.cli.progress import optional_progressbar -from string import ascii_uppercase from typing import Any, Callable, Dict, Generator, List, Optional, Set, Sequence, Tuple, Union @@ -1218,8 +1217,6 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> start = time.time() for sd in sdfg.all_sdfgs_recursive(): - id = sd.cfg_id - for cfg in sd.all_control_flow_regions(): while True: edges = list(cfg.nx.edges) @@ -1235,7 +1232,7 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> continue candidate = {StateFusion.first_state: u, StateFusion.second_state: v} sf = StateFusion() - sf.setup_match(cfg, id, -1, candidate, 0, override=True) + sf.setup_match(cfg, cfg.cfg_id, -1, candidate, 0, override=True) if sf.can_be_applied(cfg, 0, sd, permissive=permissive): sf.apply(cfg, sd) applied += 1 @@ -1252,31 +1249,30 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - # Avoid import loops - from dace.transformation.interstate import LoopRegionInline + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, LoopRegion)] + count = 0 - counter = 0 - blocks = [(n, p) for n, p in sdfg.all_nodes_recursive() if isinstance(n, LoopRegion)] + for _block in optional_progressbar(reversed(blocks), title='Inlining Loops', + n=len(blocks), progress=progress): + block: LoopRegion = _block + if block.inline()[0]: + count += 1 - for _block, _graph in optional_progressbar(reversed(blocks), title='Inlining Loops', - n=len(blocks), progress=progress): - block: ControlFlowBlock = _block - graph: GraphT = _graph - id = block.sdfg.cfg_id + return count - # We have to reevaluate every time due to changing IDs - block_id = graph.node_id(block) - candidate = { - LoopRegionInline.loop: block, - } - inliner = LoopRegionInline() - inliner.setup_match(graph, id, block_id, candidate, 0, override=True) - if inliner.can_be_applied(graph, 0, block.sdfg, permissive=permissive): - inliner.apply(graph, block.sdfg) - counter += 1 +def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: + blocks = [n for n, _ in sdfg.all_nodes_recursive() + if isinstance(n, ControlFlowRegion) and not isinstance(n, (LoopRegion, SDFG))] + count = 0 - return counter + for _block in optional_progressbar(reversed(blocks), title='Inlining control flow blocks', + n=len(blocks), progress=progress): + block: ControlFlowRegion = _block + if block.inline()[0]: + count += 1 + + return count def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int: @@ -1303,9 +1299,10 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu for nsdfg_node in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress): # We have to reevaluate every time due to changing IDs # e.g., InlineMultistateSDFG may fission states - parent_state = nsdfg_node.sdfg.parent - parent_sdfg = parent_state.parent - parent_state_id = parent_sdfg.node_id(parent_state) + nsdfg: SDFG = nsdfg_node.sdfg + parent_state = nsdfg.parent + parent_sdfg = parent_state.sdfg + parent_state_id = parent_state.block_id if multistate: candidate = { @@ -1313,7 +1310,7 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu } inliner = InlineMultistateSDFG() inliner.setup_match(sdfg=parent_sdfg, - cfg_id=parent_sdfg.sdfg_id, + cfg_id=parent_state.parent_graph.cfg_id, state_id=parent_state_id, subgraph=candidate, expr_index=0, @@ -1328,7 +1325,7 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu } inliner = InlineSDFG() inliner.setup_match(sdfg=parent_sdfg, - cfg_id=parent_sdfg.sdfg_id, + cfg_id=parent_state.parent_graph.cfg_id, state_id=parent_state_id, subgraph=candidate, expr_index=0, @@ -1495,31 +1492,25 @@ def _traverse(scope: Node, symbols: Dict[str, dtypes.typeclass]): yield from _traverse(None, symbols) -def traverse_sdfg_with_defined_symbols( +def _tswds_cf_region( sdfg: SDFG, + region: ControlFlowRegion, + symbols: Dict[str, dtypes.typeclass], recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: - """ - Traverses the SDFG, its states and nodes, yielding the defined symbols and their types at each node. - - :return: A generator that yields tuples of (state, node in state, currently-defined symbols) - """ - # Start with global symbols - symbols = copy.copy(sdfg.symbols) - symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()}) - for desc in sdfg.arrays.values(): - symbols.update({str(s): s.dtype for s in desc.free_symbols}) - # Add symbols from inter-state edges along the state machine - start_state = sdfg.start_state + start_region = region.start_block visited = set() visited_edges = set() - for edge in sdfg.dfs_edges(start_state): + for edge in region.dfs_edges(start_region): # Source -> inter-state definition -> Destination visited_edges.add(edge) # Source if edge.src not in visited: visited.add(edge.src) - yield from _tswds_state(sdfg, edge.src, symbols, recursive) + if isinstance(edge.src, SDFGState): + yield from _tswds_state(sdfg, edge.src, {}, recursive) + elif isinstance(edge.src, ControlFlowRegion): + yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) # Add edge symbols into defined symbols issyms = edge.data.new_symbols(sdfg, symbols) @@ -1528,11 +1519,34 @@ def traverse_sdfg_with_defined_symbols( # Destination if edge.dst not in visited: visited.add(edge.dst) - yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + if isinstance(edge.dst, SDFGState): + yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + elif isinstance(edge.dst, ControlFlowRegion): + yield from _tswds_cf_region(sdfg, edge.dst, symbols, recursive) # If there is only one state, the DFS will miss it - if start_state not in visited: - yield from _tswds_state(sdfg, start_state, symbols, recursive) + if start_region not in visited: + if isinstance(start_region, SDFGState): + yield from _tswds_state(sdfg, start_region, symbols, recursive) + elif isinstance(start_region, ControlFlowRegion): + yield from _tswds_cf_region(sdfg, start_region, symbols, recursive) + + +def traverse_sdfg_with_defined_symbols( + sdfg: SDFG, + recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: + """ + Traverses the SDFG, its states and nodes, yielding the defined symbols and their types at each node. + + :return: A generator that yields tuples of (state, node in state, currently-defined symbols) + """ + # Start with global symbols + symbols = copy.copy(sdfg.symbols) + symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()}) + for desc in sdfg.arrays.values(): + symbols.update({str(s): s.dtype for s in desc.free_symbols}) + + yield from _tswds_cf_region(sdfg, sdfg, symbols, recursive) def is_fpga_kernel(sdfg, state): diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 660e45e574..480fb9c262 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -13,6 +13,7 @@ from dace.sdfg import SDFG from dace.sdfg import graph as gr from dace.memlet import Memlet + from dace.sdfg.state import ControlFlowRegion ########################################### # Validation @@ -28,13 +29,13 @@ def validate(graph: 'dace.sdfg.graph.SubgraphView'): validate_state(graph) -def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG', - region: 'dace.sdfg.state.ControlFlowRegion', +def validate_control_flow_region(sdfg: 'SDFG', + region: 'ControlFlowRegion', initialized_transients: Set[str], symbols: dict, references: Set[int] = None, **context: bool): - from dace.sdfg import SDFGState + from dace.sdfg.state import SDFGState, ControlFlowRegion from dace.sdfg.scope import is_in_scope if len(region.source_nodes()) > 1 and region.start_block is None: @@ -70,7 +71,7 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG', if isinstance(edge.src, SDFGState): validate_state(edge.src, region.node_id(edge.src), sdfg, symbols, initialized_transients, references, **context) - else: + elif isinstance(edge.src, ControlFlowRegion): validate_control_flow_region(sdfg, edge.src, initialized_transients, symbols, references, **context) ########################################## @@ -118,7 +119,7 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG', if isinstance(edge.dst, SDFGState): validate_state(edge.dst, region.node_id(edge.dst), sdfg, symbols, initialized_transients, references, **context) - else: + elif isinstance(edge.dst, ControlFlowRegion): validate_control_flow_region(sdfg, edge.dst, initialized_transients, symbols, references, **context) # End of block DFS @@ -127,7 +128,7 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG', if isinstance(start_block, SDFGState): validate_state(start_block, region.node_id(start_block), sdfg, symbols, initialized_transients, references, **context) - else: + elif isinstance(start_block, ControlFlowRegion): validate_control_flow_region(sdfg, start_block, initialized_transients, symbols, references, **context) # Validate all inter-state edges (including self-loops not found by DFS) @@ -201,9 +202,10 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context if not dtypes.validate_name(sdfg.name): raise InvalidSDFGError("Invalid name", sdfg, None) - all_blocks = set(sdfg.all_control_flow_blocks()) - if len(all_blocks) != len(set([s.label for s in all_blocks])): - raise InvalidSDFGError('Found multiple blocks with the same name', sdfg, None) + for cfg in sdfg.all_control_flow_regions(): + blocks = cfg.nodes() + if len(blocks) != len(set([s.label for s in blocks])): + raise InvalidSDFGError('Found multiple blocks with the same name in ' + cfg.name, sdfg, None) # Validate data descriptors for name, desc in sdfg._arrays.items(): diff --git a/dace/transformation/__init__.py b/dace/transformation/__init__.py index 13649d8727..3a4c65efa3 100644 --- a/dace/transformation/__init__.py +++ b/dace/transformation/__init__.py @@ -1,3 +1,3 @@ from .transformation import (PatternTransformation, SingleStateTransformation, MultiStateTransformation, - SubgraphTransformation, ExpandTransformation) + SubgraphTransformation, ExpandTransformation, experimental_cfg_block_compatible) from .pass_pipeline import Pass, Pipeline, FixedPointPipeline diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 60a35c565d..7bced3bec9 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -4,7 +4,7 @@ import dace import sympy from dace.sdfg import infer_types -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ControlFlowRegion from dace.sdfg.graph import SubgraphView from dace.sdfg.propagation import propagate_states from dace.sdfg.scope import is_devicelevel_gpu_kernel @@ -29,7 +29,7 @@ # FPGA AutoOpt from dace.transformation.auto import fpga as fpga_auto_opt -GraphViewType = Union[SDFG, SDFGState, gr.SubgraphView] +GraphViewType = Union[SDFG, SDFGState, gr.SubgraphView, ControlFlowRegion] def greedy_fuse(graph_or_subgraph: GraphViewType, @@ -53,22 +53,24 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, :param expand_reductions: Expand all reduce nodes before fusion """ debugprint = config.Config.get_bool('debugprint') - if isinstance(graph_or_subgraph, SDFG): - # If we have an SDFG, recurse into graphs - graph_or_subgraph.simplify(validate_all=validate_all) - # MapFusion for trivial cases - graph_or_subgraph.apply_transformations_repeated(MapFusion, validate_all=validate_all) + if isinstance(graph_or_subgraph, ControlFlowRegion): + if isinstance(graph_or_subgraph, SDFG): + # If we have an SDFG, recurse into graphs + graph_or_subgraph.simplify(validate_all=validate_all) + # MapFusion for trivial cases + graph_or_subgraph.apply_transformations_repeated(MapFusion, validate_all=validate_all) + # recurse into graphs for graph in graph_or_subgraph.nodes(): - - greedy_fuse(graph, - validate_all=validate_all, - device=device, - recursive=recursive, - stencil=stencil, - stencil_tile=stencil_tile, - permutations_only=permutations_only, - expand_reductions=expand_reductions) + if isinstance(graph, (SDFGState, ControlFlowRegion)): + greedy_fuse(graph, + validate_all=validate_all, + device=device, + recursive=recursive, + stencil=stencil, + stencil_tile=stencil_tile, + permutations_only=permutations_only, + expand_reductions=expand_reductions) else: # we are in graph or subgraph sdfg, graph, subgraph = None, None, None @@ -107,7 +109,7 @@ def greedy_fuse(graph_or_subgraph: GraphViewType, fusion_condition.allow_tiling = False # expand reductions if expand_reductions: - for graph in sdfg.nodes(): + for graph in sdfg.states(): for node in graph.nodes(): if isinstance(node, dace.libraries.standard.nodes.Reduce): try: @@ -190,12 +192,14 @@ def tile_wcrs(graph_or_subgraph: GraphViewType, validate_all: bool, prefer_parti graph = graph_or_subgraph if isinstance(graph_or_subgraph, gr.SubgraphView): graph = graph_or_subgraph.graph - if isinstance(graph, SDFG): - for state in graph_or_subgraph.nodes(): - tile_wcrs(state, validate_all) + if isinstance(graph, ControlFlowRegion): + for block in graph_or_subgraph.nodes(): + if isinstance(block, SDFGState): + tile_wcrs(block, validate_all) return + if not isinstance(graph, SDFGState): - raise TypeError('Graph must be a state, an SDFG, or a subgraph of either') + raise TypeError('Graph must be a state, an SDFG, a control flow region, or a subgraph of either') sdfg = graph.parent edges_to_consider: Set[Tuple[gr.MultiConnectorEdge[Memlet], nodes.MapEntry]] = set() @@ -393,7 +397,7 @@ def set_fast_implementations(sdfg: SDFG, device: dtypes.DeviceType, blocklist: L # specialized nodes: pre-expand for current_sdfg in sdfg.all_sdfgs_recursive(): - for state in current_sdfg.nodes(): + for state in current_sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.LibraryNode): if (node.default_implementation == 'specialize' @@ -461,7 +465,7 @@ def make_transients_persistent(sdfg: SDFG, persistent: Set[str] = set() not_persistent: Set[str] = set() - for state in nsdfg.nodes(): + for state in nsdfg.states(): for dnode in state.data_nodes(): if dnode.data in not_persistent: continue @@ -507,10 +511,9 @@ def make_transients_persistent(sdfg: SDFG, if device == dtypes.DeviceType.GPU: # Reset nonatomic WCR edges - for n, _ in sdfg.all_nodes_recursive(): - if isinstance(n, SDFGState): - for edge in n.edges(): - edge.data.wcr_nonatomic = False + for state in sdfg.states(): + for edge in state.edges(): + edge.data.wcr_nonatomic = False return result @@ -519,7 +522,7 @@ def apply_gpu_storage(sdfg: SDFG) -> None: """ Changes the storage of the SDFG's input and output data to GPU global memory. """ written_scalars = set() - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.data_nodes(): desc = node.desc(sdfg) if isinstance(desc, dt.Scalar) and not desc.transient and state.in_degree(node) > 0: diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index 303f1d0a64..db4c928481 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -5,7 +5,7 @@ from .mapreduce import MapReduceFusion, MapWCRFusion from .map_expansion import MapExpansion from .map_collapse import MapCollapse -from .map_for_loop import MapToForLoop +from .map_for_loop import MapToForLoop, MapToForLoopRegion from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle from .map_fusion import MapFusion diff --git a/dace/transformation/dataflow/buffer_tiling.py b/dace/transformation/dataflow/buffer_tiling.py index 2cf4bfa989..a418e167d8 100644 --- a/dace/transformation/dataflow/buffer_tiling.py +++ b/dace/transformation/dataflow/buffer_tiling.py @@ -7,7 +7,6 @@ from dace.transformation import transformation from dace.transformation.dataflow import MapTiling, MapTilingWithOverlap, MapFusion, TrivialMapElimination - @make_properties class BufferTiling(transformation.SingleStateTransformation): """ Implements the buffer tiling transformation. diff --git a/dace/transformation/dataflow/copy_to_device.py b/dace/transformation/dataflow/copy_to_device.py index 7421b9396e..28ce4dea59 100644 --- a/dace/transformation/dataflow/copy_to_device.py +++ b/dace/transformation/dataflow/copy_to_device.py @@ -4,13 +4,13 @@ from copy import deepcopy as dcpy from dace import data, properties, symbolic, dtypes -from dace.sdfg import graph, nodes +from dace.sdfg import nodes, SDFG from dace.sdfg import utils as sdutil from dace.transformation import transformation -def change_storage(sdfg, storage): - for state in sdfg.nodes(): +def change_storage(sdfg: SDFG, storage: dtypes.StorageType): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.AccessNode): node.desc(sdfg).storage = storage diff --git a/dace/transformation/dataflow/dedup_access.py b/dace/transformation/dataflow/dedup_access.py index 45955ac7af..0a0755049c 100644 --- a/dace/transformation/dataflow/dedup_access.py +++ b/dace/transformation/dataflow/dedup_access.py @@ -3,13 +3,11 @@ from collections import defaultdict import copy -import itertools -from typing import List, Set +from typing import List -from dace import data, dtypes, sdfg as sd, subsets, symbolic +from dace import sdfg as sd, subsets from dace.memlet import Memlet from dace.sdfg import nodes, graph as gr -from dace.sdfg import utils as sdutil from dace.transformation import transformation as xf import dace.transformation.helpers as helpers diff --git a/dace/transformation/dataflow/map_for_loop.py b/dace/transformation/dataflow/map_for_loop.py index b1d81e20a8..4295e8a0eb 100644 --- a/dace/transformation/dataflow/map_for_loop.py +++ b/dace/transformation/dataflow/map_for_loop.py @@ -3,23 +3,26 @@ """ import dace -from dace import data, registry, symbolic +from dace import symbolic from dace.sdfg import SDFG, SDFGState from dace.sdfg import nodes from dace.sdfg import utils as sdutil +from dace.sdfg.state import LoopRegion from dace.transformation import transformation -from typing import Tuple +from typing import Tuple, Optional -class MapToForLoop(transformation.SingleStateTransformation): +class MapToForLoopRegion(transformation.SingleStateTransformation): """ Implements the Map to for-loop transformation. - Takes a map and enforces a sequential schedule by transforming it into - a state-machine of a for-loop. Creates a nested SDFG, if necessary. + Takes a map and enforces a sequential schedule by transforming it into a loop region. Creates a nested SDFG, if + necessary. """ map_entry = transformation.PatternNode(nodes.MapEntry) + loop_region: Optional[LoopRegion] = None + @staticmethod def annotates_memlets(): return True @@ -79,11 +82,14 @@ def replace_param(param): # End of dynamic input range # Create a loop inside the nested SDFG - loop_result = nsdfg.add_loop(None, nstate, None, loop_idx, replace_param(loop_from), - '%s < %s' % (loop_idx, replace_param(loop_to + 1)), - '%s + %s' % (loop_idx, replace_param(loop_step))) - # store as object fields for external access - self.before_state, self.guard, self.after_state = loop_result + loop_region = LoopRegion('loop_' + map_entry.map.label, '%s < %s' % (loop_idx, replace_param(loop_to + 1)), + loop_idx, '%s = %s' % (loop_idx, replace_param(loop_from)), + '%s = %s + %s' % (loop_idx, loop_idx, replace_param(loop_step))) + nsdfg.add_node(loop_region, is_start_block=True) + nsdfg.remove_node(nstate) + loop_region.add_node(nstate, is_start_block=True) + # store as object field for external access + self.loop_region = loop_region # Skip map in input edges for edge in nstate.out_edges(map_entry): src_node = nstate.memlet_path(edge)[0].src @@ -104,4 +110,28 @@ def replace_param(param): # create object field for external nsdfg access self.nsdfg = nsdfg + sdfg.reset_cfg_list() + sdfg.root_sdfg.using_experimental_blocks = True + + return node, nstate + + +class MapToForLoop(MapToForLoopRegion): + """ Implements the Map to for-loop transformation. + + Takes a map and enforces a sequential schedule by transforming it into + a state-machine of a for-loop. Creates a nested SDFG, if necessary. + """ + + before_state: SDFGState + guard: SDFGState + after_state: SDFGState + + def apply(self, graph: SDFGState, sdfg: SDFG) -> Tuple[nodes.NestedSDFG, SDFGState]: + node, nstate = super().apply(graph, sdfg) + _, (self.before_state, self.guard, self.after_state) = self.loop_region.inline() + + sdfg.reset_cfg_list() + sdfg.recheck_using_experimental_blocks() + return node, nstate diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 186ea32acc..a6762d45c4 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -84,7 +84,7 @@ def find_permutation(first_map: nodes.Map, second_map: nodes.Map) -> Union[List[ return result - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False): first_map_exit = self.first_map_exit first_map_entry = graph.entry_node(first_map_exit) second_map_entry = self.second_map_entry @@ -105,9 +105,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): intermediate_data.add(dst.data) # If array is used anywhere else in this state. - num_occurrences = len([ - n for s in sdfg.nodes() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n.data == dst.data - ]) + num_occurrences = len([n for n in sdfg.data_nodes() if n.data == dst.data]) if num_occurrences > 1: return False else: @@ -430,7 +428,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): # Fix scope exit to point to the right map second_exit.map = first_entry.map - def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None): + def fuse_nodes(self, sdfg: SDFG, graph: SDFGState, edge, new_dst, new_dst_conn, other_edges=None): """ Fuses two nodes via memlets and possibly transient arrays. """ other_edges = other_edges or [] memlet_path = graph.memlet_path(edge) diff --git a/dace/transformation/dataflow/mapreduce.py b/dace/transformation/dataflow/mapreduce.py index d111cc32b6..0eef39c3cb 100644 --- a/dace/transformation/dataflow/mapreduce.py +++ b/dace/transformation/dataflow/mapreduce.py @@ -133,7 +133,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): # Add initialization state as necessary if not self.no_init and reduce_node.identity is not None: - init_state = sdfg.add_state_before(graph) + init_state = graph.parent_graph.add_state_before(graph) init_state.add_mapped_tasklet( 'freduce_init', [('o%d' % i, '%s:%s:%s' % (r[0], r[1] + 1, r[2])) for i, r in enumerate(array_edge.data.subset)], {}, diff --git a/dace/transformation/dataflow/otf_map_fusion.py b/dace/transformation/dataflow/otf_map_fusion.py index 0ff55213d7..a793d1e679 100644 --- a/dace/transformation/dataflow/otf_map_fusion.py +++ b/dace/transformation/dataflow/otf_map_fusion.py @@ -159,7 +159,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): xform = InLocalStorage() xform._sdfg = sdfg - xform.state_id = sdfg.node_id(graph) + xform.state_id = graph.parent_graph.node_id(graph) xform.node_a = edge.src xform.node_b = edge.dst xform.array = intermediate_access_node.data @@ -177,7 +177,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): if edge.data.wcr is None: xform = OutLocalStorage() xform._sdfg = sdfg - xform.state_id = sdfg.node_id(graph) + xform.state_id = graph.parent_graph.node_id(graph) xform.node_a = edge.src xform.node_b = edge.dst xform.array = intermediate_access_node.data @@ -192,7 +192,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): else: xform = AccumulateTransient() xform._sdfg = sdfg - xform.state_id = sdfg.node_id(graph) + xform.state_id = graph.parent_graph.node_id(graph) xform.map_exit = edge.src xform.outer_map_exit = edge.dst xform.array = intermediate_access_node.data diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index 36352fef0d..a2b48ec595 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -57,7 +57,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg # Fission subgraph around nsdfg into its own state to avoid data races - nsdfg_state = helpers.state_fission_after(sdfg, state, nsdfg) + nsdfg_state = helpers.state_fission_after(state, nsdfg) read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set @@ -142,7 +142,7 @@ def _candidates(nsdfg: nodes.NestedSDFG) -> Set[str]: # Any symbol that is set in all outgoing edges is ignored from # this point local_ignore = None - for e in nsdfg.sdfg.out_edges(nstate): + for e in nstate.parent_graph.out_edges(nstate): # Look for symbols in condition candidates -= (set(map(str, symbolic.symbols_in_ast(e.data.condition.code[0]))) - ignore) @@ -226,7 +226,7 @@ def _candidates(cls, nsdfg: nodes.NestedSDFG) -> Tuple[Set[str], Set[Tuple[SDFGS return set(), set() # Remove candidates that are used in the nested SDFG - for nstate in nsdfg.sdfg.nodes(): + for nstate in nsdfg.sdfg.states(): for node in nstate.data_nodes(): if node.data in candidates: # If used in nested SDFG @@ -243,7 +243,7 @@ def _candidates(cls, nsdfg: nodes.NestedSDFG) -> Tuple[Set[str], Set[Tuple[SDFGS candidate_nodes.add((nstate, node)) # Any array that is used in interstate edges is removed - for e in nsdfg.sdfg.edges(): + for e in nsdfg.sdfg.all_interstate_edges(): candidates -= (set(map(str, symbolic.symbols_in_ast(e.data.condition.code[0])))) for assign in e.data.assignments.values(): candidates -= (symbolic.free_symbols_and_functions(assign)) diff --git a/dace/transformation/dataflow/reduce_expansion.py b/dace/transformation/dataflow/reduce_expansion.py index 7be35b2914..5d3bcb594c 100644 --- a/dace/transformation/dataflow/reduce_expansion.py +++ b/dace/transformation/dataflow/reduce_expansion.py @@ -16,11 +16,6 @@ from dace.sdfg.propagation import propagate_memlets_scope from copy import deepcopy as dcpy -from typing import List - -import numpy as np - -import timeit @make_properties @@ -229,8 +224,7 @@ def expand(self, sdfg: SDFG, graph: SDFGState, reduce_node): # inline fuse back our nested SDFG from dace.transformation.interstate import InlineSDFG inline_sdfg = InlineSDFG() - inline_sdfg.setup_match(sdfg, sdfg.cfg_id, sdfg.node_id(graph), {InlineSDFG.nested_sdfg: graph.node_id(nsdfg)}, - 0) + inline_sdfg.setup_match(sdfg, sdfg.cfg_id, graph.block_id, {InlineSDFG.nested_sdfg: graph.node_id(nsdfg)}, 0) inline_sdfg.apply(graph, sdfg) new_schedule = dtypes.ScheduleType.Default diff --git a/dace/transformation/dataflow/redundant_array.py b/dace/transformation/dataflow/redundant_array.py index 680936dc70..1cffa1ed59 100644 --- a/dace/transformation/dataflow/redundant_array.py +++ b/dace/transformation/dataflow/redundant_array.py @@ -368,11 +368,8 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): return True # Find occurrences in this and other states - occurrences = [] - for state in sdfg.nodes(): - occurrences.extend( - [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == in_array.data]) - for isedge in sdfg.edges(): + occurrences = [n for n in sdfg.data_nodes() if n.data == in_array.data] + for isedge in sdfg.all_interstate_edges(): if in_array.data in isedge.data.free_symbols: occurrences.append(isedge) @@ -811,11 +808,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False # Find occurrences in this and other states - occurrences = [] - for state in sdfg.nodes(): - occurrences.extend( - [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == out_array.data]) - for isedge in sdfg.edges(): + occurrences = [n for n in sdfg.data_nodes() if n.data == out_array.data] + for isedge in sdfg.all_interstate_edges(): if out_array.data in isedge.data.free_symbols: occurrences.append(isedge) diff --git a/dace/transformation/dataflow/stream_transient.py b/dace/transformation/dataflow/stream_transient.py index 2c9f9febd5..b8c0f5820c 100644 --- a/dace/transformation/dataflow/stream_transient.py +++ b/dace/transformation/dataflow/stream_transient.py @@ -189,15 +189,13 @@ def apply(self, graph: SDFGState, sdfg: SDFG): warnings.warn('AccumulateTransient did not properly initialize ' 'newly-created transient!') return - sdfg_state: SDFGState = sdfg.node(self.state_id) - - map_entry = sdfg_state.entry_node(map_exit) + map_entry = graph.entry_node(map_exit) nested_sdfg: NestedSDFG = nest_state_subgraph(sdfg=sdfg, - state=sdfg_state, + state=graph, subgraph=SubgraphView( - sdfg_state, {map_entry, map_exit} - | sdfg_state.all_nodes_between(map_entry, map_exit))) + graph, {map_entry, map_exit} + | graph.all_nodes_between(map_entry, map_exit))) nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0] diff --git a/dace/transformation/dataflow/streaming_memory.py b/dace/transformation/dataflow/streaming_memory.py index 4cf40b30bf..2c5e31e8e4 100644 --- a/dace/transformation/dataflow/streaming_memory.py +++ b/dace/transformation/dataflow/streaming_memory.py @@ -234,7 +234,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Check if map has the right access pattern # Stride 1 access by innermost loop, innermost loop counter has to be divisible by vector size # Same code as in apply - state = sdfg.node(self.state_id) + state = graph dnode: nodes.AccessNode = self.access if self.expr_index == 0: edges = state.out_edges(dnode) @@ -705,7 +705,7 @@ def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: find_new_name=True) # Remove transient array if possible - for ostate in sdfg.nodes(): + for ostate in sdfg.states(): if ostate is state: continue if any(n.data == access.data for n in ostate.data_nodes()): diff --git a/dace/transformation/dataflow/strip_mining.py b/dace/transformation/dataflow/strip_mining.py index 48703126cd..fafcd4585d 100644 --- a/dace/transformation/dataflow/strip_mining.py +++ b/dace/transformation/dataflow/strip_mining.py @@ -466,7 +466,7 @@ def _stripmine(self, sdfg: SDFG, graph: SDFGState, map_entry: nodes.MapEntry): # Skew if necessary if self.skew: - xfh.offset_map(sdfg, graph, map_entry, dim_idx, td_rng[0]) + xfh.offset_map(graph, map_entry, dim_idx, td_rng[0]) # Return strip-mined dimension. return target_dim, new_dim, new_map diff --git a/dace/transformation/dataflow/sve/infer_types.py b/dace/transformation/dataflow/sve/infer_types.py index 7cbef36f96..fcb16cce0a 100644 --- a/dace/transformation/dataflow/sve/infer_types.py +++ b/dace/transformation/dataflow/sve/infer_types.py @@ -169,7 +169,7 @@ def infer_connector_types(sdfg: SDFG, raise ValueError('No SDFG was provided') if state is None and graph is None: - for state in sdfg.nodes(): + for state in sdfg.states(): for node in dfs_topological_sort(state): infer_node_connectors(sdfg, state, node, inferred) diff --git a/dace/transformation/dataflow/tiling_with_overlap.py b/dace/transformation/dataflow/tiling_with_overlap.py index 1af3586c39..e7fda71e82 100644 --- a/dace/transformation/dataflow/tiling_with_overlap.py +++ b/dace/transformation/dataflow/tiling_with_overlap.py @@ -2,10 +2,8 @@ """ This module contains classes and functions that implement the orthogonal tiling with overlap transformation. """ -from dace import registry from dace.properties import make_properties, ShapeProperty from dace.transformation.dataflow import MapTiling -from dace.sdfg import nodes from dace.symbolic import pystr_to_symbolic diff --git a/dace/transformation/dataflow/warp_tiling.py b/dace/transformation/dataflow/warp_tiling.py index 211910eebf..362b51d9ac 100644 --- a/dace/transformation/dataflow/warp_tiling.py +++ b/dace/transformation/dataflow/warp_tiling.py @@ -123,7 +123,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry: write = nstate.add_write(name) edge = nstate.add_nedge(read, write, copy.deepcopy(out_edge.data)) edge.data.wcr = None - xfh.state_fission(nsdfg, SubgraphView(nstate, [read, write])) + xfh.state_fission(SubgraphView(nstate, [read, write])) newnode = nstate.add_access(name) nstate.remove_edge(out_edge) diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 1a0ecf6bc4..60da5d3939 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -150,7 +150,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): # If state fission is necessary to keep semantics, do it first if state.in_degree(input) > 0: - new_state = helpers.state_fission_after(sdfg, state, tasklet) + new_state = helpers.state_fission_after(state, tasklet) else: new_state = state diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index cd73b96a68..cef0ca0fc6 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -647,7 +647,7 @@ def nest_state_subgraph(sdfg: SDFG, return nested_sdfg -def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] = None) -> SDFGState: +def state_fission(subgraph: graph.SubgraphView, label: Optional[str] = None) -> SDFGState: """ Given a subgraph, adds a new SDFG state before the state that contains it, removes the subgraph from the original state, and connects the two states. @@ -657,7 +657,7 @@ def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] """ state: SDFGState = subgraph.graph - newstate = sdfg.add_state_before(state, label=label) + newstate = state.parent_graph.add_state_before(state, label=label) # Save edges before removing nodes orig_edges = subgraph.edges() @@ -687,10 +687,10 @@ def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] return newstate -def state_fission_after(sdfg: SDFG, state: SDFGState, node: nodes.Node, label: Optional[str] = None) -> SDFGState: +def state_fission_after(state: SDFGState, node: nodes.Node, label: Optional[str] = None) -> SDFGState: """ """ - newstate = sdfg.add_state_after(state, label=label) + newstate = state.parent_graph.add_state_after(state, label=label) # Bookkeeping nodes_to_move = set([node]) @@ -930,8 +930,7 @@ def replicate_scope(sdfg: SDFG, state: SDFGState, scope: ScopeSubgraphView) -> S return ScopeSubgraphView(state, new_nodes, new_entry) -def offset_map(sdfg: SDFG, - state: SDFGState, +def offset_map(state: SDFGState, entry: nodes.MapEntry, dim: int, offset: symbolic.SymbolicType, @@ -939,7 +938,6 @@ def offset_map(sdfg: SDFG, """ Offsets a map parameter and its contents by a value. - :param sdfg: The SDFG in which the map resides. :param state: The state in which the map resides. :param entry: The map entry node. :param dim: The map dimension to offset. diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index b60b1891b1..b8bcc716e6 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -1,7 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ This module initializes the inter-state transformations package.""" -from .control_flow_inline import LoopRegionInline from .state_fusion import StateFusion from .state_fusion_with_happens_before import StateFusionExtended from .state_elimination import (EndStateElimination, StartStateElimination, StateAssignElimination, diff --git a/dace/transformation/interstate/control_flow_inline.py b/dace/transformation/interstate/control_flow_inline.py deleted file mode 100644 index b86317b8ed..0000000000 --- a/dace/transformation/interstate/control_flow_inline.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Inline control flow regions in SDFGs. """ - -from typing import Set, Optional - -from dace.frontend.python import astutils -from dace.sdfg import SDFG, InterstateEdge, SDFGState -from dace.sdfg import utils as sdutil -from dace.sdfg.nodes import CodeBlock -from dace.sdfg.state import ControlFlowRegion, LoopRegion -from dace.transformation import transformation - - -class LoopRegionInline(transformation.MultiStateTransformation): - """ - Inlines a loop regions into a single state machine. - """ - - loop = transformation.PatternNode(LoopRegion) - - @staticmethod - def annotates_memlets(): - return False - - @classmethod - def expressions(cls): - return [sdutil.node_path_graph(cls.loop)] - - def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: - # Check that the loop initialization and update statements each only contain assignments, if the loop has any. - if self.loop.init_statement is not None: - if isinstance(self.loop.init_statement.code, list): - for stmt in self.loop.init_statement.code: - if not isinstance(stmt, astutils.ast.Assign): - return False - if self.loop.update_statement is not None: - if isinstance(self.loop.update_statement.code, list): - for stmt in self.loop.update_statement.code: - if not isinstance(stmt, astutils.ast.Assign): - return False - return True - - def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]: - parent: ControlFlowRegion = graph - - internal_start = self.loop.start_block - - # Add all boilerplate loop states necessary for the structure. - init_state = parent.add_state(self.loop.label + '_init') - guard_state = parent.add_state(self.loop.label + '_guard') - end_state = parent.add_state(self.loop.label + '_end') - loop_tail_state = parent.add_state(self.loop.label + '_tail') - - # Add all loop states and make sure to keep track of all the ones that need to be connected in the end. - to_connect: Set[SDFGState] = set() - for node in self.loop.nodes(): - parent.add_node(node) - if self.loop.out_degree(node) == 0: - to_connect.add(node) - - # Handle break and continue. - for continue_state_id in self.loop.continue_states: - continue_state = self.loop.node(continue_state_id) - to_connect.add(continue_state) - for break_state_id in self.loop.break_states: - break_state = self.loop.node(break_state_id) - parent.add_edge(break_state, end_state, InterstateEdge()) - - # Add all internal loop edges. - for edge in self.loop.edges(): - parent.add_edge(edge.src, edge.dst, edge.data) - - # Redirect all edges to the loop to the init state. - for b_edge in parent.in_edges(self.loop): - parent.add_edge(b_edge.src, init_state, b_edge.data) - parent.remove_edge(b_edge) - # Redirect all edges exiting the loop to instead exit the end state. - for a_edge in parent.out_edges(self.loop): - parent.add_edge(end_state, a_edge.dst, a_edge.data) - parent.remove_edge(a_edge) - - # Add an initialization edge that initializes the loop variable if applicable. - init_edge = InterstateEdge() - if self.loop.init_statement is not None: - init_edge.assignments = {} - for stmt in self.loop.init_statement.code: - assign: astutils.ast.Assign = stmt - init_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) - if self.loop.inverted: - parent.add_edge(init_state, internal_start, init_edge) - else: - parent.add_edge(init_state, guard_state, init_edge) - - # Connect the loop tail. - update_edge = InterstateEdge() - if self.loop.update_statement is not None: - update_edge.assignments = {} - for stmt in self.loop.update_statement.code: - assign: astutils.ast.Assign = stmt - update_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) - parent.add_edge(loop_tail_state, guard_state, update_edge) - - # Add condition checking edges and connect the guard state. - cond_expr = self.loop.loop_condition.code - parent.add_edge(guard_state, end_state, - InterstateEdge(CodeBlock(astutils.negate_expr(cond_expr)).code)) - parent.add_edge(guard_state, internal_start, InterstateEdge(CodeBlock(cond_expr).code)) - - # Connect any end states from the loop's internal state machine to the tail state so they end a - # loop iteration. Do the same for any continue states. - for node in to_connect: - parent.add_edge(node, loop_tail_state, InterstateEdge()) - - # Remove the original loop. - parent.remove_node(self.loop) diff --git a/dace/transformation/interstate/fpga_transform_sdfg.py b/dace/transformation/interstate/fpga_transform_sdfg.py index 954c88d726..ac4672d892 100644 --- a/dace/transformation/interstate/fpga_transform_sdfg.py +++ b/dace/transformation/interstate/fpga_transform_sdfg.py @@ -8,6 +8,7 @@ @properties.make_properties +@transformation.single_level_sdfg_only class FPGATransformSDFG(transformation.MultiStateTransformation): """ Implements the FPGATransformSDFG transformation, which takes an entire SDFG and transforms it into an FPGA-capable SDFG. """ diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index dbf5c8d24d..60a2a33001 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -29,6 +29,7 @@ def fpga_update(sdfg, state, depth): fpga_update(node.sdfg, s, depth + 1) +@transformation.single_level_sdfg_only class FPGATransformState(transformation.MultiStateTransformation): """ Implements the FPGATransformState transformation. """ diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index c33fd6ae29..844651b071 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -83,6 +83,7 @@ def _recursive_in_check(node, state, gpu_scalars): @make_properties +@transformation.single_level_sdfg_only class GPUTransformSDFG(transformation.MultiStateTransformation): """ Implements the GPUTransformSDFG transformation. diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 274aed485f..da225232fe 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -8,10 +8,12 @@ from dace import sdfg as sd, symbolic from dace.sdfg import graph as gr, utils as sdutil +from dace.sdfg.state import ControlFlowRegion from dace.transformation import transformation # NOTE: This class extends PatternTransformation directly in order to not show up in the matches +@transformation.experimental_cfg_block_compatible class DetectLoop(transformation.PatternTransformation): """ Detects a for-loop construct from an SDFG. """ @@ -64,8 +66,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False # All nodes inside loop must be dominated by loop guard - dominators = nx.dominance.immediate_dominators(sdfg.nx, sdfg.start_state) - loop_nodes = sdutil.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard) + dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) + loop_nodes = sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard) backedge = None for node in loop_nodes: for e in graph.out_edges(node): @@ -101,7 +103,7 @@ def apply(self, _, sdfg): def find_for_loop( - sdfg: sd.SDFG, + graph: ControlFlowRegion, guard: sd.SDFGState, entry: sd.SDFGState, itervar: Optional[str] = None @@ -119,8 +121,8 @@ def find_for_loop( """ # Extract state transition edge information - guard_inedges = sdfg.in_edges(guard) - condition_edge = sdfg.edges_between(guard, entry)[0] + guard_inedges = graph.in_edges(guard) + condition_edge = graph.edges_between(guard, entry)[0] # All incoming edges to the guard must set the same variable if itervar is None: diff --git a/dace/transformation/interstate/loop_peeling.py b/dace/transformation/interstate/loop_peeling.py index 02d64a8829..5dc998c724 100644 --- a/dace/transformation/interstate/loop_peeling.py +++ b/dace/transformation/interstate/loop_peeling.py @@ -5,15 +5,18 @@ from typing import Optional from dace import sdfg as sd +from dace.sdfg.state import ControlFlowRegion from dace.properties import Property, make_properties, CodeBlock from dace.sdfg import graph as gr from dace.sdfg import utils as sdutil from dace.symbolic import pystr_to_symbolic from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from dace.transformation.interstate.loop_unroll import LoopUnroll +from dace.transformation.transformation import experimental_cfg_block_compatible @make_properties +@experimental_cfg_block_compatible class LoopPeeling(LoopUnroll): """ Splits the first `count` iterations of a state machine for-loop into @@ -73,7 +76,7 @@ def _modify_cond(self, condition, var, step): res = str(itersym) + op + str(end) return res - def apply(self, _, sdfg: sd.SDFG): + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): #################################################################### # Obtain loop information guard: sd.SDFGState = self.loop_guard @@ -81,16 +84,16 @@ def apply(self, _, sdfg: sd.SDFG): after_state: sd.SDFGState = self.exit_state # Obtain iteration variable, range, and stride - condition_edge = sdfg.edges_between(guard, begin)[0] - not_condition_edge = sdfg.edges_between(guard, after_state)[0] - itervar, rng, loop_struct = find_for_loop(sdfg, guard, begin) + condition_edge = graph.edges_between(guard, begin)[0] + not_condition_edge = graph.edges_between(guard, after_state)[0] + itervar, rng, loop_struct = find_for_loop(graph, guard, begin) # Get loop states - loop_states = list(sdutil.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard)) + loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard)) first_id = loop_states.index(begin) last_state = loop_struct[1] last_id = loop_states.index(last_state) - loop_subgraph = gr.SubgraphView(sdfg, loop_states) + loop_subgraph = gr.SubgraphView(graph, loop_states) #################################################################### # Transform @@ -101,7 +104,7 @@ def apply(self, _, sdfg: sd.SDFG): init_edges = [] before_states = loop_struct[0] for before_state in before_states: - init_edge = sdfg.edges_between(before_state, guard)[0] + init_edge = graph.edges_between(before_state, guard)[0] init_edge.data.assignments[itervar] = str(rng[0] + self.count * rng[2]) init_edges.append(init_edge) append_states = before_states @@ -122,15 +125,15 @@ def apply(self, _, sdfg: sd.SDFG): # Connect states to before the loop with unconditional edges for append_state in append_states: - sdfg.add_edge(append_state, new_states[first_id], sd.InterstateEdge()) + graph.add_edge(append_state, new_states[first_id], sd.InterstateEdge()) append_states = [new_states[last_id]] # Reconnect edge to guard state from last peeled iteration for append_state in append_states: if append_state not in before_states: for init_edge in init_edges: - sdfg.remove_edge(init_edge) - sdfg.add_edge(append_state, guard, init_edges[0].data) + graph.remove_edge(init_edge) + graph.add_edge(append_state, guard, init_edges[0].data) else: # If begin, change initialization assignment and prepend states before # guard @@ -155,10 +158,10 @@ def apply(self, _, sdfg: sd.SDFG): ) # Connect states to before the loop with unconditional edges - sdfg.add_edge(new_states[last_id], prepend_state, sd.InterstateEdge()) + graph.add_edge(new_states[last_id], prepend_state, sd.InterstateEdge()) prepend_state = new_states[first_id] # Reconnect edge to guard state from last peeled iteration if prepend_state != after_state: - sdfg.remove_edge(not_condition_edge) - sdfg.add_edge(guard, prepend_state, not_condition_edge.data) + graph.remove_edge(not_condition_edge) + graph.add_edge(guard, prepend_state, not_condition_edge.data) diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 8fb6600b76..7df057f1aa 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -75,6 +75,7 @@ def _sanitize_by_index(indices: Set[int], subset: subsets.Subset) -> subsets.Ran @make_properties +@xf.single_level_sdfg_only class LoopToMap(DetectLoop, xf.MultiStateTransformation): """Convert a control flow loop into a dataflow map. Currently only supports the simple case where there is no overlap between inputs and outputs in diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index b1dbfdd5c9..e6592b5519 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -8,11 +8,13 @@ from dace.properties import Property, make_properties from dace.sdfg import graph as gr from dace.sdfg import utils as sdutil +from dace.sdfg.state import ControlFlowRegion from dace.frontend.python.astutils import ASTFindReplace from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from dace.transformation import transformation as xf @make_properties +@xf.experimental_cfg_block_compatible class LoopUnroll(DetectLoop, xf.MultiStateTransformation): """ Unrolls a state machine for-loop into multiple states """ @@ -45,7 +47,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False return True - def apply(self, _, sdfg): + def apply(self, graph: ControlFlowRegion, sdfg): # Obtain loop information guard: sd.SDFGState = self.loop_guard begin: sd.SDFGState = self.loop_begin @@ -53,18 +55,18 @@ def apply(self, _, sdfg): # Obtain iteration variable, range, and stride, together with the last # state(s) before the loop and the last loop state. - itervar, rng, loop_struct = find_for_loop(sdfg, guard, begin) + itervar, rng, loop_struct = find_for_loop(graph, guard, begin) # Loop must be fully unrollable for now. if self.count != 0: raise NotImplementedError # TODO(later) # Get loop states - loop_states = list(sdutil.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard)) + loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard)) first_id = loop_states.index(begin) last_state = loop_struct[1] last_id = loop_states.index(last_state) - loop_subgraph = gr.SubgraphView(sdfg, loop_states) + loop_subgraph = gr.SubgraphView(graph, loop_states) try: start, end, stride = (r for r in rng) @@ -84,22 +86,22 @@ def apply(self, _, sdfg): # Connect iterations with unconditional edges if len(unrolled_states) > 0: - sdfg.add_edge(unrolled_states[-1][1], new_states[first_id], sd.InterstateEdge()) + graph.add_edge(unrolled_states[-1][1], new_states[first_id], sd.InterstateEdge()) unrolled_states.append((new_states[first_id], new_states[last_id])) # Get any assignments that might be on the edge to the after state - after_assignments = (sdfg.edges_between(guard, after_state)[0].data.assignments) + after_assignments = (graph.edges_between(guard, after_state)[0].data.assignments) # Connect new states to before and after states without conditions if unrolled_states: before_states = loop_struct[0] for before_state in before_states: - sdfg.add_edge(before_state, unrolled_states[0][0], sd.InterstateEdge()) - sdfg.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments)) + graph.add_edge(before_state, unrolled_states[0][0], sd.InterstateEdge()) + graph.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments)) # Remove old states from SDFG - sdfg.remove_nodes_from([guard] + loop_states) + graph.remove_nodes_from([guard] + loop_states) def instantiate_loop( self, @@ -119,6 +121,7 @@ def instantiate_loop( state.label = state.label + '_' + itervar + '_' + (state_suffix if state_suffix is not None else str(value)) state.replace(itervar, value) + graph = loop_states[0].parent_graph # Add subgraph to original SDFG for edge in loop_subgraph.edges(): src = new_states[loop_states.index(edge.src)] @@ -126,9 +129,9 @@ def instantiate_loop( # Replace conditions in subgraph edges data: sd.InterstateEdge = copy.deepcopy(edge.data) - if data.condition: + if not data.is_unconditional(): ASTFindReplace({itervar: str(value)}).visit(data.condition) - sdfg.add_edge(src, dst, data) + graph.add_edge(src, dst, data) return new_states diff --git a/dace/transformation/interstate/move_assignment_outside_if.py b/dace/transformation/interstate/move_assignment_outside_if.py index 3d4db9ae25..3b101818ca 100644 --- a/dace/transformation/interstate/move_assignment_outside_if.py +++ b/dace/transformation/interstate/move_assignment_outside_if.py @@ -13,6 +13,7 @@ from dace.transformation import transformation +@transformation.single_level_sdfg_only class MoveAssignmentOutsideIf(transformation.MultiStateTransformation): if_guard = transformation.PatternNode(sd.SDFGState) diff --git a/dace/transformation/interstate/move_loop_into_map.py b/dace/transformation/interstate/move_loop_into_map.py index 20c7b36e0f..916f9c5e41 100644 --- a/dace/transformation/interstate/move_loop_into_map.py +++ b/dace/transformation/interstate/move_loop_into_map.py @@ -23,6 +23,7 @@ def offset(memlet_subset_ranges, value): return (memlet_subset_ranges[0] + value, memlet_subset_ranges[1] + value, memlet_subset_ranges[2]) +@transformation.single_level_sdfg_only class MoveLoopIntoMap(DetectLoop, transformation.MultiStateTransformation): """ Moves a loop around a map into the map diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 0e4f1b4852..42dccd8616 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -1,29 +1,24 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ Inline multi-state SDFGs. """ -import ast -from collections import defaultdict from copy import deepcopy as dc -from dace.frontend.python.ndloop import ndrange import itertools -import networkx as nx -from typing import Callable, Dict, Iterable, List, Set, Optional, Tuple, Union -import warnings - -from dace import memlet, registry, sdfg as sd, Memlet, symbolic, dtypes, subsets -from dace.frontend.python import astutils -from dace.sdfg import nodes, propagation -from dace.sdfg.graph import MultiConnectorEdge, SubgraphView +from typing import Dict, List + +from dace import Memlet, symbolic, dtypes, subsets +from dace.sdfg import nodes +from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg import InterstateEdge, SDFG, SDFGState -from dace.sdfg import utils as sdutil, infer_types, propagation +from dace.sdfg import utils as sdutil, infer_types from dace.sdfg.replace import replace_datadesc_names from dace.transformation import transformation, helpers -from dace.properties import make_properties, Property +from dace.properties import make_properties from dace import data from dace.sdfg.state import StateSubgraphView @make_properties +@transformation.single_level_sdfg_only class InlineMultistateSDFG(transformation.SingleStateTransformation): """ Inlines a multi-state nested SDFG into a top-level SDFG. This only happens @@ -163,14 +158,14 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # Isolate nsdfg in a separate state # 1. Push nsdfg node plus dependencies down into new state - nsdfg_state = helpers.state_fission_after(sdfg, outer_state, nsdfg_node) + nsdfg_state = helpers.state_fission_after(outer_state, nsdfg_node) # 2. Push successors of nsdfg node into a later state direct_subgraph = set() direct_subgraph.add(nsdfg_node) direct_subgraph.update(nsdfg_state.predecessors(nsdfg_node)) direct_subgraph.update(nsdfg_state.successors(nsdfg_node)) direct_subgraph = StateSubgraphView(nsdfg_state, direct_subgraph) - nsdfg_state = helpers.state_fission(sdfg, direct_subgraph) + nsdfg_state = helpers.state_fission(direct_subgraph) # Find original source/destination edges (there is only one edge per # connector, according to match) diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index b362856bee..622dfe5595 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -2,13 +2,10 @@ """ SDFG nesting transformation. """ import ast -from collections import defaultdict from copy import deepcopy as dc -from dace.frontend.python.ndloop import ndrange import itertools import networkx as nx from typing import Callable, Dict, Iterable, List, Set, Tuple, Union -import warnings from functools import reduce import operator import copy @@ -25,6 +22,7 @@ @make_properties +@transformation.single_level_sdfg_only class InlineSDFG(transformation.SingleStateTransformation): """ Inlines a single-state nested SDFG into a top-level SDFG. @@ -565,7 +563,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): # Fission state if necessary cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): - helpers.state_fission(state.parent, cc) + helpers.state_fission(cc) for edge in removed_out_edges: # Find last access node that refers to this edge try: @@ -580,7 +578,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): cc2 = SubgraphView(state, [n for n in state.nodes() if n not in cc]) - state = helpers.state_fission(sdfg, cc2) + state = helpers.state_fission(cc2) ####################################################### # Remove nested SDFG node @@ -736,6 +734,7 @@ def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str, str], new @make_properties +@transformation.single_level_sdfg_only class InlineTransients(transformation.SingleStateTransformation): """ Inlines all transient arrays that are not used anywhere else into a @@ -879,6 +878,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: @make_properties +@transformation.single_level_sdfg_only class RefineNestedAccess(transformation.SingleStateTransformation): """ Reduces memlet shape when a memlet is connected to a nested SDFG, but not @@ -1102,6 +1102,7 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], @make_properties +@transformation.single_level_sdfg_only class NestSDFG(transformation.MultiStateTransformation): """ Implements SDFG Nesting, taking an SDFG as an input and creating a nested SDFG node from it. """ diff --git a/dace/transformation/interstate/state_elimination.py b/dace/transformation/interstate/state_elimination.py index cbb5d7b957..2640e30ccc 100644 --- a/dace/transformation/interstate/state_elimination.py +++ b/dace/transformation/interstate/state_elimination.py @@ -2,16 +2,17 @@ """ State elimination transformations """ import networkx as nx -from typing import Dict, List, Set +from typing import Dict, Set -from dace import data as dt, dtypes, registry, sdfg, symbolic +from dace import data as dt, sdfg, symbolic from dace.properties import CodeBlock -from dace.sdfg import nodes, SDFG, SDFGState, InterstateEdge +from dace.sdfg import nodes, SDFG, SDFGState from dace.sdfg import utils as sdutil +from dace.sdfg.state import ControlFlowRegion from dace.transformation import transformation -from dace.sdfg.analysis import cfg +@transformation.experimental_cfg_block_compatible class EndStateElimination(transformation.MultiStateTransformation): """ End-state elimination removes a redundant state that has one incoming edge @@ -47,18 +48,19 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph, sdfg): state = self.end_state # Handle orphan symbols (due to the deletion the incoming edge) - edge = sdfg.in_edges(state)[0] + edge = graph.in_edges(state)[0] sym_assign = edge.data.assignments.keys() - sdfg.remove_node(state) + graph.remove_node(state) # Remove orphan symbols for sym in sym_assign: if sym in sdfg.free_symbols: sdfg.remove_symbol(sym) +@transformation.experimental_cfg_block_compatible class StartStateElimination(transformation.MultiStateTransformation): """ Start-state elimination removes a redundant state that has one outgoing edge @@ -102,14 +104,14 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph, sdfg): state = self.start_state # Move assignments to the nested SDFG node's symbol mappings node = sdfg.parent_nsdfg_node - edge = sdfg.out_edges(state)[0] + edge = graph.out_edges(state)[0] for k, v in edge.data.assignments.items(): node.symbol_mapping[k] = v - sdfg.remove_node(state) + graph.remove_node(state) def _assignments_to_consider(sdfg, edge, is_constant=False): @@ -131,6 +133,7 @@ def _assignments_to_consider(sdfg, edge, is_constant=False): return assignments_to_consider +@transformation.experimental_cfg_block_compatible class StateAssignElimination(transformation.MultiStateTransformation): """ State assign elimination removes all assignments into the final state @@ -166,14 +169,14 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Otherwise, ensure the symbols are never set/used again in edges akeys = set(assignments_to_consider.keys()) - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): if e is edge: continue if e.data.free_symbols & akeys: return False # If used in any state that is not the current one, fail - for s in sdfg.nodes(): + for s in sdfg.states(): if s is state: continue if s.free_symbols & akeys: @@ -181,9 +184,9 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph, sdfg): state = self.end_state - edge = sdfg.in_edges(state)[0] + edge = graph.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. @@ -199,7 +202,7 @@ def apply(self, _, sdfg): # Remove assignments from edge del edge.data.assignments[varname] - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): if varname in e.data.free_symbols: break else: @@ -227,6 +230,7 @@ def _alias_assignments(sdfg, edge): return assignments_to_consider +@transformation.single_level_sdfg_only class SymbolAliasPromotion(transformation.MultiStateTransformation): """ SymbolAliasPromotion moves inter-state assignments that create symbolic @@ -331,6 +335,7 @@ def apply(self, _, sdfg): in_edge.assignments[k] = v +@transformation.single_level_sdfg_only class HoistState(transformation.SingleStateTransformation): """ Move a state out of a nested SDFG """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @@ -484,6 +489,7 @@ def replfunc(m): nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst) +@transformation.experimental_cfg_block_compatible class TrueConditionElimination(transformation.MultiStateTransformation): """ If a state transition condition is always true, removes condition from edge. @@ -512,13 +518,14 @@ def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): return False - def apply(self, _, sdfg: SDFG): + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): a: SDFGState = self.state_a b: SDFGState = self.state_b - edge = sdfg.edges_between(a, b)[0] + edge = graph.edges_between(a, b)[0] edge.data.condition = CodeBlock("1") +@transformation.experimental_cfg_block_compatible class FalseConditionElimination(transformation.MultiStateTransformation): """ If a state transition condition is always false, removes edge. @@ -556,8 +563,8 @@ def can_be_applied(self, graph: SDFG, expr_index, sdfg: SDFG, permissive=False): return False - def apply(self, _, sdfg: SDFG): + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): a: SDFGState = self.state_a b: SDFGState = self.state_b - edge = sdfg.edges_between(a, b)[0] + edge = graph.edges_between(a, b)[0] sdfg.remove_edge(edge) diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 6db62a097e..3abbe085f5 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -32,6 +32,7 @@ def top_level_nodes(state: SDFGState): return state.scope_children()[None] +@transformation.experimental_cfg_block_compatible class StateFusion(transformation.MultiStateTransformation): """ Implements the state-fusion transformation. @@ -458,29 +459,31 @@ def apply(self, _, sdfg): first_state: SDFGState = self.first_state second_state: SDFGState = self.second_state + graph = first_state.parent_graph + # Remove interstate edge(s) - edges = sdfg.edges_between(first_state, second_state) + edges = graph.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: - for src, dst, other_data in sdfg.in_edges(first_state): + for src, dst, other_data in graph.in_edges(first_state): other_data.assignments.update(edge.data.assignments) - sdfg.remove_edge(edge) + graph.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): - sdutil.change_edge_dest(sdfg, first_state, second_state) - sdfg.remove_node(first_state) - if sdfg.start_state == first_state: - sdfg.start_state = sdfg.node_id(second_state) + sdutil.change_edge_dest(graph, first_state, second_state) + graph.remove_node(first_state) + if graph.start_block == first_state: + graph.start_block = graph.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): - sdutil.change_edge_src(sdfg, second_state, first_state) - sdutil.change_edge_dest(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + sdutil.change_edge_dest(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) return # Normal case: both states are not empty @@ -562,7 +565,7 @@ def apply(self, _, sdfg): merged_nodes.add(n) # Redirect edges and remove second state - sdutil.change_edge_src(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) diff --git a/dace/transformation/interstate/state_fusion_with_happens_before.py b/dace/transformation/interstate/state_fusion_with_happens_before.py index 4c6ad3c992..408f5a76f2 100644 --- a/dace/transformation/interstate/state_fusion_with_happens_before.py +++ b/dace/transformation/interstate/state_fusion_with_happens_before.py @@ -5,7 +5,7 @@ import networkx as nx -from dace import data as dt, dtypes, registry, sdfg, subsets, memlet +from dace import data as dt, sdfg, subsets, memlet from dace.config import Config from dace.sdfg import nodes from dace.sdfg import utils as sdutil @@ -31,6 +31,7 @@ def top_level_nodes(state: SDFGState): return state.scope_children()[None] +@transformation.single_level_sdfg_only class StateFusionExtended(transformation.MultiStateTransformation): """ Implements the state-fusion transformation extended to fuse states with RAW and WAW dependencies. An empty memlet is used to represent a dependency between two subgraphs with RAW and WAW dependencies. diff --git a/dace/transformation/interstate/trivial_loop_elimination.py b/dace/transformation/interstate/trivial_loop_elimination.py index d4c8b13553..d214cb5343 100644 --- a/dace/transformation/interstate/trivial_loop_elimination.py +++ b/dace/transformation/interstate/trivial_loop_elimination.py @@ -7,6 +7,7 @@ from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) +@transformation.single_level_sdfg_only class TrivialLoopElimination(DetectLoop, transformation.MultiStateTransformation): """ Eliminates loops with a single loop iteration. diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 4e16bb6207..494f9c39ae 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -2,6 +2,7 @@ """ API for SDFG analysis and manipulation Passes, as well as Pipelines that contain multiple dependent passes. """ +import warnings from dace import properties, serialize from dace.sdfg import SDFG, SDFGState, graph as gr, nodes, utils as sdutil @@ -492,9 +493,35 @@ def apply_subpass(self, sdfg: SDFG, p: Pass, state: Dict[str, Any]) -> Optional[ :param state: The pipeline results state. :return: The pass return value. """ + if sdfg.root_sdfg.using_experimental_blocks: + if (not hasattr(p, '__experimental_cfg_block_compatible__') or + p.__experimental_cfg_block_compatible__ == False): + warnings.warn(p.__class__.__name__ + ' is not being applied due to incompatibility with ' + + 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + + 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + + 'True. If ' + p.__class__.__name__ + ' is compatible with experimental blocks, ' + + 'please annotate it with the class decorator ' + + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + + 'for more information.') + return None + return p.apply_pass(sdfg, state) def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if sdfg.root_sdfg.using_experimental_blocks: + if (not hasattr(self, '__experimental_cfg_block_compatible__') or + self.__experimental_cfg_block_compatible__ == False): + warnings.warn('Pipeline ' + self.__class__.__name__ + ' is being skipped due to incompatibility with ' + + 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + + 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + + 'True. If ' + self.__class__.__name__ + ' is compatible with experimental blocks, ' + + 'please annotate it with the class decorator ' + + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + + 'for more information.') + return None + state = pipeline_results retval = {} self._modified = Modifies.Nothing diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index 82cae6e470..c8bb0b7a9c 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -156,7 +156,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.nodes(): + for state in sdfg.states(): readset, writeset = set(), set() for anode in state.data_nodes(): if state.in_degree(anode) > 0: diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index 6e1253ec3a..a25858b0d6 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -5,7 +5,7 @@ from dace import SDFG, SDFGState, data, properties from dace.sdfg import nodes from dace.sdfg.analysis import cfg -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.dataflow import (RedundantArray, RedundantReadSlice, RedundantSecondArray, RedundantWriteSlice, SqueezeViewRemove, UnsqueezeViewRemove, RemoveSliceView) from dace.transformation.passes import analysis as ap @@ -13,6 +13,7 @@ @properties.make_properties +@transformation.single_level_sdfg_only class ArrayElimination(ppl.Pass): """ Merges and removes arrays and their corresponding accesses. This includes redundant array copies, unnecessary views, diff --git a/dace/transformation/passes/consolidate_edges.py b/dace/transformation/passes/consolidate_edges.py index 148998c28c..5b1aae2621 100644 --- a/dace/transformation/passes/consolidate_edges.py +++ b/dace/transformation/passes/consolidate_edges.py @@ -5,8 +5,11 @@ from dace import SDFG, properties from typing import Optional +from dace.transformation.transformation import experimental_cfg_block_compatible + @properties.make_properties +@experimental_cfg_block_compatible class ConsolidateEdges(ppl.Pass): """ Removes extraneous edges with memlets that refer to the same data containers within the same scope. diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 50aac77ae4..b0a20f70d6 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -6,7 +6,7 @@ from dace.sdfg.analysis import cfg from dace.sdfg.sdfg import InterstateEdge from dace.sdfg import nodes, utils as sdutil -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation from dace.cli.progress import optional_progressbar from dace import data, SDFG, SDFGState, dtypes, symbolic, properties from typing import Any, Dict, Set, Optional, Tuple @@ -19,6 +19,7 @@ class _UnknownValue: @dataclass(unsafe_hash=True) @properties.make_properties +@transformation.single_level_sdfg_only class ConstantPropagation(ppl.Pass): """ Propagates constants and symbols that were assigned to one value forward through the SDFG, reducing diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 9a09119825..fe181d01b4 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -11,7 +11,7 @@ from dace.sdfg import utils as sdutil from dace.sdfg.analysis import cfg from dace.sdfg import infer_types -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap PROTECTED_NAMES = {'__pystate'} #: A set of names that are not allowed to be erased @@ -19,6 +19,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties +@transformation.single_level_sdfg_only class DeadDataflowElimination(ppl.Pass): """ Removes unused computations from SDFG states. diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index a5ff0ba71a..43239fe9af 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -8,10 +8,11 @@ from dace.properties import CodeBlock from dace.sdfg.graph import Edge from dace.sdfg.validation import InvalidSDFGInterstateEdgeError -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation @properties.make_properties +@transformation.single_level_sdfg_only class DeadStateElimination(ppl.Pass): """ Removes all unreachable states (e.g., due to a branch that will never be taken) from an SDFG. diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 93764670e8..9a97afb569 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -10,10 +10,12 @@ from dace.sdfg import nodes from dace.sdfg.utils import fuse_states, inline_sdfgs from dace.transformation import pass_pipeline as ppl +from dace.transformation.transformation import experimental_cfg_block_compatible @dataclass(unsafe_hash=True) @properties.make_properties +@experimental_cfg_block_compatible class FuseStates(ppl.Pass): """ Fuses all possible states of an SDFG (and all sub-SDFGs). @@ -87,6 +89,7 @@ def report(self, pass_retval: int) -> str: @dataclass(unsafe_hash=True) @properties.make_properties +@experimental_cfg_block_compatible class FixNestedSDFGReferences(ppl.Pass): """ Fixes nested SDFG references to parent state/SDFG/node diff --git a/dace/transformation/passes/optional_arrays.py b/dace/transformation/passes/optional_arrays.py index e43448415f..f52ee5af43 100644 --- a/dace/transformation/passes/optional_arrays.py +++ b/dace/transformation/passes/optional_arrays.py @@ -5,10 +5,11 @@ from dace import SDFG, SDFGState, data, properties from dace.sdfg import nodes from dace.sdfg import utils as sdutil -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation @properties.make_properties +@transformation.single_level_sdfg_only class OptionalArrayInference(ppl.Pass): """ Infers the ``optional`` property of arrays, i.e., if they can be given None, throughout the SDFG and all nested diff --git a/dace/transformation/passes/pattern_matching.py b/dace/transformation/passes/pattern_matching.py index 31b68057c3..a046a557ce 100644 --- a/dace/transformation/passes/pattern_matching.py +++ b/dace/transformation/passes/pattern_matching.py @@ -4,11 +4,13 @@ import collections from dataclasses import dataclass import time +import warnings from dace import properties from dace.config import Config from dace.sdfg import SDFG, SDFGState from dace.sdfg import graph as gr, nodes as nd +from dace.sdfg.state import ControlFlowRegion import networkx as nx from networkx.algorithms import isomorphism as iso from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union @@ -96,6 +98,20 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[str, # For every transformation in the list, find first match and apply for xform in self.transformations: + if sdfg.root_sdfg.using_experimental_blocks: + if (not hasattr(xform, '__experimental_cfg_block_compatible__') or + xform.__experimental_cfg_block_compatible__ == False): + warnings.warn('Pattern matching is skipping transformation ' + xform.__class__.__name__ + + ' due to incompatibility with experimental control flow blocks. If the ' + + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + + 'not have `SDFG.using_experimental_blocks` set to True. If ' + + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + + 'please annotate it with the class decorator ' + + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + + 'for more information.') + continue + # Find only the first match try: match = next(m for m in match_patterns( @@ -103,13 +119,13 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[str, except StopIteration: continue - tsdfg = sdfg.cfg_list[match.cfg_id] - graph = tsdfg.node(match.state_id) if match.state_id >= 0 else tsdfg + tcfg = sdfg.cfg_list[match.cfg_id] + graph = tcfg.node(match.state_id) if match.state_id >= 0 else tcfg # Set previous pipeline results match._pipeline_results = pipeline_results - result = match.apply(graph, tsdfg) + result = match.apply(graph, tcfg.sdfg) applied_transformations[type(match).__name__].append(result) if self.validate_all: sdfg.validate() @@ -156,16 +172,16 @@ def __init__(self, # Helper function for applying and validating a transformation def _apply_and_validate(self, match: xf.PatternTransformation, sdfg: SDFG, start: float, pipeline_results: Dict[str, Any], applied_transformations: Dict[str, Any]): - tsdfg = sdfg.cfg_list[match.cfg_id] - graph = tsdfg.node(match.state_id) if match.state_id >= 0 else tsdfg + tcfg = sdfg.cfg_list[match.cfg_id] + graph = tcfg.node(match.state_id) if match.state_id >= 0 else tcfg # Set previous pipeline results match._pipeline_results = pipeline_results if self.validate_all: - match_name = match.print_match(tsdfg) + match_name = match.print_match(tcfg) - applied_transformations[type(match).__name__].append(match.apply(graph, tsdfg)) + applied_transformations[type(match).__name__].append(match.apply(graph, tcfg.sdfg)) if self.progress or (self.progress is None and (time.time() - start) > 5): print('Applied {}.\r'.format(', '.join(['%d %s' % (len(v), k) for k, v in applied_transformations.items()])), @@ -200,6 +216,20 @@ def _apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any], apply_once: while applied_anything: applied_anything = False for xform in xforms: + if sdfg.root_sdfg.using_experimental_blocks: + if (not hasattr(xform, '__experimental_cfg_block_compatible__') or + xform.__experimental_cfg_block_compatible__ == False): + warnings.warn('Pattern matching is skipping transformation ' + xform.__class__.__name__ + + ' due to incompatibility with experimental control flow blocks. If the ' + + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + + 'not have `SDFG.using_experimental_blocks` set to True. If ' + + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + + 'please annotate it with the class decorator ' + + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + + 'for more information.') + continue + applied = True while applied: applied = False @@ -350,8 +380,9 @@ def type_or_class_match(node_a, node_b): return isinstance(node_a['node'], type(node_b['node'])) -def _try_to_match_transformation(graph: Union[SDFG, SDFGState], collapsed_graph: nx.DiGraph, subgraph: Dict[int, int], - sdfg: SDFG, xform: Union[xf.PatternTransformation, Type[xf.PatternTransformation]], +def _try_to_match_transformation(graph: Union[ControlFlowRegion, SDFGState], collapsed_graph: nx.DiGraph, + subgraph: Dict[int, int], sdfg: SDFG, + xform: Union[xf.PatternTransformation, Type[xf.PatternTransformation]], expr_idx: int, nxpattern: nx.DiGraph, state_id: int, permissive: bool, options: Dict[str, Any]) -> Optional[xf.PatternTransformation]: """ @@ -377,7 +408,22 @@ def _try_to_match_transformation(graph: Union[SDFG, SDFGState], collapsed_graph: for oname, oval in opts.items(): setattr(match, oname, oval) - match.setup_match(sdfg, sdfg.cfg_id, state_id, subgraph, expr_idx, options=options) + if sdfg.root_sdfg.using_experimental_blocks: + if (not hasattr(match, '__experimental_cfg_block_compatible__') or + match.__experimental_cfg_block_compatible__ == False): + warnings.warn('Pattern matching is skipping transformation ' + match.__class__.__name__ + + ' due to incompatibility with experimental control flow blocks. If the ' + + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + + 'not have `SDFG.using_experimental_blocks` set to True. If ' + + match.__class__.__name__ + ' is compatible with experimental blocks, ' + + 'please annotate it with the class decorator ' + + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + + 'for more information.') + return None + + cfg_id = graph.parent_graph.cfg_id if isinstance(graph, SDFGState) else graph.cfg_id + match.setup_match(sdfg, cfg_id, state_id, subgraph, expr_idx, options=options) match_found = match.can_be_applied(graph, expr_idx, sdfg, permissive=permissive) except Exception as e: if Config.get_bool('optimizer', 'match_exception'): @@ -513,19 +559,19 @@ def match_patterns(sdfg: SDFG, (interstate_transformations, singlestate_transformations) = get_transformation_metadata(patterns, options) # Collect SDFG and nested SDFGs - sdfgs = sdfg.all_sdfgs_recursive() + cfrs = sdfg.all_control_flow_regions(recursive=True) # Try to find transformations on each SDFG - for tsdfg in sdfgs: + for cfr in cfrs: ################################### # Match inter-state transformations if len(interstate_transformations) > 0: # Collapse multigraph into directed graph in order to use VF2 - digraph = collapse_multigraph_to_nx(tsdfg) + digraph = collapse_multigraph_to_nx(cfr) for xform, expr_idx, nxpattern, matcher, opts in interstate_transformations: for subgraph in matcher(digraph, nxpattern, node_match, edge_match): - match = _try_to_match_transformation(tsdfg, digraph, subgraph, tsdfg, xform, expr_idx, nxpattern, -1, + match = _try_to_match_transformation(cfr, digraph, subgraph, cfr.sdfg, xform, expr_idx, nxpattern, -1, permissive, opts) if match is not None: yield match @@ -534,8 +580,8 @@ def match_patterns(sdfg: SDFG, # Match single-state transformations if len(singlestate_transformations) == 0: continue - for state_id, state in enumerate(tsdfg.nodes()): - if states is not None and state not in states: + for state_id, state in enumerate(cfr.nodes()): + if not isinstance(state, SDFGState) or (states is not None and state not in states): continue # Collapse multigraph into directed graph in order to use VF2 @@ -543,7 +589,7 @@ def match_patterns(sdfg: SDFG, for xform, expr_idx, nxpattern, matcher, opts in singlestate_transformations: for subgraph in matcher(digraph, nxpattern, node_match, edge_match): - match = _try_to_match_transformation(state, digraph, subgraph, tsdfg, xform, expr_idx, nxpattern, + match = _try_to_match_transformation(state, digraph, subgraph, cfr.sdfg, xform, expr_idx, nxpattern, state_id, permissive, opts) if match is not None: yield match diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 336ac4b428..3b3940f804 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -6,11 +6,12 @@ from dace import SDFG, dtypes, properties, symbolic from dace.sdfg import nodes -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation @dataclass(unsafe_hash=True) @properties.make_properties +@transformation.single_level_sdfg_only class RemoveUnusedSymbols(ppl.Pass): """ Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``) and interstate edges. diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py index 0bccb4ea54..5bee098c55 100644 --- a/dace/transformation/passes/reference_reduction.py +++ b/dace/transformation/passes/reference_reduction.py @@ -6,11 +6,12 @@ from dace import SDFG, SDFGState, data, properties, Memlet from dace.sdfg import nodes from dace.sdfg.analysis import cfg -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap @properties.make_properties +@transformation.single_level_sdfg_only class ReferenceToView(ppl.Pass): """ Replaces Reference data descriptors that are only set to one source with views. diff --git a/dace/transformation/passes/scalar_fission.py b/dace/transformation/passes/scalar_fission.py index eb8faf33e6..f691a861d7 100644 --- a/dace/transformation/passes/scalar_fission.py +++ b/dace/transformation/passes/scalar_fission.py @@ -4,10 +4,11 @@ from dace import SDFG, InterstateEdge from dace.sdfg import nodes as nd -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap +@transformation.single_level_sdfg_only class ScalarFission(ppl.Pass): """ Fission transient scalars or arrays of size 1 that are dominated by a write into separate data containers. diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 124efdaae1..8b4f2a9be3 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -23,6 +23,7 @@ from dace.sdfg.sdfg import InterstateEdge from dace.transformation import helpers as xfh from dace.transformation import pass_pipeline as passes +from dace.transformation.transformation import experimental_cfg_block_compatible class AttributedCallDetector(ast.NodeVisitor): @@ -95,7 +96,7 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer # Check all occurrences of candidates in SDFG and filter out candidates_seen: Set[str] = set() - for state in sdfg.nodes(): + for state in sdfg.states(): candidates_in_state: Set[str] = set() for node in state.nodes(): @@ -225,7 +226,7 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer # Filter out non-integral symbols that do not appear in inter-state edges interstate_symbols = set() - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges(): interstate_symbols |= edge.data.free_symbols for candidate in (candidates - interstate_symbols): if integers_only and sdfg.arrays[candidate].dtype not in dtypes.INTEGER_TYPES: @@ -508,7 +509,7 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): replacement symbol name. :note: Operates in-place on the SDFG. """ - for state in sdfg.nodes(): + for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in array_names] for node in scalar_nodes: symname = array_names[node.data] @@ -585,6 +586,7 @@ def translate_cpp_tasklet_to_python(code: str): @dataclass(unsafe_hash=True) @props.make_properties +@experimental_cfg_block_compatible class ScalarToSymbolPromotion(passes.Pass): CATEGORY: str = 'Simplification' @@ -633,7 +635,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: if len(to_promote) == 0: return None - for state in sdfg.nodes(): + for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] # Step 2: Assignment tasklets for node in scalar_nodes: @@ -645,8 +647,8 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # There is only zero or one incoming edges by definition tasklet_inputs = [e.src for e in state.in_edges(input)] # Step 2.1 - new_state = xfh.state_fission(sdfg, gr.SubgraphView(state, set([input, node] + tasklet_inputs))) - new_isedge: sd.InterstateEdge = sdfg.out_edges(new_state)[0] + new_state = xfh.state_fission(gr.SubgraphView(state, set([input, node] + tasklet_inputs))) + new_isedge: sd.InterstateEdge = new_state.parent_graph.out_edges(new_state)[0] # Step 2.2 node: nodes.AccessNode = new_state.sink_nodes()[0] input = new_state.in_edges(node)[0].src @@ -683,7 +685,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: remove_scalar_reads(sdfg, {k: k for k in to_promote}) # Step 4: Isolated nodes - for state in sdfg.nodes(): + for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] state.remove_nodes_from([n for n in scalar_nodes if len(state.all_edges(n)) == 0]) @@ -699,7 +701,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # Step 6: Inter-state edge cleanup cleanup_re = {s: re.compile(fr'\b{re.escape(s)}\[.*?\]') for s in to_promote} promo = TaskletPromoterDict({k: k for k in to_promote}) - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges(): ise: InterstateEdge = edge.data # Condition if not edge.data.is_unconditional(): diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index 2b1411396c..81e8e88362 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -1,9 +1,10 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass from typing import Any, Dict, Optional, Set +import warnings from dace import SDFG, config, properties -from dace.transformation import helpers as xfh +from dace.transformation import helpers as xfh, transformation from dace.transformation import pass_pipeline as ppl from dace.transformation.passes.array_elimination import ArrayElimination from dace.transformation.passes.consolidate_edges import ConsolidateEdges @@ -42,6 +43,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties +@transformation.experimental_cfg_block_compatible class SimplifyPass(ppl.FixedPointPipeline): """ A pipeline that simplifies an SDFG by applying a series of simplification passes. @@ -79,6 +81,19 @@ def apply_subpass(self, sdfg: SDFG, p: ppl.Pass, state: Dict[str, Any]): """ Apply a pass from the pipeline. This method is meant to be overridden by subclasses. """ + if sdfg.root_sdfg.using_experimental_blocks: + if (not hasattr(p, '__experimental_cfg_block_compatible__') or + p.__experimental_cfg_block_compatible__ == False): + warnings.warn(p.__class__.__name__ + ' is not being applied due to incompatibility with ' + + 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + + 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + + 'True. If ' + p.__class__.__name__ + ' is compatible with experimental blocks, ' + + 'please annotate it with the class decorator ' + + '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + + 'for more information.') + return None + if type(p) in _nonrecursive_passes: # If pass needs to run recursively, do so and modify return value ret: Dict[int, Any] = {} for sd in sdfg.all_sdfgs_recursive(): diff --git a/dace/transformation/passes/symbol_ssa.py b/dace/transformation/passes/symbol_ssa.py index 6f0f4485b0..fa59f88df7 100644 --- a/dace/transformation/passes/symbol_ssa.py +++ b/dace/transformation/passes/symbol_ssa.py @@ -3,10 +3,11 @@ from typing import Any, Dict, Optional, Set from dace import SDFG, SDFGState -from dace.transformation import pass_pipeline as ppl +from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap +@transformation.single_level_sdfg_only class StrictSymbolSSA(ppl.Pass): """ Perform an SSA transformation on all symbols in the SDFG in a strict manner, i.e., without introducing phi nodes. diff --git a/dace/transformation/passes/transient_reuse.py b/dace/transformation/passes/transient_reuse.py index ed26cbfa57..0eacec1cf0 100644 --- a/dace/transformation/passes/transient_reuse.py +++ b/dace/transformation/passes/transient_reuse.py @@ -6,9 +6,11 @@ from dace import SDFG, properties from dace.sdfg import nodes from dace.transformation import pass_pipeline as ppl +from dace.transformation.transformation import experimental_cfg_block_compatible @properties.make_properties +@experimental_cfg_block_compatible class TransientReuse(ppl.Pass): """ Reduces memory consumption by reusing allocated transient array memory. Only modifies arrays that can safely be @@ -44,7 +46,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[str]]: if arrays[a] == 1: transients.add(a) - for state in sdfg.nodes(): + for state in sdfg.states(): # Copy the whole graph G = nx.MultiDiGraph() for n in state.nodes(): diff --git a/dace/transformation/subgraph/composite.py b/dace/transformation/subgraph/composite.py index 41d145aaa3..e25ccd192a 100644 --- a/dace/transformation/subgraph/composite.py +++ b/dace/transformation/subgraph/composite.py @@ -3,17 +3,14 @@ Subgraph Fusion - Stencil Tiling Transformation """ -import dace -from dace.transformation.subgraph import stencil_tiling - -import dace.transformation.transformation as transformation from dace.transformation.subgraph import SubgraphFusion, MultiExpansion from dace.transformation.subgraph.stencil_tiling import StencilTiling from dace.transformation.subgraph import helpers +from dace.transformation import transformation -from dace import dtypes, registry, symbolic, subsets, data +from dace import dtypes from dace.properties import EnumProperty, make_properties, Property, ShapeProperty -from dace.sdfg import SDFG, SDFGState +from dace.sdfg import SDFG from dace.sdfg.graph import SubgraphView import copy @@ -21,6 +18,7 @@ @make_properties +@transformation.single_level_sdfg_only class CompositeFusion(transformation.SubgraphTransformation): """ MultiExpansion + SubgraphFusion in one Transformation Additional StencilTiling is also possible as a canonicalizing diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index 6b03b2adba..1ba86252c4 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -584,7 +584,7 @@ def apply(self, sdfg): DetectLoop.exit_state: nsdfg.node_id(end) } transformation = LoopUnroll() - transformation.setup_match(nsdfg, 0, -1, subgraph, 0) + transformation.setup_match(nsdfg, nsdfg.cfg_id, -1, subgraph, 0) transformation.apply(nsdfg, nsdfg) elif self.unroll_loops: diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 8b87939ca8..bb4a730e24 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -23,11 +23,18 @@ from dace import dtypes, serialize from dace.dtypes import ScheduleType from dace.sdfg import SDFG, SDFGState +from dace.sdfg.state import ControlFlowRegion from dace.sdfg import nodes as nd, graph as gr, utils as sdutil, propagation, infer_types, state as st from dace.properties import make_properties, Property, DictProperty, SetProperty from dace.transformation import pass_pipeline as ppl -from typing import Any, Dict, Generic, List, Optional, Set, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, Set, Type, TypeVar, Union, Callable import pydoc +import warnings + + +def experimental_cfg_block_compatible(cls: ppl.Pass): + cls.__experimental_cfg_block_compatible__ = True + return cls class TransformationBase(ppl.Pass): @@ -108,15 +115,15 @@ def expressions(cls) -> List[gr.SubgraphView]: raise NotImplementedError def can_be_applied(self, - graph: Union[SDFG, SDFGState], + graph: Union[ControlFlowRegion, SDFGState], expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: """ Returns True if this transformation can be applied on the candidate matched subgraph. - :param graph: SDFGState object if this transformation is - single-state, or SDFG object otherwise. + :param graph: SDFGState object if this transformation is single-state, or ControlFlowRegion object + otherwise. :param expr_index: The list index from `PatternTransformation.expressions` that was matched. :param sdfg: If `graph` is an SDFGState, its parent SDFG. Otherwise @@ -126,7 +133,7 @@ def can_be_applied(self, """ raise NotImplementedError - def apply(self, graph: Union[SDFG, SDFGState], sdfg: SDFG) -> Union[Any, None]: + def apply(self, graph: Union[ControlFlowRegion, SDFGState], sdfg: SDFG) -> Union[Any, None]: """ Applies this transformation instance on the matched pattern graph. @@ -142,7 +149,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[A self._pipeline_results = pipeline_results return self.apply_pattern() - def match_to_str(self, graph: Union[SDFG, SDFGState]) -> str: + def match_to_str(self, graph: Union[ControlFlowRegion, SDFGState]) -> str: """ Returns a string representation of the pattern match on the candidate subgraph. Used when identifying matches in the console UI. @@ -364,16 +371,16 @@ def apply_to(cls, def __str__(self) -> str: return type(self).__name__ - def print_match(self, sdfg: SDFG) -> str: + def print_match(self, cfg: ControlFlowRegion) -> str: """ Returns a string representation of the pattern match on the - given SDFG. Used for printing matches in the console UI. + given Control Flow Region. Used for printing matches in the console UI. """ - if not isinstance(sdfg, SDFG): - raise TypeError("Expected SDFG, got: {}".format(type(sdfg).__name__)) + if not isinstance(cfg, ControlFlowRegion): + raise TypeError("Expected ControlFlowRegion, got: {}".format(type(cfg).__name__)) if self.state_id == -1: - graph = sdfg + graph = cfg else: - graph = sdfg.nodes()[self.state_id] + graph = cfg.nodes()[self.state_id] string = type(self).__name__ + ' in ' string += self.match_to_str(graph) return string @@ -402,6 +409,7 @@ def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Patt @make_properties +@experimental_cfg_block_compatible class SingleStateTransformation(PatternTransformation, abc.ABC): """ Base class for pattern-matching transformations that find matches within a single SDFG state. @@ -497,7 +505,7 @@ def expressions(cls) -> List[gr.SubgraphView]: pass @abc.abstractmethod - def can_be_applied(self, graph: SDFG, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFG object in which the match was found. @@ -553,16 +561,18 @@ def __get__(self, instance: Optional[PatternTransformation], owner) -> T: # If an instance is used, we return the matched node node_id: int = instance.subgraph[self] state_id: int = instance.state_id + t_graph: ControlFlowRegion = instance._sdfg.cfg_list[instance.cfg_id] if not isinstance(node_id, int): # Node ID is already an object return node_id # Inter-state transformation if state_id == -1: - return instance._sdfg.node(node_id) + return t_graph.node(node_id) # Single-state transformation - return instance._sdfg.node(state_id).node(node_id) + state: SDFGState = t_graph.node(state_id) + return state.node(node_id) @make_properties @@ -706,7 +716,7 @@ def setup_match(self, subgraph: Union[Set[int], gr.SubgraphView], cfg_id: int = if isinstance(subgraph.graph, SDFGState): sdfg = subgraph.graph.parent self.cfg_id = sdfg.cfg_id - self.state_id = sdfg.node_id(subgraph.graph) + self.state_id = subgraph.graph.block_id elif isinstance(subgraph.graph, SDFG): self.cfg_id = subgraph.graph.cfg_id self.state_id = -1 @@ -866,3 +876,62 @@ def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Subg context['transformation'] = ret serialize.set_properties_from_json(ret, json_obj, context=context, ignore_properties={'transformation', 'type'}) return ret + + +def _make_function_blocksafe(cls: ppl.Pass, function_name: str, get_sdfg_arg: Callable[[Any], Optional[SDFG]]): + if hasattr(cls, function_name): + vanilla_method = getattr(cls, function_name) + def blocksafe_wrapper(tgt, *args, **kwargs): + if isinstance(tgt, SDFG): + sdfg = tgt + elif kwargs and 'sdfg' in kwargs: + sdfg = kwargs['sdfg'] + else: + sdfg = get_sdfg_arg(tgt, *args) + if sdfg and isinstance(sdfg, SDFG): + root_sdfg: SDFG = sdfg.cfg_list[0] + if not root_sdfg.using_experimental_blocks: + return vanilla_method(tgt, *args, **kwargs) + else: + warnings.warn('Skipping ' + function_name + ' from ' + cls.__name__ + + ' due to incompatibility with experimental control flow blocks') + setattr(cls, function_name, blocksafe_wrapper) + + +def _subgraph_transformation_extract_sdfg_arg(*args) -> SDFG: + subgraph = args[1] + if isinstance(subgraph, SDFG): + return subgraph + elif isinstance(subgraph, SDFGState): + return subgraph.sdfg + elif isinstance(subgraph, gr.SubgraphView): + if isinstance(subgraph.graph, SDFGState): + return subgraph.graph.sdfg + elif isinstance(subgraph.graph, SDFG): + return subgraph.graph + raise TypeError('Unrecognized graph type "%s"' % type(subgraph.graph).__name__) + raise TypeError('Unrecognized graph type "%s"' % type(subgraph).__name__) + + +def single_level_sdfg_only(cls: ppl.Pass): + + for function_name in ['apply_pass', 'apply_to']: + _make_function_blocksafe(cls, function_name, lambda *args: args[1]) + + if issubclass(cls, SubgraphTransformation): + _make_function_blocksafe(cls, 'apply', lambda *args: args[1]) + _make_function_blocksafe(cls, 'can_be_applied', lambda *args: args[1]) + _make_function_blocksafe(cls, 'setup_match', _subgraph_transformation_extract_sdfg_arg) + elif issubclass(cls, ppl.StatePass): + _make_function_blocksafe(cls, 'apply', lambda *args: args[1].sdfg) + elif issubclass(cls, ppl.ScopePass): + _make_function_blocksafe(cls, 'apply', lambda *args: args[2].sdfg) + else: + _make_function_blocksafe(cls, 'apply', lambda *args: args[2]) + _make_function_blocksafe(cls, 'can_be_applied', lambda *args: args[3]) + _make_function_blocksafe(cls, 'setup_match', lambda *args: args[1]) + + if issubclass(cls, PatternTransformation): + _make_function_blocksafe(cls, 'apply_pattern', lambda *args: args[0]._sdfg) + + return cls diff --git a/doc/frontend/parsing.rst b/doc/frontend/parsing.rst index 856c376b01..7adc415497 100644 --- a/doc/frontend/parsing.rst +++ b/doc/frontend/parsing.rst @@ -76,14 +76,15 @@ Abstract Syntax sub-Tree. The :class:`~dace.frontend.python.newast.ProgramVisito - ``annotated_types``: A dictionary from Python variables to Data-Centric datatypes. Used when variables are explicitly type-annotated in the Python code. - ``map_symbols``: The :class:`~dace.sdfg.nodes.Map` symbols defined in the :class:`~dace.sdfg.sdfg.SDFG`. Useful when deciding when an augmented assignment should be implemented with WCR or not. - ``sdfg``: The generated :class:`~dace.sdfg.sdfg.SDFG` object. -- ``last_state``: The (current) last :class:`~dace.sdfg.state.SDFGState` object created and added to the :class:`~dace.sdfg.sdfg.SDFG`. +- ``last_block``: The (current) last :class:`~dace.sdfg.state.ControlFlowBlock` object created and added to the current :class:`~dace.sdfg.state.ControlFlowRegion`. +- ``current_state``: The (current) last :class:`~dace.sdfg.state.SDFGState` object created and added to the current :class:`~dace.sdfg.state.ControlFlowRegion`, similar to `last_block`, but only tracking states. +- ``sdfg``: The current :class:`~dace.sdfg.sdfg.SDFG` being worked on. +- ``cfg_target``: The current :class:`~dace.sdfg.state.ControlFlowRegion` being worked on (may be the current :class:`~dace.sdfg.sdfg.SDFG` or a sub-region, such as a :class:`~dace.sdfg.state.LoopRegion`). +- ``last_cfg_target``: The previous :class:`~dace.sdfg.state.ControlFlowRegion` that blocks were being added to. - ``inputs``: The input connectors of the generated :class:`~dace.sdfg.nodes.NestedSDFG` and a :class:`~dace.memlet.Memlet`-like representation of the corresponding Data subsets read. - ``outputs``: The output connectors of the generated :class:`~dace.sdfg.nodes.NestedSDFG` and a :class:`~dace.memlet.Memlet`-like representation of the corresponding Data subsets written. - ``current_lineinfo``: The current :class:`~dace.dtypes.DebugInfo`. Used for debugging. - ``modules``: The modules imported in the file of the top-level Data-Centric Python program. Produced by filtering `globals`. -- ``loop_idx``: The current scope-depth in a nested loop construct. -- ``continue_states``: The generated :class:`~dace.sdfg.state.SDFGState` objects corresponding to Python `continue `_ statements. Useful for generating proper nested loop control-flow. -- ``break_states``: The generated :class:`~dace.sdfg.state.SDFGState` objects corresponding to Python `break `_ statements. Useful for generating proper nested loop control-flow. - ``symbols``: The loop symbols defined in the :class:`~dace.sdfg.sdfg.SDFG` object. Useful for memlet/state propagation when multiple loops use the same iteration variable but with different ranges. - ``indirections``: A dictionary from Python code indirection expressions to Data-Centric symbols. @@ -167,6 +168,10 @@ Example: :align: center :alt: Generated SDFG for-loop for the above Data-Centric Python program +If the :class:`~dace.frontend.python.parser.DaceProgram`'s +:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, this will utilize +:class:`~dace.sdfg.state.LoopRegion`s instead of the explicit state machine depicted above. + :func:`~dace.frontend.python.newast.ProgramVisitor.visit_While` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -185,6 +190,10 @@ Parses `while `_ statement :align: center :alt: Generated SDFG while-loop for the above Data-Centric Python program +If the :class:`~dace.frontend.python.parser.DaceProgram`'s +:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, this will utilize +:class:`~dace.sdfg.state.LoopRegion`s instead of the explicit state machine depicted above. + :func:`~dace.frontend.python.newast.ProgramVisitor.visit_Break` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -204,6 +213,11 @@ behaves as an if-else statement. This is also evident from the generated dataflo :align: center :alt: Generated SDFG for-loop with a break statement for the above Data-Centric Python program +If the :class:`~dace.frontend.python.parser.DaceProgram`'s +:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, loops are +represented with :class:`~dace.sdfg.state.LoopRegion`s, and a break is represented with a special +:class:`~dace.sdfg.state.LoopRegion.BreakState`. + :func:`~dace.frontend.python.newast.ProgramVisitor.visit_Continue` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -223,6 +237,11 @@ of `continue` makes the ``A[i] = i`` statement unreachable. This is also evident :align: center :alt: Generated SDFG for-loop with a continue statement for the above Data-Centric Python program +If the :class:`~dace.frontend.python.parser.DaceProgram`'s +:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, loops are +represented with :class:`~dace.sdfg.state.LoopRegion`s, and a continue is represented with a special +:class:`~dace.sdfg.state.LoopRegion.ContinueState`. + :func:`~dace.frontend.python.newast.ProgramVisitor.visit_If` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index f0ecba933b..d385abb9e1 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ }, include_package_data=True, install_requires=[ - 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'websockets', 'jinja2', + 'numpy < 2.0', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'websockets', 'jinja2', 'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"' ] + cmake_requires, diff --git a/tests/codegen/data_instrumentation_test.py b/tests/codegen/data_instrumentation_test.py index 3c0a6605d8..b254a204b5 100644 --- a/tests/codegen/data_instrumentation_test.py +++ b/tests/codegen/data_instrumentation_test.py @@ -318,8 +318,11 @@ def dinstr(A: dace.float64[20]): assert len(dreport.keys()) == 1 assert 'i' in dreport.keys() assert len(dreport['i']) == 22 - desired = [0] + list(range(0, 20)) - assert np.allclose(dreport['i'][:21], desired) + desired = list(range(1, 19)) + s_idx = dreport['i'].index(1) + e_idx = dreport['i'].index(18) + assert np.allclose(dreport['i'][s_idx:e_idx+1], desired) + assert 19 in dreport['i'] @pytest.mark.datainstrument diff --git a/tests/fortran/fortran_loops_test.py b/tests/fortran/fortran_loops_test.py new file mode 100644 index 0000000000..4d4c259f07 --- /dev/null +++ b/tests/fortran/fortran_loops_test.py @@ -0,0 +1,45 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_loop_region_basic_loop(): + test_name = "loop_test" + test_string = """ + PROGRAM loop_test_program + implicit none + double precision a(10,10) + double precision b(10,10) + double precision c(10,10) + + CALL loop_test_function(a,b,c) + end + + SUBROUTINE loop_test_function(a,b,c) + double precision :: a(10,10) + double precision :: b(10,10) + double precision :: c(10,10) + + INTEGER :: JK,JL + DO JK=1,10 + DO JL=1,10 + c(JK,JL) = a(JK,JL) + b(JK,JL) + ENDDO + ENDDO + end SUBROUTINE loop_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, use_experimental_cfg_blocks=True) + + a_test = np.full([10, 10], 2, order="F", dtype=np.float64) + b_test = np.full([10, 10], 3, order="F", dtype=np.float64) + c_test = np.zeros([10, 10], order="F", dtype=np.float64) + sdfg(a=a_test, b=b_test, c=c_test) + + validate = np.full([10, 10], 5, order="F", dtype=np.float64) + + assert np.allclose(c_test, validate) + + +if __name__ == '__main__': + test_fortran_frontend_loop_region_basic_loop() diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 02cc57a204..140ec105f7 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -263,7 +263,7 @@ def test_promote_loop(): def testprog8(A: dace.float32[20, 20]): i = dace.ndarray([1], dtype=dace.int32) i = 0 - while i[0] < N: + while i < N: A += i i += 2 diff --git a/tests/python_frontend/loop_regions_test.py b/tests/python_frontend/loop_regions_test.py new file mode 100644 index 0000000000..b6509bb0c3 --- /dev/null +++ b/tests/python_frontend/loop_regions_test.py @@ -0,0 +1,635 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import pytest +import dace +import numpy as np + +from dace.frontend.python.common import DaceSyntaxError +from dace.sdfg.state import LoopRegion + +# NOTE: Some tests have been disabled due to issues with our control flow detection during codegen. +# The issue is documented in #1586, and in parts in #635. The problem causes the listed tests to fail when +# automatic simplification is turned off ONLY. There are several active efforts to address this issue. +# For one, there are fixes being made to the control flow detection itself (commits da7af41 and c830f92 +# are the start of that). Additionally, codegen is being adapted (in a separate, following PR) to make use +# of the control flow region constructs directly, circumventing this issue entirely. +# As such, disabling these tests is a very temporary solution that should not be longer lived than +# a few weeks at most. +# TODO: Re-enable after issues are addressed. + +@dace.program +def for_loop(): + A = dace.ndarray([10], dtype=dace.int32) + A[:] = 0 + for i in range(0, 10, 2): + A[i] = i + return A + + +def test_for_loop(): + for_loop.use_experimental_cfg_blocks = True + + sdfg = for_loop.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0], dtype=np.int32) + assert (np.array_equal(A, A_ref)) + + +@dace.program +def for_loop_with_break_continue(): + A = dace.ndarray([10], dtype=dace.int32) + A[:] = 0 + for i in range(20): + if i >= 10: + break + if i % 2 == 1: + continue + A[i] = i + return A + + +@pytest.mark.skip(reason='Control flow detection issues through extraneous states, needs control flow detection fix') +def test_for_loop_with_break_continue(): + for_loop_with_break_continue.use_experimental_cfg_blocks = True + + sdfg = for_loop_with_break_continue.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0], dtype=np.int32) + assert (np.array_equal(A, A_ref)) + + +@dace.program +def nested_for_loop(): + A = dace.ndarray([10, 10], dtype=dace.int32) + A[:] = 0 + for i in range(20): + if i >= 10: + break + if i % 2 == 1: + continue + for j in range(20): + if j >= 10: + break + if j % 2 == 1: + continue + A[i, j] = j + return A + + +@pytest.mark.skip(reason='Control flow detection issues through extraneous states, needs control flow detection fix') +def test_nested_for_loop(): + nested_for_loop.use_experimental_cfg_blocks = True + + sdfg = nested_for_loop.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.zeros([10, 10], dtype=np.int32) + for i in range(0, 10, 2): + A_ref[i] = [0, 0, 2, 0, 4, 0, 6, 0, 8, 0] + assert (np.array_equal(A, A_ref)) + + +@dace.program +def while_loop(): + A = dace.ndarray([10], dtype=dace.int32) + A[:] = 0 + i = 0 + while (i < 10): + A[i] = i + i += 2 + return A + + +def test_while_loop(): + while_loop.use_experimental_cfg_blocks = True + + sdfg = while_loop.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0], dtype=np.int32) + assert (np.array_equal(A, A_ref)) + + +@dace.program +def while_loop_with_break_continue(): + A = dace.ndarray([10], dtype=dace.int32) + A[:] = 0 + i = -1 + while i < 20: + i += 1 + if i >= 10: + break + if i % 2 == 1: + continue + A[i] = i + return A + + +def test_while_loop_with_break_continue(): + while_loop_with_break_continue.use_experimental_cfg_blocks = True + + sdfg = while_loop_with_break_continue.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0], dtype=np.int32) + assert (np.array_equal(A, A_ref)) + + +@dace.program +def nested_while_loop(): + A = dace.ndarray([10, 10], dtype=dace.int32) + A[:] = 0 + i = -1 + while i < 20: + i += 1 + if i >= 10: + break + if i % 2 == 1: + continue + j = -1 + while j < 20: + j += 1 + if j >= 10: + break + if j % 2 == 1: + continue + A[i, j] = j + return A + + +def test_nested_while_loop(): + nested_while_loop.use_experimental_cfg_blocks = True + + sdfg = nested_while_loop.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.zeros([10, 10], dtype=np.int32) + for i in range(0, 10, 2): + A_ref[i] = [0, 0, 2, 0, 4, 0, 6, 0, 8, 0] + assert (np.array_equal(A, A_ref)) + + +@dace.program +def nested_for_while_loop(): + A = dace.ndarray([10, 10], dtype=dace.int32) + A[:] = 0 + for i in range(20): + if i >= 10: + break + if i % 2 == 1: + continue + j = -1 + while j < 20: + j += 1 + if j >= 10: + break + if j % 2 == 1: + continue + A[i, j] = j + return A + + +@pytest.mark.skip(reason='Control flow detection issues through extraneous states, needs control flow detection fix') +def test_nested_for_while_loop(): + nested_for_while_loop.use_experimental_cfg_blocks = True + + sdfg = nested_for_while_loop.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.zeros([10, 10], dtype=np.int32) + for i in range(0, 10, 2): + A_ref[i] = [0, 0, 2, 0, 4, 0, 6, 0, 8, 0] + assert (np.array_equal(A, A_ref)) + + +@dace.program +def nested_while_for_loop(): + A = dace.ndarray([10, 10], dtype=dace.int32) + A[:] = 0 + i = -1 + while i < 20: + i += 1 + if i >= 10: + break + if i % 2 == 1: + continue + for j in range(20): + if j >= 10: + break + if j % 2 == 1: + continue + A[i, j] = j + return A + + +@pytest.mark.skip(reason='Control flow detection issues through extraneous states, needs control flow detection fix') +def test_nested_while_for_loop(): + nested_while_for_loop.use_experimental_cfg_blocks = True + + sdfg = nested_while_for_loop.to_sdfg() + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + A = sdfg() + A_ref = np.zeros([10, 10], dtype=np.int32) + for i in range(0, 10, 2): + A_ref[i] = [0, 0, 2, 0, 4, 0, 6, 0, 8, 0] + assert (np.array_equal(A, A_ref)) + + +@dace.program +def map_with_break_continue(): + A = dace.ndarray([10], dtype=dace.int32) + A[:] = 0 + for i in dace.map[0:20]: + if i >= 10: + break + if i % 2 == 1: + continue + A[i] = i + return A + + +def test_map_with_break_continue(): + try: + map_with_break_continue.use_experimental_cfg_blocks = True + map_with_break_continue() + except Exception as e: + if isinstance(e, DaceSyntaxError): + return 0 + assert (False) + + +@dace.program +def nested_map_for_loop(): + A = np.ndarray([10, 10], dtype=np.int64) + for i in dace.map[0:10]: + for j in range(10): + A[i, j] = i * 10 + j + return A + + +def test_nested_map_for_loop(): + ref = np.zeros([10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + ref[i, j] = i * 10 + j + nested_map_for_loop.use_experimental_cfg_blocks = True + val = nested_map_for_loop() + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_map_for_for_loop(): + A = np.ndarray([10, 10, 10], dtype=np.int64) + for i in dace.map[0:10]: + for j in range(10): + for k in range(10): + A[i, j, k] = i * 100 + j * 10 + k + return A + + +def test_nested_map_for_for_loop(): + ref = np.zeros([10, 10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + for k in range(10): + ref[i, j, k] = i * 100 + j * 10 + k + nested_map_for_for_loop.use_experimental_cfg_blocks = True + val = nested_map_for_for_loop() + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_for_map_for_loop(): + A = np.ndarray([10, 10, 10], dtype=np.int64) + for i in range(10): + for j in dace.map[0:10]: + for k in range(10): + A[i, j, k] = i * 100 + j * 10 + k + return A + + +def test_nested_for_map_for_loop(): + ref = np.zeros([10, 10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + for k in range(10): + ref[i, j, k] = i * 100 + j * 10 + k + nested_for_map_for_loop.use_experimental_cfg_blocks = True + val = nested_for_map_for_loop() + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_map_for_loop_with_tasklet(): + A = np.ndarray([10, 10], dtype=np.int64) + for i in dace.map[0:10]: + for j in range(10): + + @dace.tasklet + def comp(): + out >> A[i, j] + out = i * 10 + j + + return A + + +def test_nested_map_for_loop_with_tasklet(): + ref = np.zeros([10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + ref[i, j] = i * 10 + j + nested_map_for_loop_with_tasklet.use_experimental_cfg_blocks = True + val = nested_map_for_loop_with_tasklet() + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_map_for_for_loop_with_tasklet(): + A = np.ndarray([10, 10, 10], dtype=np.int64) + for i in dace.map[0:10]: + for j in range(10): + for k in range(10): + + @dace.tasklet + def comp(): + out >> A[i, j, k] + out = i * 100 + j * 10 + k + + return A + + +def test_nested_map_for_for_loop_with_tasklet(): + ref = np.zeros([10, 10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + for k in range(10): + ref[i, j, k] = i * 100 + j * 10 + k + nested_map_for_for_loop_with_tasklet.use_experimental_cfg_blocks = True + val = nested_map_for_for_loop_with_tasklet() + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_for_map_for_loop_with_tasklet(): + A = np.ndarray([10, 10, 10], dtype=np.int64) + for i in range(10): + for j in dace.map[0:10]: + for k in range(10): + + @dace.tasklet + def comp(): + out >> A[i, j, k] + out = i * 100 + j * 10 + k + + return A + + +def test_nested_for_map_for_loop_with_tasklet(): + ref = np.zeros([10, 10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + for k in range(10): + ref[i, j, k] = i * 100 + j * 10 + k + nested_for_map_for_loop_with_tasklet.use_experimental_cfg_blocks = True + val = nested_for_map_for_loop_with_tasklet() + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_map_for_loop_2(B: dace.int64[10, 10]): + A = np.ndarray([10, 10], dtype=np.int64) + for i in dace.map[0:10]: + for j in range(10): + A[i, j] = 2 * B[i, j] + i * 10 + j + return A + + +def test_nested_map_for_loop_2(): + B = np.ones([10, 10], dtype=np.int64) + ref = np.zeros([10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + ref[i, j] = 2 + i * 10 + j + nested_map_for_loop_2.use_experimental_cfg_blocks = True + val = nested_map_for_loop_2(B) + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_map_for_loop_with_tasklet_2(B: dace.int64[10, 10]): + A = np.ndarray([10, 10], dtype=np.int64) + for i in dace.map[0:10]: + for j in range(10): + + @dace.tasklet + def comp(): + inp << B[i, j] + out >> A[i, j] + out = 2 * inp + i * 10 + j + + return A + + +def test_nested_map_for_loop_with_tasklet_2(): + B = np.ones([10, 10], dtype=np.int64) + ref = np.zeros([10, 10], dtype=np.int64) + for i in range(10): + for j in range(10): + ref[i, j] = 2 + i * 10 + j + nested_map_for_loop_with_tasklet_2.use_experimental_cfg_blocks = True + val = nested_map_for_loop_with_tasklet_2(B) + assert (np.array_equal(val, ref)) + + +@dace.program +def nested_map_with_symbol(): + A = np.zeros([10, 10], dtype=np.int64) + for i in dace.map[0:10]: + for j in dace.map[i:10]: + A[i, j] = i * 10 + j + return A + + +def test_nested_map_with_symbol(): + ref = np.zeros([10, 10], dtype=np.int64) + for i in range(10): + for j in range(i, 10): + ref[i, j] = i * 10 + j + nested_map_with_symbol.use_experimental_cfg_blocks = True + val = nested_map_with_symbol() + assert (np.array_equal(val, ref)) + + +@pytest.mark.skip(reason='Control flow detection issues through extraneous states, needs control flow detection fix') +def test_for_else(): + + @dace.program + def for_else(A: dace.float64[20]): + for i in range(1, 20): + if A[i] >= 10: + A[0] = i + break + if i % 2 == 1: + continue + A[i] = i + else: + A[0] = -1.0 + + A = np.random.rand(20) + A_2 = np.copy(A) + expected_1 = np.copy(A) + expected_2 = np.copy(A) + + expected_2[6] = 20.0 + for_else.f(expected_1) + for_else.f(expected_2) + + for_else.use_experimental_cfg_blocks = True + + for_else(A) + assert np.allclose(A, expected_1) + + A_2[6] = 20.0 + for_else(A_2) + assert np.allclose(A_2, expected_2) + + +def test_while_else(): + + @dace.program + def while_else(A: dace.float64[2]): + while A[0] < 5.0: + if A[1] < 0.0: + A[0] = -1.0 + break + A[0] += 1.0 + else: + A[1] = 1.0 + A[1] = 1.0 + + while_else.use_experimental_cfg_blocks = True + + A = np.array([0.0, 0.0]) + expected = np.array([5.0, 1.0]) + while_else(A) + assert np.allclose(A, expected) + + A = np.array([0.0, -1.0]) + expected = np.array([-1.0, -1.0]) + while_else(A) + assert np.allclose(A, expected) + + +@dace.program +def branch_in_for(cond: dace.int32): + for i in range(10): + if cond > 0: + break + else: + continue + + +def test_branch_in_for(): + branch_in_for.use_experimental_cfg_blocks = True + sdfg = branch_in_for.to_sdfg(simplify=False) + assert len(sdfg.source_nodes()) == 1 + + +@dace.program +def branch_in_while(cond: dace.int32): + i = 0 + while i < 10: + if cond > 0: + break + else: + i += 1 + continue + + +def test_branch_in_while(): + branch_in_while.use_experimental_cfg_blocks = True + sdfg = branch_in_while.to_sdfg(simplify=False) + assert len(sdfg.source_nodes()) == 1 + +def test_for_with_return(): + + @dace.program + def for_with_return(A: dace.int32[10]): + for i in range(10): + if A[i] < 0: + return 1 + return 0 + + for_with_return.use_experimental_cfg_blocks = True + sdfg = for_with_return.to_sdfg() + + A = np.full((10,), 1).astype(np.int32) + A2 = np.full((10,), 1).astype(np.int32) + A2[5] = -1 + rval1 = sdfg(A) + expected1 = for_with_return.f(A) + rval2 = sdfg(A2) + expected2 = for_with_return.f(A2) + assert rval1 == expected1 + assert rval2 == expected2 + +def test_for_while_with_return(): + + @dace.program + def for_while_with_return(A: dace.int32[10, 10]): + for i in range(10): + j = 0 + while (j < 10): + if A[i,j] < 0: + return 1 + j += 1 + return 0 + + for_while_with_return.use_experimental_cfg_blocks = True + sdfg = for_while_with_return.to_sdfg() + + A = np.full((10,10), 1).astype(np.int32) + A2 = np.full((10,10), 1).astype(np.int32) + A2[5,5] = -1 + rval1 = sdfg(A) + expected1 = for_while_with_return.f(A) + rval2 = sdfg(A2) + expected2 = for_while_with_return.f(A2) + assert rval1 == expected1 + assert rval2 == expected2 + + +if __name__ == "__main__": + test_for_loop() + test_for_loop_with_break_continue() + test_nested_for_loop() + test_while_loop() + test_while_loop_with_break_continue() + test_nested_while_loop() + test_nested_for_while_loop() + test_nested_while_for_loop() + test_map_with_break_continue() + test_nested_map_for_loop() + test_nested_map_for_for_loop() + test_nested_for_map_for_loop() + test_nested_map_for_loop_with_tasklet() + test_nested_map_for_for_loop_with_tasklet() + test_nested_for_map_for_loop_with_tasklet() + test_nested_map_for_loop_2() + test_nested_map_for_loop_with_tasklet_2() + test_nested_map_with_symbol() + test_for_else() + test_while_else() + test_branch_in_for() + test_branch_in_while() + test_for_with_return() + test_for_while_with_return() \ No newline at end of file diff --git a/tests/python_frontend/loops_test.py b/tests/python_frontend/loops_test.py index ecbfdd6cc0..952d69b8fb 100644 --- a/tests/python_frontend/loops_test.py +++ b/tests/python_frontend/loops_test.py @@ -1,9 +1,19 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +import pytest import dace import numpy as np from dace.frontend.python.common import DaceSyntaxError +# NOTE: Some tests have been disabled due to issues with our control flow detection during codegen. +# The issue is documented in #1586, and in parts in #635. The problem causes the listed tests to fail when +# automatic simplification is turned off ONLY. There are several active efforts to address this issue. +# For one, there are fixes being made to the control flow detection itself (commits da7af41 and c830f92 +# are the start of that). Additionally, codegen is being adapted (in a separate, following PR) to make use +# of the control flow region constructs directly, circumventing this issue entirely. +# As such, disabling these tests is a very temporary solution that should not be longer lived than +# a few weeks at most. +# TODO: Re-enable after issues are addressed. @dace.program def for_loop(): @@ -33,6 +43,8 @@ def for_loop_with_break_continue(): return A +@pytest.mark.skipif(dace.Config.get_bool('optimizer', 'automatic_simplification') == False, + reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_for_loop_with_break_continue(): A = for_loop_with_break_continue() A_ref = np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0], dtype=np.int32) @@ -57,6 +69,8 @@ def nested_for_loop(): return A +@pytest.mark.skipif(dace.Config.get_bool('optimizer', 'automatic_simplification') == False, + reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_nested_for_loop(): A = nested_for_loop() A_ref = np.zeros([10, 10], dtype=np.int32) @@ -153,6 +167,8 @@ def nested_for_while_loop(): return A +@pytest.mark.skipif(dace.Config.get_bool('optimizer', 'automatic_simplification') == False, + reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_nested_for_while_loop(): A = nested_for_while_loop() A_ref = np.zeros([10, 10], dtype=np.int32) @@ -181,6 +197,8 @@ def nested_while_for_loop(): return A +@pytest.mark.skipif(dace.Config.get_bool('optimizer', 'automatic_simplification') == False, + reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_nested_while_for_loop(): A = nested_while_for_loop() A_ref = np.zeros([10, 10], dtype=np.int32) @@ -404,6 +422,8 @@ def test_nested_map_with_symbol(): assert (np.array_equal(val, ref)) +@pytest.mark.skipif(dace.Config.get_bool('optimizer', 'automatic_simplification') == False, + reason='Control flow detection issues through extraneous states, needs control flow detection fix') def test_for_else(): @dace.program diff --git a/tests/transformations/control_flow_inline_test.py b/tests/sdfg/control_flow_inline_test.py similarity index 94% rename from tests/transformations/control_flow_inline_test.py rename to tests/sdfg/control_flow_inline_test.py index 106a955143..87af09b9c4 100644 --- a/tests/transformations/control_flow_inline_test.py +++ b/tests/sdfg/control_flow_inline_test.py @@ -189,9 +189,9 @@ def test_loop_inlining_for_continue_break(): update_expr='i = i + 1', inverted=False) sdfg.add_node(loop1) state1 = loop1.add_state('state1', is_start_block=True) - state2 = loop1.add_state('state2') + state2 = loop1.add_continue('state2') state3 = loop1.add_state('state3') - state4 = loop1.add_state('state4') + state4 = loop1.add_break('state4') state5 = loop1.add_state('state5') state6 = loop1.add_state('state6') loop1.add_edge(state1, state2, dace.InterstateEdge(condition='i < 5')) @@ -199,8 +199,6 @@ def test_loop_inlining_for_continue_break(): loop1.add_edge(state3, state4, dace.InterstateEdge(condition='i < 6')) loop1.add_edge(state3, state5, dace.InterstateEdge(condition='i >= 6')) loop1.add_edge(state5, state6, dace.InterstateEdge()) - loop1.continue_states = {loop1.node_id(state2)} - loop1.break_states = {loop1.node_id(state4)} sdfg.add_edge(state0, loop1, dace.InterstateEdge()) state7 = sdfg.add_state('state7') sdfg.add_edge(loop1, state7, dace.InterstateEdge()) @@ -211,15 +209,21 @@ def test_loop_inlining_for_continue_break(): assert len(states) == 12 assert not any(isinstance(s, LoopRegion) for s in states) end_state = None - tail_state = None + latch_state = None + break_state = None + continue_state = None for state in states: if state.label == 'loop1_end': end_state = state - elif state.label == 'loop1_tail': - tail_state = state + elif state.label == 'loop1_latch': + latch_state = state + elif state.label == 'loop1_state2': + continue_state = state + elif state.label == 'loop1_state4': + break_state = state assert end_state is not None - assert len(sdfg.edges_between(state4, end_state)) == 1 - assert len(sdfg.edges_between(state2, tail_state)) == 1 + assert len(sdfg.edges_between(break_state, end_state)) == 1 + assert len(sdfg.edges_between(continue_state, latch_state)) == 1 def test_loop_inlining_multi_assignments(): @@ -247,18 +251,18 @@ def test_loop_inlining_multi_assignments(): guard_state = None init_state = None - tail_state = None + latch_state = None for state in sdfg.states(): if state.label == 'loop1_guard': guard_state = state elif state.label == 'loop1_init': init_state = state - elif state.label == 'loop1_tail': - tail_state = state + elif state.label == 'loop1_latch': + latch_state = state init_edge = sdfg.edges_between(init_state, guard_state)[0] assert 'i' in init_edge.data.assignments assert 'j' in init_edge.data.assignments - update_edge = sdfg.edges_between(tail_state, guard_state)[0] + update_edge = sdfg.edges_between(latch_state, guard_state)[0] assert 'i' in update_edge.data.assignments assert 'j' in update_edge.data.assignments diff --git a/tests/state_propagation_test.py b/tests/state_propagation_test.py index ac4393a58d..226775a0e7 100644 --- a/tests/state_propagation_test.py +++ b/tests/state_propagation_test.py @@ -1,7 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from dace.dtypes import Language -from dace.properties import CodeProperty +from dace.properties import CodeProperty, CodeBlock from dace.sdfg.sdfg import InterstateEdge import dace from dace.sdfg.propagation import propagate_states @@ -47,203 +47,147 @@ def test_conditional_fake_merge(): def test_conditional_full_merge(): - @dace.program(dace.int32, dace.int32, dace.int32) - def conditional_full_merge(a, b, c): - if a < 10: - if b < 10: - c = 0 - else: - c = 1 - c += 1 - - sdfg = conditional_full_merge.to_sdfg(simplify=False) + sdfg = dace.SDFG('conditional_full_merge') + + sdfg.add_scalar('a', dace.int32) + sdfg.add_scalar('b', dace.int32) + + init_state = sdfg.add_state('init_state') + if_guard_1 = sdfg.add_state('if_guard_1') + l_branch_1 = sdfg.add_state('l_branch_1') + if_guard_2 = sdfg.add_state('if_guard_2') + l_branch = sdfg.add_state('l_branch') + r_branch = sdfg.add_state('r_branch') + if_merge_1 = sdfg.add_state('if_merge_1') + if_merge_2 = sdfg.add_state('if_merge_2') + + sdfg.add_edge(init_state, if_guard_1, dace.InterstateEdge()) + sdfg.add_edge(if_guard_1, l_branch_1, dace.InterstateEdge(condition=CodeBlock('a < 10'))) + sdfg.add_edge(l_branch_1, if_guard_2, dace.InterstateEdge()) + sdfg.add_edge(if_guard_1, if_merge_1, dace.InterstateEdge(condition=CodeBlock('not (a < 10)'))) + sdfg.add_edge(if_guard_2, l_branch, dace.InterstateEdge(condition=CodeBlock('b < 10'))) + sdfg.add_edge(if_guard_2, r_branch, dace.InterstateEdge(condition=CodeBlock('not (b < 10)'))) + sdfg.add_edge(l_branch, if_merge_2, dace.InterstateEdge()) + sdfg.add_edge(r_branch, if_merge_2, dace.InterstateEdge()) + sdfg.add_edge(if_merge_2, if_merge_1, dace.InterstateEdge()) + propagate_states(sdfg) # Check start state. - state = sdfg.start_state - state_check_executions(state, 1) + state_check_executions(init_state, 1) # Check the first if guard, `a < 10`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1) - # Get edges to the true and fals branches. - oedges = sdfg.out_edges(state) - true_branch_edge = None - false_branch_edge = None - for edge in oedges: - if edge.data.label == '(a < 10)': - true_branch_edge = edge - elif edge.data.label == '(not (a < 10))': - false_branch_edge = edge - if false_branch_edge is None or true_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(if_guard_1, 1) # Check the true branch. - state = true_branch_edge.dst - state_check_executions(state, 1, expected_dynamic=True) + state_check_executions(l_branch_1, 1, expected_dynamic=True) # Check the next if guard, `b < 20` - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1, expected_dynamic=True) - # Get edges to the true and fals branches. - oedges = sdfg.out_edges(state) - true_branch_edge = None - false_branch_edge = None - for edge in oedges: - if edge.data.label == '(b < 10)': - true_branch_edge = edge - elif edge.data.label == '(not (b < 10))': - false_branch_edge = edge - if false_branch_edge is None or true_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(if_guard_2, 1, expected_dynamic=True) # Check the true branch. - state = true_branch_edge.dst - state_check_executions(state, 1, expected_dynamic=True) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1, expected_dynamic=True) + state_check_executions(l_branch_1, 1, expected_dynamic=True) # Check the false branch. - state = false_branch_edge.dst - state_check_executions(state, 1, expected_dynamic=True) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1, expected_dynamic=True) - + state_check_executions(r_branch, 1, expected_dynamic=True) # Check the first branch merge state. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1, expected_dynamic=True) - + state_check_executions(if_merge_2, 1, expected_dynamic=True) # Check the second branch merge state. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1) - - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1) + state_check_executions(if_merge_1, 1) def test_while_inside_for(): - @dace.program(dace.int32) - def while_inside_for(a): - for i in range(20): - j = 0 - while j < 20: - a += 5 - - sdfg = while_inside_for.to_sdfg(simplify=False) + sdfg = dace.SDFG('while_inside_for') + + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + + init_state = sdfg.add_state('init') + guard_1 = sdfg.add_state('guard_1') + loop_1 = sdfg.add_state('loop_1') + end_1 = sdfg.add_state('end_1') + guard_2 = sdfg.add_state('guard_2') + loop_2 = sdfg.add_state('loop_2') + end_2 = sdfg.add_state('end_2') + + sdfg.add_edge(init_state, guard_1, dace.InterstateEdge(assignments={'i': 0})) + sdfg.add_edge(guard_1, end_1, dace.InterstateEdge(condition=CodeBlock('not (i < 20)'))) + sdfg.add_edge(guard_1, loop_1, dace.InterstateEdge(condition=CodeBlock('i < 20'))) + sdfg.add_edge(loop_1, guard_2, dace.InterstateEdge()) + sdfg.add_edge(end_2, guard_1, dace.InterstateEdge(assignments={'i': 'i + 1'})) + + sdfg.add_edge(guard_2, end_2, dace.InterstateEdge(condition=CodeBlock('not (j < 20)'))) + sdfg.add_edge(guard_2, loop_2, dace.InterstateEdge(condition=CodeBlock('j < 20'))) + sdfg.add_edge(loop_2, guard_2, dace.InterstateEdge()) + propagate_states(sdfg) # Check start state. - state = sdfg.start_state - state_check_executions(state, 1) + state_check_executions(init_state, 1) # Check the for loop guard, `i in range(20)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 21) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(i < 20)': - for_branch_edge = edge - elif edge.data.label == '(not (i < 20))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(guard_1, 21) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 1) + state_check_executions(end_1, 1) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, 20) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 20) + state_check_executions(loop_1, 20) # Check the while guard, `j < 20`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 0, expected_dynamic=True) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(j < 20)': - for_branch_edge = edge - elif edge.data.label == '(not (j < 20))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(guard_2, 0, expected_dynamic=True) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 20) + state_check_executions(end_2, 20) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, 0, expected_dynamic=True) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 0, expected_dynamic=True) + state_check_executions(loop_2, 0, expected_dynamic=True) def test_for_with_nested_full_merge_branch(): - @dace.program(dace.int32) - def for_with_nested_full_merge_branch(a): - for i in range(20): - if i < 10: - a += 2 - else: - a += 1 - - sdfg = for_with_nested_full_merge_branch.to_sdfg(simplify=False) + sdfg = dace.SDFG('for_full_merge') + + sdfg.add_symbol('i', dace.int32) + sdfg.add_scalar('a', dace.int32) + + init_state = sdfg.add_state('init') + guard_1 = sdfg.add_state('guard_1') + if_guard = sdfg.add_state('if_guard') + l_branch = sdfg.add_state('l_branch') + r_branch = sdfg.add_state('r_branch') + if_merge = sdfg.add_state('if_merge') + end_1 = sdfg.add_state('end_1') + + lra = l_branch.add_access('a') + lt = l_branch.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 + 5') + lwa = l_branch.add_access('a') + l_branch.add_edge(lra, None, lt, 'i1', dace.Memlet('a[0]')) + l_branch.add_edge(lt, 'o1', lwa, None, dace.Memlet('a[0]')) + + rra = r_branch.add_access('a') + rt = r_branch.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 + 10') + rwa = r_branch.add_access('a') + r_branch.add_edge(rra, None, rt, 'i1', dace.Memlet('a[0]')) + r_branch.add_edge(rt, 'o1', rwa, None, dace.Memlet('a[0]')) + + sdfg.add_edge(init_state, guard_1, dace.InterstateEdge(assignments={'i': 0})) + sdfg.add_edge(guard_1, end_1, dace.InterstateEdge(condition=CodeBlock('not (i < 20)'))) + sdfg.add_edge(guard_1, if_guard, dace.InterstateEdge(condition=CodeBlock('i < 20'))) + sdfg.add_edge(if_guard, l_branch, dace.InterstateEdge(condition=CodeBlock('not (a < 10)'))) + sdfg.add_edge(if_guard, r_branch, dace.InterstateEdge(condition=CodeBlock('a < 10'))) + sdfg.add_edge(l_branch, if_merge, dace.InterstateEdge()) + sdfg.add_edge(r_branch, if_merge, dace.InterstateEdge()) + sdfg.add_edge(if_merge, guard_1, dace.InterstateEdge(assignments={'i': 'i + 1'})) + propagate_states(sdfg) # Check start state. - state = sdfg.start_state - state_check_executions(state, 1) + state_check_executions(init_state, 1) # For loop, check loop guard, `for i in range(20)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 21) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(i < 20)': - for_branch_edge = edge - elif edge.data.label == '(not (i < 20))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(guard_1, 21) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 1) + state_check_executions(end_1, 1) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, 20) - - # Check the branch guard, `if i < 10`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 20) - # Get edges to both sides of the conditional split. - oedges = sdfg.out_edges(state) - condition_met_edge = None - condition_broken_edge = None - for edge in oedges: - if edge.data.label == '(i < 10)': - condition_met_edge = edge - elif edge.data.label == '(not (i < 10))': - condition_broken_edge = edge - if condition_met_edge is None or condition_broken_edge is None: - raise RuntimeError('Couldn\'t identify conditional guard edges') + state_check_executions(if_guard, 20) # Check the 'true' branch. - state = condition_met_edge.dst - state_check_executions(state, 20, expected_dynamic=True) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 20, expected_dynamic=True) + state_check_executions(r_branch, 20, expected_dynamic=True) # Check the 'false' branch. - state = condition_broken_edge.dst - state_check_executions(state, 20, expected_dynamic=True) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 20, expected_dynamic=True) - + state_check_executions(l_branch, 20, expected_dynamic=True) # Check where the branches meet again. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 20) + state_check_executions(if_merge, 20) def test_for_inside_branch(): @@ -322,70 +266,56 @@ def test_full_merge_inside_loop(): def test_while_with_nested_full_merge_branch(): - @dace.program(dace.int32) - def while_with_nested_full_merge_branch(a): - while a < 20: - if a < 10: - a += 2 - else: - a += 1 - - sdfg = while_with_nested_full_merge_branch.to_sdfg(simplify=False) + sdfg = dace.SDFG('while_full_merge') + + sdfg.add_scalar('a', dace.int32) + + init_state = sdfg.add_state('init') + guard_1 = sdfg.add_state('guard_1') + if_guard = sdfg.add_state('if_guard') + l_branch = sdfg.add_state('l_branch') + r_branch = sdfg.add_state('r_branch') + if_merge = sdfg.add_state('if_merge') + end_1 = sdfg.add_state('end_1') + + lra = l_branch.add_access('a') + lt = l_branch.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1 + 5') + lwa = l_branch.add_access('a') + l_branch.add_edge(lra, None, lt, 'i1', dace.Memlet('a[0]')) + l_branch.add_edge(lt, 'o1', lwa, None, dace.Memlet('a[0]')) + + rra = r_branch.add_access('a') + rt = r_branch.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1 + 10') + rwa = r_branch.add_access('a') + r_branch.add_edge(rra, None, rt, 'i1', dace.Memlet('a[0]')) + r_branch.add_edge(rt, 'o1', rwa, None, dace.Memlet('a[0]')) + + sdfg.add_edge(init_state, guard_1, dace.InterstateEdge()) + sdfg.add_edge(guard_1, end_1, dace.InterstateEdge(condition=CodeBlock('not (a < 20)'))) + sdfg.add_edge(guard_1, if_guard, dace.InterstateEdge(condition=CodeBlock('a < 20'))) + sdfg.add_edge(if_guard, l_branch, dace.InterstateEdge(condition=CodeBlock('not (a < 10)'))) + sdfg.add_edge(if_guard, r_branch, dace.InterstateEdge(condition=CodeBlock('a < 10'))) + sdfg.add_edge(l_branch, if_merge, dace.InterstateEdge()) + sdfg.add_edge(r_branch, if_merge, dace.InterstateEdge()) + sdfg.add_edge(if_merge, guard_1, dace.InterstateEdge()) + propagate_states(sdfg) # Check start state. - state = sdfg.start_state - state_check_executions(state, 1) + state_check_executions(init_state, 1) # While loop, check loop guard, `while a < N`. Must be dynamic unbounded. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 0, expected_dynamic=True) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(a < 20)': - for_branch_edge = edge - elif edge.data.label == '(not (a < 20))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(guard_1, 0, expected_dynamic=True) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 1) + state_check_executions(end_1, 1) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, 0, expected_dynamic=True) - - # Check the branch guard, `if a < 10`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 0, expected_dynamic=True) - # Get edges to both sides of the conditional split. - oedges = sdfg.out_edges(state) - condition_met_edge = None - condition_broken_edge = None - for edge in oedges: - if edge.data.label == '(a < 10)': - condition_met_edge = edge - elif edge.data.label == '(not (a < 10))': - condition_broken_edge = edge - if condition_met_edge is None or condition_broken_edge is None: - raise RuntimeError('Couldn\'t identify conditional guard edges') + state_check_executions(if_guard, 0, expected_dynamic=True) # Check the 'true' branch. - state = condition_met_edge.dst - state_check_executions(state, 0, expected_dynamic=True) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 0, expected_dynamic=True) + state_check_executions(r_branch, 0, expected_dynamic=True) # Check the 'false' branch. - state = condition_broken_edge.dst - state_check_executions(state, 0, expected_dynamic=True) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 0, expected_dynamic=True) - + state_check_executions(l_branch, 0, expected_dynamic=True) # Check where the branches meet again. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 0, expected_dynamic=True) + state_check_executions(if_merge, 0, expected_dynamic=True) def test_3_fold_nested_loop_with_symbolic_bounds(): @@ -393,165 +323,123 @@ def test_3_fold_nested_loop_with_symbolic_bounds(): M = dace.symbol('M') K = dace.symbol('K') - @dace.program(dace.int32) - def nested_3_symbolic(a): - for i in range(N): - for j in range(M): - for k in range(K): - a += 5 + sdfg = dace.SDFG('nest_3_symbolic') + + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + + init_state = sdfg.add_state('init') + guard_1 = sdfg.add_state('guard_1') + loop_1 = sdfg.add_state('loop_1') + end_1 = sdfg.add_state('end_1') + guard_2 = sdfg.add_state('guard_2') + loop_2 = sdfg.add_state('loop_2') + end_2 = sdfg.add_state('end_2') + guard_3 = sdfg.add_state('guard_3') + end_3 = sdfg.add_state('end_3') + loop_3 = sdfg.add_state('loop_3') + + sdfg.add_edge(init_state, guard_1, dace.InterstateEdge(assignments={'i': 0})) + sdfg.add_edge(guard_1, end_1, dace.InterstateEdge(condition=CodeBlock('not (i < N)'))) + sdfg.add_edge(guard_1, loop_1, dace.InterstateEdge(condition=CodeBlock('i < N'))) + sdfg.add_edge(loop_1, guard_2, dace.InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(end_2, guard_1, dace.InterstateEdge(assignments={'i': 'i + 1'})) + + sdfg.add_edge(guard_2, end_2, dace.InterstateEdge(condition=CodeBlock('not (j < M)'))) + sdfg.add_edge(guard_2, loop_2, dace.InterstateEdge(condition=CodeBlock('j < M'))) + sdfg.add_edge(loop_2, guard_3, dace.InterstateEdge(assignments={'k': 0})) + sdfg.add_edge(end_3, guard_2, dace.InterstateEdge(assignments={'j': 'j + 1'})) + + sdfg.add_edge(guard_3, end_3, dace.InterstateEdge(condition=CodeBlock('not (k < K)'))) + sdfg.add_edge(guard_3, loop_3, dace.InterstateEdge(condition=CodeBlock('k < K'))) + sdfg.add_edge(loop_3, guard_3, dace.InterstateEdge(assignments={'k': 'k + 1'})) - sdfg = nested_3_symbolic.to_sdfg(simplify=False) propagate_states(sdfg) # Check start state. - state = sdfg.start_state - state_check_executions(state, 1) + state_check_executions(init_state, 1) - # 1st level loop, check loop guard, `for i in range(20)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, N + 1) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(i < N)': - for_branch_edge = edge - elif edge.data.label == '(not (i < N))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + # 1st level loop, check loop guard, `for i in range(N)`. + state_check_executions(guard_1, N + 1) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 1) + state_check_executions(end_1, 1) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, N) + state_check_executions(loop_1, N) - # 2nd level nested loop, check loog guard, `for j in range(i, 20)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, M * N + N) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(j < M)': - for_branch_edge = edge - elif edge.data.label == '(not (j < M))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + # 2nd level nested loop, check loog guard, `for j in range(M)`. + state_check_executions(guard_2, M * N + N) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, N) + state_check_executions(end_2, N) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, M * N) - - # 3rd level nested loop, check loog guard, `for k in range(i, j)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, M * N * K + M * N) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(k < K)': - for_branch_edge = edge - elif edge.data.label == '(not (k < K))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(loop_2, M * N) + + # 3rd level nested loop, check loop guard, `for k in range(K)`. + state_check_executions(guard_3, M * N * K + M * N) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, M * N) + state_check_executions(end_3, M * N) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, M * N * K) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, M * N * K) + state_check_executions(loop_3, M * N * K) def test_3_fold_nested_loop(): - @dace.program(dace.int32[20, 20]) - def nested_3(A): - for i in range(20): - for j in range(i, 20): - for k in range(i, j): - A[k, j] += 5 - - sdfg = nested_3.to_sdfg(simplify=False) + sdfg = dace.SDFG('nest_3') + + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + + init_state = sdfg.add_state('init') + guard_1 = sdfg.add_state('guard_1') + loop_1 = sdfg.add_state('loop_1') + end_1 = sdfg.add_state('end_1') + guard_2 = sdfg.add_state('guard_2') + loop_2 = sdfg.add_state('loop_2') + end_2 = sdfg.add_state('end_2') + guard_3 = sdfg.add_state('guard_3') + end_3 = sdfg.add_state('end_3') + loop_3 = sdfg.add_state('loop_3') + + sdfg.add_edge(init_state, guard_1, dace.InterstateEdge(assignments={'i': 0})) + sdfg.add_edge(guard_1, end_1, dace.InterstateEdge(condition=CodeBlock('not (i < 20)'))) + sdfg.add_edge(guard_1, loop_1, dace.InterstateEdge(condition=CodeBlock('i < 20'))) + sdfg.add_edge(loop_1, guard_2, dace.InterstateEdge(assignments={'j': 'i'})) + sdfg.add_edge(end_2, guard_1, dace.InterstateEdge(assignments={'i': 'i + 1'})) + + sdfg.add_edge(guard_2, end_2, dace.InterstateEdge(condition=CodeBlock('not (j < 20)'))) + sdfg.add_edge(guard_2, loop_2, dace.InterstateEdge(condition=CodeBlock('j < 20'))) + sdfg.add_edge(loop_2, guard_3, dace.InterstateEdge(assignments={'k': 'i'})) + sdfg.add_edge(end_3, guard_2, dace.InterstateEdge(assignments={'j': 'j + 1'})) + + sdfg.add_edge(guard_3, end_3, dace.InterstateEdge(condition=CodeBlock('not (k < j)'))) + sdfg.add_edge(guard_3, loop_3, dace.InterstateEdge(condition=CodeBlock('k < j'))) + sdfg.add_edge(loop_3, guard_3, dace.InterstateEdge(assignments={'k': 'k + 1'})) + propagate_states(sdfg) # Check start state. - state = sdfg.start_state - state_check_executions(state, 1) + state_check_executions(init_state, 1) # 1st level loop, check loop guard, `for i in range(20)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 21) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(i < 20)': - for_branch_edge = edge - elif edge.data.label == '(not (i < 20))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(guard_1, 21) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 1) + state_check_executions(end_1, 1) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, 20) + state_check_executions(loop_1, 20) # 2nd level nested loop, check loog guard, `for j in range(i, 20)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 230) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(j < 20)': - for_branch_edge = edge - elif edge.data.label == '(not (j < 20))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(guard_2, 230) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 20) + state_check_executions(end_2, 20) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, 210) - - # 3rd level nested loop, check loog guard, `for k in range(i, j)`. - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1540) - # Get edges to inside and outside the loop. - oedges = sdfg.out_edges(state) - end_branch_edge = None - for_branch_edge = None - for edge in oedges: - if edge.data.label == '(k < j)': - for_branch_edge = edge - elif edge.data.label == '(not (k < j))': - end_branch_edge = edge - if end_branch_edge is None or for_branch_edge is None: - raise RuntimeError('Couldn\'t identify guard edges') + state_check_executions(loop_2, 210) + + # 3rd level nested loop, check loop guard, `for k in range(i, j)`. + state_check_executions(guard_3, 1540) # Check loop-end branch. - state = end_branch_edge.dst - state_check_executions(state, 210) + state_check_executions(end_3, 210) # Check inside the loop. - state = for_branch_edge.dst - state_check_executions(state, 1330) - state = sdfg.out_edges(state)[0].dst - state_check_executions(state, 1330) + state_check_executions(loop_3, 1330) if __name__ == "__main__": diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 7c556362e4..8cd6947bb5 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -11,6 +11,7 @@ import dace from dace.sdfg import nodes, propagation from dace.transformation.interstate import LoopToMap +from dace.transformation.interstate.loop_detection import DetectLoop def make_sdfg(with_wcr, map_in_guard, reverse_loop, use_variable, assign_after, log_path): diff --git a/tests/transformations/state_fission_test.py b/tests/transformations/state_fission_test.py index 7c03fbed89..37bd375590 100644 --- a/tests/transformations/state_fission_test.py +++ b/tests/transformations/state_fission_test.py @@ -127,7 +127,7 @@ def test_state_fission(): vec_add1 = state.nodes()[3] subg = dace.sdfg.graph.SubgraphView(state, [node_x, node_y, vec_add1, node_z]) - helpers.state_fission(sdfg, subg) + helpers.state_fission(subg) sdfg.validate() assert (len(sdfg.states()) == 2)