Skip to content

Commit

Permalink
Finished LoopToMap
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Oct 9, 2024
1 parent 352171a commit bbddc88
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 70 deletions.
2 changes: 1 addition & 1 deletion dace/sdfg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def add_edge(self, src: NodeT, dst: NodeT, data: EdgeT = None):

def remove_node(self, node: NodeT):
try:
for edge in itertools.chain(self.in_edges(node), self.out_edges(node)):
for edge in self.all_edges(node):
self.remove_edge(edge)
del self._nodes[node]
self._nx.remove_node(node)
Expand Down
27 changes: 15 additions & 12 deletions dace/sdfg/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import itertools
import warnings
from collections import deque
from typing import List, Set
from typing import TYPE_CHECKING, List, Set

import sympy
from sympy import Symbol, ceiling
Expand All @@ -22,6 +22,11 @@
from dace.symbolic import issymbolic, pystr_to_symbolic, simplify


if TYPE_CHECKING:
from dace.sdfg import SDFG
from dace.sdfg.state import SDFGState


@registry.make_registry
class MemletPattern(object):
"""
Expand Down Expand Up @@ -561,7 +566,7 @@ def propagate(self, array, expressions, node_range):
return subsets.Range(rng)


def _annotate_loop_ranges(sdfg, unannotated_cycle_states):
def _annotate_loop_ranges(sdfg: 'SDFG', unannotated_cycle_states):
"""
Annotate each valid for loop construct with its loop variable ranges.
Expand Down Expand Up @@ -682,7 +687,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states):

return condition_edges

def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None:
def propagate_states(sdfg: 'SDFG', concretize_dynamic_unbounded: bool = False) -> None:
"""
Annotate the states of an SDFG with the number of executions.
Expand Down Expand Up @@ -948,7 +953,7 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None:
sdfg.remove_node(temp_exit_state)


def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node):
def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState', nsdfg_node: nodes.NestedSDFG):
"""
Propagate memlets out of a nested sdfg.
Expand Down Expand Up @@ -980,7 +985,7 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node):
# the corresponding memlets and use them to calculate the memlet volume and
# subset corresponding to the outside memlet attached to that connector.
# This is passed out via `border_memlets` and propagated along from there.
for state in sdfg.nodes():
for state in sdfg.states():
for node in state.data_nodes():
for direction in border_memlets:
if (node.label not in border_memlets[direction]):
Expand Down Expand Up @@ -1139,34 +1144,32 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node):
oedge.data.dynamic = True


def reset_state_annotations(sdfg):
def reset_state_annotations(sdfg: 'SDFG'):
""" Resets the state (loop-related) annotations of an SDFG.
:note: This operation is shallow (does not go into nested SDFGs).
"""
for state in sdfg.nodes():
for state in sdfg.states():
state.executions = 0
state.dynamic_executions = True
state.ranges = {}
state.is_loop_guard = False
state.itervar = None


def propagate_memlets_sdfg(sdfg):
def propagate_memlets_sdfg(sdfg: 'SDFG'):
""" Propagates memlets throughout an entire given SDFG.
:note: This is an in-place operation on the SDFG.
"""
# Reset previous annotations first
reset_state_annotations(sdfg)

for state in sdfg.nodes():
for state in sdfg.states():
propagate_memlets_state(sdfg, state)

propagate_states(sdfg)


def propagate_memlets_state(sdfg, state):
def propagate_memlets_state(sdfg: 'SDFG', state: 'SDFGState'):
""" Propagates memlets throughout one SDFG state.
:param sdfg: The SDFG in which the state is situated.
Expand Down
51 changes: 37 additions & 14 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,13 +2501,14 @@ def sdfg(self) -> 'SDFG':


@make_properties
class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView,
ControlFlowBlock):
class AbstractControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView,
ControlFlowBlock, abc.ABC):

def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None):
def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None,
parent: Optional['AbstractControlFlowRegion'] = None):
OrderedDiGraph.__init__(self)
ControlGraphView.__init__(self)
ControlFlowBlock.__init__(self, label, sdfg)
ControlFlowBlock.__init__(self, label, sdfg, parent)

self._labels: Set[str] = set()
self._start_block: Optional[int] = None
Expand Down Expand Up @@ -2683,9 +2684,13 @@ def add_node(self,
self._cached_start_block = None
node.parent_graph = self
if isinstance(self, dace.SDFG):
node.sdfg = self
sdfg = self
else:
node.sdfg = self.sdfg
sdfg = self.sdfg
node.sdfg = sdfg
if isinstance(node, AbstractControlFlowRegion):
for n in node.all_control_flow_blocks():
n.sdfg = self.sdfg
start_block = is_start_block
if is_start_state is not None:
warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning)
Expand Down Expand Up @@ -2963,6 +2968,13 @@ def start_block(self, block_id):
self._cached_start_block = self.node(block_id)


@make_properties
class ControlFlowRegion(AbstractControlFlowRegion):

def __init__(self, label = '', sdfg = None, parent = None):
super().__init__(label, sdfg, parent)


@make_properties
class LoopRegion(ControlFlowRegion):
"""
Expand Down Expand Up @@ -3244,7 +3256,7 @@ def has_return(self) -> bool:


