Skip to content

Commit

Permalink
Make non-strict zip strict in tensor/shape.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica committed Jul 23, 2024
1 parent e6069da commit c58bf33
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,15 @@ def specify_shape(
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
)

# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if not new_shape_info and len(shape) == x.type.ndim:
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)

new_shape_matches = all(
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
)
if new_shape_matches:
return x

return _specify_shape(x, *shape)
Expand Down

0 comments on commit c58bf33

Please sign in to comment.