Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pad #748

Merged
merged 38 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
91311d5
Refactor linspace, logspace, and geomspace to match numpy implementation
jessegrabowski May 11, 2024
a98b8ae
Add `pt.pad`
jessegrabowski May 4, 2024
02566b6
Use subclassed `OpFromGraph` to represent `pad` Op
jessegrabowski May 18, 2024
32b14d2
Add test for `flip`
jessegrabowski May 18, 2024
3e65827
Address reviewer feedback
jessegrabowski May 18, 2024
c2b8465
Remove `inplace` argument to `set_subtensor`
jessegrabowski May 18, 2024
8f213a3
Delay setting dtype of `xspace` Ops until after all computation to ma…
jessegrabowski May 18, 2024
be6ed82
Use `shape_padright` instead of `.reshape` tricks
jessegrabowski May 18, 2024
9c76a8f
Add test for `dtype` kwarg on `xspace` Ops
jessegrabowski May 18, 2024
eeb9fa3
Save keyword arguments in `Pad` `OpFromGraph`
jessegrabowski May 18, 2024
c28faaa
Add test for arbitrary padding at higher dimensions
jessegrabowski May 18, 2024
ab99a1e
First draft JAX overload
jessegrabowski May 18, 2024
e93fa56
Expect symbolic `num` argument
jessegrabowski May 18, 2024
2d55a06
Split `wrap_pad` into separate function; eliminate use of `scan`
jessegrabowski Jul 7, 2024
48037c6
<<DO NOT MERGE>> testing notebook
jessegrabowski Jul 8, 2024
b4ffdc5
Add `reflect` and `symmetric` padding
jessegrabowski Jul 12, 2024
6808648
Remove test notebook
jessegrabowski Jul 12, 2024
0f9ca38
Correct reflect and symmetric implementations
jessegrabowski Jul 13, 2024
c1dd0bc
Fix docs, JAX test
jessegrabowski Jul 13, 2024
3b34779
Remove `_broadcast_inputs` helper, update docstrings
jessegrabowski Jul 13, 2024
4cd9702
Remove `OpFromGraph` and associated `JAX` dispatch
jessegrabowski Jul 13, 2024
dbda326
Revert "Remove `OpFromGraph` and associated `JAX` dispatch"
jessegrabowski Jul 13, 2024
32ae3eb
Add issue link to `reflect_type` error message
jessegrabowski Jul 13, 2024
d543ed6
Move `flip` to `tensor/subtensor.py`, add docstring
jessegrabowski Jul 13, 2024
ba6c613
Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor…
jessegrabowski Jul 13, 2024
aa95403
Appease mypy
jessegrabowski Jul 13, 2024
2a11ffa
Appease mypy, add docstring to `pad`
jessegrabowski Jul 13, 2024
89f8cdf
Appease mypy, add docstring to `pad`
jessegrabowski Jul 13, 2024
0f5e2ba
Fix doctests
jessegrabowski Jul 13, 2024
74e39b2
Fix doctests
jessegrabowski Jul 13, 2024
d169928
Fix doctests
jessegrabowski Jul 13, 2024
e429bab
Propagate all optional arguments to JAX
jessegrabowski Jul 13, 2024
8722bfe
Propagate all optional arguments to JAX
jessegrabowski Jul 13, 2024
3f8a377
Appease mypy
jessegrabowski Jul 13, 2024
163b441
Test `NUMBA` backend
jessegrabowski Jul 14, 2024
2c9f727
I love mypy
jessegrabowski Jul 14, 2024
19ed2c0
Skip failing numba test
jessegrabowski Jul 14, 2024
bbeb300
Merge branch 'main' into pad
jessegrabowski Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 13 additions & 44 deletions pytensor/tensor/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.subtensor import flip, set_subtensor
from pytensor.tensor.subtensor import flip, set_subtensor, slice_at_axis


PadMode = Literal[
Expand Down Expand Up @@ -52,37 +52,6 @@
}