@make_properties
class ConditionalBlock(ControlFlowBlock, ControlGraphView):
class ConditionalBlock(AbstractControlFlowRegion):

_branches: List[Tuple[Optional[CodeBlock], ControlFlowRegion]]

Expand All @@ -3264,7 +3276,7 @@ def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]:

def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion):
self._branches.append([condition, branch])
branch.parent_graph = self.parent_graph
branch.parent_graph = self
branch.sdfg = self.sdfg

def remove_branch(self, branch: ControlFlowRegion):
Expand All @@ -3273,12 +3285,6 @@ def remove_branch(self, branch: ControlFlowRegion):
if b is not branch:
filtered_branches.append((c, b))
self._branches = filtered_branches

def nodes(self) -> List['ControlFlowBlock']:
return [node for _, node in self._branches if node is not None]

def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]:
return []

def _used_symbols_internal(self,
all_symbols: bool,
Expand Down Expand Up @@ -3397,6 +3403,23 @@ def inline(self) -> Tuple[bool, Any]:

return True, (guard_state, end_state)

# Graph API overrides.

def nodes(self) -> List['ControlFlowBlock']:
return [node for _, node in self._branches if node is not None]

def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]:
return []

def in_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]:
return []

def out_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]:
return []

def all_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]:
return []


@make_properties
class NamedRegion(ControlFlowRegion):
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/interstate/loop_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG):
if k != itvar:
left_over_incr_assignments[k] = incr_edge.data.assignments[k]

if inverted and incr_edge is cond_edge:
if (inverted or self.expr_index == 4) and incr_edge is cond_edge:
update_before_condition = False
else:
update_before_condition = True
Expand Down
12 changes: 0 additions & 12 deletions dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,16 +556,4 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
if itervar in sdfg.free_symbols:
sdfg.remove_symbol(itervar)

# Reset all nested SDFG parent pointers
if nsdfg is not None:
if isinstance(nsdfg, nodes.NestedSDFG):
nsdfg = nsdfg.sdfg

for nstate in nsdfg.nodes():
for nnode in nstate.nodes():
if isinstance(nnode, nodes.NestedSDFG):
nnode.sdfg.parent_nsdfg_node = nnode
nnode.sdfg.parent = nstate
nnode.sdfg.parent_sdfg = nsdfg

sdfg.reset_cfg_list()
2 changes: 1 addition & 1 deletion dace/transformation/passes/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s

# Edges that read from arrays add to both ends' access sets
anames = sdfg.arrays.keys()
for e in sdfg.edges():
for e in sdfg.all_interstate_edges():
fsyms = e.data.free_symbols & anames
if fsyms:
result[e.src][0].update(fsyms)
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/passes/dead_dataflow_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def modifies(self) -> ppl.Modifies:

def should_reapply(self, modified: ppl.Modifies) -> bool:
# If dataflow or states changed, new dead code may be exposed
return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.States)
return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.CFG)

def depends_on(self) -> Set[Type[ppl.Pass]]:
return {ap.StateReachability, ap.AccessSets}
Expand Down
4 changes: 3 additions & 1 deletion dace/transformation/passes/simplify.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Set
import warnings
Expand All @@ -16,10 +16,12 @@
from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion
from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols
from dace.transformation.passes.reference_reduction import ReferenceToView
from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising

