diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 4d9cc83bb5..552e43059d 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -336,6 +336,7 @@ def local_subtensor_of_dot(fgraph, node): @register_useless @register_canonicalize @register_specialize +@register_stabilize @node_rewriter([Subtensor]) def local_useless_slice(fgraph, node): """ @@ -343,8 +344,13 @@ def local_useless_slice(fgraph, node): 1. X[0, :] -> X[0] 2. X[:] -> X + Also, rewrite Subtensor of the form: + X[0:7:1] -> X[None:None:None] + where X is a vector of length 7 + """ idxs = get_idx_list(node.inputs, node.op.idx_list) + x = node.inputs[0] if not idxs: return [node.inputs[0]] @@ -364,74 +370,63 @@ def local_useless_slice(fgraph, node): last_useless_slice -= 1 else: break - # check if we removed something - if last_useless_slice < len(idxs): - new_idxs = idxs[:last_useless_slice] - - if new_idxs: - new_subtensor = Subtensor(new_idxs) - new_subtensor_inputs = get_slice_elements( - new_idxs, lambda x: isinstance(x, Variable) - ) - out = new_subtensor(node.inputs[0], *new_subtensor_inputs) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, out) - return [out] - else: - # Subtensor is not needed at all - return [node.inputs[0]] - -@register_useless -@register_canonicalize -@register_stabilize -@register_specialize -@node_rewriter([Subtensor]) -def local_replace_slice(fgraph, node): - """ - Rewrite Subtensor of the form: - X[0:7:1] -> X[None:None:None] - where X is a vector of length 7 - - """ - idxs = get_idx_list(node.inputs, node.op.idx_list) - x = node.inputs[0] - - if not idxs: - return - - new_idxs = list(idxs) - idx_flag = False + new_idxs = list(idxs)[:last_useless_slice] + change_flag = False for dim, s in enumerate(new_idxs): - if not isinstance(s, slice): + if not isinstance(s, slice) or s == slice(None): continue start = s.start stop = s.stop step = s.step - if extract_constant(start, only_process_constants=True) == 0: - idx_flag = True + if ( + start is not None + and extract_constant(start, only_process_constants=True) == 0 + ): + change_flag = True start = None if ( - x.type.shape[dim] is not None + stop is not None + and x.type.shape[dim] is not None and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] ): - idx_flag = True + change_flag = True stop = None - if extract_constant(step, only_process_constants=True) == 1: - idx_flag = True + if ( + step is not None + and extract_constant(step, only_process_constants=True) == 1 + ): + change_flag = True step = None new_idxs[dim] = slice(start, stop, step) - if idx_flag is True: + if change_flag is True or last_useless_slice < len(idxs): out = x[tuple(new_idxs)] # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) return [out] + # elif last_useless_slice >= len(idxs): + # return [x] + # check if we removed something + # if last_useless_slice < len(idxs): + # new_idxs = idxs[:last_useless_slice] + # if new_idxs: + # new_subtensor = Subtensor(new_idxs) + # new_subtensor_inputs = get_slice_elements( + # new_idxs, lambda x: isinstance(x, Variable) + # ) + # out = new_subtensor(node.inputs[0], *new_subtensor_inputs) + # # Copy over previous output stacktrace + # copy_stack_trace(node.outputs, out) + # return [out] + # else: + # # Subtensor is not needed at all + # return [node.inputs[0]] # fast_compile to allow opt subtensor(cast{float32}(make_vector))