Skip to content

Commit

Permalink
Implement wrap padding
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed May 11, 2024
1 parent d3700de commit d193afa
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
35 changes: 33 additions & 2 deletions pytensor/tensor/pad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from collections.abc import Callable
from typing import Literal

from pytensor.scan import scan
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import TensorVariable, as_tensor, zeros
from pytensor.tensor.basic import (
TensorVariable,
as_tensor,
moveaxis,
zeros,
)
from pytensor.tensor.extra_ops import broadcast_to, linspace
from pytensor.tensor.math import divmod as pt_divmod
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import mean, minimum
from pytensor.tensor.math import min as pt_min
Expand All @@ -12,7 +19,7 @@


PadMode = Literal[
"constant", "edge", "linear_ramp", "maximum", "minimum", "mean", "median"
"constant", "edge", "linear_ramp", "maximum", "minimum", "mean", "median", "wrap"
]
stat_funcs = {"maximum": pt_max, "minimum": pt_min, "mean": mean}

Expand Down Expand Up @@ -265,6 +272,28 @@ def _linear_ramp_pad(
return padded


def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))

for axis in range(x.ndim):
size = x.shape[axis]
repeats, (left_remainder, right_remainder) = pt_divmod(pad_width[axis], size)

left_trim = size - left_remainder
right_trim = size - right_remainder
total_repeats = repeats.sum() + 3 # left, right, center

parts, _ = scan(lambda x: x, non_sequences=[x], n_steps=total_repeats)

parts = moveaxis(parts, 0, axis)
new_shape = [-1 if i == axis else x.shape[i] for i in range(x.ndim)]
x = parts.reshape(new_shape)
trim_slice = _slice_at_axis(slice(left_trim, -right_trim), axis)
x = x[trim_slice]

return x


def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
allowed_kwargs = {
"edge": [],
Expand Down Expand Up @@ -300,6 +329,8 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
elif mode == "linear_ramp":
end_values = kwargs.pop("end_values", 0)
return _linear_ramp_pad(x, pad_width, end_values)
elif mode == "wrap":
return _wrap_pad(x, pad_width)


__all__ = ["pad"]
26 changes: 22 additions & 4 deletions tests/tensor/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from pytensor.tensor.pad import PadMode, pad


floatX = pytensor.config.floatX


@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 3, 3)], ids=["1d", "2d square", "3d square"]
)
Expand All @@ -13,7 +16,7 @@
def test_constant_pad(
size: tuple, constant: int | float, pad_width: int | tuple[int, ...]
):
x = np.random.normal(size=size)
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="constant", constant_values=constant)
z = pad(x, pad_width, mode="constant", constant_values=constant)
f = pytensor.function([], z, mode="FAST_COMPILE")
Expand All @@ -28,7 +31,7 @@ def test_constant_pad(
"pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical_1d"]
)
def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
x = np.random.normal(size=size)
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="edge")
z = pad(x, pad_width, mode="edge")
f = pytensor.function([], z, mode="FAST_COMPILE")
Expand All @@ -48,7 +51,7 @@ def test_linear_ramp_pad(
pad_width: int | tuple[int, ...],
end_values: int | float | tuple[int | float, ...],
):
x = np.random.normal(size=size)
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values)
z = pad(x, pad_width, mode="linear_ramp", end_values=end_values)
f = pytensor.function([], z, mode="FAST_COMPILE")
Expand All @@ -70,9 +73,24 @@ def test_stat_pad(
stat: PadMode,
stat_length: int | None,
):
x = np.random.normal(size=size)
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length)
z = pad(x, pad_width, mode=stat, stat_length=stat_length)
f = pytensor.function([], z, mode="FAST_COMPILE")

np.testing.assert_allclose(expected, f())


@pytest.mark.parametrize(
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
)
@pytest.mark.parametrize(
"pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical_1d"]
)
def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]):
x = np.random.normal(size=size).astype(floatX)
expected = np.pad(x, pad_width, mode="wrap")
z = pad(x, pad_width, mode="wrap")
f = pytensor.function([], z, mode="FAST_COMPILE")

np.testing.assert_allclose(expected, f())

0 comments on commit d193afa

Please sign in to comment.