Skip to content

Commit

Permalink
Merge Canonicalize slice and useless slice rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed May 29, 2024
1 parent a681a0e commit edc5ddd
Showing 1 changed file with 40 additions and 45 deletions.
85 changes: 40 additions & 45 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,21 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@register_stabilize
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form:
1. X[0, :] -> X[0]
2. X[:] -> X
Also, rewrite Subtensor of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7
"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]

if not idxs:
return [node.inputs[0]]
Expand All @@ -364,74 +370,63 @@ def local_useless_slice(fgraph, node):
last_useless_slice -= 1
else:
break
# check if we removed something
if last_useless_slice < len(idxs):
new_idxs = idxs[:last_useless_slice]

if new_idxs:
new_subtensor = Subtensor(new_idxs)
new_subtensor_inputs = get_slice_elements(
new_idxs, lambda x: isinstance(x, Variable)
)
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)
return [out]
else:
# Subtensor is not needed at all
return [node.inputs[0]]


@register_useless
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Subtensor])
def local_replace_slice(fgraph, node):
"""
Rewrite Subtensor of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7
"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]

if not idxs:
return

new_idxs = list(idxs)
idx_flag = False
new_idxs = list(idxs)[:last_useless_slice]
change_flag = False
for dim, s in enumerate(new_idxs):
if not isinstance(s, slice):
if not isinstance(s, slice) or s == slice(None):
continue

start = s.start
stop = s.stop
step = s.step
if extract_constant(start, only_process_constants=True) == 0:
idx_flag = True
if (
start is not None
and extract_constant(start, only_process_constants=True) == 0
):
change_flag = True
start = None

if (
x.type.shape[dim] is not None
stop is not None
and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
):
idx_flag = True
change_flag = True
stop = None

if extract_constant(step, only_process_constants=True) == 1:
idx_flag = True
if (
step is not None
and extract_constant(step, only_process_constants=True) == 1
):
change_flag = True
step = None

new_idxs[dim] = slice(start, stop, step)

if idx_flag is True:
if change_flag is True or last_useless_slice < len(idxs):
out = x[tuple(new_idxs)]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)

return [out]
# elif last_useless_slice >= len(idxs):
# return [x]
# check if we removed something
# if last_useless_slice < len(idxs):
# new_idxs = idxs[:last_useless_slice]
# if new_idxs:
# new_subtensor = Subtensor(new_idxs)
# new_subtensor_inputs = get_slice_elements(
# new_idxs, lambda x: isinstance(x, Variable)
# )
# out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
# # Copy over previous output stacktrace
# copy_stack_trace(node.outputs, out)
# return [out]
# else:
# # Subtensor is not needed at all
# return [node.inputs[0]]


# fast_compile to allow opt subtensor(cast{float32}(make_vector))
Expand Down

0 comments on commit edc5ddd

Please sign in to comment.