Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions in Constant Propagation and Reference-To-View #1679

Merged
merged 5 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dace/codegen/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@ def dispatch_copy(self, src_node: nodes.Node, dst_node: nodes.Node, edge: MultiC
cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, function_stream: CodeIOStream,
output_stream: CodeIOStream) -> None:
""" Dispatches a code generator for a memory copy operation. """
if edge.data.is_empty():
return
state = cfg.state(state_id)
target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state)
if target is None:
Expand All @@ -616,6 +618,9 @@ def dispatch_output_definition(self, src_node: nodes.Node, dst_node: nodes.Node,
"""
state = cfg.state(state_id)
target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state)
if target is None:
raise ValueError(
f'Could not dispatch copy code generator for {src_node} -> {dst_node} in state {state.label}')

# Dispatch
self._used_targets.add(target)
Expand Down
16 changes: 15 additions & 1 deletion dace/transformation/passes/constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _add_nested_datanames(name: str, desc: data.Structure):
# TODO: How are we handling this?
pass
arrays.add(f'{name}.{k}')

for name, desc in sdfg.arrays.items():
if isinstance(desc, data.Structure):
_add_nested_datanames(name, desc)
Expand Down Expand Up @@ -222,6 +222,20 @@ def _add_nested_datanames(name: str, desc: data.Structure):
else:
assignments[aname] = aval

for edge in sdfg.out_edges(state):
for aname, aval in assignments.items():
# If the specific replacement would result in the value
# being both used and reassigned on the same inter-state
# edge, remove it from consideration.
replacements = symbolic.free_symbols_and_functions(aval)
used_in_assignments = {
k
for k, v in edge.data.assignments.items() if aname in symbolic.free_symbols_and_functions(v)
}
reassignments = replacements & edge.data.assignments.keys()
if reassignments and (used_in_assignments - reassignments):
assignments[aname] = _UnknownValue

if state not in result: # Condition may evaluate to False when state is the start-state
result[state] = {}
redo |= self._propagate(result[state], assignments)
Expand Down
37 changes: 22 additions & 15 deletions dace/transformation/passes/reference_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,28 @@ def remove_refsets(
affected_nodes = set()
for e in state.in_edges_by_connector(node, 'set'):
# This is a reference set edge. Consider scope and neighbors and remove set
edges_to_remove.add(e)
affected_nodes.add(e.src)
affected_nodes.add(e.dst)

# If source node does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)):
nodes_to_remove.add(e.src)
# If set reference does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)):
nodes_to_remove.add(node)

# If in a scope, ensure reference node will not be disconnected
scope = state.entry_node(node)
if scope is not None and node not in nodes_to_remove:
edges_to_add.append((scope, None, node, None, Memlet()))
if state.out_degree(e.dst) == 0:
edges_to_remove.add(e)
affected_nodes.add(e.src)
affected_nodes.add(e.dst)

# If source node does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)):
nodes_to_remove.add(e.src)
# If set reference does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)):
nodes_to_remove.add(node)

# If in a scope, ensure reference node will not be disconnected
scope = state.entry_node(node)
if scope is not None and node not in nodes_to_remove:
edges_to_add.append((scope, None, node, None, Memlet()))
else: # Node has other neighbors, modify edge to become an empty memlet instead
e.dst_conn = None
e.dst.remove_in_connector('set')
e.data = Memlet()



# Modify the state graph as necessary
for e in edges_to_remove:
Expand Down
55 changes: 55 additions & 0 deletions tests/passes/constant_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,59 @@ def test_dependency_change():
assert a[0] == ref


@pytest.mark.parametrize('extra_state', (False, True))
def test_dependency_change_same_edge(extra_state):
"""
Tests a regression in constant propagation that stems from a variable's
dependency being set in the same edge where the pre-propagated symbol was
also a right-hand side expression. In this case, ``i61`` is incorrectly
propagated to ``i60`` and ``i17`` is set to ``i61``, which is also updated
on the same inter-state edge.
"""

sdfg = dace.SDFG('tester')
sdfg.add_symbol('N', dace.int64)
sdfg.add_array('a', [1], dace.int64)
sdfg.add_scalar('cont', dace.int64, transient=True)
init = sdfg.add_state()
entry = sdfg.add_state('entry')
body = sdfg.add_state('body')
latch = sdfg.add_state('latch')
final = sdfg.add_state('final')

sdfg.add_edge(init, entry, dace.InterstateEdge(assignments=dict(i60='0')))
sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12')))
sdfg.add_edge(body, final, dace.InterstateEdge('cont'))
sdfg.add_edge(body, latch, dace.InterstateEdge('not cont', dict(i60='i61')))
if not extra_state:
sdfg.add_edge(latch, body, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12')))
else:
# Test that the multi-value definition is not propagated to following edges
extra = sdfg.add_state('extra')
sdfg.add_edge(latch, extra, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12')))
sdfg.add_edge(extra, body, dace.InterstateEdge(assignments=dict(i18='i60 + i61')))

t = body.add_tasklet('add', {'inp'}, {'out', 'c'}, 'out = inp + i17; c = i61 == 10')
body.add_edge(body.add_read('a'), None, t, 'inp', dace.Memlet('a[0]'))
body.add_edge(t, 'out', body.add_write('a'), None, dace.Memlet('a[0]'))
body.add_edge(t, 'c', body.add_write('cont'), None, dace.Memlet('cont[0]'))

ConstantPropagation().apply_pass(sdfg, {})

sdfg.validate()

# Python code equivalent of the above SDFG
ref = 0
i60 = 0
for i60 in range(0, 10):
i17 = i60 * 12
ref += i17

a = np.zeros([1], np.int64)
sdfg(a=a)
assert a[0] == ref


if __name__ == '__main__':
test_simple_constants()
test_nested_constants()
Expand All @@ -592,3 +645,5 @@ def test_dependency_change():
test_for_with_external_init_nested_start_with_guard()
test_skip_branch()
test_dependency_change()
test_dependency_change_same_edge(False)
test_dependency_change_same_edge(True)
47 changes: 47 additions & 0 deletions tests/sdfg/reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dace.transformation.passes.reference_reduction import ReferenceToView
import numpy as np
import pytest
import networkx as nx


def test_unset_reference():
Expand Down Expand Up @@ -636,6 +637,51 @@ def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate):
assert np.allclose(B, ref)


def test_ref2view_reconnection():
"""
Tests a regression in which ReferenceToView disconnects an existing weakly-connected state
and thus creating a race condition.
"""
sdfg = dace.SDFG('reftest')
sdfg.add_array('A', [2], dace.float64)
sdfg.add_array('B', [1], dace.float64)
sdfg.add_reference('ref', [1], dace.float64)

state = sdfg.add_state()
a2 = state.add_access('A')
ref = state.add_access('ref')
b = state.add_access('B')

t2 = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1')
state.add_edge(ref, None, t2, 'inp', dace.Memlet('ref[0]'))
state.add_edge(t2, 'out', b, None, dace.Memlet('B[0]'))
state.add_edge(a2, None, ref, 'set', dace.Memlet('A[1]'))

t1 = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1')
a1 = state.add_access('A')
state.add_edge(a1, None, t1, 'inp', dace.Memlet('A[1]'))
state.add_edge(t1, 'out', a2, None, dace.Memlet('A[1]'))

# Test correctness before pass
A = np.random.rand(2)
B = np.random.rand(1)
ref = (A[1] + 2)
sdfg(A=A, B=B)
assert np.allclose(B, ref)

# Test reference-to-view
result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {})
assert result['ReferenceToView'] == {'ref'}

# Pass should not break order
assert len(list(nx.weakly_connected_components(state.nx))) == 1

# Test correctness after pass
ref = (A[1] + 2)
sdfg(A=A, B=B)
assert np.allclose(B, ref)


if __name__ == '__main__':
test_unset_reference()
test_reference_branch()
Expand All @@ -662,3 +708,4 @@ def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate):
test_ref2view_refset_in_scope(False, True)
test_ref2view_refset_in_scope(True, False)
test_ref2view_refset_in_scope(True, True)
test_ref2view_reconnection()
Loading