diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 3a684d2c07..76077fd62a 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -523,7 +523,7 @@ def basic_shape(shape, indices): """ res_shape = () - for idx, n in zip(indices, shape, strict=False): + for n, idx in zip(shape[: len(indices)], indices, strict=True): if isinstance(idx, slice): res_shape += (slice_len(idx, n),) elif isinstance(getattr(idx, "type", None), SliceType):