Skip to content

Commit

Permalink
fix[next][dace]: Bugfix for neighbors reduction with lift expressions (
Browse files Browse the repository at this point in the history
…#1599)

Addresses the issue reported in spcl/dace#1625

GT4Py was generating a wrong SDFG for icon4py stencil
`calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools`
  • Loading branch information
edopao authored Aug 2, 2024
1 parent 9d1e4e9 commit 5bf0488
Showing 1 changed file with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ def _visit_lift_in_neighbors_reduction(

if offset_provider.has_skip_values:
# check neighbor validity on if/else inter-state edge
start_state = lift_context.body.add_state("start", is_start_block=True)
# use one branch for connectivity case
start_state = lift_context.body.add_state_before(
lift_context.body.start_state,
"start",
condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}",
)
# use the other branch for skip value case
skip_neighbor_state = lift_context.body.add_state("skip_neighbor")
skip_neighbor_state.add_edge(
skip_neighbor_state.add_tasklet(
Expand All @@ -315,11 +321,6 @@ def _visit_lift_in_neighbors_reduction(
skip_neighbor_state,
dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"),
)
lift_context.body.add_edge(
start_state,
lift_context.state,
dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}"),
)

return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)]

Expand Down

0 comments on commit 5bf0488

Please sign in to comment.