Skip to content

Commit

Permalink
Fix race conditions in Constant Propagation and Reference-To-View (#1679
Browse files Browse the repository at this point in the history
)

* Fixes a case where constant propagation would cause an inter-state
edge assignment race condition
* Fixes reference-to-view disconnecting a state graph and causing a race
condition
* More informative error message in code generation for copy dispatching
  • Loading branch information
tbennun authored Oct 11, 2024
1 parent 6525bc5 commit e6440a6
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 16 deletions.
5 changes: 5 additions & 0 deletions dace/codegen/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@ def dispatch_copy(self, src_node: nodes.Node, dst_node: nodes.Node, edge: MultiC
cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, function_stream: CodeIOStream,
output_stream: CodeIOStream) -> None:
""" Dispatches a code generator for a memory copy operation. """
if edge.data.is_empty():
return
state = cfg.state(state_id)
target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state)
if target is None:
Expand All @@ -616,6 +618,9 @@ def dispatch_output_definition(self, src_node: nodes.Node, dst_node: nodes.Node,
"""
state = cfg.state(state_id)
target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state)
if target is None:
raise ValueError(
f'Could not dispatch copy code generator for {src_node} -> {dst_node} in state {state.label}')

# Dispatch
self._used_targets.add(target)
Expand Down
16 changes: 15 additions & 1 deletion dace/transformation/passes/constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _add_nested_datanames(name: str, desc: data.Structure):
# TODO: How are we handling this?
pass
arrays.add(f'{name}.{k}')

for name, desc in sdfg.arrays.items():
if isinstance(desc, data.Structure):
_add_nested_datanames(name, desc)
Expand Down Expand Up @@ -222,6 +222,20 @@ def _add_nested_datanames(name: str, desc: data.Structure):
else:
assignments[aname] = aval

for edge in sdfg.out_edges(state):
for aname, aval in assignments.items():
# If the specific replacement would result in the value
# being both used and reassigned on the same inter-state
# edge, remove it from consideration.
replacements = symbolic.free_symbols_and_functions(aval)
used_in_assignments = {
k
for k, v in edge.data.assignments.items() if aname in symbolic.free_symbols_and_functions(v)
}
reassignments = replacements & edge.data.assignments.keys()
if reassignments and (used_in_assignments - reassignments):
assignments[aname] = _UnknownValue

if state not in result: # Condition may evaluate to False when state is the start-state
result[state] = {}
redo |= self._propagate(result[state], assignments)
Expand Down
37 changes: 22 additions & 15 deletions dace/transformation/passes/reference_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,28 @@ def remove_refsets(
affected_nodes = set()
for e in state.in_edges_by_connector(node, 'set'):
# This is a reference set edge. Consider scope and neighbors and remove set
edges_to_remove.add(e)
affected_nodes.add(e.src)
affected_nodes.add(e.dst)

# If source node does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)):
nodes_to_remove.add(e.src)
# If set reference does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)):
nodes_to_remove.add(node)

# If in a scope, ensure reference node will not be disconnected
scope = state.entry_node(node)
if scope is not None and node not in nodes_to_remove:
edges_to_add.append((scope, None, node, None, Memlet()))
if state.out_degree(e.dst) == 0:
edges_to_remove.add(e)
affected_nodes.add(e.src)
affected_nodes.add(e.dst)

# If source node does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)):
nodes_to_remove.add(e.src)
# If set reference does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)):
nodes_to_remove.add(node)

# If in a scope, ensure reference node will not be disconnected
scope = state.entry_node(node)
if scope is not None and node not in nodes_to_remove:
edges_to_add.append((scope, None, node, None, Memlet()))
else: # Node has other neighbors, modify edge to become an empty memlet instead
e.dst_conn = None
e.dst.remove_in_connector('set')
e.data = Memlet()



# Modify the state graph as necessary
for e in edges_to_remove:
Expand Down
55 changes: 55 additions & 0 deletions tests/passes/constant_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,59 @@ def test_dependency_change():
assert a[0] == ref


@pytest.mark.parametrize('extra_state', (False, True))
def test_dependency_change_same_edge(extra_state):
"""
Tests a regression in constant propagation that stems from a variable's
dependency being set in the same edge where the pre-propagated symbol was
also a right-hand side expression. In this case, ``i61`` is incorrectly
propagated to ``i60`` and ``i17`` is set to ``i61``, which is also updated
on the same inter-state edge.
"""

sdfg = dace.SDFG('tester')
sdfg.add_symbol('N', dace.int64)
sdfg.add_array('a', [1], dace.int64)
sdfg.add_scalar('cont', dace.int64, transient=True)
init = sdfg.add_state()
entry = sdfg.add_state('entry')
body = sdfg.add_state('body')
latch = sdfg.add_state('latch')
final = sdfg.add_state('final')

sdfg.add_edge(init, entry, dace.InterstateEdge(assignments=dict(i60='0')))
sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12')))
sdfg.add_edge(body, final, dace.InterstateEdge('cont'))
sdfg.add_edge(body, latch, dace.InterstateEdge('not cont', dict(i60='i61')))
if not extra_state:
sdfg.add_edge(latch, body, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12')))
else:
# Test that the multi-value definition is not propagated to following edges
extra = sdfg.add_state('extra')
sdfg.add_edge(latch, extra, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12')))
sdfg.add_edge(extra, body, dace.InterstateEdge(assignments=dict(i18='i60 + i61')))

t = body.add_tasklet('add', {'inp'}, {'out', 'c'}, 'out = inp + i17; c = i61 == 10')
body.add_edge(body.add_read('a'), None, t, 'inp', dace.Memlet('a[0]'))
body.add_edge(t, 'out', body.add_write('a'), None, dace.Memlet('a[0]'))
body.add_edge(t, 'c', body.add_write('cont'), None, dace.Memlet('cont[0]'))

ConstantPropagation().apply_pass(sdfg, {})

sdfg.validate()

# Python code equivalent of the above SDFG
ref = 0
i60 = 0
for i60 in range(0, 10):
i17 = i60 * 12
ref += i17

a = np.zeros([1], np.int64)
sdfg(a=a)
assert a[0] == ref


if __name__ == '__main__':
test_simple_constants()
test_nested_constants()
Expand All @@ -592,3 +645,5 @@ def test_dependency_change():
test_for_with_external_init_nested_start_with_guard()
test_skip_branch()
test_dependency_change()
test_dependency_change_same_edge(False)
test_dependency_change_same_edge(True)
47 changes: 47 additions & 0 deletions tests/sdfg/reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dace.transformation.passes.reference_reduction import ReferenceToView
import numpy as np
import pytest
import networkx as nx


def test_unset_reference():
Expand Down Expand Up @@ -636,6 +637,51 @@ def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate):
assert np.allclose(B, ref)


def test_ref2view_reconnection():
"""
Tests a regression in which ReferenceToView disconnects an existing weakly-connected state
and thus creating a race condition.
"""
sdfg = dace.SDFG('reftest')
sdfg.add_array('A', [2], dace.float64)
sdfg.add_array('B', [1], dace.float64)
sdfg.add_reference('ref', [1], dace.float64)

state = sdfg.add_state()
a2 = state.add_access('A')
ref = state.add_access('ref')
b = state.add_access('B')

t2 = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1')
state.add_edge(ref, None, t2, 'inp', dace.Memlet('ref[0]'))
state.add_edge(t2, 'out', b, None, dace.Memlet('B[0]'))
state.add_edge(a2, None, ref, 'set', dace.Memlet('A[1]'))

t1 = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1')
a1 = state.add_access('A')
state.add_edge(a1, None, t1, 'inp', dace.Memlet('A[1]'))
state.add_edge(t1, 'out', a2, None, dace.Memlet('A[1]'))

# Test correctness before pass
A = np.random.rand(2)
B = np.random.rand(1)
ref = (A[1] + 2)
sdfg(A=A, B=B)
assert np.allclose(B, ref)

# Test reference-to-view
result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {})
assert result['ReferenceToView'] == {'ref'}

# Pass should not break order
assert len(list(nx.weakly_connected_components(state.nx))) == 1

# Test correctness after pass
ref = (A[1] + 2)
sdfg(A=A, B=B)
assert np.allclose(B, ref)


if __name__ == '__main__':
test_unset_reference()
test_reference_branch()
Expand All @@ -662,3 +708,4 @@ def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate):
test_ref2view_refset_in_scope(False, True)
test_ref2view_refset_in_scope(True, False)
test_ref2view_refset_in_scope(True, True)
test_ref2view_reconnection()

0 comments on commit e6440a6

Please sign in to comment.