Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pad #748

Merged
merged 38 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
91311d5
Refactor linspace, logspace, and geomspace to match numpy implementation
jessegrabowski May 11, 2024
a98b8ae
Add `pt.pad`
jessegrabowski May 4, 2024
02566b6
Use subclassed `OpFromGraph` to represent `pad` Op
jessegrabowski May 18, 2024
32b14d2
Add test for `flip`
jessegrabowski May 18, 2024
3e65827
Address reviewer feedback
jessegrabowski May 18, 2024
c2b8465
Remove `inplace` argument to `set_subtensor`
jessegrabowski May 18, 2024
8f213a3
Delay setting dtype of `xspace` Ops until after all computation to ma…
jessegrabowski May 18, 2024
be6ed82
Use `shape_padright` instead of `.reshape` tricks
jessegrabowski May 18, 2024
9c76a8f
Add test for `dtype` kwarg on `xspace` Ops
jessegrabowski May 18, 2024
eeb9fa3
Save keyword arguments in `Pad` `OpFromGraph`
jessegrabowski May 18, 2024
c28faaa
Add test for arbitrary padding at higher dimensions
jessegrabowski May 18, 2024
ab99a1e
First draft JAX overload
jessegrabowski May 18, 2024
e93fa56
Expect symbolic `num` argument
jessegrabowski May 18, 2024
2d55a06
Split `wrap_pad` into separate function; eliminate use of `scan`
jessegrabowski Jul 7, 2024
48037c6
<<DO NOT MERGE>> testing notebook
jessegrabowski Jul 8, 2024
b4ffdc5
Add `reflect` and `symmetric` padding
jessegrabowski Jul 12, 2024
6808648
Remove test notebook
jessegrabowski Jul 12, 2024
0f9ca38
Correct reflect and symmetric implementations
jessegrabowski Jul 13, 2024
c1dd0bc
Fix docs, JAX test
jessegrabowski Jul 13, 2024
3b34779
Remove `_broadcast_inputs` helper, update docstrings
jessegrabowski Jul 13, 2024
4cd9702
Remove `OpFromGraph` and associated `JAX` dispatch
jessegrabowski Jul 13, 2024
dbda326
Revert "Remove `OpFromGraph` and associated `JAX` dispatch"
jessegrabowski Jul 13, 2024
32ae3eb
Add issue link to `reflect_type` error message
jessegrabowski Jul 13, 2024
d543ed6
Move `flip` to `tensor/subtensor.py`, add docstring
jessegrabowski Jul 13, 2024
ba6c613
Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor…
jessegrabowski Jul 13, 2024
aa95403
Appease mypy
jessegrabowski Jul 13, 2024
2a11ffa
Appease mypy, add docstring to `pad`
jessegrabowski Jul 13, 2024
89f8cdf
Appease mypy, add docstring to `pad`
jessegrabowski Jul 13, 2024
0f5e2ba
Fix doctests
jessegrabowski Jul 13, 2024
74e39b2
Fix doctests
jessegrabowski Jul 13, 2024
d169928
Fix doctests
jessegrabowski Jul 13, 2024
e429bab
Propagate all optional arguments to JAX
jessegrabowski Jul 13, 2024
8722bfe
Propagate all optional arguments to JAX
jessegrabowski Jul 13, 2024
3f8a377
Appease mypy
jessegrabowski Jul 13, 2024
163b441
Test `NUMBA` backend
jessegrabowski Jul 14, 2024
2c9f727
I love mypy
jessegrabowski Jul 14, 2024
19ed2c0
Skip failing numba test
jessegrabowski Jul 14, 2024
bbeb300
Merge branch 'main' into pad
jessegrabowski Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytensor.link.jax.dispatch.subtensor
import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad
import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.random
Expand Down
20 changes: 20 additions & 0 deletions pytensor/link/jax/dispatch/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import jax.numpy as jnp

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.pad import Pad


fixed_kwargs = {"reflect": ["reflect_type"], "symmetric": ["reflect_type"]}


@jax_funcify.register(Pad)
def jax_funcify_pad(op, **kwargs):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
pad_mode = op.pad_mode
expected_kwargs = fixed_kwargs.get(pad_mode, {})
mode_kwargs = {kwarg: getattr(op, kwarg) for kwarg in expected_kwargs}
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

def pad(x, pad_width, *args):
print(args)
return jnp.pad(x, pad_width, mode=pad_mode, **mode_kwargs)

return pad
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
from pytensor.tensor.extra_ops import *
from pytensor.tensor.io import *
from pytensor.tensor.math import *
from pytensor.tensor.pad import pad
from pytensor.tensor.shape import (
reshape,
shape,
Expand Down
Loading
Loading