Skip to content

Commit

Permalink
Implement pad (#748)
Browse files Browse the repository at this point in the history
* Add `pt.pad`

* Refactor linspace, logspace, and geomspace to match numpy implementation

* Add `pt.flip`

* Move `flip` to `tensor/subtensor.py`, add docstring

* Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor.tensor`
  • Loading branch information
jessegrabowski authored Jul 19, 2024
1 parent f489cf4 commit 981688c
Show file tree
Hide file tree
Showing 11 changed files with 1,632 additions and 40 deletions.
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad
import pytensor.link.jax.dispatch.math
import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.random
Expand Down
53 changes: 53 additions & 0 deletions pytensor/link/jax/dispatch/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import jax.numpy as jnp
import numpy as np

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.pad import Pad


@jax_funcify.register(Pad)
def jax_funcify_pad(op, **kwargs):
pad_mode = op.pad_mode
reflect_type = op.reflect_type
has_stat_length = op.has_stat_length

if pad_mode == "constant":

def constant_pad(x, pad_width, constant_values):
return jnp.pad(x, pad_width, mode=pad_mode, constant_values=constant_values)

return constant_pad

elif pad_mode == "linear_ramp":

def lr_pad(x, pad_width, end_values):
# JAX does not allow a dynamic input if end_values is non-scalar
if not isinstance(end_values, int | float):
end_values = tuple(np.array(end_values))
return jnp.pad(x, pad_width, mode=pad_mode, end_values=end_values)

return lr_pad

elif pad_mode in ["maximum", "minimum", "mean"] and has_stat_length:

def stat_pad(x, pad_width, stat_length):
# JAX does not allow a dynamic input here, need to cast to tuple
return jnp.pad(
x, pad_width, mode=pad_mode, stat_length=tuple(np.array(stat_length))
)

return stat_pad

elif pad_mode in ["reflect", "symmetric"]:

def loop_pad(x, pad_width):
return jnp.pad(x, pad_width, mode=pad_mode, reflect_type=reflect_type)

return loop_pad

else:

def pad(x, pad_width):
return jnp.pad(x, pad_width, mode=pad_mode)

return pad
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
from pytensor.tensor.extra_ops import *
from pytensor.tensor.io import *
from pytensor.tensor.math import *
from pytensor.tensor.pad import pad
from pytensor.tensor.shape import (
reshape,
shape,
Expand Down
Loading

0 comments on commit 981688c

Please sign in to comment.