SIMPLIFY_PASSES = [
InlineSDFGs,
ScalarToSymbolPromotion,
ControlFlowRaising,
FuseStates,
OptionalArrayInference,
ConstantPropagation,
Expand Down
8 changes: 4 additions & 4 deletions dace/transformation/transformation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
"""
This file contains classes that describe data-centric transformations.
Expand All @@ -20,10 +20,10 @@

import abc
import copy
from dace import dtypes, serialize
from dace import serialize
from dace.dtypes import ScheduleType
from dace.sdfg import SDFG, SDFGState
from dace.sdfg.state import ControlFlowRegion
from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion
from dace.sdfg import nodes as nd, graph as gr, utils as sdutil, propagation, infer_types, state as st
from dace.properties import make_properties, Property, DictProperty, SetProperty
from dace.transformation import pass_pipeline as ppl
Expand Down Expand Up @@ -339,7 +339,7 @@ def _can_be_applied_and_apply(
# Check that all keyword arguments are nodes and if interstate or not
sample_node = next(iter(where.values()))

if isinstance(sample_node, SDFGState):
if isinstance(sample_node, ControlFlowBlock):
graph = sample_node.parent_graph
state_id = -1
cfg_id = graph.cfg_id
Expand Down
38 changes: 15 additions & 23 deletions tests/transformations/loop_to_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import copy
import os
import tempfile
from typing import Tuple

import numpy as np
import pytest

import dace
from dace.sdfg import nodes, propagation
from dace.sdfg import nodes
from dace.sdfg.state import LoopRegion
from dace.transformation.interstate import LoopToMap, StateFusion
from dace.transformation.interstate.loop_detection import DetectLoop
from dace.transformation.interstate.loop_lifting import LoopLifting


Expand Down Expand Up @@ -653,34 +652,25 @@ def nested_loops(A: dace.int32[10, 10, 10], l: dace.int32):

sdfg = nested_loops.to_sdfg()

def find_loop(sdfg: dace.SDFG, itervar: str) -> Tuple[dace.SDFGState, dace.SDFGState, dace.SDFGState]:

guard, begin, fexit = None, None, None
for e in sdfg.edges():
if itervar in e.data.assignments and e.data.assignments[itervar] == '0':
guard = e.dst
elif e.data.condition.as_string in (f'({itervar} >= 10)', f'(not ({itervar} < 10))'):
fexit = e.dst
assert all(s is not None for s in (guard, fexit))

begin = next((e for e in sdfg.out_edges(guard) if e.dst != fexit)).dst

return guard, begin, fexit
def find_loop(sdfg: dace.SDFG, itervar: str) -> LoopRegion:
for cfg in sdfg.all_control_flow_regions():
if isinstance(cfg, LoopRegion) and cfg.loop_variable == itervar:
return cfg

sdfg0 = copy.deepcopy(sdfg)
i_guard, i_begin, i_exit = find_loop(sdfg0, 'i')
LoopToMap.apply_to(sdfg0, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit)
i_loop = find_loop(sdfg0, 'i')
LoopToMap.apply_to(sdfg0, loop=i_loop)
nsdfg = next((sd for sd in sdfg0.all_sdfgs_recursive() if sd.parent is not None))
j_guard, j_begin, j_exit = find_loop(nsdfg, 'j')
LoopToMap.apply_to(nsdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit)
j_loop = find_loop(nsdfg, 'j')
LoopToMap.apply_to(nsdfg, loop=j_loop)

val = np.arange(1000, dtype=np.int32).reshape(10, 10, 10).copy()
sdfg(A=val, l=5)

assert np.allclose(ref, val)

j_guard, j_begin, j_exit = find_loop(sdfg, 'j')
LoopToMap.apply_to(sdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit)
j_loop = find_loop(sdfg, 'j')
LoopToMap.apply_to(sdfg, loop=j_loop)
# NOTE: The following fails to apply because of subset A[0:i+1], which is overapproximated.
# i_guard, i_begin, i_exit = find_loop(sdfg, 'i')
# LoopToMap.apply_to(sdfg, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit)
Expand Down Expand Up @@ -720,7 +710,7 @@ def internal_write(inp0: dace.int32[10], inp1: dace.int32[10], out: dace.int32[1
val = np.empty((10, ), dtype=np.int32)

internal_write.f(inp0, inp1, ref)
internal_write(inp0, inp1, val)
sdfg(inp0, inp1, val)

assert np.array_equal(val, ref)

Expand Down Expand Up @@ -782,6 +772,8 @@ def test_self_loop_to_map():
body.add_edge(body.add_read('A'), None, t, 'inp', dace.Memlet('A[i]'))
body.add_edge(t, 'out', body.add_write('A'), None, dace.Memlet('A[i]'))

sdfg.apply_transformations_repeated([LoopLifting])

assert sdfg.apply_transformations_repeated(LoopToMap) == 1

a = np.random.rand(20)
Expand Down

0 comments on commit bbddc88

Please sign in to comment.