Skip to content

Commit

Permalink
Fix ref2view for connected reference sets
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Oct 11, 2024
1 parent e7bb16f commit 581e4eb
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions dace/transformation/passes/reference_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,28 @@ def remove_refsets(
affected_nodes = set()
for e in state.in_edges_by_connector(node, 'set'):
# This is a reference set edge. Consider scope and neighbors and remove set
edges_to_remove.add(e)
affected_nodes.add(e.src)
affected_nodes.add(e.dst)

# If source node does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)):
nodes_to_remove.add(e.src)
# If set reference does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)):
nodes_to_remove.add(node)

# If in a scope, ensure reference node will not be disconnected
scope = state.entry_node(node)
if scope is not None and node not in nodes_to_remove:
edges_to_add.append((scope, None, node, None, Memlet()))
if state.out_degree(e.dst) == 0:
edges_to_remove.add(e)
affected_nodes.add(e.src)
affected_nodes.add(e.dst)

# If source node does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)):
nodes_to_remove.add(e.src)
# If set reference does not have any other neighbors, it can be removed
if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)):
nodes_to_remove.add(node)

# If in a scope, ensure reference node will not be disconnected
scope = state.entry_node(node)
if scope is not None and node not in nodes_to_remove:
edges_to_add.append((scope, None, node, None, Memlet()))
else: # Node has other neighbors, modify edge to become an empty memlet instead
e.dst_conn = None
e.dst.remove_in_connector('set')
e.data = Memlet()



# Modify the state graph as necessary
for e in edges_to_remove:
Expand Down

0 comments on commit 581e4eb

Please sign in to comment.