Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pad #748

Merged
merged 38 commits into from
Jul 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
91311d5
Refactor linspace, logspace, and geomspace to match numpy implementation
jessegrabowski May 11, 2024
a98b8ae
Add `pt.pad`
jessegrabowski May 4, 2024
02566b6
Use subclassed `OpFromGraph` to represent `pad` Op
jessegrabowski May 18, 2024
32b14d2
Add test for `flip`
jessegrabowski May 18, 2024
3e65827
Address reviewer feedback
jessegrabowski May 18, 2024
c2b8465
Remove `inplace` argument to `set_subtensor`
jessegrabowski May 18, 2024
8f213a3
Delay setting dtype of `xspace` Ops until after all computation to ma…
jessegrabowski May 18, 2024
be6ed82
Use `shape_padright` instead of `.reshape` tricks
jessegrabowski May 18, 2024
9c76a8f
Add test for `dtype` kwarg on `xspace` Ops
jessegrabowski May 18, 2024
eeb9fa3
Save keyword arguments in `Pad` `OpFromGraph`
jessegrabowski May 18, 2024
c28faaa
Add test for arbitrary padding at higher dimensions
jessegrabowski May 18, 2024
ab99a1e
First draft JAX overload
jessegrabowski May 18, 2024
e93fa56
Expect symbolic `num` argument
jessegrabowski May 18, 2024
2d55a06
Split `wrap_pad` into separate function; eliminate use of `scan`
jessegrabowski Jul 7, 2024
48037c6
<<DO NOT MERGE>> testing notebook
jessegrabowski Jul 8, 2024
b4ffdc5
Add `reflect` and `symmetric` padding
jessegrabowski Jul 12, 2024
6808648
Remove test notebook
jessegrabowski Jul 12, 2024
0f9ca38
Correct reflect and symmetric implementations
jessegrabowski Jul 13, 2024
c1dd0bc
Fix docs, JAX test
jessegrabowski Jul 13, 2024
3b34779
Remove `_broadcast_inputs` helper, update docstrings
jessegrabowski Jul 13, 2024
4cd9702
Remove `OpFromGraph` and associated `JAX` dispatch
jessegrabowski Jul 13, 2024
dbda326
Revert "Remove `OpFromGraph` and associated `JAX` dispatch"
jessegrabowski Jul 13, 2024
32ae3eb
Add issue link to `reflect_type` error message
jessegrabowski Jul 13, 2024
d543ed6
Move `flip` to `tensor/subtensor.py`, add docstring
jessegrabowski Jul 13, 2024
ba6c613
Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor…
jessegrabowski Jul 13, 2024
aa95403
Appease mypy
jessegrabowski Jul 13, 2024
2a11ffa
Appease mypy, add docstring to `pad`
jessegrabowski Jul 13, 2024
89f8cdf
Appease mypy, add docstring to `pad`
jessegrabowski Jul 13, 2024
0f5e2ba
Fix doctests
jessegrabowski Jul 13, 2024
74e39b2
Fix doctests
jessegrabowski Jul 13, 2024
d169928
Fix doctests
jessegrabowski Jul 13, 2024
e429bab
Propagate all optional arguments to JAX
jessegrabowski Jul 13, 2024
8722bfe
Propagate all optional arguments to JAX
jessegrabowski Jul 13, 2024
3f8a377
Appease mypy
jessegrabowski Jul 13, 2024
163b441
Test `NUMBA` backend
jessegrabowski Jul 14, 2024
2c9f727
I love mypy
jessegrabowski Jul 14, 2024
19ed2c0
Skip failing numba test
jessegrabowski Jul 14, 2024
bbeb300
Merge branch 'main' into pad
jessegrabowski Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/tensor/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
# We could also return early in the more general case of left_length == right_length, but we don't necessarily
# know these shapes.
# TODO: Add rewrite to simplify in this case
return left_stat, left_stat

Check warning on line 203 in pytensor/tensor/pad.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/pad.py#L203

Added line #L203 was not covered by tests

# Calculate statistic for the right side
right_length = (
Expand Down Expand Up @@ -553,7 +553,7 @@

print(pt.pad(a, (2, 3), 'linear_ramp', end_values=(5, -4)).eval())

..testoutput::
.. testoutput::
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

[ 5. 3. 1. 2. 3. 4. 5. 2. -1. -4.]

Expand Down Expand Up @@ -615,7 +615,7 @@

"""
if any(value not in allowed_kwargs[mode] for value in kwargs.keys()):
raise ValueError(

Check warning on line 618 in pytensor/tensor/pad.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/pad.py#L618

Added line #L618 was not covered by tests
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
)
x = as_tensor(x, name="x")
Expand All @@ -637,7 +637,7 @@
if mode == "median":
# TODO: Revisit this after we implement a quantile function.
# See https://github.com/pymc-devs/pytensor/issues/53
raise NotImplementedError("Median padding not implemented")

Check warning on line 640 in pytensor/tensor/pad.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/pad.py#L640

Added line #L640 was not covered by tests
stat_func = cast(Callable, stat_funcs[mode])
stat_length = kwargs.get("stat_length")
if stat_length is not None:
Expand Down
Loading