def _slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
"""
Construct tuple of slices to slice an array in the given dimension.

Copied from numpy.lib.arraypad._slice_at_axis
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33

Parameters
----------
sl : slice
The slice for the given dimension.
axis : int
The axis to which `sl` is applied. All other dimensions are left
"unsliced".

Returns
-------
sl : tuple of slices
A tuple with slices matching `shape` in length.

Examples
--------

.. code-block:: python

_slice_at_axis(slice(None, 3, -1), 1)
(slice(None, None, None), slice(None, 3, -1), (...,))
"""
return (slice(None),) * axis + (sl,) + (...,) # type: ignore


def _get_edges(
padded: TensorVariable, axis: int, width_pair: tuple[TensorVariable, TensorVariable]
) -> tuple[TensorVariable, TensorVariable]:
Expand Down Expand Up @@ -110,11 +79,11 @@
`axis` which will have a length of 1.
"""
left_index = width_pair[0]
left_slice = _slice_at_axis(slice(left_index, left_index + 1), axis)
left_slice = slice_at_axis(slice(left_index, left_index + 1), axis)
left_edge = padded[left_slice]

right_index = padded.shape[axis] - width_pair[1]
right_slice = _slice_at_axis(slice(right_index - 1, right_index), axis)
right_slice = slice_at_axis(slice(right_index - 1, right_index), axis)
right_edge = padded[right_slice]

return left_edge, right_edge
Expand All @@ -139,8 +108,8 @@
width_pair: tuple[TensorVariable, TensorVariable],
axis: int,
) -> tuple[tuple[slice, ...], tuple[slice, ...]]:
left_slice = _slice_at_axis(slice(None, width_pair[0]), axis)
right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis)
left_slice = slice_at_axis(slice(None, width_pair[0]), axis)
right_slice = slice_at_axis(slice(dim_shape - width_pair[1], None), axis)

return left_slice, right_slice

Expand Down Expand Up @@ -224,20 +193,20 @@
left_length = (
minimum(left_length, max_length) if left_length is not None else max_length
)
left_slice = _slice_at_axis(slice(left_index, left_index + left_length), axis)
left_slice = slice_at_axis(slice(left_index, left_index + left_length), axis)
left_chunk = padded[left_slice]
left_stat = stat_func(left_chunk, axis=axis, keepdims=True)
if left_length is None and right_length is None:
# We could also return early in the more general case of left_length == right_length, but we don't necessarily
# know these shapes.
# TODO: Add rewrite to simplify in this case
return left_stat, left_stat

Check warning on line 203 in pytensor/tensor/pad.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/pad.py#L203

Added line #L203 was not covered by tests

# Calculate statistic for the right side
right_length = (
minimum(right_length, max_length) if right_length is not None else max_length
)
right_slice = _slice_at_axis(slice(right_index - right_length, right_index), axis)
right_slice = slice_at_axis(slice(right_index - right_length, right_index), axis)
right_chunk = padded[right_slice]
right_stat = stat_func(right_chunk, axis=axis, keepdims=True)

Expand Down Expand Up @@ -298,7 +267,7 @@
)

# Reverse the direction of the ramp for the "right" side
right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)] # type: ignore
right_ramp = right_ramp[slice_at_axis(slice(None, None, -1), axis)] # type: ignore

padded = set_subtensor(padded[left_slice], left_ramp)
padded = set_subtensor(padded[right_slice], right_ramp)
Expand Down Expand Up @@ -336,7 +305,7 @@
x = parts.reshape(new_shape)

# Trim the excess on the active dimension
trim_slice = _slice_at_axis(slice(left_trim, -right_trim), axis)
trim_slice = slice_at_axis(slice(left_trim, -right_trim), axis)
x = x[trim_slice]

return x
Expand Down Expand Up @@ -393,7 +362,7 @@
)
right_trim = x.shape[axis] - right_trim

trim_slice = _slice_at_axis(slice(left_trim, right_trim), axis)
trim_slice = slice_at_axis(slice(left_trim, right_trim), axis)
x = x[trim_slice]

return x
Expand All @@ -407,7 +376,7 @@
for axis in range(x.ndim):
trimmed_size = x.shape[axis] - 1

trim_slice = _slice_at_axis(slice(None, -1), axis)
trim_slice = slice_at_axis(slice(None, -1), axis)
x_trimmed = x[trim_slice]
x_flipped = flip(x, axis=axis)[trim_slice]

Expand All @@ -429,8 +398,8 @@
inner_func=partial(_reflect_inner, padding_left=False),
)

left_trim = _slice_at_axis(slice(trimmed_size - remainders[0] - 1, -1), axis)
right_trim = _slice_at_axis(
left_trim = slice_at_axis(slice(trimmed_size - remainders[0] - 1, -1), axis)
right_trim = slice_at_axis(
slice(1, right_padding.shape[axis] - trimmed_size + remainders[1] + 1), axis
)

Expand All @@ -457,7 +426,7 @@

def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
if any(value not in allowed_kwargs[mode] for value in kwargs.keys()):
raise ValueError(

Check warning on line 429 in pytensor/tensor/pad.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/pad.py#L429

Added line #L429 was not covered by tests
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
)
x = as_tensor(x, name="x")
Expand All @@ -479,7 +448,7 @@
if mode == "median":
# TODO: Revisit this after we implement a quantile function.
# See https://github.com/pymc-devs/pytensor/issues/53
raise NotImplementedError("Median padding not implemented")

Check warning on line 451 in pytensor/tensor/pad.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/pad.py#L451

Added line #L451 was not covered by tests
stat_func = cast(Callable, stat_funcs[mode])
stat_length = kwargs.get("stat_length")
if stat_length is not None:
Expand Down
53 changes: 52 additions & 1 deletion pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,8 +3021,57 @@ def _get_vector_length_Subtensor(op, var):
raise ValueError(f"Length of {var} cannot be determined")


def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
"""
Construct tuple of slices to slice an array in the given dimension.

