From 80d6640e0a2969bfdee9b1ff233f491179e37e9a Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Mon, 24 Jun 2024 02:33:55 -0400 Subject: [PATCH 01/12] Add a strict argument to all zips --- pytensor/compile/builders.py | 32 +++++--- pytensor/compile/debugmode.py | 16 ++-- pytensor/compile/function/pfunc.py | 6 +- pytensor/compile/function/types.py | 35 +++++---- pytensor/d3viz/formatting.py | 4 +- pytensor/gradient.py | 26 ++++--- pytensor/graph/basic.py | 18 ++--- pytensor/graph/op.py | 4 +- pytensor/graph/replace.py | 6 +- pytensor/graph/rewriting/basic.py | 40 ++++++---- pytensor/ifelse.py | 16 ++-- pytensor/link/basic.py | 26 ++++--- pytensor/link/c/basic.py | 26 ++++--- pytensor/link/c/cmodule.py | 2 +- pytensor/link/c/op.py | 2 +- pytensor/link/c/params_type.py | 2 +- pytensor/link/jax/dispatch/scan.py | 15 ++-- pytensor/link/jax/dispatch/shape.py | 2 +- pytensor/link/jax/dispatch/tensor_basic.py | 3 +- pytensor/link/jax/linker.py | 6 +- pytensor/link/numba/dispatch/basic.py | 4 +- .../link/numba/dispatch/cython_support.py | 7 +- pytensor/link/numba/dispatch/elemwise.py | 10 ++- pytensor/link/numba/dispatch/extra_ops.py | 2 +- pytensor/link/numba/dispatch/scalar.py | 4 +- pytensor/link/numba/dispatch/scan.py | 6 +- pytensor/link/numba/dispatch/slinalg.py | 2 +- pytensor/link/numba/dispatch/subtensor.py | 4 +- pytensor/link/numba/dispatch/tensor_basic.py | 8 +- .../link/numba/dispatch/vectorize_codegen.py | 35 +++++---- pytensor/link/utils.py | 8 +- pytensor/link/vm.py | 12 +-- pytensor/misc/check_blas.py | 2 +- pytensor/printing.py | 6 +- pytensor/scalar/basic.py | 20 ++--- pytensor/scalar/loop.py | 12 +-- pytensor/scan/basic.py | 8 +- pytensor/scan/op.py | 64 ++++++++++----- pytensor/scan/rewriting.py | 53 +++++++------ pytensor/scan/utils.py | 2 +- pytensor/sparse/basic.py | 4 +- pytensor/tensor/basic.py | 22 +++--- pytensor/tensor/blockwise.py | 54 ++++++++----- pytensor/tensor/conv/abstract_conv.py | 2 +- pytensor/tensor/elemwise.py | 78 ++++++++++++------- pytensor/tensor/elemwise_cgen.py | 30 +++++-- pytensor/tensor/extra_ops.py | 7 +- pytensor/tensor/functional.py | 6 +- pytensor/tensor/nlinalg.py | 4 +- pytensor/tensor/random/basic.py | 2 +- pytensor/tensor/random/op.py | 14 ++-- pytensor/tensor/random/rewriting/basic.py | 8 +- pytensor/tensor/random/utils.py | 22 ++++-- pytensor/tensor/rewriting/basic.py | 6 +- pytensor/tensor/rewriting/blas.py | 2 +- pytensor/tensor/rewriting/blockwise.py | 9 ++- pytensor/tensor/rewriting/elemwise.py | 14 ++-- pytensor/tensor/rewriting/math.py | 8 +- pytensor/tensor/rewriting/shape.py | 18 +++-- pytensor/tensor/rewriting/subtensor.py | 7 +- pytensor/tensor/shape.py | 12 ++- pytensor/tensor/slinalg.py | 2 +- pytensor/tensor/subtensor.py | 16 ++-- pytensor/tensor/type.py | 7 +- pytensor/tensor/utils.py | 2 +- pytensor/tensor/variable.py | 4 +- tests/compile/function/test_types.py | 8 +- tests/compile/test_builders.py | 2 +- tests/d3viz/test_formatting.py | 2 +- tests/graph/test_fg.py | 17 +++- tests/graph/utils.py | 4 +- tests/link/jax/test_basic.py | 2 +- tests/link/jax/test_random.py | 8 +- tests/link/numba/test_basic.py | 2 +- tests/link/numba/test_scan.py | 2 +- tests/link/pytorch/test_basic.py | 2 +- tests/link/test_link.py | 2 +- tests/scan/test_basic.py | 9 ++- tests/scan/test_printing.py | 16 ++-- tests/scan/test_utils.py | 16 ++-- tests/sparse/test_basic.py | 12 +-- tests/tensor/conv/test_abstract_conv.py | 53 +++++++++---- tests/tensor/random/rewriting/test_basic.py | 2 +- tests/tensor/random/test_utils.py | 2 +- tests/tensor/rewriting/test_elemwise.py | 12 +-- tests/tensor/rewriting/test_subtensor.py | 6 +- tests/tensor/test_basic.py | 35 +++++---- tests/tensor/test_blas.py | 4 +- tests/tensor/test_blockwise.py | 2 +- tests/tensor/test_casting.py | 1 + tests/tensor/test_elemwise.py | 3 + tests/tensor/test_extra_ops.py | 17 ++-- tests/tensor/test_nlinalg.py | 2 +- tests/tensor/test_subtensor.py | 16 ++-- tests/tensor/utils.py | 4 +- tests/test_gradient.py | 17 ++-- tests/test_ifelse.py | 4 +- tests/test_printing.py | 4 +- tests/typed_list/test_basic.py | 4 +- tests/unittest_tools.py | 4 +- 100 files changed, 746 insertions(+), 466 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 759c9b09bb..4379919f7a 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -43,7 +43,7 @@ def infer_shape(outs, inputs, input_shapes): # TODO: ShapeFeature should live elsewhere from pytensor.tensor.rewriting.shape import ShapeFeature - for inp, inp_shp in zip(inputs, input_shapes): + for inp, inp_shp in zip(inputs, input_shapes, strict=True): if inp_shp is not None and len(inp_shp) != inp.type.ndim: assert len(inp_shp) == inp.type.ndim @@ -51,7 +51,7 @@ def infer_shape(outs, inputs, input_shapes): shape_feature.on_attach(FunctionGraph([], [])) # Initialize shape_of with the input shapes - for inp, inp_shp in zip(inputs, input_shapes): + for inp, inp_shp in zip(inputs, input_shapes, strict=True): shape_feature.set_shape(inp, inp_shp) def local_traverse(out): @@ -108,7 +108,9 @@ def construct_nominal_fgraph( replacements = dict( zip( - inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs + inputs + implicit_shared_inputs, + dummy_inputs + dummy_implicit_shared_inputs, + strict=True, ) ) @@ -138,7 +140,7 @@ def construct_nominal_fgraph( NominalVariable(n, var.type) for n, var in enumerate(local_inputs) ) - fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) + fgraph.replace_all(zip(local_inputs, nominal_local_inputs, strict=True)) for i, inp in enumerate(fgraph.inputs): nom_inp = nominal_local_inputs[i] @@ -557,7 +559,9 @@ def lop_overrides(inps, grads): # compute non-overriding downsteam grads from upstreams grads # it's normal some input may be disconnected, thus the 'ignore' wrt = [ - lin for lin, gov in zip(inner_inputs, custom_input_grads) if gov is None + lin + for lin, gov in zip(inner_inputs, custom_input_grads, strict=True) + if gov is None ] default_input_grads = fn_grad(wrt=wrt) if wrt else [] input_grads = self._combine_list_overrides( @@ -648,7 +652,7 @@ def _build_and_cache_rop_op(self): f = [ output for output, custom_output_grad in zip( - inner_outputs, custom_output_grads + inner_outputs, custom_output_grads, strict=True ) if custom_output_grad is None ] @@ -728,18 +732,24 @@ def make_node(self, *inputs): non_shared_inputs = [ inp_t.filter_variable(inp) - for inp, inp_t in zip(non_shared_inputs, self.input_types) + for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True) ] new_shared_inputs = inputs[num_expected_inps:] - inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs)) + inner_and_input_shareds = list( + zip(self.shared_inputs, new_shared_inputs, strict=True) + ) if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds): # The shared variables are not equal to the original shared # variables, so we construct a new `Op` that uses the new shared # variables instead. replace = dict( - zip(self.inner_inputs[num_expected_inps:], new_shared_inputs) + zip( + self.inner_inputs[num_expected_inps:], + new_shared_inputs, + strict=True, + ) ) # If the new shared variables are inconsistent with the inner-graph, @@ -806,7 +816,7 @@ def infer_shape(self, fgraph, node, shapes): # each shape call. PyTensor optimizer will clean this up later, but this # will make extra work for the optimizer. - repl = dict(zip(self.inner_inputs, node.inputs)) + repl = dict(zip(self.inner_inputs, node.inputs, strict=True)) clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)] cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl) ret = [] @@ -848,5 +858,5 @@ def clone(self): def perform(self, node, inputs, outputs): variables = self.fn(*inputs) assert len(variables) == len(outputs) - for output, variable in zip(outputs, variables): + for output, variable in zip(outputs, variables, strict=True): output[0] = variable diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index bfcaf1ecf0..cc1a5b225a 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -865,7 +865,7 @@ def _get_preallocated_maps( # except if broadcastable, or for dimensions above # config.DebugMode__check_preallocated_output_ndim buf_shape = [] - for s, b in zip(r_vals[r].shape, r.broadcastable): + for s, b in zip(r_vals[r].shape, r.broadcastable, strict=True): if b or ((r.ndim - len(buf_shape)) > check_ndim): buf_shape.append(s) else: @@ -943,7 +943,7 @@ def _get_preallocated_maps( r_shape_diff = shape_diff[: r.ndim] new_buf_shape = [ max((s + sd), 0) - for s, sd in zip(r_vals[r].shape, r_shape_diff) + for s, sd in zip(r_vals[r].shape, r_shape_diff, strict=True) ] new_buf = np.empty(new_buf_shape, dtype=r.type.dtype) new_buf[...] = np.asarray(def_val).astype(r.type.dtype) @@ -1575,7 +1575,7 @@ def f(): # try: # compute the value of all variables for i, (thunk_py, thunk_c, node) in enumerate( - zip(thunks_py, thunks_c, order) + zip(thunks_py, thunks_c, order, strict=True) ): _logger.debug(f"{i} - starting node {i} {node}") @@ -1855,7 +1855,7 @@ def thunk(): assert s[0] is None # store our output variables to their respective storage lists - for output, storage in zip(fgraph.outputs, output_storage): + for output, storage in zip(fgraph.outputs, output_storage, strict=True): storage[0] = r_vals[output] # transfer all inputs back to their respective storage lists @@ -1931,11 +1931,11 @@ def deco(): f, [ Container(input, storage, readonly=False) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks_py, order, @@ -2122,7 +2122,9 @@ def __init__( no_borrow = [ output - for output, spec in zip(fgraph.outputs, outputs + additional_outputs) + for output, spec in zip( + fgraph.outputs, outputs + additional_outputs, strict=True + ) if not spec.borrow ] if no_borrow: diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 49a6840719..935c77219a 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -603,7 +603,7 @@ def construct_pfunc_ins_and_outs( new_inputs = [] - for i, iv in zip(inputs, input_variables): + for i, iv in zip(inputs, input_variables, strict=True): new_i = copy(i) new_i.variable = iv @@ -637,13 +637,13 @@ def construct_pfunc_ins_and_outs( assert len(fgraph.inputs) == len(inputs) assert len(fgraph.outputs) == len(outputs) - for fg_inp, inp in zip(fgraph.inputs, inputs): + for fg_inp, inp in zip(fgraph.inputs, inputs, strict=True): if fg_inp != getattr(inp, "variable", inp): raise ValueError( f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}" ) - for fg_out, out in zip(fgraph.outputs, outputs): + for fg_out, out in zip(fgraph.outputs, outputs, strict=True): if fg_out != getattr(out, "variable", out): raise ValueError( f"`fgraph`'s output does not match the provided output: {fg_out}, {out}" diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 43199328a3..e34bef35c6 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -241,7 +241,7 @@ def std_fgraph( fgraph.attach_feature( Supervisor( input - for spec, input in zip(input_specs, fgraph.inputs) + for spec, input in zip(input_specs, fgraph.inputs, strict=True) if not ( spec.mutable or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) @@ -422,7 +422,7 @@ def distribute(indices, cs, value): # this loop works by modifying the elements (as variable c) of # self.input_storage inplace. for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate( - zip(self.indices, defaults) + zip(self.indices, defaults, strict=True) ): if indices is None: # containers is being used as a stack. Here we pop off @@ -651,7 +651,7 @@ def checkSV(sv_ori, sv_rpl): else: outs = list(map(SymbolicOutput, fg_cpy.outputs)) - for out_ori, out_cpy in zip(maker.outputs, outs): + for out_ori, out_cpy in zip(maker.outputs, outs, strict=False): out_cpy.borrow = out_ori.borrow # swap SharedVariable @@ -664,7 +664,7 @@ def checkSV(sv_ori, sv_rpl): raise ValueError(f"SharedVariable: {sv.name} not found") # Swap SharedVariable in fgraph and In instances - for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs)): + for index, (i, in_v) in enumerate(zip(ins, fg_cpy.inputs, strict=True)): # Variables in maker.inputs are defined by user, therefore we # use them to make comparison and do the mapping. # Otherwise we don't touch them. @@ -688,7 +688,7 @@ def checkSV(sv_ori, sv_rpl): # Delete update if needed rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()} - for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)): + for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs, strict=True)): inp.variable = in_var if not delete_updates and inp.update is not None: out_idx = rev_update_mapping[n] @@ -748,7 +748,11 @@ def checkSV(sv_ori, sv_rpl): ).create(input_storage, storage_map=new_storage_map) for in_ori, in_cpy, ori, cpy in zip( - maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage + maker.inputs, + f_cpy.maker.inputs, + self.input_storage, + f_cpy.input_storage, + strict=True, ): # Share immutable ShareVariable and constant input's storage swapped = swap is not None and in_ori.variable in swap @@ -908,6 +912,7 @@ def restore_defaults(): self.input_storage[k].storage[0] for k in args_share_memory[j] ], + strict=True, ) if any( ( @@ -1000,7 +1005,7 @@ def restore_defaults(): if getattr(self.vm, "allow_gc", False): assert len(self.output_storage) == len(self.maker.fgraph.outputs) for o_container, o_variable in zip( - self.output_storage, self.maker.fgraph.outputs + self.output_storage, self.maker.fgraph.outputs, strict=True ): if o_variable.owner is not None: # this node is the variable of computation @@ -1012,7 +1017,7 @@ def restore_defaults(): if getattr(self.vm, "need_update_inputs", True): # Update the inputs that have an update function for input, storage in reversed( - list(zip(self.maker.expanded_inputs, self.input_storage)) + list(zip(self.maker.expanded_inputs, self.input_storage, strict=True)) ): if input.update is not None: storage.data = outputs.pop() @@ -1047,7 +1052,7 @@ def restore_defaults(): assert len(self.output_keys) == len(outputs) if output_subset is None: - return dict(zip(self.output_keys, outputs)) + return dict(zip(self.output_keys, outputs, strict=True)) else: return { self.output_keys[index]: outputs[index] @@ -1115,7 +1120,7 @@ def _pickle_Function(f): input_storage = [] for (input, indices, inputs), (required, refeed, default) in zip( - f.indices, f.defaults + f.indices, f.defaults, strict=True ): input_storage.append(ins[0]) del ins[0] @@ -1157,7 +1162,7 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False): f = maker.create(input_storage) assert len(f.input_storage) == len(inputs_data) - for container, x in zip(f.input_storage, inputs_data): + for container, x in zip(f.input_storage, inputs_data, strict=True): assert ( (container.data is x) or (isinstance(x, np.ndarray) and (container.data == x).all()) @@ -1191,7 +1196,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): reason = "insert_deepcopy" updated_fgraph_inputs = { fgraph_i - for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs) + for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True) if getattr(i, "update", False) } @@ -1528,7 +1533,9 @@ def __init__( # return the internal storage pointer. no_borrow = [ output - for output, spec in zip(fgraph.outputs, outputs + found_updates) + for output, spec in zip( + fgraph.outputs, outputs + found_updates, strict=True + ) if not spec.borrow ] @@ -1595,7 +1602,7 @@ def create(self, input_storage=None, storage_map=None): # defaults lists. assert len(self.indices) == len(input_storage) for i, ((input, indices, subinputs), input_storage_i) in enumerate( - zip(self.indices, input_storage) + zip(self.indices, input_storage, strict=True) ): # Replace any default value given as a variable by its # container. Note that this makes sense only in the diff --git a/pytensor/d3viz/formatting.py b/pytensor/d3viz/formatting.py index 80936a513d..b9fb8ee5a5 100644 --- a/pytensor/d3viz/formatting.py +++ b/pytensor/d3viz/formatting.py @@ -244,14 +244,14 @@ def format_map(m): ext_inputs = [self.__node_id(x) for x in node.inputs] int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs] assert len(ext_inputs) == len(int_inputs) - h = format_map(zip(ext_inputs, int_inputs)) + h = format_map(zip(ext_inputs, int_inputs, strict=True)) pd_node.get_attributes()["subg_map_inputs"] = h # Outputs mapping ext_outputs = [self.__node_id(x) for x in node.outputs] int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs] assert len(ext_outputs) == len(int_outputs) - h = format_map(zip(int_outputs, ext_outputs)) + h = format_map(zip(int_outputs, ext_outputs, strict=True)) pd_node.get_attributes()["subg_map_outputs"] = h return graph diff --git a/pytensor/gradient.py b/pytensor/gradient.py index abf80bff43..65301dbcca 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -213,7 +213,7 @@ def Rop( # Check that each element of wrt corresponds to an element # of eval_points with the same dimensionality. - for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)): + for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): try: if wrt_elem.type.ndim != eval_point.type.ndim: raise ValueError( @@ -262,7 +262,7 @@ def _traverse(node): seen_nodes[inp.owner][inp.owner.outputs.index(inp)] ) same_type_eval_points = [] - for x, y in zip(inputs, local_eval_points): + for x, y in zip(inputs, local_eval_points, strict=True): if y is not None: if not isinstance(x, Variable): x = pytensor.tensor.as_tensor_variable(x) @@ -399,7 +399,7 @@ def Lop( _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] assert len(_f) == len(grads) - known = dict(zip(_f, grads)) + known = dict(zip(_f, grads, strict=True)) ret = grad( cost=None, @@ -778,7 +778,7 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False): for i in range(len(grads)): grads[i] += cost_grads[i] - pgrads = dict(zip(params, grads)) + pgrads = dict(zip(params, grads, strict=True)) # separate wrt from end grads: wrt_grads = [pgrads[k] for k in wrt] end_grads = [pgrads[k] for k in end] @@ -1044,7 +1044,7 @@ def access_term_cache(node): any( input_to_output and output_to_cost for input_to_output, output_to_cost in zip( - input_to_outputs, outputs_connected + input_to_outputs, outputs_connected, strict=True ) ) ) @@ -1069,7 +1069,7 @@ def access_term_cache(node): not any( in_to_out and out_to_cost and not out_nan for in_to_out, out_to_cost, out_nan in zip( - in_to_outs, outputs_connected, ograd_is_nan + in_to_outs, outputs_connected, ograd_is_nan, strict=True ) ) ) @@ -1129,7 +1129,7 @@ def try_to_copy_if_needed(var): # DO NOT force integer variables to have integer dtype. # This is a violation of the op contract. new_output_grads = [] - for o, og in zip(node.outputs, output_grads): + for o, og in zip(node.outputs, output_grads, strict=True): o_dt = getattr(o.type, "dtype", None) og_dt = getattr(og.type, "dtype", None) if ( @@ -1143,7 +1143,7 @@ def try_to_copy_if_needed(var): # Make sure that, if new_output_grads[i] has a floating point # dtype, it is the same dtype as outputs[i] - for o, ng in zip(node.outputs, new_output_grads): + for o, ng in zip(node.outputs, new_output_grads, strict=True): o_dt = getattr(o.type, "dtype", None) ng_dt = getattr(ng.type, "dtype", None) if ( @@ -1165,7 +1165,9 @@ def try_to_copy_if_needed(var): # by the user, not computed by Op.grad, and some gradients are # only computed and returned, but never passed as another # node's output grads. - for idx, packed in enumerate(zip(node.outputs, new_output_grads)): + for idx, packed in enumerate( + zip(node.outputs, new_output_grads, strict=True) + ): orig_output, new_output_grad = packed if not hasattr(orig_output, "shape"): continue @@ -1231,7 +1233,7 @@ def try_to_copy_if_needed(var): not in [ in_to_out and out_to_cost and not out_int for in_to_out, out_to_cost, out_int in zip( - in_to_outs, outputs_connected, output_is_int + in_to_outs, outputs_connected, output_is_int, strict=True ) ] ) @@ -1312,7 +1314,7 @@ def try_to_copy_if_needed(var): # Check that op.connection_pattern matches the connectivity # logic driving the op.grad method for i, (ipt, ig, connected) in enumerate( - zip(inputs, input_grads, inputs_connected) + zip(inputs, input_grads, inputs_connected, strict=True) ): actually_connected = not isinstance(ig.type, DisconnectedType) @@ -1599,7 +1601,7 @@ def abs_rel_errors(self, g_pt): if len(g_pt) != len(self.gf): raise ValueError("argument has wrong number of elements", len(g_pt)) errs = [] - for i, (a, b) in enumerate(zip(g_pt, self.gf)): + for i, (a, b) in enumerate(zip(g_pt, self.gf, strict=True)): if a.shape != b.shape: raise ValueError( f"argument element {i} has wrong shapes {a.shape}, {b.shape}" diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 2ffd101c23..6ddff2eeeb 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -272,7 +272,7 @@ def clone_with_new_inputs( # as the output type depends on the input values and not just their types output_type_depends_on_input_value = self.op._output_type_depends_on_input_value - for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): + for i, (curr, new) in enumerate(zip(self.inputs, new_inputs, strict=True)): # Check if the input type changed or if the Op has output types that depend on input values if (curr.type != new.type) or output_type_depends_on_input_value: # In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one. @@ -1302,7 +1302,7 @@ def clone_node_and_cache( if new_node.op is not node.op: clone_d.setdefault(node.op, new_node.op) - for old_o, new_o in zip(node.outputs, new_node.outputs): + for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True): clone_d.setdefault(old_o, new_o) return new_node @@ -1891,7 +1891,7 @@ def equal_computations( if in_ys is None: in_ys = [] - for x, y in zip(xs, ys): + for x, y in zip(xs, ys, strict=True): if not isinstance(x, Variable) and not isinstance(y, Variable): return np.array_equal(x, y) if not isinstance(x, Variable): @@ -1914,13 +1914,13 @@ def equal_computations( if len(in_xs) != len(in_ys): return False - for _x, _y in zip(in_xs, in_ys): + for _x, _y in zip(in_xs, in_ys, strict=True): if not (_y.type.in_same_class(_x.type)): return False - common = set(zip(in_xs, in_ys)) + common = set(zip(in_xs, in_ys, strict=True)) different: set[tuple[Variable, Variable]] = set() - for dx, dy in zip(xs, ys): + for dx, dy in zip(xs, ys, strict=True): assert isinstance(dx, Variable) # We checked above that both dx and dy have an owner or not if dx.owner is None: @@ -1956,7 +1956,7 @@ def compare_nodes(nd_x, nd_y, common, different): return False else: all_in_common = True - for dx, dy in zip(nd_x.outputs, nd_y.outputs): + for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True): if (dx, dy) in different: return False if (dx, dy) not in common: @@ -1966,7 +1966,7 @@ def compare_nodes(nd_x, nd_y, common, different): return True # Compare the individual inputs for equality - for dx, dy in zip(nd_x.inputs, nd_y.inputs): + for dx, dy in zip(nd_x.inputs, nd_y.inputs, strict=True): if (dx, dy) not in common: # Equality between the variables is unknown, compare # their respective owners, if they have some @@ -2001,7 +2001,7 @@ def compare_nodes(nd_x, nd_y, common, different): # If the code reaches this statement then the inputs are pair-wise # equivalent so the outputs of the current nodes are also # pair-wise equivalents - for dx, dy in zip(nd_x.outputs, nd_y.outputs): + for dx, dy in zip(nd_x.outputs, nd_y.outputs, strict=True): common.add((dx, dy)) return True diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 160a65dd7a..5f9f19b947 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -231,14 +231,14 @@ def make_node(self, *inputs: Variable) -> Apply: ) if not all( expected_type.is_super(var.type) - for var, expected_type in zip(inputs, self.itypes) + for var, expected_type in zip(inputs, self.itypes, strict=True) ): raise TypeError( f"Invalid input types for Op {self}:\n" + "\n".join( f"Input {i}/{len(inputs)}: Expected {inp}, got {out}" for i, (inp, out) in enumerate( - zip(self.itypes, (inp.type for inp in inputs)), + zip(self.itypes, (inp.type for inp in inputs), strict=True), start=1, ) if inp != out diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 9b12192452..5092d55e6b 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -78,7 +78,7 @@ def clone_replace( items = list(_format_replace(replace).items()) tmp_replace = [(x, x.type()) for x, y in items] - new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)] + new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items, strict=True)] _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds) # TODO Explain why we call it twice ?! @@ -295,11 +295,11 @@ def vectorize_graph( inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys()) new_inputs = [replace.get(inp, inp) for inp in inputs] - vect_vars = dict(zip(inputs, new_inputs)) + vect_vars = dict(zip(inputs, new_inputs, strict=True)) for node in io_toposort(inputs, seq_outputs): vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] vect_node = vectorize_node(node, *vect_inputs) - for output, vect_output in zip(node.outputs, vect_node.outputs): + for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): if output in vect_vars: # This can happen when some outputs of a multi-output node are given a replacement, # while some of the remaining outputs are still needed in the graph. diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 2bc0508f7d..faec736c98 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -399,14 +399,14 @@ def print_profile(cls, stream, prof, level=0): file=stream, ) ll = [] - for rewrite, nb_n in zip(rewrites, nb_nodes): + for rewrite, nb_n in zip(rewrites, nb_nodes, strict=True): if hasattr(rewrite, "__name__"): name = rewrite.__name__ else: name = rewrite.name idx = rewrites.index(rewrite) ll.append((name, rewrite.__class__.__name__, idx, *nb_n)) - lll = sorted(zip(prof, ll), key=lambda a: a[0]) + lll = sorted(zip(prof, ll, strict=True), key=lambda a: a[0]) for t, rewrite in lll[::-1]: i = rewrite[2] @@ -480,7 +480,8 @@ def merge_profile(prof1, prof2): new_rewrite = SequentialGraphRewriter(*new_l) new_nb_nodes = [ - (p1[0] + p2[0], p1[1] + p2[1]) for p1, p2 in zip(prof1[8], prof2[8]) + (p1[0] + p2[0], p1[1] + p2[1]) + for p1, p2 in zip(prof1[8], prof2[8], strict=True) ] new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :]) new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :]) @@ -635,7 +636,7 @@ def process_node(self, fgraph, node): inputs_match = all( node_in is cand_in - for node_in, cand_in in zip(node.inputs, candidate.inputs) + for node_in, cand_in in zip(node.inputs, candidate.inputs, strict=True) ) if inputs_match and node.op == candidate.op: @@ -649,6 +650,7 @@ def process_node(self, fgraph, node): node.outputs, candidate.outputs, ["merge"] * len(node.outputs), + strict=True, ) ) @@ -721,7 +723,9 @@ def apply(self, fgraph): inputs_match = all( node_in is cand_in for node_in, cand_in in zip( - var.owner.inputs, candidate_var.owner.inputs + var.owner.inputs, + candidate_var.owner.inputs, + strict=True, ) ) @@ -1434,7 +1438,7 @@ def transform(self, fgraph, node): repl = self.op2.make_node(*node.inputs) if self.transfer_tags: repl.tag = copy.copy(node.tag) - for output, new_output in zip(node.outputs, repl.outputs): + for output, new_output in zip(node.outputs, repl.outputs, strict=True): new_output.tag = copy.copy(output.tag) return repl.outputs @@ -1614,7 +1618,7 @@ def transform(self, fgraph, node, get_nodes=True): for real_node in self.get_nodes(fgraph, node): ret = self.transform(fgraph, real_node, get_nodes=False) if ret is not False and ret is not None: - return dict(zip(real_node.outputs, ret)) + return dict(zip(real_node.outputs, ret, strict=True)) if node.op != self.op: return False @@ -1646,7 +1650,7 @@ def transform(self, fgraph, node, get_nodes=True): len(node.outputs) == len(ret.owner.outputs) and all( o.type.is_super(new_o.type) - for o, new_o in zip(node.outputs, ret.owner.outputs) + for o, new_o in zip(node.outputs, ret.owner.outputs, strict=True) ) ): return False @@ -1935,7 +1939,7 @@ def process_node( ) # None in the replacement mean that this variable isn't used # and we want to remove it - for r, rnew in zip(old_vars, replacements): + for r, rnew in zip(old_vars, replacements, strict=True): if rnew is None and len(fgraph.clients[r]) > 0: raise ValueError( f"Node rewriter {node_rewriter} tried to remove a variable" @@ -1945,7 +1949,7 @@ def process_node( # the replacement repl_pairs = [ (r, rnew) - for r, rnew in zip(old_vars, replacements) + for r, rnew in zip(old_vars, replacements, strict=True) if rnew is not r and rnew is not None ] @@ -2628,17 +2632,23 @@ def print_profile(cls, stream, prof, level=0): print(blanc, "Global, final, and clean up rewriters", file=stream) for i in range(len(loop_timing)): print(blanc, f"Iter {int(i)}", file=stream) - for o, prof in zip(rewrite.global_rewriters, global_sub_profs[i]): + for o, prof in zip( + rewrite.global_rewriters, global_sub_profs[i], strict=True + ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: print(blanc, "merge not implemented for ", o) - for o, prof in zip(rewrite.final_rewriters, final_sub_profs[i]): + for o, prof in zip( + rewrite.final_rewriters, final_sub_profs[i], strict=True + ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: print(blanc, "merge not implemented for ", o) - for o, prof in zip(rewrite.cleanup_rewriters, cleanup_sub_profs[i]): + for o, prof in zip( + rewrite.cleanup_rewriters, cleanup_sub_profs[i], strict=True + ): try: o.print_profile(stream, prof, level + 2) except NotImplementedError: @@ -2856,7 +2866,7 @@ def local_recursive_function( outs, rewritten_vars = local_recursive_function( rewrite_list, inp, rewritten_vars, depth + 1 ) - for k, v in zip(inp.owner.outputs, outs): + for k, v in zip(inp.owner.outputs, outs, strict=True): rewritten_vars[k] = v nw_in = outs[inp.owner.outputs.index(inp)] @@ -2874,7 +2884,7 @@ def local_recursive_function( if ret is not False and ret is not None: assert isinstance(ret, Sequence) assert len(ret) == len(node.outputs), rewrite - for k, v in zip(node.outputs, ret): + for k, v in zip(node.outputs, ret, strict=True): rewritten_vars[k] = v results = ret if ret[0].owner: diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index b41b5f460d..6c85bbce8c 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -170,7 +170,9 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any): output_vars = [] new_inputs_true_branch = [] new_inputs_false_branch = [] - for input_t, input_f in zip(inputs_true_branch, inputs_false_branch): + for input_t, input_f in zip( + inputs_true_branch, inputs_false_branch, strict=True + ): if not isinstance(input_t, Variable): input_t = as_symbolic(input_t) if not isinstance(input_f, Variable): @@ -207,7 +209,9 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any): # allowed to have distinct shapes from either branch new_shape = tuple( s_t if s_t == s_f else None - for s_t, s_f in zip(input_t.type.shape, input_f.type.shape) + for s_t, s_f in zip( + input_t.type.shape, input_f.type.shape, strict=True + ) ) # TODO FIXME: The presence of this keyword is a strong # assumption. Find something that's guaranteed by the/a @@ -301,7 +305,7 @@ def thunk(): if len(ls) > 0: return ls else: - for out, t in zip(outputs, input_true_branch): + for out, t in zip(outputs, input_true_branch, strict=True): compute_map[out][0] = 1 val = storage_map[t][0] if self.as_view: @@ -321,7 +325,7 @@ def thunk(): if len(ls) > 0: return ls else: - for out, f in zip(outputs, inputs_false_branch): + for out, f in zip(outputs, inputs_false_branch, strict=True): compute_map[out][0] = 1 # can't view both outputs unless destroyhandler # improves @@ -637,7 +641,7 @@ def apply(self, fgraph): old_outs += [proposal.outputs] else: old_outs += proposal.outputs - pairs = list(zip(old_outs, new_outs)) + pairs = list(zip(old_outs, new_outs, strict=True)) fgraph.replace_all_validate(pairs, reason="cond_merge") @@ -736,7 +740,7 @@ def cond_merge_random_op(fgraph, main_node): old_outs += [proposal.outputs] else: old_outs += proposal.outputs - pairs = list(zip(old_outs, new_outs)) + pairs = list(zip(old_outs, new_outs, strict=True)) main_outs = clone_replace(main_node.outputs, replace=pairs) return main_outs diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 30154a98ce..ea069c51cf 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -385,11 +385,11 @@ def make_all( f, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, order, @@ -509,7 +509,9 @@ def make_thunk(self, **kwargs): kwargs.pop("input_storage", None) make_all += [x.make_all(**kwargs) for x in self.linkers[1:]] - fns, input_lists, output_lists, thunk_lists, order_lists = zip(*make_all) + fns, input_lists, output_lists, thunk_lists, order_lists = zip( + *make_all, strict=True + ) order_list0 = order_lists[0] for order_list in order_lists[1:]: @@ -521,12 +523,12 @@ def make_thunk(self, **kwargs): inputs0 = input_lists[0] outputs0 = output_lists[0] - thunk_groups = list(zip(*thunk_lists)) - order = [x[0] for x in zip(*order_lists)] + thunk_groups = list(zip(*thunk_lists, strict=True)) + order = [x[0] for x in zip(*order_lists, strict=True)] to_reset = [ thunk.outputs[j] - for thunks, node in zip(thunk_groups, order) + for thunks, node in zip(thunk_groups, order, strict=True) for j, output in enumerate(node.outputs) if output in no_recycling for thunk in thunks @@ -537,12 +539,12 @@ def make_thunk(self, **kwargs): def f(): for inputs in input_lists[1:]: - for input1, input2 in zip(inputs0, inputs): + for input1, input2 in zip(inputs0, inputs, strict=True): input2.storage[0] = copy(input1.storage[0]) for x in to_reset: x[0] = None pre(self, [input.data for input in input_lists[0]], order, thunk_groups) - for i, (thunks, node) in enumerate(zip(thunk_groups, order)): + for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)): try: wrapper(self.fgraph, i, node, *thunks) except Exception: @@ -664,7 +666,9 @@ def thunk( ): outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) - for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): + for o_var, o_storage, o_val in zip( + fgraph.outputs, thunk_outputs, outputs, strict=True + ): compute_map[o_var][0] = True o_storage[0] = self.output_filter(o_var, o_val) return outputs @@ -730,11 +734,11 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): fn, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, nodes, diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index 417580e09c..6fb4c8378e 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1112,11 +1112,15 @@ def __compile__( module, [ Container(input, storage) - for input, storage in zip(self.fgraph.inputs, input_storage) + for input, storage in zip( + self.fgraph.inputs, input_storage, strict=True + ) ], [ Container(output, storage, readonly=True) - for output, storage in zip(self.fgraph.outputs, output_storage) + for output, storage in zip( + self.fgraph.outputs, output_storage, strict=True + ) ], error_storage, ) @@ -1887,11 +1891,11 @@ def make_all( f, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, order, @@ -1989,22 +1993,26 @@ def make_thunk(self, **kwargs): ) def f(): - for input1, input2 in zip(i1, i2): + for input1, input2 in zip(i1, i2, strict=True): # Set the inputs to be the same in both branches. # The copy is necessary in order for inplace ops not to # interfere. input2.storage[0] = copy(input1.storage[0]) - for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2): - for output, storage in zip(node1.outputs, thunk1.outputs): + for thunk1, thunk2, node1, node2 in zip( + thunks1, thunks2, order1, order2, strict=True + ): + for output, storage in zip(node1.outputs, thunk1.outputs, strict=True): if output in no_recycling: storage[0] = None - for output, storage in zip(node2.outputs, thunk2.outputs): + for output, storage in zip(node2.outputs, thunk2.outputs, strict=True): if output in no_recycling: storage[0] = None try: thunk1() thunk2() - for output1, output2 in zip(thunk1.outputs, thunk2.outputs): + for output1, output2 in zip( + thunk1.outputs, thunk2.outputs, strict=True + ): self.checker(output1, output2) except Exception: raise_with_op(fgraph, node1) diff --git a/pytensor/link/c/cmodule.py b/pytensor/link/c/cmodule.py index d206c650e0..7416da1e24 100644 --- a/pytensor/link/c/cmodule.py +++ b/pytensor/link/c/cmodule.py @@ -2446,7 +2446,7 @@ def patch_ldflags(flag_list: list[str]) -> list[str]: if not libs: return flag_list libs = GCC_compiler.linking_patch(lib_dirs, libs) - for flag_idx, lib in zip(flag_idxs, libs): + for flag_idx, lib in zip(flag_idxs, libs, strict=True): flag_list[flag_idx] = lib return flag_list diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 61c90d2b10..bc446556c0 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -59,7 +59,7 @@ def make_c_thunk( e = FunctionGraph(node.inputs, node.outputs) e_no_recycling = [ new_o - for (new_o, old_o) in zip(e.outputs, node.outputs) + for (new_o, old_o) in zip(e.outputs, node.outputs, strict=True) if old_o in no_recycling ] cl = pytensor.link.c.basic.CLinker().accept(e, no_recycling=e_no_recycling) diff --git a/pytensor/link/c/params_type.py b/pytensor/link/c/params_type.py index 9b0d106d8d..7cd8d7f742 100644 --- a/pytensor/link/c/params_type.py +++ b/pytensor/link/c/params_type.py @@ -704,7 +704,7 @@ def c_support_code(self, **kwargs): c_init_list = [] c_cleanup_list = [] c_extract_list = [] - for attribute_name, type_instance in zip(self.fields, self.types): + for attribute_name, type_instance in zip(self.fields, self.types, strict=True): try: # c_support_code() may return a code string or a list of code strings. support_code = type_instance.c_support_code() diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index b82fd67e3f..d98328f0cf 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -30,7 +30,9 @@ def scan(*outer_inputs): seqs = op.outer_seqs(outer_inputs) # JAX `xs` mit_sot_init = [] - for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)): + for tap, seq in zip( + op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True + ): init_slice = seq[: abs(min(tap))] mit_sot_init.append(init_slice) @@ -61,7 +63,9 @@ def jax_args_to_inner_func_args(carry, x): inner_seqs = x mit_sot_flatten = [] - for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices): + for array, index in zip( + inner_mit_sot, op.info.mit_sot_in_slices, strict=True + ): mit_sot_flatten.extend(array[jnp.array(index)]) inner_scan_inputs = [ @@ -98,8 +102,7 @@ def inner_func_outs_to_jax_outs( inner_mit_sot_new = [ jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0) for old_mit_sot, new_val in zip( - inner_mit_sot, - inner_mit_sot_outs, + inner_mit_sot, inner_mit_sot_outs, strict=True ) ] @@ -152,7 +155,9 @@ def get_partial_traces(traces): + op.outer_nitsot(outer_inputs) ) partial_traces = [] - for init_state, trace, buffer in zip(init_states, traces, buffers): + for init_state, trace, buffer in zip( + init_states, traces, buffers, strict=True + ): if init_state is not None: # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer trace = jnp.atleast_1d(trace) diff --git a/pytensor/link/jax/dispatch/shape.py b/pytensor/link/jax/dispatch/shape.py index 6d75b7ae6f..6d809252a7 100644 --- a/pytensor/link/jax/dispatch/shape.py +++ b/pytensor/link/jax/dispatch/shape.py @@ -96,7 +96,7 @@ def shape_i(x): def jax_funcify_SpecifyShape(op, node, **kwargs): def specifyshape(x, *shape): assert x.ndim == len(shape) - for actual, expected in zip(x.shape, shape): + for actual, expected in zip(x.shape, shape, strict=True): if expected is None: continue if actual != expected: diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index bf1a93ce5b..9cd9870616 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -200,7 +200,8 @@ def jax_funcify_Tri(op, node, **kwargs): def tri(*args): # args is N, M, k args = [ - x if const_x is None else const_x for x, const_x in zip(args, const_args) + x if const_x is None else const_x + for x, const_x in zip(args, const_args, strict=True) ] return jnp.tri(*args, dtype=op.dtype) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 667806a80f..2450b24150 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -35,12 +35,14 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): ] fgraph.replace_all( - zip(shared_rng_inputs, new_shared_rng_inputs), + zip(shared_rng_inputs, new_shared_rng_inputs, strict=True), import_missing=True, reason="JAXLinker.fgraph_convert", ) - for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs): + for old_inp, new_inp in zip( + shared_rng_inputs, new_shared_rng_inputs, strict=True + ): new_inp_storage = [new_inp.get_value(borrow=True)] storage_map[new_inp] = new_inp_storage old_inp_storage = storage_map.pop(old_inp) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2b934d049c..f30cf2cc80 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -403,7 +403,7 @@ def py_perform_return(inputs): def py_perform_return(inputs): return tuple( out_type.filter(out[0]) - for out_type, out in zip(output_types, py_perform(inputs)) + for out_type, out in zip(output_types, py_perform(inputs), strict=True) ) @numba_njit @@ -566,7 +566,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): func_conditions = [ f"assert x.shape[{i}] == {shape_input_names}" for i, (shape_input, shape_input_names) in enumerate( - zip(shape_inputs, shape_input_names) + zip(shape_inputs, shape_input_names, strict=True) ) if shape_input is not NoneConst ] diff --git a/pytensor/link/numba/dispatch/cython_support.py b/pytensor/link/numba/dispatch/cython_support.py index 36b3e80850..c62594ce94 100644 --- a/pytensor/link/numba/dispatch/cython_support.py +++ b/pytensor/link/numba/dispatch/cython_support.py @@ -45,7 +45,7 @@ def arg_numba_types(self) -> list[DTypeLike]: def can_cast_args(self, args: list[DTypeLike]) -> bool: ok = True count = 0 - for name, dtype in zip(self.arg_names, self.arg_dtypes): + for name, dtype in zip(self.arg_names, self.arg_dtypes, strict=True): if name == "__pyx_skip_dispatch": continue if len(args) <= count: @@ -164,7 +164,10 @@ def __wrapper_address__(self): return self._func_ptr def __call__(self, *args, **kwargs): - args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)] + args = [ + dtype(arg) + for arg, dtype in zip(args, self._signature.arg_dtypes, strict=True) + ] if self.has_pyx_skip_dispatch(): output = self._pyfunc(*args[:-1], **kwargs) else: diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index b6f806bb4c..64768fd366 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -515,8 +515,10 @@ def elemwise(*inputs): inputs = [np.asarray(input) for input in inputs] inputs_bc = np.broadcast_arrays(*inputs) shape = inputs[0].shape - for input, bc in zip(inputs, input_bc_patterns): - for length, allow_bc, iter_length in zip(input.shape, bc, shape): + for input, bc in zip(inputs, input_bc_patterns, strict=True): + for length, allow_bc, iter_length in zip( + input.shape, bc, shape, strict=True + ): if length == 1 and shape and iter_length != 1 and not allow_bc: raise ValueError("Broadcast not allowed.") @@ -527,11 +529,11 @@ def elemwise(*inputs): outs = scalar_op_fn(*vals) if not isinstance(outs, tuple): outs = (outs,) - for out, out_val in zip(outputs, outs): + for out, out_val in zip(outputs, outs, strict=True): out[idx] = out_val outputs_summed = [] - for output, bc in zip(outputs, output_bc_patterns): + for output, bc in zip(outputs, output_bc_patterns, strict=True): axes = tuple(np.nonzero(bc)[0]) outputs_summed.append(output.sum(axes, keepdims=True)) if len(outputs_summed) != 1: diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index e2a4668242..b3123d561d 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -186,7 +186,7 @@ def ravelmultiindex(*inp): new_arr = arr.T.astype(np.float64).copy() for i, b in enumerate(new_arr): - for j, (d, v) in enumerate(zip(shape, b)): + for j, (d, v) in enumerate(zip(shape, b, strict=True)): if v < 0 or v >= d: mode_fn(new_arr, i, j, v, d) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index f2c1bbc185..82ee380029 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -114,7 +114,9 @@ def {scalar_op_fn_name}({input_names}): input_names = [unique_names(v, force_unique=True) for v in node.inputs] converted_call_args = ", ".join( f"direct_cast({i_name}, {i_tmp_dtype_name})" - for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names) + for i_name, i_tmp_dtype_name in zip( + input_names, input_tmp_dtype_names, strict=False + ) ) if not has_pyx_skip_dispatch: scalar_op_src = f""" diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 92566a7f78..cc75fc3742 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -163,10 +163,11 @@ def add_inner_in_expr( op.info.mit_mot_in_slices + op.info.mit_sot_in_slices + op.info.sit_sot_in_slices, + strict=True, ) ) inner_in_names_to_output_taps: dict[str, tuple[int, ...] | None] = dict( - zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices) + zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices, strict=True) ) # Inner-outputs consist of: @@ -373,7 +374,8 @@ def add_output_storage_post_proc_stmt( inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) inner_out_to_outer_out_stmts = "\n".join( - f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names) + f"{s} = {d}" + for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names, strict=True) ) scan_op_src = f""" diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 1bf5a6c8fa..617bb12178 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -420,7 +420,7 @@ def block_diag(*arrs): out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype) r, c = 0, 0 - for arr, shape in zip(arrs, shapes): + for arr, shape in zip(arrs, shapes, strict=True): rr, cc = shape out[r : r + rr, c : c + cc] = arr r += rr diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 178ce0b857..40b0518664 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -158,7 +158,7 @@ def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, vals, idxs): if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") - for idx, val in zip(idxs, vals): + for idx, val in zip(idxs, vals, strict=True): x[idx] = val return x else: @@ -184,7 +184,7 @@ def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, vals, idxs): if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") - for idx, val in zip(idxs, vals): + for idx, val in zip(idxs, vals, strict=True): x[idx] += val return x diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 09421adeb6..80b05d4e81 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -36,7 +36,9 @@ def numba_funcify_AllocEmpty(op, node, **kwargs): shapes_to_items_src = indent( "\n".join( f"{item_name} = to_scalar({shape_name})" - for item_name, shape_name in zip(shape_var_item_names, shape_var_names) + for item_name, shape_name in zip( + shape_var_item_names, shape_var_names, strict=True + ) ), " " * 4, ) @@ -68,7 +70,9 @@ def numba_funcify_Alloc(op, node, **kwargs): shapes_to_items_src = indent( "\n".join( f"{item_name} = to_scalar({shape_name})" - for item_name, shape_name in zip(shape_var_item_names, shape_var_names) + for item_name, shape_name in zip( + shape_var_item_names, shape_var_names, strict=True + ) ), " " * 4, ) diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index a680f9747d..74870e29bd 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -44,7 +44,7 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): inner_out_signature = ", ".join(inner_outputs) store_outputs = "\n".join( f"{output}[...] = {inner_output}" - for output, inner_output in zip(outputs, inner_outputs) + for output, inner_output in zip(outputs, inner_outputs, strict=True) ) func_src = f""" def store_core_outputs({inp_signature}, {out_signature}): @@ -137,7 +137,7 @@ def _vectorized( ) core_input_types = [] - for input_type, bc_pattern in zip(input_types, input_bc_patterns): + for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True): core_ndim = input_type.ndim - len(bc_pattern) # TODO: Reconsider this if core_ndim == 0: @@ -150,14 +150,18 @@ def _vectorized( core_out_types = [ types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C") - for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + for dtype, output_core_shape in zip( + output_dtypes, output_core_shape_types, strict=True + ) ] out_types = [ types.Array( numba.from_dtype(np.dtype(dtype)), batch_ndim + len(output_core_shape), "C" ) - for dtype, output_core_shape in zip(output_dtypes, output_core_shape_types) + for dtype, output_core_shape in zip( + output_dtypes, output_core_shape_types, strict=True + ) ] for output_idx, input_idx in inplace_pattern: @@ -211,7 +215,7 @@ def codegen( inputs = [ arrayobj.make_array(ty)(ctx, builder, val) - for ty, val in zip(input_types, inputs) + for ty, val in zip(input_types, inputs, strict=True) ] in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] @@ -283,7 +287,9 @@ def compute_itershape( if size is not None: shape = size for i in range(batch_ndim): - for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + for j, (bc, in_shape) in enumerate( + zip(broadcast_pattern, in_shapes, strict=True) + ): length = in_shape[i] if bc[i]: with builder.if_then( @@ -318,7 +324,9 @@ def compute_itershape( else: # Size is implied by the broadcast pattern for i in range(batch_ndim): - for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): + for j, (bc, in_shape) in enumerate( + zip(broadcast_pattern, in_shapes, strict=True) + ): length = in_shape[i] if bc[i]: with builder.if_then( @@ -374,7 +382,7 @@ def make_outputs( one = ir.IntType(64)(1) inplace_dict = dict(inplace) for i, (core_shape, bc, dtype) in enumerate( - zip(output_core_shapes, out_bc, dtypes) + zip(output_core_shapes, out_bc, dtypes, strict=True) ): if i in inplace_dict: output_arrays.append(inputs[inplace_dict[i]]) @@ -388,7 +396,8 @@ def make_outputs( # This is actually an internal numba function, I guess we could # call `numba.nd.unsafe.ndarray` instead? batch_shape = [ - length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc) + length if not bc_dim else one + for length, bc_dim in zip(iter_shape, bc, strict=True) ] shape = batch_shape + core_shape array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) @@ -458,10 +467,10 @@ def make_loop_call( # Load values from input arrays input_vals = [] - for input, input_type, bc in zip(inputs, input_types, input_bc): + for input, input_type, bc in zip(inputs, input_types, input_bc, strict=True): core_ndim = input_type.ndim - len(bc) - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ zero ] * core_ndim ptr = cgutils.get_item_pointer2( @@ -506,13 +515,13 @@ def make_loop_call( # Create output slices to pass to inner func output_slices = [] - for output, output_type, bc in zip(outputs, output_types, output_bc): + for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True): core_ndim = output_type.ndim - len(bc) size_type = output.shape.type.element # type: ignore output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [ + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ zero ] * core_ndim ptr = cgutils.get_item_pointer2( diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index c51b13c427..7f48edcfb6 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -88,7 +88,7 @@ def map_storage( assert len(fgraph.inputs) == len(input_storage) # add input storage into storage_map - for r, storage in zip(fgraph.inputs, input_storage): + for r, storage in zip(fgraph.inputs, input_storage, strict=True): if r in storage_map: assert storage_map[r] is storage, ( "Given input_storage conflicts " @@ -108,7 +108,7 @@ def map_storage( # allocate output storage if output_storage is not None: assert len(fgraph.outputs) == len(output_storage) - for r, storage in zip(fgraph.outputs, output_storage): + for r, storage in zip(fgraph.outputs, output_storage, strict=True): if r in storage_map: assert storage_map[r] is storage, ( "Given output_storage confl" @@ -191,7 +191,7 @@ def streamline_default_f(): x[0] = None try: for thunk, node, old_storage in zip( - thunks, order, post_thunk_old_storage + thunks, order, post_thunk_old_storage, strict=True ): thunk() for old_s in old_storage: @@ -206,7 +206,7 @@ def streamline_nice_errors_f(): for x in no_recycling: x[0] = None try: - for thunk, node in zip(thunks, order): + for thunk, node in zip(thunks, order, strict=True): thunk() except Exception: raise_with_op(fgraph, node, thunk) diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 587b379cf0..a9d625a8da 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -244,7 +244,7 @@ def clear_storage(self): def update_profile(self, profile): """Update a profile object.""" for node, thunk, t, c in zip( - self.nodes, self.thunks, self.call_times, self.call_counts + self.nodes, self.thunks, self.call_times, self.call_counts, strict=True ): profile.apply_time[(self.fgraph, node)] += t @@ -310,7 +310,9 @@ def __init__( self.output_storage = output_storage self.inp_storage_and_out_idx = tuple( (inp_storage, self.fgraph.outputs.index(update_vars[inp])) - for inp, inp_storage in zip(self.fgraph.inputs, self.input_storage) + for inp, inp_storage in zip( + self.fgraph.inputs, self.input_storage, strict=True + ) if inp in update_vars ) @@ -1241,7 +1243,7 @@ def make_all( self.profile.linker_node_make_thunks += t1 - t0 self.profile.linker_make_thunk_time = linker_make_thunk_time - for node, thunk in zip(order, thunks): + for node, thunk in zip(order, thunks, strict=True): thunk.inputs = [storage_map[v] for v in node.inputs] thunk.outputs = [storage_map[v] for v in node.outputs] @@ -1298,11 +1300,11 @@ def make_all( vm, [ Container(input, storage) - for input, storage in zip(fgraph.inputs, input_storage) + for input, storage in zip(fgraph.inputs, input_storage, strict=True) ], [ Container(output, storage, readonly=True) - for output, storage in zip(fgraph.outputs, output_storage) + for output, storage in zip(fgraph.outputs, output_storage, strict=True) ], thunks, order, diff --git a/pytensor/misc/check_blas.py b/pytensor/misc/check_blas.py index 8ee4482f0e..fc2fe02377 100644 --- a/pytensor/misc/check_blas.py +++ b/pytensor/misc/check_blas.py @@ -59,7 +59,7 @@ def execute(execute=True, verbose=True, M=2000, N=2000, K=2000, iters=10, order= if any(x.op.__class__.__name__ == "Gemm" for x in f.maker.fgraph.toposort()): c_impl = [ hasattr(thunk, "cthunk") - for node, thunk in zip(f.vm.nodes, f.vm.thunks) + for node, thunk in zip(f.vm.nodes, f.vm.thunks, strict=True) if node.op.__class__.__name__ == "Gemm" ] assert len(c_impl) == 1 diff --git a/pytensor/printing.py b/pytensor/printing.py index 5c8bb77752..a974ca21bc 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -311,7 +311,7 @@ def debugprint( ) for var, profile, storage_map, topo_order in zip( - outputs_to_print, profile_list, storage_maps, topo_orders + outputs_to_print, profile_list, storage_maps, topo_orders, strict=True ): if hasattr(var.owner, "op"): if ( @@ -941,7 +941,7 @@ def pp_process(input, new_precedence): str(i): x for i, x in enumerate( pp_process(input, precedence) - for input, precedence in zip(node.inputs, precedences) + for input, precedence in zip(node.inputs, precedences, strict=False) ) } r = pattern % d @@ -1448,7 +1448,7 @@ def apply_name(node): if isinstance(fct, Function): # TODO: Get rid of all this `expanded_inputs` nonsense and use # `fgraph.update_mapping` - function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs) + function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs, strict=True) for i, fg_ii in reversed(list(function_inputs)): if i.update is not None: k = outputs.pop() diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d4c41d5cb5..7d0bd97914 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1146,7 +1146,7 @@ def perform(self, node, inputs, output_storage): else: variables = from_return_values(self.impl(*inputs)) assert len(variables) == len(output_storage) - for storage, variable in zip(output_storage, variables): + for storage, variable in zip(output_storage, variables, strict=True): storage[0] = variable def impl(self, *inputs): @@ -4107,7 +4107,9 @@ def c_support_code(self, **kwargs): def c_support_code_apply(self, node, name): rval = [] - for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): + for subnode, subnodename in zip( + self.fgraph.toposort(), self.nodenames, strict=True + ): subnode_support_code = subnode.op.c_support_code_apply( subnode, subnodename % dict(nodename=name) ) @@ -4213,7 +4215,7 @@ def __init__(self, inputs, outputs, name="Composite"): res2 = pytensor.compile.rebuild_collect_shared( inputs=outputs[0].owner.op.inputs, outputs=outputs[0].owner.op.outputs, - replace=dict(zip(outputs[0].owner.op.inputs, res[1])), + replace=dict(zip(outputs[0].owner.op.inputs, res[1], strict=True)), ) assert len(res2[1]) == len(outputs) assert len(res[0]) == len(inputs) @@ -4299,7 +4301,7 @@ def make_node(self, *inputs): assert len(inputs) == self.nin res = pytensor.compile.rebuild_collect_shared( self.outputs, - replace=dict(zip(self.inputs, inputs)), + replace=dict(zip(self.inputs, inputs, strict=True)), rebuild_strict=False, ) # After rebuild_collect_shared, the Variable in inputs @@ -4312,7 +4314,7 @@ def make_node(self, *inputs): def perform(self, node, inputs, output_storage): outputs = self.py_perform_fn(*inputs) - for storage, out_val in zip(output_storage, outputs): + for storage, out_val in zip(output_storage, outputs, strict=True): storage[0] = out_val def grad(self, inputs, output_grads): @@ -4382,8 +4384,8 @@ def c_code_template(self): def c_code(self, node, nodename, inames, onames, sub): d = dict( chain( - zip((f"i{int(i)}" for i in range(len(inames))), inames), - zip((f"o{int(i)}" for i in range(len(onames))), onames), + zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True), + zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True), ), **sub, ) @@ -4431,7 +4433,7 @@ def apply(self, fgraph): ) # make sure we don't produce any float16. assert not any(o.dtype == "float16" for o in new_node.outputs) - mapping.update(zip(node.outputs, new_node.outputs)) + mapping.update(zip(node.outputs, new_node.outputs, strict=True)) new_ins = [mapping[inp] for inp in fgraph.inputs] new_outs = [mapping[out] for out in fgraph.outputs] @@ -4474,7 +4476,7 @@ def handle_composite(node, mapping): new_op = node.op.clone_float32() new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True) assert len(new_outs) == len(node.outputs) - for o, no in zip(node.outputs, new_outs): + for o, no in zip(node.outputs, new_outs, strict=True): mapping[o] = no diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 189cd461c7..4c76fa7140 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -93,7 +93,7 @@ def _validate_updates( ) else: update = outputs - for i, u in zip(init, update): + for i, u in zip(init, update, strict=False): if i.type != u.type: raise TypeError( "Init and update types must be the same: " @@ -166,7 +166,7 @@ def make_node(self, n_steps, *inputs): # Make a new op with the right input types. res = rebuild_collect_shared( self.outputs, - replace=dict(zip(self.inputs, inputs)), + replace=dict(zip(self.inputs, inputs, strict=True)), rebuild_strict=False, ) if self.is_while: @@ -207,7 +207,7 @@ def perform(self, node, inputs, output_storage): for i in range(n_steps): carry = inner_fn(*carry, *constant) - for storage, out_val in zip(output_storage, carry): + for storage, out_val in zip(output_storage, carry, strict=True): storage[0] = out_val @property @@ -295,7 +295,7 @@ def c_code_template(self): # Set the carry variables to the output variables _c_code += "\n" - for init, update in zip(carry_subd.values(), update_subd.values()): + for init, update in zip(carry_subd.values(), update_subd.values(), strict=True): _c_code += f"{init} = {update};\n" # _c_code += 'printf("%%ld\\n", i);\n' @@ -321,8 +321,8 @@ def c_code_template(self): def c_code(self, node, nodename, inames, onames, sub): d = dict( chain( - zip((f"i{int(i)}" for i in range(len(inames))), inames), - zip((f"o{int(i)}" for i in range(len(onames))), onames), + zip((f"i{int(i)}" for i in range(len(inames))), inames, strict=True), + zip((f"o{int(i)}" for i in range(len(onames))), onames, strict=True), ), **sub, ) diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 0f7c9dcc69..e93587c5b6 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -884,7 +884,9 @@ def wrap_into_list(x): if condition is not None: outputs.append(condition) fake_nonseqs = [x.type() for x in non_seqs] - fake_outputs = clone_replace(outputs, replace=dict(zip(non_seqs, fake_nonseqs))) + fake_outputs = clone_replace( + outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True)) + ) all_inputs = filter( lambda x: ( isinstance(x, Variable) @@ -1047,7 +1049,7 @@ def wrap_into_list(x): if not isinstance(arg, SharedVariable | Constant) ] - inner_replacements.update(dict(zip(other_scan_args, other_inner_args))) + inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) if strict: non_seqs_set = set(non_sequences if non_sequences is not None else []) @@ -1069,7 +1071,7 @@ def wrap_into_list(x): ] inner_replacements.update( - dict(zip(other_shared_scan_args, other_shared_inner_args)) + dict(zip(other_shared_scan_args, other_shared_inner_args, strict=True)) ) ## diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 4f6dc7e0be..3b80b04ec3 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -170,7 +170,7 @@ def check_broadcast(v1, v2): ) size = min(v1.type.ndim, v2.type.ndim) for n, (b1, b2) in enumerate( - zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:]) + zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:], strict=False) ): if b1 != b2: a1 = n + size - v1.type.ndim + 1 @@ -577,6 +577,7 @@ def get_oinp_iinp_iout_oout_mappings(self): inner_input_indices, inner_output_indices, outer_output_indices, + strict=True, ): if oout != -1: mappings["outer_inp_from_outer_out"][oout] = oinp @@ -958,7 +959,7 @@ def make_node(self, *inputs): # them have the same dtype argoffset = 0 for inner_seq, outer_seq in zip( - self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs) + self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs), strict=True ): check_broadcast(outer_seq, inner_seq) new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq)) @@ -977,6 +978,7 @@ def make_node(self, *inputs): self.info.mit_mot_in_slices, self.info.mit_mot_out_slices[: self.info.n_mit_mot], self.outer_mitmot(inputs), + strict=True, ) ): outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos]) @@ -1031,6 +1033,7 @@ def make_node(self, *inputs): self.info.mit_sot_in_slices, self.outer_mitsot(inputs), self.inner_mitsot_outs(self.inner_outputs), + strict=True, ) ): outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos]) @@ -1083,6 +1086,7 @@ def make_node(self, *inputs): self.inner_sitsot(self.inner_inputs), self.outer_sitsot(inputs), self.inner_sitsot_outs(self.inner_outputs), + strict=True, ) ): outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot) @@ -1130,6 +1134,7 @@ def make_node(self, *inputs): self.inner_shared(self.inner_inputs), self.inner_shared_outs(self.inner_outputs), self.outer_shared(inputs), + strict=True, ) ): outer_shared = copy_var_format(_outer_shared, as_var=inner_shared) @@ -1188,7 +1193,9 @@ def make_node(self, *inputs): # type of tensor as the output, it is always a scalar int. new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)] for inner_nonseq, _outer_nonseq in zip( - self.inner_non_seqs(self.inner_inputs), self.outer_non_seqs(inputs) + self.inner_non_seqs(self.inner_inputs), + self.outer_non_seqs(inputs), + strict=True, ): outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq) new_inputs.append(outer_nonseq) @@ -1271,7 +1278,9 @@ def __eq__(self, other): if len(self.inner_outputs) != len(other.inner_outputs): return False - for self_in, other_in in zip(self.inner_inputs, other.inner_inputs): + for self_in, other_in in zip( + self.inner_inputs, other.inner_inputs, strict=True + ): if self_in.type != other_in.type: return False @@ -1406,7 +1415,7 @@ def prepare_fgraph(self, fgraph): fgraph.attach_feature( Supervisor( inp - for spec, inp in zip(wrapped_inputs, fgraph.inputs) + for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True) if not ( getattr(spec, "mutable", None) or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp])) @@ -2086,7 +2095,9 @@ def perform(self, node, inputs, output_storage): jout = j + offset_out output_storage[j][0] = inner_output_storage[jout].storage[0] - pos = [(idx + 1) % store for idx, store in zip(pos, store_steps)] + pos = [ + (idx + 1) % store for idx, store in zip(pos, store_steps, strict=True) + ] i = i + 1 # 6. Check if you need to re-order output buffers @@ -2171,7 +2182,7 @@ def perform(self, node, inputs, output_storage): def infer_shape(self, fgraph, node, input_shapes): # input_shapes correspond to the shapes of node.inputs - for inp, inp_shp in zip(node.inputs, input_shapes): + for inp, inp_shp in zip(node.inputs, input_shapes, strict=True): assert inp_shp is None or len(inp_shp) == inp.type.ndim # Here we build 2 variables; @@ -2240,7 +2251,9 @@ def infer_shape(self, fgraph, node, input_shapes): # Non-sequences have a direct equivalent from self.inner_inputs in # node.inputs inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :] - out_equivalent.update(zip(inner_non_sequences, node.inputs[offset:])) + out_equivalent.update( + zip(inner_non_sequences, node.inputs[offset:], strict=True) + ) if info.as_while: self_outs = self.inner_outputs[:-1] @@ -2274,7 +2287,7 @@ def infer_shape(self, fgraph, node, input_shapes): r = node.outputs[n_outs + x] assert r.ndim == 1 + len(out_shape_x) shp = [node.inputs[offset + info.n_shared_outs + x]] - for i, shp_i in zip(range(1, r.ndim), out_shape_x): + for i, shp_i in zip(range(1, r.ndim), out_shape_x, strict=True): # Validate shp_i. v_shape_i is either None (if invalid), # or a (variable, Boolean) tuple. The Boolean indicates # whether variable is shp_i (if True), or an valid @@ -2296,7 +2309,7 @@ def infer_shape(self, fgraph, node, input_shapes): if info.as_while: scan_outs_init = scan_outs scan_outs = [] - for o, x in zip(node.outputs, scan_outs_init): + for o, x in zip(node.outputs, scan_outs_init, strict=True): if x is None: scan_outs.append(None) else: @@ -2572,7 +2585,9 @@ def compute_all_gradients(known_grads): dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) else: disconnected_dC_dinps_t[dx] = False - for Xt, Xt_placeholder in zip(diff_outputs[info.n_mit_mot_outs :], Xts): + for Xt, Xt_placeholder in zip( + diff_outputs[info.n_mit_mot_outs :], Xts, strict=True + ): tmp = forced_replace(dC_dinps_t[dx], Xt, Xt_placeholder) dC_dinps_t[dx] = tmp @@ -2652,7 +2667,9 @@ def compute_all_gradients(known_grads): n = n_steps.tag.test_value else: n = inputs[0].tag.test_value - for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)): + for taps, x in zip( + info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True + ): mintap = np.min(taps) if hasattr(x[::-1][:mintap], "test_value"): assert x[::-1][:mintap].tag.test_value.shape[0] == n @@ -2667,7 +2684,9 @@ def compute_all_gradients(known_grads): assert x[::-1].tag.test_value.shape[0] == n outer_inp_seqs += [ x[::-1][: np.min(taps)] - for taps, x in zip(info.mit_sot_in_slices, self.outer_mitsot_outs(outs)) + for taps, x in zip( + info.mit_sot_in_slices, self.outer_mitsot_outs(outs), strict=True + ) ] outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] @@ -2998,6 +3017,7 @@ def compute_all_gradients(known_grads): zip( outputs[offset : offset + info.n_seqs], type_outs[offset : offset + info.n_seqs], + strict=True, ) ): if t == "connected": @@ -3027,7 +3047,7 @@ def compute_all_gradients(known_grads): gradients.append(NullType(t)()) end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot - for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])): + for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end], strict=True)): if t == "connected": # If the forward scan is in as_while mode, we need to pad # the gradients, so that they match the size of the input @@ -3062,7 +3082,7 @@ def compute_all_gradients(known_grads): for idx in range(info.n_shared_outs): disconnected = True connected_flags = self.connection_pattern(node)[idx + start] - for dC_dout, connected in zip(dC_douts, connected_flags): + for dC_dout, connected in zip(dC_douts, connected_flags, strict=True): if not isinstance(dC_dout.type, DisconnectedType) and connected: disconnected = False if disconnected: @@ -3079,7 +3099,9 @@ def compute_all_gradients(known_grads): begin = end end = begin + n_sitsot_outs - for p, (x, t) in enumerate(zip(outputs[begin:end], type_outs[begin:end])): + for p, (x, t) in enumerate( + zip(outputs[begin:end], type_outs[begin:end], strict=True) + ): if t == "connected": gradients.append(x[-1]) elif t == "disconnected": @@ -3156,7 +3178,7 @@ def R_op(self, inputs, eval_points): e = 1 + info.n_seqs ie = info.n_seqs clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3171,7 +3193,7 @@ def R_op(self, inputs, eval_points): ib = ie ie = ie + int(sum(len(x) for x in info.mit_mot_in_slices)) clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3186,7 +3208,7 @@ def R_op(self, inputs, eval_points): ib = ie ie = ie + int(sum(len(x) for x in info.mit_sot_in_slices)) clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3201,7 +3223,7 @@ def R_op(self, inputs, eval_points): ib = ie ie = ie + info.n_sit_sot clean_eval_points = [] - for inp, evp in zip(inputs[b:e], eval_points[b:e]): + for inp, evp in zip(inputs[b:e], eval_points[b:e], strict=True): if evp is not None: clean_eval_points.append(evp) else: @@ -3225,7 +3247,7 @@ def R_op(self, inputs, eval_points): # All other arguments clean_eval_points = [] - for inp, evp in zip(inputs[e:], eval_points[e:]): + for inp, evp in zip(inputs[e:], eval_points[e:], strict=True): if evp is not None: clean_eval_points.append(evp) else: diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index c0a4b9b208..4e06864117 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -166,7 +166,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): # Look through non sequences nw_inner_nonseq = [] nw_outer_nonseq = [] - for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)): + for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs, strict=True)): if isinstance(nw_out, Constant): givens[nw_in] = nw_out elif nw_in in all_ins: @@ -203,7 +203,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): allow_gc=op.allow_gc, ) nw_outs = nwScan(*nw_outer, return_list=True) - return dict([("remove", [node]), *zip(node.outputs, nw_outs)]) + return dict([("remove", [node]), *zip(node.outputs, nw_outs, strict=True)]) else: return False @@ -348,7 +348,7 @@ def add_to_replace(y): nw_outer = [] nw_inner = [] for to_repl, repl_in, repl_out in zip( - clean_to_replace, clean_replace_with_in, clean_replace_with_out + clean_to_replace, clean_replace_with_in, clean_replace_with_out, strict=True ): if isinstance(repl_out, Constant): repl_in = repl_out @@ -380,7 +380,7 @@ def add_to_replace(y): # Do not call make_node for test_value nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner - replacements = dict(zip(node.outputs, nw_node.outputs)) + replacements = dict(zip(node.outputs, nw_node.outputs, strict=True)) replacements["remove"] = [node] return replacements elif not to_keep_set: @@ -584,7 +584,7 @@ def add_to_replace(y): nw_outer = [] nw_inner = [] for to_repl, repl_in, repl_out in zip( - clean_to_replace, clean_replace_with_in, clean_replace_with_out + clean_to_replace, clean_replace_with_in, clean_replace_with_out, strict=True ): if isinstance(repl_out, Constant): repl_in = repl_out @@ -616,7 +616,7 @@ def add_to_replace(y): return_list=True, )[0].owner - replacements = dict(zip(node.outputs, nw_node.outputs)) + replacements = dict(zip(node.outputs, nw_node.outputs, strict=True)) replacements["remove"] = [node] return replacements @@ -814,7 +814,7 @@ def add_nitsot_outputs( # replacements["remove"] = [old_scan_node] # return new_scan_node, replacements fgraph.replace_all_validate_remove( # type: ignore - list(zip(old_scan_node.outputs, new_node_old_outputs)), + list(zip(old_scan_node.outputs, new_node_old_outputs, strict=True)), remove=[old_scan_node], reason="scan_pushout_add", ) @@ -1020,7 +1020,7 @@ def attempt_scan_inplace( # This whole rewrite should be a simple local rewrite, but, because # of this awful approach, it can't be. fgraph.replace_all_validate_remove( # type: ignore - list(zip(node.outputs, new_outs)), + list(zip(node.outputs, new_outs, strict=True)), remove=[node], reason="scan_make_inplace", ) @@ -1941,7 +1941,7 @@ def merge(self, nodes): if not isinstance(new_outs, list | tuple): new_outs = [new_outs] - return list(zip(outer_outs, new_outs)) + return list(zip(outer_outs, new_outs, strict=True)) def belongs_to_set(self, node, set_nodes): """ @@ -2010,7 +2010,9 @@ def belongs_to_set(self, node, set_nodes): ] inner_inputs = op.inner_inputs rep_inner_inputs = rep_op.inner_inputs - for nominal_input, rep_nominal_input in zip(nominal_inputs, rep_nominal_inputs): + for nominal_input, rep_nominal_input in zip( + nominal_inputs, rep_nominal_inputs, strict=True + ): conds.append(node.inputs[mapping[inner_inputs.index(nominal_input)]]) rep_conds.append( rep_node.inputs[rep_mapping[rep_inner_inputs.index(rep_nominal_input)]] @@ -2067,7 +2069,7 @@ def make_equiv(lo, li): seeno = {} left = [] right = [] - for o, i in zip(lo, li): + for o, i in zip(lo, li, strict=True): if o in seeno: left += [i] right += [o] @@ -2104,7 +2106,7 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(a.outer_in_seqs): new_outer_seqs = [] new_inner_seqs = [] - for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): + for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs, strict=True): if out_seq in new_outer_seqs: i = new_outer_seqs.index(out_seq) inp_equiv[in_seq] = new_inner_seqs[i] @@ -2117,7 +2119,9 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(a.outer_in_non_seqs): new_outer_nseqs = [] new_inner_nseqs = [] - for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs): + for out_nseq, in_nseq in zip( + a.outer_in_non_seqs, a.inner_in_non_seqs, strict=True + ): if out_nseq in new_outer_nseqs: i = new_outer_nseqs.index(out_nseq) inp_equiv[in_nseq] = new_inner_nseqs[i] @@ -2180,7 +2184,7 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(na.outer_in_mit_mot): seen = {} for omm, imm, _sl in zip( - na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices + na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices, strict=True ): sl = tuple(_sl) if (omm, sl) in seen: @@ -2193,7 +2197,7 @@ def scan_merge_inouts(fgraph, node): if has_duplicates(na.outer_in_mit_sot): seen = {} for oms, ims, _sl in zip( - na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices + na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices, strict=True ): sl = tuple(_sl) if (oms, sl) in seen: @@ -2227,7 +2231,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.outer_out_nit_sot = [ map_out(outer_i, inner_o, outer_o, seen) for outer_i, inner_o, outer_o in zip( - na.outer_in_nit_sot, na.inner_out_nit_sot, na.outer_out_nit_sot + na.outer_in_nit_sot, na.inner_out_nit_sot, na.outer_out_nit_sot, strict=True ) ] @@ -2237,7 +2241,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.outer_out_sit_sot = [ map_out(outer_i, inner_o, outer_o, seen) for outer_i, inner_o, outer_o in zip( - na.outer_in_sit_sot, na.inner_out_sit_sot, na.outer_out_sit_sot + na.outer_in_sit_sot, na.inner_out_sit_sot, na.outer_out_sit_sot, strict=True ) ] @@ -2247,7 +2251,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.outer_out_mit_sot = [ map_out(outer_i, inner_o, outer_o, seen) for outer_i, inner_o, outer_o in zip( - na.outer_in_mit_sot, na.inner_out_mit_sot, na.outer_out_mit_sot + na.outer_in_mit_sot, na.inner_out_mit_sot, na.outer_out_mit_sot, strict=True ) ] @@ -2261,6 +2265,7 @@ def map_out(outer_i, inner_o, outer_o, seen): na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices, + strict=True, ): for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: if ( @@ -2275,7 +2280,9 @@ def map_out(outer_i, inner_o, outer_o, seen): new_outer_out_mit_mot.append(outer_omm) na.outer_out_mit_mot = new_outer_out_mit_mot if remove: - return dict([("remove", remove), *zip(node.outputs, na.outer_outputs)]) + return dict( + [("remove", remove), *zip(node.outputs, na.outer_outputs, strict=True)] + ) return na.outer_outputs @@ -2300,7 +2307,7 @@ def scan_push_out_dot1(fgraph, node): sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) outer_sitsot = op.outer_sitsot_outs(node.outputs) seqs = op.inner_seqs(op.inner_inputs) - for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot): + for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot, strict=True): if ( out.owner and isinstance(out.owner.op, Elemwise) @@ -2453,10 +2460,12 @@ def scan_push_out_dot1(fgraph, node): new_out = dot(val, out_seq) pos = node.outputs.index(outer_out) - old_new = list(zip(node.outputs[:pos], new_outs[:pos])) + old_new = list(zip(node.outputs[:pos], new_outs[:pos], strict=True)) old = fgraph.clients[node.outputs[pos]][0][0].outputs[0] old_new.append((old, new_out)) - old_new += list(zip(node.outputs[pos + 1 :], new_outs[pos:])) + old_new += list( + zip(node.outputs[pos + 1 :], new_outs[pos:], strict=True) + ) replacements = dict(old_new) replacements["remove"] = [node] return replacements diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index c55820eb68..611012b97e 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -559,7 +559,7 @@ def reconstruct_graph(inputs, outputs, tag=None): tag = "" nw_inputs = [safe_new(x, tag) for x in inputs] - givens = {x: nw_x for nw_x, x in zip(nw_inputs, inputs)} + givens = {x: nw_x for nw_x, x in zip(nw_inputs, inputs, strict=True)} nw_outputs = clone_replace(outputs, replace=givens) return (nw_inputs, nw_outputs) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index a1f7fd5b13..ef980e943e 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -2849,7 +2849,7 @@ def choose(continuous, derivative): else: return None - return [choose(c, d) for c, d in zip(is_continuous, derivative)] + return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] def infer_shape(self, fgraph, node, ins_shapes): def _get(l): @@ -2928,7 +2928,7 @@ def choose(continuous, derivative): else: return None - return [choose(c, d) for c, d in zip(is_continuous, derivative)] + return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] def infer_shape(self, fgraph, node, ins_shapes): def _get(l): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 119c44c647..920958fd59 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1545,6 +1545,7 @@ def make_node(self, value, *shape): extended_value_broadcastable, extended_value_static_shape, static_shape, + strict=True, ) ): # If value is not broadcastable and we don't know the target static shape: use value static shape @@ -1565,7 +1566,7 @@ def make_node(self, value, *shape): def _check_runtime_broadcast(node, value, shape): value_static_shape = node.inputs[0].type.shape for v_static_dim, value_dim, out_dim in zip( - value_static_shape[::-1], value.shape[::-1], shape[::-1] + value_static_shape[::-1], value.shape[::-1], shape[::-1], strict=False ): if v_static_dim is None and value_dim == 1 and out_dim != 1: raise ValueError(Alloc._runtime_broadcast_error_msg) @@ -1668,6 +1669,7 @@ def grad(self, inputs, grads): inputs[0].type.shape, # We need the dimensions corresponding to x grads[0].type.shape[-inputs[0].ndim :], + strict=False, ) ): if ib == 1 and gb != 1: @@ -2187,7 +2189,7 @@ def grad(self, inputs, g_outputs): ] # Else, we have to make them zeros before joining them new_g_outputs = [] - for o, g in zip(outputs, g_outputs): + for o, g in zip(outputs, g_outputs, strict=True): if isinstance(g.type, DisconnectedType): new_g_outputs.append(o.zeros_like()) else: @@ -2624,7 +2626,7 @@ def grad(self, axis_and_tensors, grads): else specify_broadcastable( g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1) ) - for t, g in zip(tens, split_gz) + for t, g in zip(tens, split_gz, strict=True) ] rval = rval + split_gz else: @@ -2736,7 +2738,7 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs): ): batch_ndims = { batch_input.type.ndim - old_input.type.ndim - for batch_input, old_input in zip(batch_inputs, old_inputs) + for batch_input, old_input in zip(batch_inputs, old_inputs, strict=True) } if len(batch_ndims) == 1: [batch_ndim] = batch_ndims @@ -3319,7 +3321,7 @@ def __getitem__(self, *args): tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j)) for j, r in enumerate(ranges) ] - ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)] + ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes, strict=True)] if self.sparse: grids = ranges else: @@ -3391,7 +3393,7 @@ def make_node(self, x, y, inverse): out_shape = [ 1 if xb == 1 and yb == 1 else None - for xb, yb in zip(x.type.shape, y.type.shape) + for xb, yb in zip(x.type.shape, y.type.shape, strict=True) ] out_type = tensor(dtype=x.type.dtype, shape=out_shape) @@ -3456,7 +3458,7 @@ def perform(self, node, inp, out): # Make sure the output is big enough out_s = [] - for xdim, ydim in zip(x_s, y_s): + for xdim, ydim in zip(x_s, y_s, strict=True): if xdim == ydim: outdim = xdim elif xdim == 1: @@ -3516,7 +3518,7 @@ def grad(self, inp, grads): assert gx.type.ndim == x.type.ndim assert all( s1 == s2 - for s1, s2 in zip(gx.type.shape, x.type.shape) + for s1, s2 in zip(gx.type.shape, x.type.shape, strict=True) if s1 == 1 or s2 == 1 ) @@ -3957,7 +3959,7 @@ def moveaxis( order = [n for n in range(a.ndim) if n not in source] - for dest, src in sorted(zip(destination, source)): + for dest, src in sorted(zip(destination, source, strict=True)): order.insert(dest, src) result = a.dimshuffle(order) @@ -4307,7 +4309,7 @@ def _make_along_axis_idx(arr_shape, indices, axis): # build a fancy index, consisting of orthogonal aranges, with the # requested index inserted at the right location fancy_index = [] - for dim, n in zip(dest_dims, arr_shape): + for dim, n in zip(dest_dims, arr_shape, strict=True): if dim is None: fancy_index.append(indices) else: diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 08956a0534..e4ae81c35c 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -88,7 +88,7 @@ def __getstate__(self): def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: core_input_types = [] - for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): if inp.type.ndim < len(sig): raise ValueError( f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" @@ -106,7 +106,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: raise ValueError( f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}" ) - for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)): + for i, (core_out, sig) in enumerate( + zip(core_node.outputs, self.outputs_sig, strict=True) + ): if core_out.type.ndim != len(sig): raise ValueError( f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" @@ -120,12 +122,13 @@ def make_node(self, *inputs): core_node = self._create_dummy_core_node(inputs) batch_ndims = max( - inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig) + inp.type.ndim - len(sig) + for inp, sig in zip(inputs, self.inputs_sig, strict=True) ) batched_inputs = [] batch_shapes = [] - for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): # Append missing dims to the left missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig)) if missing_batch_ndims: @@ -140,7 +143,7 @@ def make_node(self, *inputs): try: batch_shape = tuple( broadcast_static_dim_lengths(batch_dims) - for batch_dims in zip(*batch_shapes) + for batch_dims in zip(*batch_shapes, strict=True) ) except ValueError: raise ValueError( @@ -166,10 +169,10 @@ def infer_shape( batch_ndims = self.batch_ndim(node) core_dims: dict[str, Any] = {} batch_shapes = [input_shape[:batch_ndims] for input_shape in input_shapes] - for input_shape, sig in zip(input_shapes, self.inputs_sig): + for input_shape, sig in zip(input_shapes, self.inputs_sig, strict=True): core_shape = input_shape[batch_ndims:] - for core_dim, dim_name in zip(core_shape, sig): + for core_dim, dim_name in zip(core_shape, sig, strict=True): prev_core_dim = core_dims.get(core_dim) if prev_core_dim is None: core_dims[dim_name] = core_dim @@ -180,7 +183,7 @@ def infer_shape( batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) out_shapes = [] - for output, sig in zip(node.outputs, self.outputs_sig): + for output, sig in zip(node.outputs, self.outputs_sig, strict=True): core_out_shape = [] for i, dim_name in enumerate(sig): # The output dim is the same as another input dim @@ -211,17 +214,17 @@ def as_core(t, core_t): with config.change_flags(compute_test_value="off"): safe_inputs = [ tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) - for inp, sig in zip(inputs, self.inputs_sig) + for inp, sig in zip(inputs, self.inputs_sig, strict=True) ] core_node = self._create_dummy_core_node(safe_inputs) core_inputs = [ as_core(inp, core_inp) - for inp, core_inp in zip(inputs, core_node.inputs) + for inp, core_inp in zip(inputs, core_node.inputs, strict=True) ] core_ograds = [ as_core(ograd, core_ograd) - for ograd, core_ograd in zip(ograds, core_node.outputs) + for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True) ] core_outputs = core_node.outputs @@ -230,7 +233,11 @@ def as_core(t, core_t): igrads = vectorize_graph( [core_igrad for core_igrad in core_igrads if core_igrad is not None], replace=dict( - zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds) + zip( + core_inputs + core_outputs + core_ograds, + inputs + outputs + ograds, + strict=True, + ) ), ) @@ -256,7 +263,7 @@ def L_op(self, inputs, outs, ograds): # the return value obviously zero so that gradient.grad can tell # this op did the right thing. new_rval = [] - for elem, inp in zip(rval, inputs): + for elem, inp in zip(rval, inputs, strict=True): if isinstance(elem.type, NullType | DisconnectedType): new_rval.append(elem) else: @@ -270,7 +277,7 @@ def L_op(self, inputs, outs, ograds): # Sum out the broadcasted dimensions batch_ndims = self.batch_ndim(outs[0].owner) batch_shape = outs[0].type.shape[:batch_ndims] - for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): if isinstance(rval[i].type, NullType | DisconnectedType): continue @@ -278,7 +285,9 @@ def L_op(self, inputs, outs, ograds): to_sum = [ j - for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape)) + for j, (inp_s, out_s) in enumerate( + zip(inp.type.shape, batch_shape, strict=False) + ) if inp_s == 1 and out_s != 1 ] if to_sum: @@ -318,9 +327,14 @@ def _check_runtime_broadcast(self, node, inputs): for dims_and_bcast in zip( *[ - zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim]) - for input, sinput in zip(inputs, node.inputs) - ] + zip( + input.shape[:batch_ndim], + sinput.type.broadcastable[:batch_ndim], + strict=True, + ) + for input, sinput in zip(inputs, node.inputs, strict=True) + ], + strict=True, ): if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: raise ValueError( @@ -341,7 +355,9 @@ def perform(self, node, inputs, output_storage): if not isinstance(res, tuple): res = (res,) - for node_out, out_storage, r in zip(node.outputs, output_storage, res): + for node_out, out_storage, r in zip( + node.outputs, output_storage, res, strict=True + ): out_dtype = getattr(node_out, "dtype", None) if out_dtype and out_dtype != r.dtype: r = np.asarray(r, dtype=out_dtype) diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 73d402cfca..0addd2b5f0 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -506,7 +506,7 @@ def check_dim(given, computed): return all( check_dim(given, computed) - for (given, computed) in zip(output_shape, computed_output_shape) + for (given, computed) in zip(output_shape, computed_output_shape, strict=True) ) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index de966f1a78..ecb380b389 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -212,7 +212,7 @@ def make_node(self, _input): "The number of dimensions of the " f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." ) - for expected, b in zip(self.input_broadcastable, ib): + for expected, b in zip(self.input_broadcastable, ib, strict=True): if expected and not b: raise TypeError( "The broadcastable pattern of the " @@ -446,7 +446,7 @@ def get_output_info(self, dim_shuffle, *inputs): out_shapes = [ [ broadcast_static_dim_lengths(shape) - for shape in zip(*[inp.type.shape for inp in inputs]) + for shape in zip(*[inp.type.shape for inp in inputs], strict=True) ] ] * shadow.nout except ValueError: @@ -459,8 +459,7 @@ def get_output_info(self, dim_shuffle, *inputs): if inplace_pattern: for overwriter, overwritten in inplace_pattern.items(): for out_s, in_s in zip( - out_shapes[overwriter], - inputs[overwritten].type.shape, + out_shapes[overwriter], inputs[overwritten].type.shape, strict=True ): if in_s == 1 and out_s != 1: raise ValueError( @@ -491,7 +490,7 @@ def make_node(self, *inputs): out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) outputs = [ TensorType(dtype=dtype, shape=shape)() - for dtype, shape in zip(out_dtypes, out_shapes) + for dtype, shape in zip(out_dtypes, out_shapes, strict=True) ] return Apply(self, inputs, outputs) @@ -513,7 +512,9 @@ def R_op(self, inputs, eval_points): bgrads = self._bgrad(inputs, outs, ograds) rop_out = None - for jdx, (inp, eval_point) in enumerate(zip(inputs, eval_points)): + for jdx, (inp, eval_point) in enumerate( + zip(inputs, eval_points, strict=True) + ): # if None, then we can just ignore this branch .. # what we do is to assume that for any non-differentiable # branch, the gradient is actually 0, which I think is not @@ -556,7 +557,7 @@ def L_op(self, inputs, outs, ograds): # the return value obviously zero so that gradient.grad can tell # this op did the right thing. new_rval = [] - for elem, ipt in zip(rval, inputs): + for elem, ipt in zip(rval, inputs, strict=True): if isinstance(elem.type, NullType | DisconnectedType): new_rval.append(elem) else: @@ -642,7 +643,7 @@ def transform(r): return new_r ret = [] - for scalar_igrad, ipt in zip(scalar_igrads, inputs): + for scalar_igrad, ipt in zip(scalar_igrads, inputs, strict=True): if scalar_igrad is None: # undefined gradient ret.append(None) @@ -765,7 +766,7 @@ def perform(self, node, inputs, output_storage): variables = [variables] for i, (variable, storage, nout) in enumerate( - zip(variables, output_storage, node.outputs) + zip(variables, output_storage, node.outputs, strict=True) ): if getattr(variable, "dtype", "") == "object": # Since numpy 1.6, function created with numpy.frompyfunc @@ -800,9 +801,10 @@ def perform(self, node, inputs, output_storage): def _check_runtime_broadcast(node, inputs): for dims_and_bcast in zip( *[ - zip(input.shape, sinput.type.broadcastable) - for input, sinput in zip(inputs, node.inputs) - ] + zip(input.shape, sinput.type.broadcastable, strict=False) + for input, sinput in zip(inputs, node.inputs, strict=True) + ], + strict=True, ): if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: raise ValueError( @@ -831,9 +833,11 @@ def _c_all(self, node, nodename, inames, onames, sub): # assert that inames and inputs order stay consistent. # This is to protect again futur change of uniq. assert len(inames) == len(inputs) - ii, iii = list(zip(*uniq(list(zip(_inames, node.inputs))))) - assert all(x == y for x, y in zip(ii, inames)) - assert all(x == y for x, y in zip(iii, inputs)) + ii, iii = list( + zip(*uniq(list(zip(_inames, node.inputs, strict=True))), strict=True) + ) + assert all(x == y for x, y in zip(ii, inames, strict=True)) + assert all(x == y for x, y in zip(iii, inputs, strict=True)) defines = "" undefs = "" @@ -854,9 +858,10 @@ def _c_all(self, node, nodename, inames, onames, sub): zip( *[ (r, s, r.type.dtype_specs()[1]) - for r, s in zip(node.outputs, onames) + for r, s in zip(node.outputs, onames, strict=True) if r not in dmap - ] + ], + strict=True, ) ) if real: @@ -868,7 +873,14 @@ def _c_all(self, node, nodename, inames, onames, sub): # (output, name), transposed (c type name not needed since we don't # need to allocate. aliased = list( - zip(*[(r, s) for (r, s) in zip(node.outputs, onames) if r in dmap]) + zip( + *[ + (r, s) + for (r, s) in zip(node.outputs, onames, strict=True) + if r in dmap + ], + strict=True, + ) ) if aliased: aliased_outputs, aliased_onames = aliased @@ -886,7 +898,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # dimensionality) nnested = len(orders[0]) sub = dict(sub) - for i, (input, iname) in enumerate(zip(inputs, inames)): + for i, (input, iname) in enumerate(zip(inputs, inames, strict=True)): # the c generators will substitute the input names for # references to loop variables lv0, lv1, ... sub[f"lv{i}"] = iname @@ -896,7 +908,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # Check if all inputs (except broadcasted scalar) are fortran. # In that case, create a fortran output ndarray. - z = list(zip(inames, inputs)) + z = list(zip(inames, inputs, strict=True)) alloc_fortran = " && ".join( f"PyArray_ISFORTRAN({arr})" for arr, var in z @@ -911,7 +923,9 @@ def _c_all(self, node, nodename, inames, onames, sub): # We loop over the "real" outputs, i.e., those that are not # inplace (must be allocated) and we declare/allocate/check # them - for output, oname, odtype in zip(real_outputs, real_onames, real_odtypes): + for output, oname, odtype in zip( + real_outputs, real_onames, real_odtypes, strict=True + ): i += 1 # before this loop, i = number of inputs sub[f"lv{i}"] = oname sub["olv"] = oname @@ -928,7 +942,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # inplace (overwrite the contents of one of the inputs) and # make the output pointers point to their corresponding input # pointers. - for output, oname in zip(aliased_outputs, aliased_onames): + for output, oname in zip(aliased_outputs, aliased_onames, strict=True): olv_index = inputs.index(dmap[output][0]) iname = inames[olv_index] # We make the output point to the corresponding input and @@ -989,12 +1003,16 @@ def _c_all(self, node, nodename, inames, onames, sub): task_decl = "".join( f"{dtype}& {name}_i = *{name}_iter;\n" for name, dtype in zip( - inames + list(real_onames), idtypes + list(real_odtypes) + inames + list(real_onames), + idtypes + list(real_odtypes), + strict=True, ) ) preloops = {} - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate( + zip(loop_orders, dtypes, strict=True) + ): for j, index in enumerate(loop_order): if index != "x": preloops.setdefault(j, "") @@ -1066,7 +1084,9 @@ def _c_all(self, node, nodename, inames, onames, sub): # assume they will have the same size or all( len(set(inp_shape)) == 1 and None not in inp_shape - for inp_shape in zip(*(inp.type.shape for inp in node.inputs)) + for inp_shape in zip( + *(inp.type.shape for inp in node.inputs), strict=True + ) ) ): z = onames[0] @@ -1075,7 +1095,9 @@ def _c_all(self, node, nodename, inames, onames, sub): npy_intp n = PyArray_SIZE({z}); """ index = "" - for x, var in zip(inames + onames, inputs + node.outputs): + for x, var in zip( + inames + onames, inputs + node.outputs, strict=True + ): if not all(s == 1 for s in var.type.shape): contig += f""" dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}); @@ -1097,7 +1119,7 @@ def _c_all(self, node, nodename, inames, onames, sub): }} """ if contig is not None: - z = list(zip(inames + onames, inputs + node.outputs)) + z = list(zip(inames + onames, inputs + node.outputs, strict=True)) all_broadcastable = all(s == 1 for s in var.type.shape) cond1 = " && ".join( f"PyArray_ISCONTIGUOUS({arr})" @@ -1505,7 +1527,7 @@ def _c_all(self, node, name, inames, onames, sub): nnested = len(order1) sub = dict(sub) - for i, (input, iname) in enumerate(zip(node.inputs, inames)): + for i, (input, iname) in enumerate(zip(node.inputs, inames, strict=True)): sub[f"lv{i}"] = iname decl = "" diff --git a/pytensor/tensor/elemwise_cgen.py b/pytensor/tensor/elemwise_cgen.py index 3e37bf7d1a..e70bb936eb 100644 --- a/pytensor/tensor/elemwise_cgen.py +++ b/pytensor/tensor/elemwise_cgen.py @@ -7,7 +7,7 @@ def make_declare(loop_orders, dtypes, sub): """ decl = "" - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): var = sub[f"lv{int(i)}"] # input name corresponding to ith loop variable # we declare an iteration variable # and an integer for the number of dimensions @@ -38,7 +38,7 @@ def make_declare(loop_orders, dtypes, sub): def make_checks(loop_orders, dtypes, sub): init = "" - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): var = f"%(lv{int(i)})s" # List of dimensions of var that are not broadcasted nonx = [x for x in loop_order if x != "x"] @@ -92,7 +92,7 @@ def make_checks(loop_orders, dtypes, sub): "If broadcasting was intended, use `specify_broadcastable` on the relevant input." ) - for matches in zip(*loop_orders): + for matches in zip(*loop_orders, strict=True): to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"] # elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx ) @@ -140,7 +140,7 @@ def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str: Note: We could specialize C code even further with the known static output shapes """ dims_c_code = "" - for i, candidates in enumerate(zip(*loop_orders)): + for i, candidates in enumerate(zip(*loop_orders, strict=True)): # Borrow the length of the first non-broadcastable input dimension for j, candidate in enumerate(candidates): if candidate != "x": @@ -260,7 +260,7 @@ def loop_over(preloop, code, indices, i): """ preloops = {} - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): for j, index in enumerate(loop_order): if index != "x": preloops.setdefault(j, "") @@ -277,7 +277,14 @@ def loop_over(preloop, code, indices, i): s = "" for i, (pre_task, task), indices in reversed( - list(zip(range(len(loop_tasks) - 1), loop_tasks, list(zip(*loop_orders)))) + list( + zip( + range(len(loop_tasks) - 1), + loop_tasks, + list(zip(*loop_orders, strict=True)), + strict=False, + ) + ) ): s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i) @@ -524,7 +531,7 @@ def loop_over(preloop, code, indices, i): """ preloops = {} - for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): + for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): for j, index in enumerate(loop_order): if index != "x": preloops.setdefault(j, "") @@ -543,7 +550,14 @@ def loop_over(preloop, code, indices, i): else: s = "" for i, (pre_task, task), indices in reversed( - list(zip(range(len(loop_tasks) - 1), loop_tasks, list(zip(*loop_orders)))) + list( + zip( + range(len(loop_tasks) - 1), + loop_tasks, + list(zip(*loop_orders, strict=True)), + strict=False, + ) + ) ): s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index cf809a55ef..d7b071b59b 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1528,13 +1528,16 @@ def broadcast_shape_iter( array_shapes = [ (one,) * (max_dims - a.ndim) - + tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)) + + tuple( + one if t_sh == 1 else sh + for sh, t_sh in zip(a.shape, a.type.shape, strict=True) + ) for a in _arrays ] result_dims = [] - for dim_shapes in zip(*array_shapes): + for dim_shapes in zip(*array_shapes, strict=True): # Get the shapes in this dimension that are not broadcastable # (i.e. not symbolically known to be broadcastable) non_bcast_shapes = [shape for shape in dim_shapes if shape != one] diff --git a/pytensor/tensor/functional.py b/pytensor/tensor/functional.py index e7a5371b02..34ba436b4b 100644 --- a/pytensor/tensor/functional.py +++ b/pytensor/tensor/functional.py @@ -85,7 +85,7 @@ def inner(*inputs): # Create dummy core inputs by stripping the batched dimensions of inputs core_inputs = [] - for input, input_sig in zip(inputs, inputs_sig): + for input, input_sig in zip(inputs, inputs_sig, strict=True): if not isinstance(input, TensorVariable): raise TypeError( f"Inputs to vectorize function must be TensorVariable, got {type(input)}" @@ -119,7 +119,9 @@ def inner(*inputs): ) # Vectorize graph by replacing dummy core inputs by original inputs - outputs = vectorize_graph(core_outputs, replace=dict(zip(core_inputs, inputs))) + outputs = vectorize_graph( + core_outputs, replace=dict(zip(core_inputs, inputs, strict=True)) + ) return outputs return inner diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 6db6ae2638..ac23b60516 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -362,7 +362,7 @@ def grad(self, inputs, g_outputs): def _zero_disconnected(outputs, grads): l = [] - for o, g in zip(outputs, grads): + for o, g in zip(outputs, grads, strict=True): if isinstance(g.type, DisconnectedType): l.append(o.zeros_like()) else: @@ -664,7 +664,7 @@ def s_grad_only( return s_grad_only(U, VT, ds) for disconnected, output_grad, output in zip( - is_disconnected, output_grads, [U, s, VT] + is_disconnected, output_grads, [U, s, VT], strict=True ): if disconnected: new_output_grads.append(output.zeros_like()) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 4a2c47b2af..d5e346a5bf 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1862,7 +1862,7 @@ def rng_fn(cls, rng, p, size): # to `p.shape[:-1]` in the call to `vsearchsorted` below. if len(size) < (p.ndim - 1): raise ValueError("`size` is incompatible with the shape of `p`") - for s, ps in zip(reversed(size), reversed(p.shape[:-1])): + for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=True): if s == 1 and ps != 1: raise ValueError("`size` is incompatible with the shape of `p`") diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index ba400454cd..dc37668ec8 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -152,11 +152,13 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): # Try to infer missing support dims from signature of params for param, param_sig, ndim_params in zip( - dist_params, self.inputs_sig, self.ndims_params + dist_params, self.inputs_sig, self.ndims_params, strict=True ): if ndim_params == 0: continue - for param_dim, dim in zip(param.shape[-ndim_params:], param_sig): + for param_dim, dim in zip( + param.shape[-ndim_params:], param_sig, strict=True + ): if dim in core_out_shape and core_out_shape[dim] is None: core_out_shape[dim] = param_dim @@ -231,7 +233,7 @@ def _infer_shape( # Fail early when size is incompatible with parameters for i, (param, param_ndim_supp) in enumerate( - zip(dist_params, self.ndims_params) + zip(dist_params, self.ndims_params, strict=True) ): param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp if param_batched_dims > size_len: @@ -255,7 +257,7 @@ def extract_batch_shape(p, ps, n): batch_shape = tuple( s if not b else constant(1, "int64") - for s, b in zip(shape[:-n], p.type.broadcastable[:-n]) + for s, b in zip(shape[:-n], p.type.broadcastable[:-n], strict=True) ) return batch_shape @@ -264,7 +266,9 @@ def extract_batch_shape(p, ps, n): # independent variate dimensions are left. params_batch_shape = tuple( extract_batch_shape(p, ps, n) - for p, ps, n in zip(dist_params, param_shapes, self.ndims_params) + for p, ps, n in zip( + dist_params, param_shapes, self.ndims_params, strict=False + ) ) if len(params_batch_shape) == 1: diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index 2fd617d8be..de82ab83ac 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -48,7 +48,7 @@ def random_make_inplace(fgraph, node): props["inplace"] = True new_op = type(op)(**props) new_outputs = new_op.make_node(*node.inputs).outputs - for old_out, new_out in zip(node.outputs, new_outputs): + for old_out, new_out in zip(node.outputs, new_outputs, strict=True): copy_stack_trace(old_out, new_out) return new_outputs @@ -171,7 +171,7 @@ def local_dimshuffle_rv_lift(fgraph, node): # Updates the params to reflect the Dimshuffled dimensions new_dist_params = [] - for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params, strict=True): # Add the parameter support dimension indexes to the batched dimensions Dimshuffle param_new_order = batched_dims_ds_order + tuple( range(batched_dims, batched_dims + param_ndim_supp) @@ -290,12 +290,12 @@ def is_nd_advanced_idx(idx, dtype) -> bool: # non-broadcastable (non-degenerate) parameter dims. These parameters and the new size # should still correctly broadcast any degenerate parameter dims. new_dist_params = [] - for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): + for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params, strict=True): # Check which dims are broadcasted by either size or other parameters bcast_param_dims = tuple( dim for dim, (param_dim_bcast, output_dim_bcast) in enumerate( - zip(param.type.broadcastable, rv.type.broadcastable) + zip(param.type.broadcastable, rv.type.broadcastable, strict=False) ) if param_dim_bcast and not output_dim_bcast ) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 075d09b053..e96fd779a5 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -44,7 +44,7 @@ def params_broadcast_shapes( max_fn = maximum if use_pytensor else max rev_extra_dims: list[int] = [] - for ndim_param, param_shape in zip(ndims_params, param_shapes): + for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True): # We need this in order to use `len` param_shape = tuple(param_shape) extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) @@ -67,7 +67,7 @@ def max_bcast(x, y): (extra_dims + tuple(param_shape)[-ndim_param:]) if ndim_param > 0 else extra_dims - for ndim_param, param_shape in zip(ndims_params, param_shapes) + for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True) ] return bcast_shapes @@ -112,7 +112,9 @@ def broadcast_params( for p in params: param_shape = tuple( 1 if bcast else s - for s, bcast in zip(p.shape, getattr(p, "broadcastable", (False,) * p.ndim)) + for s, bcast in zip( + p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=True + ) ) use_pytensor |= isinstance(p, Variable) param_shapes.append(param_shape) @@ -123,7 +125,8 @@ def broadcast_params( broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to bcast_params = [ - broadcast_to_fn(param, shape) for shape, param in zip(shapes, params) + broadcast_to_fn(param, shape) + for shape, param in zip(shapes, params, strict=True) ] return bcast_params @@ -137,7 +140,8 @@ def explicit_expand_dims( """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" batch_dims = [ - param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params) + param.type.ndim - ndim_param + for param, ndim_param in zip(params, ndim_params, strict=False) ] if size_length is not None: @@ -146,7 +150,7 @@ def explicit_expand_dims( max_batch_dims = max(batch_dims, default=0) new_params = [] - for new_param, batch_dim in zip(params, batch_dims): + for new_param, batch_dim in zip(params, batch_dims, strict=True): missing_dims = max_batch_dims - batch_dim if missing_dims: new_param = shape_padleft(new_param, missing_dims) @@ -161,7 +165,7 @@ def compute_batch_shape( params = explicit_expand_dims(params, ndims_params) batch_params = [ param[(..., *(0,) * core_ndim)] - for param, core_ndim in zip(params, ndims_params) + for param, core_ndim in zip(params, ndims_params, strict=True) ] return broadcast_arrays(*batch_params)[0].shape @@ -279,7 +283,9 @@ def seed(self, seed=None): self.gen_seedgen = np.random.SeedSequence(seed) old_r_seeds = self.gen_seedgen.spawn(len(self.state_updates)) - for (old_r, new_r), old_r_seed in zip(self.state_updates, old_r_seeds): + for (old_r, new_r), old_r_seed in zip( + self.state_updates, old_r_seeds, strict=True + ): old_r.set_value(self.rng_ctor(old_r_seed), borrow=True) def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable: diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 4a7570dad3..1365f0b767 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -97,11 +97,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: if len(bx) < len(by): return True bx = bx[-len(by) :] - return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by)) + return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True)) def merge_broadcastables(broadcastables): - return [all(bcast) for bcast in zip(*broadcastables)] + return [all(bcast) for bcast in zip(*broadcastables, strict=True)] def alloc_like( @@ -1203,7 +1203,7 @@ def local_merge_alloc(fgraph, node): # broadcasted dimensions to its inputs[0]. Eg: # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) i = 0 - for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev): + for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev, strict=False): if dim_inner != dim_outer: if isinstance(dim_inner, Constant) and dim_inner.data == 1: pass diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index cc8dd472e6..43e748165e 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -502,7 +502,7 @@ def on_import(new_node): ].tag.values_eq_approx = values_eq_approx_remove_inf_nan try: fgraph.replace_all_validate_remove( - list(zip(node.outputs, new_outputs)), + list(zip(node.outputs, new_outputs, strict=True)), [old_dot22], reason="GemmOptimizer", # For now we disable the warning as we know case diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 0bed304c29..019ff1c5fc 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -109,7 +109,7 @@ def local_blockwise_alloc(fgraph, node): new_inputs = [] batch_shapes = [] can_push_any_alloc = False - for inp, inp_sig in zip(node.inputs, op.inputs_sig): + for inp, inp_sig in zip(node.inputs, op.inputs_sig, strict=True): if inp.owner and isinstance(inp.owner.op, Alloc): # Push batch dims from Alloc value, *shape = inp.owner.inputs @@ -130,6 +130,7 @@ def local_blockwise_alloc(fgraph, node): for broadcastable, dim in zip( squeezed_value.type.broadcastable[:squeezed_value_batch_ndim], tuple(squeezed_value.shape)[:squeezed_value_batch_ndim], + strict=True, ) ] squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) @@ -143,7 +144,7 @@ def local_blockwise_alloc(fgraph, node): tuple( 1 if broadcastable else dim for broadcastable, dim in zip( - inp.type.broadcastable, shape[:batch_ndim] + inp.type.broadcastable, shape[:batch_ndim], strict=False ) ) ) @@ -166,7 +167,9 @@ def local_blockwise_alloc(fgraph, node): # We pick the most parsimonious batch dim from the pushed Alloc missing_ndim = old_out_type.ndim - new_out_type.ndim batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] - for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples + for i, batch_dims in enumerate( + zip(*batch_shapes, strict=True) + ): # Transpose shape tuples for batch_dim in batch_dims: if batch_dim == 1: continue diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 99dee1fd3f..a2859ce441 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -300,7 +300,7 @@ def apply(self, fgraph): ) new_node = new_outputs[0].owner - for r, new_r in zip(node.outputs, new_outputs): + for r, new_r in zip(node.outputs, new_outputs, strict=True): prof["nb_call_replace"] += 1 fgraph.replace( r, new_r, reason="inplace_elemwise_optimizer" @@ -1036,12 +1036,12 @@ def update_fuseable_mappings_after_fg_replace( ) if not isinstance(composite_outputs, list): composite_outputs = [composite_outputs] - for old_out, composite_out in zip(outputs, composite_outputs): + for old_out, composite_out in zip(outputs, composite_outputs, strict=True): if old_out.name: composite_out.name = old_out.name fgraph.replace_all_validate( - list(zip(outputs, composite_outputs)), + list(zip(outputs, composite_outputs, strict=True)), reason=self.__class__.__name__, ) nb_replacement += 1 @@ -1117,7 +1117,7 @@ def local_useless_composite_outputs(fgraph, node): used_inputs = [node.inputs[i] for i in used_inputs_idxs] c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) - return dict(zip([node.outputs[i] for i in used_outputs_idxs], e)) + return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True)) @node_rewriter([CAReduce]) @@ -1217,7 +1217,9 @@ def local_inline_composite_constants(fgraph, node): new_outer_inputs = [] new_inner_inputs = [] inner_replacements = {} - for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs): + for outer_inp, inner_inp in zip( + node.inputs, composite_op.fgraph.inputs, strict=True + ): # Complex variables don't have a `c_literal` that can be inlined if "complex" not in outer_inp.type.dtype: unique_value = get_unique_constant_value(outer_inp) @@ -1354,7 +1356,7 @@ def local_useless_2f1grad_loop(fgraph, node): replacements = {converges: new_converges} i = 0 - for grad_var, is_used in zip(grad_vars, grad_var_is_used): + for grad_var, is_used in zip(grad_vars, grad_var_is_used, strict=True): if not is_used: continue replacements[grad_var] = new_outs[i] diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 75dba82d97..f0d98dadb1 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1139,7 +1139,9 @@ def transform(self, fgraph, node): num, denum = self.simplify(list(orig_num), list(orig_denum), out.type) def same(x, y): - return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in zip(x, y)) + return len(x) == len(y) and all( + np.all(xe == ye) for xe, ye in zip(x, y, strict=True) + ) if ( same(orig_num, num) @@ -2442,7 +2444,9 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0): [(n + num, d + denum, out_type) for (n, d) in neg_pairs], ) ) - for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs): + for (n, d), (nn, dd) in zip( + pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs, strict=True + ): # We calculate how many operations we are saving with the new # num and denum score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd) diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 1426a7d993..2f53ebae1d 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -186,7 +186,7 @@ def get_shape(self, var, idx): # Only change the variables and dimensions that would introduce # extra computation - for new_shps, out in zip(o_shapes, node.outputs): + for new_shps, out in zip(o_shapes, node.outputs, strict=True): if not hasattr(out.type, "ndim"): continue @@ -578,7 +578,7 @@ def on_import(self, fgraph, node, reason): new_shape += sh[len(new_shape) :] o_shapes[sh_idx] = tuple(new_shape) - for r, s in zip(node.outputs, o_shapes): + for r, s in zip(node.outputs, o_shapes, strict=True): self.set_shape(r, s) def on_change_input(self, fgraph, node, i, r, new_r, reason): @@ -709,7 +709,7 @@ def same_shape( sx = canon_shapes[: len(sx)] sy = canon_shapes[len(sx) :] - for dx, dy in zip(sx, sy): + for dx, dy in zip(sx, sy, strict=True): if not equal_computations([dx], [dy]): return False @@ -778,7 +778,7 @@ def f(fgraph, node): # rewrite. if rval.type.ndim == node.outputs[0].type.ndim and all( s1 == s2 - for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) + for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True) if s1 == 1 or s2 == 1 ): return [rval] @@ -817,7 +817,7 @@ def local_useless_reshape(fgraph, node): and output.type.ndim == 1 and all( s1 == s2 - for s1, s2 in zip(inp.type.shape, output.type.shape) + for s1, s2 in zip(inp.type.shape, output.type.shape, strict=True) if s1 == 1 or s2 == 1 ) ): @@ -1100,7 +1100,9 @@ def local_specify_shape_lift(fgraph, node): nonbcast_dims = { i - for i, (dim, bcast) in enumerate(zip(shape, out_broadcastable)) + for i, (dim, bcast) in enumerate( + zip(shape, out_broadcastable, strict=True) + ) if (not bcast and not NoneConst.equals(dim)) } new_elem_inps = elem_inps.copy() @@ -1202,7 +1204,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): new_order = node.inputs[0].owner.op.new_order inp = node.inputs[0].owner.inputs[0] new_order_of_nonbroadcast = [] - for i, s in zip(new_order, node.inputs[0].type.shape): + for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): if s != 1: new_order_of_nonbroadcast.append(i) no_change_in_order = all( @@ -1226,7 +1228,7 @@ def local_useless_unbroadcast(fgraph, node): x = node.inputs[0] if x.type.ndim == node.outputs[0].type.ndim and all( s1 == s2 - for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape) + for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True) if s1 == 1 or s2 == 1 ): # No broadcastable flag was modified diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 8ee86e6021..4a0ccbc487 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -648,7 +648,7 @@ def local_subtensor_of_alloc(fgraph, node): # Slices to take from val val_slices = [] - for i, (sl, dim) in enumerate(zip(slices, dims)): + for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): # If val was not copied over that dim, # we need to take the appropriate subtensor on it. if i >= n_added_dims: @@ -1772,7 +1772,7 @@ def local_join_subtensors(fgraph, node): if all( idxs_nonaxis_subtensor1 == idxs_nonaxis_subtensor2 for i, (idxs_nonaxis_subtensor1, idxs_nonaxis_subtensor2) in enumerate( - zip(idxs_subtensor1, idxs_subtensor2) + zip(idxs_subtensor1, idxs_subtensor2, strict=True) ) if i != axis ): @@ -1914,7 +1914,7 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): x_batch_bcast = x.type.broadcastable[:batch_ndim] y_batch_bcast = y.type.broadcastable[:batch_ndim] - if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)): + if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)): # Need to broadcast batch x dims batch_shape = tuple( x_dim if (not xb or yb) else y_dim @@ -1923,6 +1923,7 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node): tuple(x.shape)[:batch_ndim], y_batch_bcast, tuple(y.shape)[:batch_ndim], + strict=True, ) ) core_shape = tuple(x.shape)[batch_ndim:] diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 236c34b442..21a5519466 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -436,7 +436,7 @@ def make_node(self, x, *shape): ) type_shape = [None] * x.ndim - for i, (xts, s) in enumerate(zip(x.type.shape, shape)): + for i, (xts, s) in enumerate(zip(x.type.shape, shape, strict=True)): if xts is not None: type_shape[i] = xts else: @@ -459,7 +459,9 @@ def perform(self, node, inp, out_): raise AssertionError( f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." ) - if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None): + if not all( + xs == s for xs, s in zip(x.shape, shape, strict=True) if s is not None + ): raise AssertionError( f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." ) @@ -523,7 +525,9 @@ def c_code(self, node, name, i_names, o_names, sub): """ ) - for i, (shp_name, shp) in enumerate(zip(shape_names, node.inputs[1:])): + for i, (shp_name, shp) in enumerate( + zip(shape_names, node.inputs[1:], strict=True) + ): if NoneConst.equals(shp): continue code += dedent( @@ -586,7 +590,7 @@ def specify_shape( # The above is a type error in Python 3.9 but not 3.12. # Thus we need to ignore unused-ignore on 3.12. new_shape_info = any( - s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None + s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None ) # If shape does not match x.ndim, we rely on the `Op` to raise a ValueError if not new_shape_info and len(shape) == x.type.ndim: diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index db8303b2d8..997e024c78 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -897,7 +897,7 @@ def grad(self, inputs, gout): return [gout[0][slc] for slc in slices] def infer_shape(self, fgraph, nodes, shapes): - first, second = zip(*shapes) + first, second = zip(*shapes, strict=True) return [(pt.add(*first), pt.add(*second))] def _validate_and_prepare_inputs(self, matrices, as_tensor_func): diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 41b4c6bd5a..3a684d2c07 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -523,7 +523,7 @@ def basic_shape(shape, indices): """ res_shape = () - for idx, n in zip(indices, shape): + for idx, n in zip(indices, shape, strict=False): if isinstance(idx, slice): res_shape += (slice_len(idx, n),) elif isinstance(getattr(idx, "type", None), SliceType): @@ -611,7 +611,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ) for basic, grp_dim_indices in idx_groups: - dim_nums, grp_indices = zip(*grp_dim_indices) + dim_nums, grp_indices = zip(*grp_dim_indices, strict=True) remaining_dims = tuple(dim for dim in remaining_dims if dim not in dim_nums) if basic: @@ -839,7 +839,7 @@ def make_node(self, x, *inputs): assert len(inputs) == len(input_types) - for input, expected_type in zip(inputs, input_types): + for input, expected_type in zip(inputs, input_types, strict=True): if not expected_type.is_super(input.type): raise TypeError( f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}." @@ -861,7 +861,7 @@ def extract_const(value): except NotScalarConstantError: return value, False - for the_slice, length in zip(padded, x.type.shape): + for the_slice, length in zip(padded, x.type.shape, strict=True): if not isinstance(the_slice, slice): continue @@ -916,7 +916,7 @@ def infer_shape(self, fgraph, node, shapes): len(xshp) - len(self.idx_list) ) i = 0 - for idx, xl in zip(padded, xshp): + for idx, xl in zip(padded, xshp, strict=True): if isinstance(idx, slice): # If it is the default (None, None, None) slice, or a variant, # the shape will be xl @@ -1688,7 +1688,7 @@ def make_node(self, x, y, *inputs): raise IndexError( "Not enough inputs to fill in the Subtensor template.", inputs, idx_list ) - for input, expected_type in zip(inputs, input_types): + for input, expected_type in zip(inputs, input_types, strict=True): if not expected_type.is_super(input.type): raise TypeError( f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}." @@ -2703,7 +2703,7 @@ def is_bool_index(idx): indices = node.inputs[1:] index_shapes = [] - for idx, ishape in zip(indices, ishapes[1:]): + for idx, ishape in zip(indices, ishapes[1:], strict=True): # Mixed bool indexes are converted to nonzero entries if is_bool_index(idx): index_shapes.extend( @@ -2807,7 +2807,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): x_is_batched = x.type.ndim < batch_x.type.ndim idxs_are_batched = any( batch_idx.type.ndim > idx.type.ndim - for batch_idx, idx in zip(batch_idxs, idxs) + for batch_idx, idx in zip(batch_idxs, idxs, strict=True) if isinstance(batch_idx, TensorVariable) ) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 3ba34a2903..f389dc68b1 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -251,7 +251,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: if not all( ds == ts if ts is not None else True - for ds, ts in zip(data.shape, self.shape) + for ds, ts in zip(data.shape, self.shape, strict=True) ): raise TypeError( f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" @@ -326,7 +326,10 @@ def is_super(self, otype): and otype.ndim == self.ndim # `otype` is allowed to be as or more shape-specific than `self`, # but not less - and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape)) + and all( + sb == ob or sb is None + for sb, ob in zip(self.shape, otype.shape, strict=True) + ) ): return True diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 8f8ef99657..c1a5ce3682 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -98,7 +98,7 @@ def shape_of_variables( numeric_input_dims = [dim for inp in fgraph.inputs for dim in input_shapes[inp]] numeric_output_dims = compute_shapes(*numeric_input_dims) - sym_to_num_dict = dict(zip(output_dims, numeric_output_dims)) + sym_to_num_dict = dict(zip(output_dims, numeric_output_dims, strict=True)) l = {} for var in shape_feature.shape_of: diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 613fb80f3e..885ba7d25a 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -1061,7 +1061,9 @@ def __init__(self, type: _TensorTypeType, data, name=None): data_shape = np.shape(data) if len(data_shape) != type.ndim or any( - ds != ts for ds, ts in zip(np.shape(data), type.shape) if ts is not None + ds != ts + for ds, ts in zip(np.shape(data), type.shape, strict=True) + if ts is not None ): raise ValueError( f"Shape of data ({data_shape}) does not match shape of type ({type.shape})" diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index af292eb10d..9c730fb38e 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -371,7 +371,7 @@ def test_copy_share_memory(self): # Assert storages of SharedVariable without updates are shared for (input, _1, _2), here, there in zip( - ori.indices, ori.input_storage, cpy.input_storage + ori.indices, ori.input_storage, cpy.input_storage, strict=True ): assert here.data is there.data @@ -467,7 +467,7 @@ def test_swap_SharedVariable_with_given(self): swap={train_x: test_x, train_y: test_y}, delete_updates=True ) - for in1, in2 in zip(test_def.maker.inputs, test_cpy.maker.inputs): + for in1, in2 in zip(test_def.maker.inputs, test_cpy.maker.inputs, strict=True): assert in1.value is in2.value def test_copy_delete_updates(self): @@ -905,7 +905,7 @@ def test_deepcopy(self): # print(f"{f.defaults = }") # print(f"{g.defaults = }") for (f_req, f_feed, f_val), (g_req, g_feed, g_val) in zip( - f.defaults, g.defaults + f.defaults, g.defaults, strict=True ): assert f_req == g_req and f_feed == g_feed and f_val == g_val @@ -1076,7 +1076,7 @@ def test_optimizations_preserved(self): tf = f.maker.fgraph.toposort() tg = f.maker.fgraph.toposort() assert len(tf) == len(tg) - for nf, ng in zip(tf, tg): + for nf, ng in zip(tf, tg, strict=True): assert nf.op == ng.op assert len(nf.inputs) == len(ng.inputs) assert len(nf.outputs) == len(ng.outputs) diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index d71094bfed..23e9193c0e 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -722,5 +722,5 @@ def test_debugprint(): └─ *2- [id I] """ - for truth, out in zip(exp_res.split("\n"), lines): + for truth, out in zip(exp_res.split("\n"), lines, strict=True): assert truth.strip() == out.strip() diff --git a/tests/d3viz/test_formatting.py b/tests/d3viz/test_formatting.py index f0cbd3fdd7..9f5f8be9ec 100644 --- a/tests/d3viz/test_formatting.py +++ b/tests/d3viz/test_formatting.py @@ -19,7 +19,7 @@ def setup_method(self): def node_counts(self, graph): node_types = [node.get_attributes()["node_type"] for node in graph.get_nodes()] a, b = np.unique(node_types, return_counts=True) - nc = dict(zip(a, b)) + nc = dict(zip(a, b, strict=True)) return nc @pytest.mark.parametrize("mode", ["FAST_RUN", "FAST_COMPILE"]) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index f2550d348e..e82a59e790 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -32,13 +32,22 @@ def test_pickle(self): s = pickle.dumps(func) new_func = pickle.loads(s) - assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs)) - assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs)) + assert all( + type(a) is type(b) + for a, b in zip(func.inputs, new_func.inputs, strict=True) + ) + assert all( + type(a) is type(b) + for a, b in zip(func.outputs, new_func.outputs, strict=True) + ) assert all( type(a.op) is type(b.op) - for a, b in zip(func.apply_nodes, new_func.apply_nodes) + for a, b in zip(func.apply_nodes, new_func.apply_nodes, strict=True) + ) + assert all( + a.type == b.type + for a, b in zip(func.variables, new_func.variables, strict=True) ) - assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables)) def test_validate_inputs(self): var1 = op1() diff --git a/tests/graph/utils.py b/tests/graph/utils.py index d48e0b2a35..86b52a7ed1 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -137,7 +137,9 @@ def __init__(self, inner_inputs, inner_outputs): if not isinstance(v, Constant) ] outputs = clone_replace(inner_outputs, replace=input_replacements) - _, inputs = zip(*input_replacements) if input_replacements else (None, []) + _, inputs = ( + zip(*input_replacements, strict=True) if input_replacements else (None, []) + ) self.fgraph = FunctionGraph(inputs, outputs, clone=False) def make_node(self, *inputs): diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 5cd2bd54c6..49e26e4c8d 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -80,7 +80,7 @@ def compare_jax_and_py( py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: - for j, p in zip(jax_res, py_res): + for j, p in zip(jax_res, py_res, strict=True): assert_fn(j, p) else: assert_fn(jax_res, py_res) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index dfbc888e30..fe41e635ff 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -61,7 +61,9 @@ def test_random_updates(rng_ctor): # Check that original rng variable content was not overwritten when calling jax_typify assert all( a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) - for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__()) + for a, b in zip( + rng.get_value().__getstate__(), original_value.__getstate__(), strict=True + ) ) @@ -92,7 +94,9 @@ def test_replaced_shared_rng_storage_order(noise_first): ), "Test may need to be tweaked" # Confirm that input_storage type and fgraph input order are aligned - for storage, fgrapn_input in zip(f.input_storage, f.maker.fgraph.inputs): + for storage, fgrapn_input in zip( + f.input_storage, f.maker.fgraph.inputs, strict=True + ): assert storage.type == fgrapn_input.type assert mu.get_value() == 1 diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index cfbc61eaca..dfadc58a69 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -292,7 +292,7 @@ def assert_fn(x, y): eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) if len(fn_outputs) > 1: - for j, p in zip(numba_res, py_res): + for j, p in zip(numba_res, py_res, strict=True): assert_fn(j, p) else: assert_fn(numba_res[0], py_res[0]) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 5db0f24222..5b9436688b 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -488,7 +488,7 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1): ref_fn = pytensor.function(list(test), outs, mode=get_mode("FAST_COMPILE")) ref_res = ref_fn(*test.values()) - for numba_r, ref_r in zip(numba_res, ref_res): + for numba_r, ref_r in zip(numba_res, ref_res, strict=True): np.testing.assert_array_almost_equal(numba_r, ref_r) benchmark(numba_fn, *test.values()) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 27c1b1bd6a..6dde84f5f9 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -66,7 +66,7 @@ def compare_pytorch_and_py( py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: - for j, p in zip(pytorch_res, py_res): + for j, p in zip(pytorch_res, py_res, strict=True): assert_fn(j.cpu(), p) else: assert_fn([pytorch_res[0].cpu()], py_res) diff --git a/tests/link/test_link.py b/tests/link/test_link.py index a2e264759b..7d84c2a478 100644 --- a/tests/link/test_link.py +++ b/tests/link/test_link.py @@ -44,7 +44,7 @@ def execute(*args): got = len(args) if got != takes: raise TypeError(f"Function call takes exactly {takes} args ({got} given)") - for arg, variable in zip(args, inputs): + for arg, variable in zip(args, inputs, strict=True): variable.data = arg thunk() if unpack_single: diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 880fcbd5fc..f0ed90c94d 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -174,7 +174,7 @@ def max_err(self, _g_pt): raise ValueError("argument has wrong number of elements", len(g_pt)) errs = [] - for i, (a, b) in enumerate(zip(g_pt, self.gx)): + for i, (a, b) in enumerate(zip(g_pt, self.gx, strict=True)): if a.shape != b.shape: raise ValueError( f"argument element {i} has wrong shape {(a.shape, b.shape)}" @@ -202,7 +202,10 @@ def scan_project_sum(*args, **kwargs): rng.add_default_updates = False factors = [rng.uniform(0.1, 0.9, size=s.shape) for s in scan_outputs] # Random values (?) - return (sum((s * f).sum() for s, f in zip(scan_outputs, factors)), updates) + return ( + sum((s * f).sum() for s, f in zip(scan_outputs, factors, strict=True)), + updates, + ) def asarrayX(value): @@ -3844,7 +3847,7 @@ def one_step(x_t, h_tm2, h_tm1, W_ih, W_hh, b_h, W_ho, b_o): gparams = grad(cost, params) updates = [ (param, param - gparam * learning_rate) - for param, gparam in zip(params, gparams) + for param, gparam in zip(params, gparams, strict=True) ] learn_rnn_fn = function(inputs=[x, t], outputs=cost, updates=updates, mode=mode) function(inputs=[x], outputs=y, mode=mode) diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 42d81fbf11..80fe808f59 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -64,7 +64,7 @@ def test_debugprint_sitsot(): ├─ *0- [id X] -> [id E] (inner_in_sit_sot-0) └─ *1- [id Y] -> [id M] (inner_in_non_seqs-0)""" - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=False): assert truth.strip() == out.strip() @@ -122,7 +122,7 @@ def test_debugprint_sitsot_no_extra_info(): ├─ *0- [id X] -> [id E] └─ *1- [id Y] -> [id M]""" - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=False): assert truth.strip() == out.strip() @@ -190,7 +190,7 @@ def test_debugprint_nitsot(): ├─ *2- [id BA] -> [id W] (inner_in_non_seqs-0) └─ *1- [id BB] -> [id U] (inner_in_seqs-1)""" - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=False): assert truth.strip() == out.strip() @@ -305,7 +305,7 @@ def compute_A_k(A, k): ├─ *0- [id CB] -> [id BG] (inner_in_sit_sot-0) └─ *1- [id CC] -> [id BO] (inner_in_non_seqs-0)""" - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=False): assert truth.strip() == out.strip() fg = FunctionGraph([c, k, A], [final_result]) @@ -404,7 +404,7 @@ def compute_A_k(A, k): ├─ *0- [id CA] (inner_in_sit_sot-0) └─ *1- [id CB] (inner_in_non_seqs-0)""" - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=False): assert truth.strip() == out.strip() @@ -479,7 +479,7 @@ def fn(a_m2, a_m1, b_m2, b_m1): ├─ *3- [id BF] -> [id O] (inner_in_mit_sot-1-1) └─ *2- [id BG] -> [id O] (inner_in_mit_sot-1-0)""" - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=False): assert truth.strip() == out.strip() @@ -615,7 +615,7 @@ def test_debugprint_mitmot(): ├─ *0- [id CT] -> [id H] (inner_in_sit_sot-0) └─ *1- [id CW] -> [id P] (inner_in_non_seqs-0)""" - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=False): assert truth.strip() == out.strip() @@ -677,7 +677,7 @@ def no_shared_fn(n, x_tm1, M): output_str = debugprint(out, file="str", print_op_info=True) lines = output_str.split("\n") - for truth, out in zip(expected_output.split("\n"), lines): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() diff --git a/tests/scan/test_utils.py b/tests/scan/test_utils.py index a26c2cbd4b..3586101ada 100644 --- a/tests/scan/test_utils.py +++ b/tests/scan/test_utils.py @@ -220,7 +220,7 @@ def test_ScanArgs_remove_inner_input(): test_v = sigmas_t rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=False) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert sigmas_t in removed_nodes assert sigmas_t not in scan_args_copy.inner_in_seqs @@ -232,7 +232,7 @@ def test_ScanArgs_remove_inner_input(): # This removal includes dependents rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `sigmas[t]` (i.e. inner-graph input) should be gone assert sigmas_t in removed_nodes @@ -288,7 +288,7 @@ def test_ScanArgs_remove_outer_input(): scan_args_copy = copy(scan_args) test_v = sigmas_in rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `sigmas_in` (i.e. outer-graph input) should be gone assert scan_args.outer_in_seqs[-1] in removed_nodes @@ -334,7 +334,7 @@ def test_ScanArgs_remove_inner_output(): scan_args_copy = copy(scan_args) test_v = Y_t rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `Y_t` (i.e. inner-graph output) should be gone assert Y_t in removed_nodes @@ -371,7 +371,7 @@ def test_ScanArgs_remove_outer_output(): scan_args_copy = copy(scan_args) test_v = Y_rv rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) # `Y_t` (i.e. inner-graph output) should be gone assert Y_t in removed_nodes @@ -409,7 +409,7 @@ def test_ScanArgs_remove_nonseq_outer_input(): scan_args_copy = copy(scan_args) test_v = Gamma_rv rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert Gamma_rv in removed_nodes assert Gamma_in in removed_nodes @@ -447,7 +447,7 @@ def test_ScanArgs_remove_nonseq_inner_input(): scan_args_copy = copy(scan_args) test_v = Gamma_in rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert Gamma_in in removed_nodes assert Gamma_rv in removed_nodes @@ -482,7 +482,7 @@ def test_ScanArgs_remove_shared_inner_output(): scan_update = scan_args.inner_out_shared[0] scan_args_copy = copy(scan_args) rm_info = scan_args_copy.remove_from_fields(scan_update, rm_dependents=True) - removed_nodes, _ = zip(*rm_info) + removed_nodes, _ = zip(*rm_info, strict=True) assert rng_in in removed_nodes assert all(v in removed_nodes for v in scan_args.inner_out_shared) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index e4f2a69404..f2f3d76b62 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -335,7 +335,7 @@ def f(spdata): oconv = conv_none def conv_op(*inputs): - ipt = [conv(i) for i, conv in zip(inputs, iconv)] + ipt = [conv(i) for i, conv in zip(inputs, iconv, strict=True)] out = op(*ipt) return oconv(out) @@ -2193,7 +2193,7 @@ def setup_method(self): def test_op(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) f = pytensor.function(variable, self.op(*variable)) @@ -2204,7 +2204,7 @@ def test_op(self): def test_infer_shape(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) self._compile_and_check( variable, [self.op(*variable)], data, self.op_class @@ -2212,7 +2212,7 @@ def test_infer_shape(self): def test_grad(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) verify_grad_sparse(self.op, data, structured=False) @@ -2224,7 +2224,7 @@ def setup_method(self): def test_op(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) data[0][0, 0] = data[0][1, 1] = 0 @@ -2243,7 +2243,7 @@ def test_op(self): def test_grad(self): for format in sparse.sparse_formats: - for shape in zip(range(5, 9), range(3, 7)[::-1]): + for shape in zip(range(5, 9), range(3, 7)[::-1], strict=True): variable, data = sparse_random_inputs(format, shape=shape) verify_grad_sparse(self.op, data, structured=False) diff --git a/tests/tensor/conv/test_abstract_conv.py b/tests/tensor/conv/test_abstract_conv.py index 223e3774c2..fbd45ffdaa 100644 --- a/tests/tensor/conv/test_abstract_conv.py +++ b/tests/tensor/conv/test_abstract_conv.py @@ -461,7 +461,8 @@ def get_output_shape( self, inputs_shape, filters_shape, subsample, border_mode, filter_dilation ): dil_filters = tuple( - (s - 1) * d + 1 for s, d in zip(filters_shape[2:], filter_dilation) + (s - 1) * d + 1 + for s, d in zip(filters_shape[2:], filter_dilation, strict=True) ) if border_mode == "valid": border_mode = (0,) * (len(inputs_shape) - 2) @@ -484,6 +485,7 @@ def get_output_shape( subsample, border_mode, filter_dilation, + strict=True, ) ), ) @@ -760,7 +762,7 @@ def test_all(self): db = self.default_border_mode dflip = self.default_filter_flip dprovide_shape = self.default_provide_shape - for i, f in zip(self.inputs_shapes, self.filters_shapes): + for i, f in zip(self.inputs_shapes, self.filters_shapes, strict=True): for provide_shape in self.provide_shape: self.run_test_case(i, f, ds, db, dflip, provide_shape) if min(i) > 0 and min(f) > 0: @@ -1761,7 +1763,9 @@ def test_conv2d_grad_wrt_inputs(self): # the outputs of `pytensor.tensor.conv` forward grads to make sure the # results are the same. - for in_shape, fltr_shape in zip(self.inputs_shapes, self.filters_shapes): + for in_shape, fltr_shape in zip( + self.inputs_shapes, self.filters_shapes, strict=False + ): for bm in self.border_modes: for ss in self.subsamples: for ff in self.filter_flip: @@ -1823,7 +1827,9 @@ def test_conv2d_grad_wrt_weights(self): # the outputs of `pytensor.tensor.conv` forward grads to make sure the # results are the same. - for in_shape, fltr_shape in zip(self.inputs_shapes, self.filters_shapes): + for in_shape, fltr_shape in zip( + self.inputs_shapes, self.filters_shapes, strict=False + ): for bm in self.border_modes: for ss in self.subsamples: for ff in self.filter_flip: @@ -1915,7 +1921,7 @@ def test_fwd(self): kern_sym = tensor5("kern") for imshp, kshp, groups in zip( - self.img_shape, self.kern_shape, self.num_groups + self.img_shape, self.kern_shape, self.num_groups, strict=True ): img = np.random.random(imshp).astype(config.floatX) kern = np.random.random(kshp).astype(config.floatX) @@ -1951,7 +1957,7 @@ def test_fwd(self): ) ref_concat_output = [ ref_func(img_arr, kern_arr) - for img_arr, kern_arr in zip(split_imgs, split_kern) + for img_arr, kern_arr in zip(split_imgs, split_kern, strict=True) ] ref_concat_output = np.concatenate(ref_concat_output, axis=1) @@ -1967,7 +1973,11 @@ def test_gradweights(self): img_sym = tensor5("img") top_sym = tensor5("kern") for imshp, kshp, tshp, groups in zip( - self.img_shape, self.kern_shape, self.top_shape, self.num_groups + self.img_shape, + self.kern_shape, + self.top_shape, + self.num_groups, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) top = np.random.random(tshp).astype(config.floatX) @@ -2005,7 +2015,7 @@ def test_gradweights(self): ) ref_concat_output = [ ref_func(img_arr, top_arr) - for img_arr, top_arr in zip(split_imgs, split_top) + for img_arr, top_arr in zip(split_imgs, split_top, strict=True) ] ref_concat_output = np.concatenate(ref_concat_output, axis=0) @@ -2028,7 +2038,11 @@ def test_gradinputs(self): kern_sym = tensor5("kern") top_sym = tensor5("top") for imshp, kshp, tshp, groups in zip( - self.img_shape, self.kern_shape, self.top_shape, self.num_groups + self.img_shape, + self.kern_shape, + self.top_shape, + self.num_groups, + strict=True, ): kern = np.random.random(kshp).astype(config.floatX) top = np.random.random(tshp).astype(config.floatX) @@ -2066,7 +2080,7 @@ def test_gradinputs(self): ) ref_concat_output = [ ref_func(kern_arr, top_arr) - for kern_arr, top_arr in zip(split_kerns, split_top) + for kern_arr, top_arr in zip(split_kerns, split_top, strict=True) ] ref_concat_output = np.concatenate(ref_concat_output, axis=1) @@ -2368,6 +2382,7 @@ def test_fwd(self): self.subsample, self.num_groups, self.verify_flags, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) kern = np.random.random(kshp).astype(config.floatX) @@ -2426,6 +2441,7 @@ def test_gradweight(self): self.subsample, self.num_groups, self.verify_flags, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) top = np.random.random(topshp).astype(config.floatX) @@ -2494,6 +2510,7 @@ def test_gradinput(self): self.subsample, self.num_groups, self.verify_flags, + strict=True, ): single_kshp = kshp[:1] + kshp[3:] @@ -2576,7 +2593,9 @@ def test_fwd(self): img_sym = tensor4("img") kern_sym = tensor4("kern") - for imshp, kshp, pad in zip(self.img_shape, self.kern_shape, self.border_mode): + for imshp, kshp, pad in zip( + self.img_shape, self.kern_shape, self.border_mode, strict=True + ): img = np.random.random(imshp).astype(config.floatX) kern = np.random.random(kshp).astype(config.floatX) @@ -2627,7 +2646,11 @@ def test_gradweight(self): top_sym = tensor4("top") for imshp, kshp, topshp, pad in zip( - self.img_shape, self.kern_shape, self.topgrad_shape, self.border_mode + self.img_shape, + self.kern_shape, + self.topgrad_shape, + self.border_mode, + strict=True, ): img = np.random.random(imshp).astype(config.floatX) top = np.random.random(topshp).astype(config.floatX) @@ -2684,7 +2707,11 @@ def test_gradinput(self): top_sym = tensor4("top") for imshp, kshp, topshp, pad in zip( - self.img_shape, self.kern_shape, self.topgrad_shape, self.border_mode + self.img_shape, + self.kern_shape, + self.topgrad_shape, + self.border_mode, + strict=True, ): kern = np.random.random(kshp).astype(config.floatX) top = np.random.random(topshp).astype(config.floatX) diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index f342d5b81c..acc793156f 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -140,7 +140,7 @@ def test_inplace_rewrites(rv_op): assert new_op._props_dict() == (op._props_dict() | {"inplace": True}) assert all( np.array_equal(a.data, b.data) - for a, b in zip(new_op.dist_params(new_node), op.dist_params(node)) + for a, b in zip(new_op.dist_params(new_node), op.dist_params(node), strict=True) ) assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data) assert check_stack_trace(f) diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 3616b2fd24..70e8a710e9 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -271,7 +271,7 @@ def __init__(self, seed=123): g2 = Graph(seed=987) f2 = function([], g2.y) - for su1, su2 in zip(g1.rng.state_updates, g2.rng.state_updates): + for su1, su2 in zip(g1.rng.state_updates, g2.rng.state_updates, strict=True): su2[0].set_value(su1[0].get_value()) np.testing.assert_array_almost_equal(f1(), f2(), decimal=6) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 692598c2c7..0d97d124d3 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -988,10 +988,12 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): else: out = [ self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out") - for g_, od in zip(g, out_dtype) + for g_, od in zip(g, out_dtype, strict=True) ] - assert all(o.dtype == g_.dtype for o, g_ in zip(out, g)) - f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode) + assert all(o.dtype == g_.dtype for o, g_ in zip(out, g, strict=True)) + f = function( + sym_inputs, [], updates=list(zip(out, g, strict=True)), mode=self.mode + ) for x in range(nb_repeat): f(*val_inputs) out = [o.get_value() for o in out] @@ -1001,7 +1003,7 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): if any(o == "float32" for o in out_dtype): atol = 1e-6 - for o, a in zip(out, answer): + for o, a in zip(out, answer, strict=True): np.testing.assert_allclose(o, a * nb_repeat, atol=atol) topo = f.maker.fgraph.toposort() @@ -1021,7 +1023,7 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): ) assert expected_len_sym_inputs == len(sym_inputs) - for od, o in zip(out_dtype, out): + for od, o in zip(out_dtype, out, strict=True): assert od == o.dtype def test_fusion_35_inputs(self): diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index f7ea7cdce4..2fe7b3920b 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -1389,7 +1389,7 @@ def test_none_slice(self): for x_s in self.x_shapes: x_val = self.rng.uniform(size=x_s).astype(config.floatX) - for i_val in zip(*values): + for i_val in zip(*values, strict=True): f(x_val, *i_val) def test_none_index(self): @@ -1447,7 +1447,7 @@ def test_none_index(self): for x_s in self.x_shapes: x_val = self.rng.uniform(size=x_s).astype(config.floatX) - for i_val in zip(*values): + for i_val in zip(*values, strict=True): # The index could be out of bounds # In that case, an Exception should be raised, # otherwise, we let DebugMode check f @@ -1568,7 +1568,7 @@ def test_stack_trace(self): incs = [set_subtensor(x[idx], y) for y in ys] outs = [inc[idx] for inc in incs] - for y, out in zip(ys, outs): + for y, out in zip(ys, outs, strict=True): f = function([x, y, idx], out, self.mode) assert check_stack_trace(f, ops_to_check=(Assert, ps.Cast)) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 49c8e9c38c..71541bcf46 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -420,7 +420,7 @@ def test_make_vector(self, dtype, inputs): # The gradient should be 0 utt.assert_allclose(g_val, 0) else: - for var, grval in zip((b, i, d), g_val): + for var, grval in zip((b, i, d), g_val, strict=True): float_inputs = [] if var.dtype in int_dtypes: pass @@ -777,6 +777,7 @@ def test_alloc_constant_folding(self): # AdvancedIncSubtensor (some_matrix[idx, idx], 1), ], + strict=True, ): derp = pt_sum(dense_dot(subtensor, variables)) @@ -1120,7 +1121,7 @@ def check(m): assert np.allclose(res_matrix, np.vstack(np.nonzero(m))) - for i, j in zip(res_tuple, np.nonzero(m)): + for i, j in zip(res_tuple, np.nonzero(m), strict=True): assert np.allclose(i, j) rand0d = np.empty(()) @@ -2170,7 +2171,7 @@ def test_split_view(self, linker): ) x_test = np.arange(5, dtype=config.floatX) res = f(x_test) - for r, expected in zip(res, ([], [0, 1, 2], [3, 4])): + for r, expected in zip(res, ([], [0, 1, 2], [3, 4]), strict=True): assert np.allclose(r, expected) if linker == "py": assert r.base is x_test @@ -2951,8 +2952,8 @@ def test_mgrid_numpy_equiv(self): mgrid[0:1:0.1, 1:10:1.0, 10:100:10.0], mgrid[0:2:1, 1:10:1, 10:100:10], ) - for n, t in zip(nmgrid, tmgrid): - for ng, tg in zip(n, t): + for n, t in zip(nmgrid, tmgrid, strict=True): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg.eval()) def test_ogrid_numpy_equiv(self): @@ -2966,8 +2967,8 @@ def test_ogrid_numpy_equiv(self): ogrid[0:1:0.1, 1:10:1.0, 10:100:10.0], ogrid[0:2:1, 1:10:1, 10:100:10], ) - for n, t in zip(nogrid, togrid): - for ng, tg in zip(n, t): + for n, t in zip(nogrid, togrid, strict=True): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg.eval()) def test_mgrid_pytensor_variable_numpy_equiv(self): @@ -2979,8 +2980,10 @@ def test_mgrid_pytensor_variable_numpy_equiv(self): timgrid = mgrid[l:2:1, 1:m:1, 10:100:n] ff = pytensor.function([i, j, k], tfmgrid) fi = pytensor.function([l, m, n], timgrid) - for n, t in zip((nfmgrid, nimgrid), (ff(0, 10, 10.0), fi(0, 10, 10))): - for ng, tg in zip(n, t): + for n, t in zip( + (nfmgrid, nimgrid), (ff(0, 10, 10.0), fi(0, 10, 10)), strict=True + ): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg) def test_ogrid_pytensor_variable_numpy_equiv(self): @@ -2992,8 +2995,10 @@ def test_ogrid_pytensor_variable_numpy_equiv(self): tiogrid = ogrid[l:2:1, 1:m:1, 10:100:n] ff = pytensor.function([i, j, k], tfogrid) fi = pytensor.function([l, m, n], tiogrid) - for n, t in zip((nfogrid, niogrid), (ff(0, 10, 10.0), fi(0, 10, 10))): - for ng, tg in zip(n, t): + for n, t in zip( + (nfogrid, niogrid), (ff(0, 10, 10.0), fi(0, 10, 10)), strict=True + ): + for ng, tg in zip(n, t, strict=True): utt.assert_allclose(ng, tg) @@ -3038,7 +3043,7 @@ def test_dim2(self): assert np.all(f_inverse(inv_val) == p_val) # Check that, for each permutation, # permutation(inverse) == inverse(permutation) = identity - for p_row, i_row in zip(p_val, inv_val): + for p_row, i_row in zip(p_val, inv_val, strict=True): assert np.all(p_row[i_row] == np.arange(10)) assert np.all(i_row[p_row] == np.arange(10)) @@ -3104,7 +3109,9 @@ def test_2_2(self): # Each row of p contains a permutation to apply to the corresponding # row of input - out_bis = np.asarray([i_row[p_row] for i_row, p_row in zip(input_val, p_val)]) + out_bis = np.asarray( + [i_row[p_row] for i_row, p_row in zip(input_val, p_val, strict=True)] + ) assert np.all(out_val == out_bis) # Verify gradient @@ -4660,7 +4667,7 @@ def test_where(): np.testing.assert_allclose(np.where(cond, ift, iff), where(cond, ift, iff).eval()) # Test for only condition input - for np_output, pt_output in zip(np.where(cond), where(cond)): + for np_output, pt_output in zip(np.where(cond), where(cond), strict=True): np.testing.assert_allclose(np_output, pt_output.eval()) # Test for error diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 34a1d1bcf9..103b2e8cd5 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -2593,7 +2593,7 @@ def test_ger(self): lambda xs, ys: np.asarray( [ x * y if x.ndim == 0 or y.ndim == 0 else np.dot(x, y) - for x, y in zip(xs, ys) + for x, y in zip(xs, ys, strict=True) ], dtype=ps.upcast(xs.dtype, ys.dtype), ) @@ -2696,7 +2696,7 @@ def check_first_dim(inverted): assert x.strides[0] == direction * np.dtype(config.floatX).itemsize assert not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]) result = f(x, w) - ref_result = np.asarray([np.dot(u, v) for u, v in zip(x, w)]) + ref_result = np.asarray([np.dot(u, v) for u, v in zip(x, w, strict=True)]) utt.assert_allclose(ref_result, result) for inverted in (0, 1): diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 43f9b77f4f..353448abd8 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -251,7 +251,7 @@ def create_batched_inputs(self, batch_idx: int | None = None): vec_inputs = [] vec_inputs_testvals = [] for idx, (batch_shape, param_sig) in enumerate( - zip(batch_shapes, self.params_sig) + zip(batch_shapes, self.params_sig, strict=True) ): if batch_idx is not None and idx != batch_idx: # Skip out combinations in which other inputs are batched diff --git a/tests/tensor/test_casting.py b/tests/tensor/test_casting.py index 4ddfd40ed8..46bdb2c910 100644 --- a/tests/tensor/test_casting.py +++ b/tests/tensor/test_casting.py @@ -72,6 +72,7 @@ def test_illegal(self): _convert_to_float32, _convert_to_float64, ], + strict=True, ), ) def test_basic(self, type1, type2, converter): diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 94e91821fa..659e8b0234 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -337,6 +337,7 @@ def test_fill(self): [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], + strict=True, ): x = t(pytensor.config.floatX, shape=(None, None))("x") y = t(pytensor.config.floatX, shape=(1, 1))("y") @@ -368,6 +369,7 @@ def test_weird_strides(self): [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], + strict=True, ): x = t(pytensor.config.floatX, shape=(None,) * 5)("x") y = t(pytensor.config.floatX, shape=(None,) * 5)("y") @@ -388,6 +390,7 @@ def test_same_inputs(self): [self.op, self.cop], [self.type, self.ctype], [self.rand_val, self.rand_cval], + strict=True, ): x = t(pytensor.config.floatX, shape=(None,) * 2)("x") e = op(ps.add)(x, x) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 3b3cc5ec7f..3cf59a54cd 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -367,6 +367,7 @@ def setup_method(self): [1, None, None], [1, None, 1, 1, None], ], + strict=True, ), ) def test_op(self, shape, var_shape): @@ -390,6 +391,7 @@ def test_op(self, shape, var_shape): [1, None, None], [1, None, 1, 1, None], ], + strict=True, ), ) def test_infer_shape(self, shape, var_shape): @@ -409,6 +411,7 @@ def test_infer_shape(self, shape, var_shape): [True, False, False], [True, False, True, True, False], ], + strict=True, ), ) def test_grad(self, shape, broadcast): @@ -424,6 +427,7 @@ def test_grad(self, shape, broadcast): [1, None, None], [1, None, 1, 1, None], ], + strict=True, ), ) def test_var_interface(self, shape, var_shape): @@ -509,6 +513,7 @@ def setup_method(self): [1, 1, 0, 1, 0], ], [(2, 3), (4, 3), (4, 3), (4, 3), (4, 3), (3, 5)], + strict=True, ), ) def test_op(self, axis, cond, shape): @@ -893,11 +898,13 @@ def test_basic_vector(self, x, inp, axis): np.unique(inp, False, True, True, axis=axis), np.unique(inp, True, True, True, axis=axis), ] - for params, outs_expected in zip(self.op_params, list_outs_expected): + for params, outs_expected in zip( + self.op_params, list_outs_expected, strict=True + ): out = pt.unique(x, *params, axis=axis) f = pytensor.function(inputs=[x], outputs=out) outs = f(inp) - for out, out_exp in zip(outs, outs_expected): + for out, out_exp in zip(outs, outs_expected, strict=True): utt.assert_allclose(out, out_exp) @pytest.mark.parametrize( @@ -1066,7 +1073,7 @@ def shape_tuple(x, use_bcast=True): if use_bcast: return tuple( s if not bcast else 1 - for s, bcast in zip(tuple(x.shape), x.broadcastable) + for s, bcast in zip(tuple(x.shape), x.broadcastable, strict=True) ) else: return tuple(s for s in tuple(x.shape)) @@ -1206,12 +1213,12 @@ def test_broadcast_shape_constants(): def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res): s1s = pt.lscalars(len(s1_vals)) eval_point = {} - for s, s_val in zip(s1s, s1_vals): + for s, s_val in zip(s1s, s1_vals, strict=True): eval_point[s] = s_val s.tag.test_value = s_val s2s = pt.lscalars(len(s2_vals)) - for s, s_val in zip(s2s, s2_vals): + for s, s_val in zip(s2s, s2_vals, strict=True): eval_point[s] = s_val s.tag.test_value = s_val diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 1a13992011..4b83446c5f 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -198,7 +198,7 @@ def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag): np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs] - for np_val, pt_val in zip(np_outputs, pt_outputs): + for np_val, pt_val in zip(np_outputs, pt_outputs, strict=True): assert _allclose(np_val, pt_val) def test_svd_infer_shape(self): diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index d02880f543..c989903f5a 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -1053,7 +1053,7 @@ def test_shape_i_const(self): shapes += [data.get_value(borrow=True)[start:stop:step].shape] f = self.function([], outs, mode=mode_opt, op=subtensor_ops, N=0) t_shapes = f() - for t_shape, shape in zip(t_shapes, shapes): + for t_shape, shape in zip(t_shapes, shapes, strict=True): assert np.all(t_shape == shape) assert Subtensor not in [x.op for x in f.maker.fgraph.toposort()] @@ -1317,7 +1317,9 @@ def test_advanced1_inc_and_set(self): f_outs = f(*all_inputs_num) assert len(f_outs) == len(all_outputs_num) - for params, f_out, output_num in zip(all_params, f_outs, all_outputs_num): + for params, f_out, output_num in zip( + all_params, f_outs, all_outputs_num, strict=True + ): # NB: if this assert fails, it will probably be easier to debug if # you enable the debug code above. assert np.allclose(f_out, output_num), (params, f_out, output_num) @@ -1394,7 +1396,7 @@ def test_adv1_inc_sub_notlastdim_1_2dval_broadcast(self): shape_i = ((4,), (4, 2)) shape_val = ((3, 1), (3, 1, 1)) - for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val): + for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val, strict=True): sub_m = m[:, i] m1 = set_subtensor(sub_m, np.zeros(shp_v)) m2 = inc_subtensor(sub_m, np.ones(shp_v)) @@ -1424,7 +1426,7 @@ def test_adv1_inc_sub_notlastdim_1_2dval_no_broadcast(self): shape_i = ((4,), (4, 2)) shape_val = ((3, 4), (3, 4, 2)) - for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val): + for i, shp_i, shp_v in zip(sym_i, shape_i, shape_val, strict=True): sub_m = m[:, i] m1 = set_subtensor(sub_m, np.zeros(shp_v)) m2 = inc_subtensor(sub_m, np.ones(shp_v)) @@ -1841,7 +1843,7 @@ def test_index_into_vec_w_matrix(self): assert a.type.ndim == self.ix2.type.ndim assert all( s1 == s2 - for s1, s2 in zip(a.type.shape, self.ix2.type.shape) + for s1, s2 in zip(a.type.shape, self.ix2.type.shape, strict=True) if s1 == 1 or s2 == 1 ) @@ -2601,7 +2603,9 @@ def idx_as_tensor(x): def bcast_shape_tuple(x): if not hasattr(x, "shape"): return x - return tuple(s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape)) + return tuple( + s if ss != 1 else 1 for s, ss in zip(tuple(x.shape), x.type.shape, strict=True) + ) test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True])) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 2f97d0e18f..82b5fc014a 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -508,7 +508,9 @@ def test_good(self): if not isinstance(expecteds, list | tuple): expecteds = (expecteds,) - for i, (variable, expected) in enumerate(zip(variables, expecteds)): + for i, (variable, expected) in enumerate( + zip(variables, expecteds, strict=True) + ): condition = ( variable.dtype != expected.dtype or variable.shape != expected.shape diff --git a/tests/test_gradient.py b/tests/test_gradient.py index c45d07662d..79c55caf44 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -68,6 +68,7 @@ def grad_sources_inputs(sources, inputs): wrt=inputs, consider_constant=inputs, ), + strict=True, ) ) @@ -629,7 +630,9 @@ def test_known_grads(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(10), rng.integers(10), rng.standard_normal()] - values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] + values = [ + np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + ] true_grads = grad(cost, inputs, disconnected_inputs="ignore") true_grads = pytensor.function(inputs, true_grads) @@ -637,14 +640,14 @@ def test_known_grads(): for layer in layers: first = grad(cost, layer, disconnected_inputs="ignore") - known = dict(zip(layer, first)) + known = dict(zip(layer, first, strict=True)) full = grad( cost=None, known_grads=known, wrt=inputs, disconnected_inputs="ignore" ) full = pytensor.function(inputs, full) full = full(*values) assert len(true_grads) == len(full) - for a, b, var in zip(true_grads, full, inputs): + for a, b, var in zip(true_grads, full, inputs, strict=True): assert np.allclose(a, b) @@ -742,7 +745,9 @@ def test_subgraph_grad(): inputs = [t, x] rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(2), rng.standard_normal(3)] - values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] + values = [ + np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + ] wrt = [w2, w1] cost = cost2 + cost1 @@ -755,13 +760,13 @@ def test_subgraph_grad(): param_grad, next_grad = subgraph_grad( wrt=params[i], end=grad_ends[i], start=next_grad, cost=costs[i] ) - next_grad = dict(zip(grad_ends[i], next_grad)) + next_grad = dict(zip(grad_ends[i], next_grad, strict=True)) param_grads.extend(param_grad) pgrads = pytensor.function(inputs, param_grads) pgrads = pgrads(*values) - for true_grad, pgrad in zip(true_grads, pgrads): + for true_grad, pgrad in zip(true_grads, pgrads, strict=True): assert np.sum(np.abs(true_grad - pgrad)) < 0.00001 diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index d506d96df6..5ca7de6e63 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -234,14 +234,14 @@ def test_multiple_out_grad(self): np.asarray(rng.uniform(size=(l,)), pytensor.config.floatX) for l in lens ] outs_1 = f(1, *values) - assert all(x.shape[0] == y for x, y in zip(outs_1, lens)) + assert all(x.shape[0] == y for x, y in zip(outs_1, lens, strict=True)) assert np.all(outs_1[0] == 1.0) assert np.all(outs_1[1] == 1.0) assert np.all(outs_1[2] == 0.0) assert np.all(outs_1[3] == 0.0) outs_0 = f(0, *values) - assert all(x.shape[0] == y for x, y in zip(outs_1, lens)) + assert all(x.shape[0] == y for x, y in zip(outs_1, lens, strict=True)) assert np.all(outs_0[0] == 0.0) assert np.all(outs_0[1] == 0.0) assert np.all(outs_0[2] == 1.0) diff --git a/tests/test_printing.py b/tests/test_printing.py index d5b0707442..73403880e9 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -385,7 +385,7 @@ def test_debugprint_inner_graph(): └─ *1- [id F] """ - for exp_line, res_line in zip(exp_res.split("\n"), lines): + for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): assert exp_line.strip() == res_line.strip() # Test nested inner-graph `Op`s @@ -413,7 +413,7 @@ def test_debugprint_inner_graph(): └─ *1- [id E] """ - for exp_line, res_line in zip(exp_res.split("\n"), lines): + for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): assert exp_line.strip() == res_line.strip() diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 4b309c2324..466bdc865d 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -587,7 +587,7 @@ def test_correct_answer(self): z = make_list((x, y)) fc = pytensor.function([a, b], c) fz = pytensor.function([x, y], z) - for m, n in zip(fc(A, B), [A, B]): + for m, n in zip(fc(A, B), [A, B], strict=True): assert (m == n).all() - for m, n in zip(fz(X, Y), [X, Y]): + for m, n in zip(fz(X, Y), [X, Y], strict=True): assert (m == n).all() diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index a556e3a275..476d430048 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -216,7 +216,7 @@ def _compile_and_check( if excluding: mode = mode.excluding(*excluding) if warn: - for var, inp in zip(inputs, numeric_inputs): + for var, inp in zip(inputs, numeric_inputs, strict=True): if isinstance(inp, int | float | list | tuple): inp = var.type.filter(inp) if not hasattr(inp, "shape"): @@ -261,7 +261,7 @@ def _compile_and_check( # Check that the shape produced agrees with the actual shape. numeric_outputs = outputs_function(*numeric_inputs) numeric_shapes = shapes_function(*numeric_inputs) - for out, shape in zip(numeric_outputs, numeric_shapes): + for out, shape in zip(numeric_outputs, numeric_shapes, strict=True): assert np.all(out.shape == shape), (out.shape, shape) From a30e7e845330e4bbf5f8d649a95587dff45ec23e Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Mon, 24 Jun 2024 10:17:38 -0400 Subject: [PATCH 02/12] Enable the ruff rule ensuring explicit strictness for zips --- pyproject.toml | 6 +++++- pytensor/link/c/op.py | 2 +- pytensor/link/pytorch/dispatch/shape.py | 2 +- pytensor/tensor/pad.py | 4 +++- pytensor/tensor/rewriting/ofg.py | 8 ++++++-- 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 81a1285da8..6ad61f4a41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,8 +125,12 @@ line-length = 88 exclude = ["doc/", "pytensor/_version.py"] [tool.ruff.lint] -select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"] +select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] +unfixable = [ + # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead + "B905", +] [tool.ruff.lint.isort] diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index bc446556c0..74905d686f 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -352,7 +352,7 @@ def load_c_code(self, func_files: Iterable[Path]) -> None: "be used at the same time." ) - for func_file, code in zip(func_files, self.func_codes): + for func_file, code in zip(func_files, self.func_codes, strict=True): if self.backward_re.search(code): # This is backward compat code that will go away in a while diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index 7633e28e01..12d938a40f 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -34,7 +34,7 @@ def shape_i(x): def pytorch_funcify_SpecifyShape(op, node, **kwargs): def specifyshape(x, *shape): assert x.ndim == len(shape) - for actual, expected in zip(x.shape, shape): + for actual, expected in zip(x.shape, shape, strict=True): if expected is None: continue if actual != expected: diff --git a/pytensor/tensor/pad.py b/pytensor/tensor/pad.py index 91aef44004..2a3b8b4588 100644 --- a/pytensor/tensor/pad.py +++ b/pytensor/tensor/pad.py @@ -263,7 +263,9 @@ def _linear_ramp_pad( dtype=padded.dtype, axis=axis, ) - for end_value, edge, width in zip(end_value_pair, edge_pair, width_pair) + for end_value, edge, width in zip( + end_value_pair, edge_pair, width_pair, strict=True + ) ) # Reverse the direction of the ramp for the "right" side diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 265f3ff2e8..978cbd03bb 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -18,7 +18,9 @@ def inline_ofg_expansion(fgraph, node): if not op.is_inline: return False - new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + new_out = clone_replace( + op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True)) + ) copy_stack_trace(op.inner_outputs, new_out) return new_out @@ -62,7 +64,9 @@ def late_inline_OpFromGraph(fgraph, node): """ op = node.op - new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + new_out = clone_replace( + op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True)) + ) copy_stack_trace(op.inner_outputs, new_out) return new_out From 9ac53248c4a665e41abd912c825af168d1043460 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sat, 29 Jun 2024 02:30:32 -0400 Subject: [PATCH 03/12] Make non-strict zips strict in tests/scan --- tests/scan/test_printing.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 80fe808f59..451333b207 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -62,9 +62,10 @@ def test_debugprint_sitsot(): Scan{scan_fn, while_loop=False, inplace=none} [id C] ← Mul [id W] (inner_out_sit_sot-0) ├─ *0- [id X] -> [id E] (inner_in_sit_sot-0) - └─ *1- [id Y] -> [id M] (inner_in_non_seqs-0)""" + └─ *1- [id Y] -> [id M] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines, strict=False): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -120,9 +121,10 @@ def test_debugprint_sitsot_no_extra_info(): Scan{scan_fn, while_loop=False, inplace=none} [id C] ← Mul [id W] ├─ *0- [id X] -> [id E] - └─ *1- [id Y] -> [id M]""" + └─ *1- [id Y] -> [id M] + """ - for truth, out in zip(expected_output.split("\n"), lines, strict=False): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -188,9 +190,10 @@ def test_debugprint_nitsot(): ├─ *0- [id Y] -> [id S] (inner_in_seqs-0) └─ Pow [id Z] ├─ *2- [id BA] -> [id W] (inner_in_non_seqs-0) - └─ *1- [id BB] -> [id U] (inner_in_seqs-1)""" + └─ *1- [id BB] -> [id U] (inner_in_seqs-1) + """ - for truth, out in zip(expected_output.split("\n"), lines, strict=False): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -303,9 +306,10 @@ def compute_A_k(A, k): Scan{scan_fn, while_loop=False, inplace=none} [id BE] ← Mul [id CA] (inner_out_sit_sot-0) ├─ *0- [id CB] -> [id BG] (inner_in_sit_sot-0) - └─ *1- [id CC] -> [id BO] (inner_in_non_seqs-0)""" + └─ *1- [id CC] -> [id BO] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines, strict=False): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() fg = FunctionGraph([c, k, A], [final_result]) @@ -402,9 +406,10 @@ def compute_A_k(A, k): → *1- [id CB] -> [id BA] (inner_in_non_seqs-0) ← Mul [id CC] (inner_out_sit_sot-0) ├─ *0- [id CA] (inner_in_sit_sot-0) - └─ *1- [id CB] (inner_in_non_seqs-0)""" + └─ *1- [id CB] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines, strict=False): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -477,9 +482,10 @@ def fn(a_m2, a_m1, b_m2, b_m1): └─ *0- [id BD] -> [id E] (inner_in_mit_sot-0-0) ← Add [id BE] (inner_out_mit_sot-1) ├─ *3- [id BF] -> [id O] (inner_in_mit_sot-1-1) - └─ *2- [id BG] -> [id O] (inner_in_mit_sot-1-0)""" + └─ *2- [id BG] -> [id O] (inner_in_mit_sot-1-0) + """ - for truth, out in zip(expected_output.split("\n"), lines, strict=False): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() @@ -613,9 +619,10 @@ def test_debugprint_mitmot(): Scan{scan_fn, while_loop=False, inplace=none} [id F] ← Mul [id CV] (inner_out_sit_sot-0) ├─ *0- [id CT] -> [id H] (inner_in_sit_sot-0) - └─ *1- [id CW] -> [id P] (inner_in_non_seqs-0)""" + └─ *1- [id CW] -> [id P] (inner_in_non_seqs-0) + """ - for truth, out in zip(expected_output.split("\n"), lines, strict=False): + for truth, out in zip(expected_output.split("\n"), lines, strict=True): assert truth.strip() == out.strip() From 0d33b3721c8b15d2dccaae0c77789575c13e744e Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sat, 29 Jun 2024 02:53:09 -0400 Subject: [PATCH 04/12] Make non-strict zips strict in tensor/elemwise_cgen --- pytensor/tensor/elemwise_cgen.py | 38 +++++++++++++------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/pytensor/tensor/elemwise_cgen.py b/pytensor/tensor/elemwise_cgen.py index e70bb936eb..3593da6ed4 100644 --- a/pytensor/tensor/elemwise_cgen.py +++ b/pytensor/tensor/elemwise_cgen.py @@ -208,7 +208,13 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): """ -def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): +def make_loop( + loop_orders: list[tuple[int | str, ...]], + dtypes: list, + loop_tasks: list, + sub: dict[str, str], + openmp: bool = False, +): """ Make a nested loop over several arrays and associate specific code to each level of nesting. @@ -226,7 +232,7 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): string is code to be executed before the ith loop starts, the second one contains code to be executed just before going to the next element of the ith dimension. - The last element if loop_tasks is a single string, containing code + The last element of loop_tasks is a single string, containing code to be executed at the very end. sub : dictionary Maps 'lv#' to a suitable variable name. @@ -259,7 +265,7 @@ def loop_over(preloop, code, indices, i): }} """ - preloops = {} + preloops: dict[int, str] = {} for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes, strict=True)): for j, index in enumerate(loop_order): if index != "x": @@ -276,16 +282,8 @@ def loop_over(preloop, code, indices, i): s = "" - for i, (pre_task, task), indices in reversed( - list( - zip( - range(len(loop_tasks) - 1), - loop_tasks, - list(zip(*loop_orders, strict=True)), - strict=False, - ) - ) - ): + tasks_indices = zip(loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True) + for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))): s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i) s += loop_tasks[-1] @@ -549,16 +547,10 @@ def loop_over(preloop, code, indices, i): s = preloops.get(0, "") else: s = "" - for i, (pre_task, task), indices in reversed( - list( - zip( - range(len(loop_tasks) - 1), - loop_tasks, - list(zip(*loop_orders, strict=True)), - strict=False, - ) - ) - ): + tasks_indices = zip( + loop_tasks[:-1], zip(*loop_orders, strict=True), strict=True + ) + for i, ((pre_task, task), indices) in reversed(list(enumerate(tasks_indices))): s = loop_over(preloops.get(i, "") + pre_task, s + task, indices, i) s += loop_tasks[-1] From 2cac558ff9c8206c3b47efe1d285a6364434c7f4 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sat, 29 Jun 2024 03:24:48 -0400 Subject: [PATCH 05/12] Make non-strict zip strict in scalar/loop.py --- pytensor/scalar/loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 4c76fa7140..a55c6f3e05 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -93,7 +93,7 @@ def _validate_updates( ) else: update = outputs - for i, u in zip(init, update, strict=False): + for i, u in zip(init[: len(update)], update, strict=True): if i.type != u.type: raise TypeError( "Init and update types must be the same: " From 4ebad14121d73293e61c735631aaa314d10b7672 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sat, 29 Jun 2024 03:29:55 -0400 Subject: [PATCH 06/12] Make non-strict zip strict in printing.py --- pytensor/printing.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytensor/printing.py b/pytensor/printing.py index a974ca21bc..1fc73fb2bc 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -930,7 +930,7 @@ def process(self, output, pstate): ) idx = node.outputs.index(output) pattern, precedences = self.patterns[idx] - precedences += (1000,) * len(node.inputs) + precedences += (1000,) * (len(node.inputs) - len(precedences)) def pp_process(input, new_precedence): with set_precedence(pstate, new_precedence): @@ -938,10 +938,9 @@ def pp_process(input, new_precedence): return r d = { - str(i): x - for i, x in enumerate( - pp_process(input, precedence) - for input, precedence in zip(node.inputs, precedences, strict=False) + str(i): pp_process(input, precedence) + for i, (input, precedence) in enumerate( + zip(node.inputs, precedences, strict=True) ) } r = pattern % d From fed69b928962d3c88096610f2e746516833587f1 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Jul 2024 05:49:56 -0400 Subject: [PATCH 07/12] Make non-strict zip strict in test_abstract_conv --- tests/tensor/conv/test_abstract_conv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tensor/conv/test_abstract_conv.py b/tests/tensor/conv/test_abstract_conv.py index fbd45ffdaa..23ba23e1e9 100644 --- a/tests/tensor/conv/test_abstract_conv.py +++ b/tests/tensor/conv/test_abstract_conv.py @@ -1745,7 +1745,7 @@ def setup_method(self): self.random_stream = np.random.default_rng(utt.fetch_seed()) self.inputs_shapes = [(8, 1, 12, 12), (1, 1, 5, 5), (1, 1, 5, 6), (1, 1, 6, 6)] - self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] + self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] * 2 self.subsamples = [(1, 1), (2, 2)] self.border_modes = ["valid", "full"] @@ -1764,7 +1764,7 @@ def test_conv2d_grad_wrt_inputs(self): # results are the same. for in_shape, fltr_shape in zip( - self.inputs_shapes, self.filters_shapes, strict=False + self.inputs_shapes, self.filters_shapes, strict=True ): for bm in self.border_modes: for ss in self.subsamples: @@ -1828,7 +1828,7 @@ def test_conv2d_grad_wrt_weights(self): # results are the same. for in_shape, fltr_shape in zip( - self.inputs_shapes, self.filters_shapes, strict=False + self.inputs_shapes, self.filters_shapes, strict=True ): for bm in self.border_modes: for ss in self.subsamples: From f674b3056f56d984a56c5248acead17087c4168c Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Jul 2024 06:20:52 -0400 Subject: [PATCH 08/12] Rewrite local_merge_alloc to remove a non-strict zip --- pytensor/tensor/rewriting/basic.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 1365f0b767..34400a91d5 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -1196,25 +1196,23 @@ def local_merge_alloc(fgraph, node): inputs_inner = node.inputs[0].owner.inputs dims_outer = inputs_outer[1:] dims_inner = inputs_inner[1:] - dims_outer_rev = dims_outer[::-1] - dims_inner_rev = dims_inner[::-1] + assert len(dims_inner) <= len(dims_outer) # check if the pattern of broadcasting is matched, in the reversed ordering. # The reverse ordering is needed when an Alloc add an implicit new # broadcasted dimensions to its inputs[0]. Eg: # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - i = 0 - for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev, strict=False): - if dim_inner != dim_outer: - if isinstance(dim_inner, Constant) and dim_inner.data == 1: - pass - else: - dims_outer[-1 - i] = Assert( - "You have a shape error in your graph. To see a better" - " error message and a stack trace of where in your code" - " the error is created, use the PyTensor flags" - " optimizer=None or optimizer=fast_compile." - )(dim_outer, eq(dim_outer, dim_inner)) - i += 1 + for i, dim_inner in enumerate(reversed(dims_inner)): + dim_outer = dims_outer[-1 - i] + if dim_inner == dim_outer: + continue + if isinstance(dim_inner, Constant) and dim_inner.data == 1: + continue + dims_outer[-1 - i] = Assert( + "You have a shape error in your graph. To see a better" + " error message and a stack trace of where in your code" + " the error is created, use the PyTensor flags" + " optimizer=None or optimizer=fast_compile." + )(dim_outer, eq(dim_outer, dim_inner)) return [alloc(inputs_inner[0], *dims_outer)] From 6b3902626f17b9185517e698d0d995bf0629df1d Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Jul 2024 06:40:52 -0400 Subject: [PATCH 09/12] Make non-strict zip strict in tensor/random/utils --- pytensor/tensor/random/utils.py | 2 +- tests/tensor/random/test_op.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index e96fd779a5..1bdb936bdf 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -141,7 +141,7 @@ def explicit_expand_dims( batch_dims = [ param.type.ndim - ndim_param - for param, ndim_param in zip(params, ndim_params, strict=False) + for param, ndim_param in zip(params, ndim_params, strict=True) ] if size_length is not None: diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 8e74b06bd4..edec9a4389 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -74,16 +74,16 @@ def test_RandomVariable_basics(strict_test_value_flags): # `dtype` is respected rv = RandomVariable("normal", signature="(),()->()", dtype="int32") with config.change_flags(compute_test_value="off"): - rv_out = rv() + rv_out = rv(0, 0) assert rv_out.dtype == "int32" - rv_out = rv(dtype="int64") + rv_out = rv(0, 0, dtype="int64") assert rv_out.dtype == "int64" with pytest.raises( ValueError, match="Cannot change the dtype of a normal RV from int32 to float32", ): - assert rv(dtype="float32").dtype == "float32" + assert rv(0, 0, dtype="float32").dtype == "float32" def test_RandomVariable_bcast(strict_test_value_flags): From 0de617fe30e485da0d025c66a6fea1f9cbc4c7ca Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Jul 2024 06:50:32 -0400 Subject: [PATCH 10/12] Make non-strict zip strict in local_subtensor_of_alloc --- pytensor/tensor/rewriting/subtensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 4a0ccbc487..30e90b05a9 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -648,7 +648,7 @@ def local_subtensor_of_alloc(fgraph, node): # Slices to take from val val_slices = [] - for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): + for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)): # If val was not copied over that dim, # we need to take the appropriate subtensor on it. if i >= n_added_dims: From e6069da5aafe2a2bffff4203c152494f364495fc Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sun, 7 Jul 2024 03:41:14 -0400 Subject: [PATCH 11/12] Make non-strict zip strict in tensor/subtensor.py --- pytensor/tensor/subtensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 3a684d2c07..76077fd62a 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -523,7 +523,7 @@ def basic_shape(shape, indices): """ res_shape = () - for idx, n in zip(indices, shape, strict=False): + for n, idx in zip(shape[: len(indices)], indices, strict=True): if isinstance(idx, slice): res_shape += (slice_len(idx, n),) elif isinstance(getattr(idx, "type", None), SliceType): From c58bf33e5f65299d3644b6ebb055912c6cd7eb16 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sun, 7 Jul 2024 04:03:43 -0400 Subject: [PATCH 12/12] Make non-strict zip strict in tensor/shape.py --- pytensor/tensor/shape.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 21a5519466..a4ef904ad1 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -589,11 +589,15 @@ def specify_shape( x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore] # The above is a type error in Python 3.9 but not 3.12. # Thus we need to ignore unused-ignore on 3.12. - new_shape_info = any( - s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None - ) + # If shape does not match x.ndim, we rely on the `Op` to raise a ValueError - if not new_shape_info and len(shape) == x.type.ndim: + if len(shape) != x.type.ndim: + return _specify_shape(x, *shape) + + new_shape_matches = all( + s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None + ) + if new_shape_matches: return x return _specify_shape(x, *shape)