Copied from numpy.lib.arraypad._slice_at_axis
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33

Parameters
----------
sl : slice
The slice for the given dimension.
axis : int
The axis to which `sl` is applied. All other dimensions are left
"unsliced".

Returns
-------
sl : tuple of slices
A tuple with slices matching `shape` in length.

Examples
--------

.. code-block:: python

import pytensor.tensor as pt

s = pt.slice_at_axis(slice(None, 1), 1)
s
# Output: (slice(None, None, None), slice(None, 3, -1), (...,))

x = pt.tensor('x', shape=(None, None, None))
x_sliced = x[s]

f = pytensor.function([x], x_sliced)
x = np.arange(27).reshape(3, 3, 3)
f(x)
# Output: array([[[ 0., 1., 2.]],
# [[ 9., 10., 11.]],
# [[18., 19., 20.]]])
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
"""
if axis >= 0:
return (slice(None),) * axis + (sl,) + (...,) # type: ignore
else:
# If axis = -1 we want zero right padding (and so on), so subtract one
axis = abs(axis) - 1
return (...,) + (sl,) + (slice(None),) * axis # type: ignore


def flip(
arr: TensorVariable, axis: int | tuple[int] | TensorVariable = None
arr: TensorVariable, axis: int | tuple[int] | TensorVariable | None = None
) -> TensorVariable:
"""
Reverse the order of elements in an tensor along the given axis.
Expand Down Expand Up @@ -3066,12 +3115,14 @@ def flip(
slice(None, None, -1) if i in axis else slice(None, None, None)
for i in range(arr.ndim)
]

return arr[index]


__all__ = [
"take",
"flip",
"slice_at_axis",
"inc_subtensor",
"set_subtensor",
]
11 changes: 11 additions & 0 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
index_vars_to_types,
indexed_result_shape,
set_subtensor,
slice_at_axis,
take,
)
from pytensor.tensor.type import (
Expand Down Expand Up @@ -2903,6 +2904,16 @@ def test_vectorize_adv_subtensor(
)


def test_slice_at_axis():
x = ptb.tensor("x", shape=(3, 4, 5))
x_sliced = x[slice_at_axis(slice(None, 1), axis=0)]
assert x_sliced.type.shape == (1, 4, 5)

# Negative axis
x_sliced = x[slice_at_axis(slice(None, 1), axis=-2)]
assert x_sliced.type.shape == (3, 1, 5)


@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
Expand Down
Loading