-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a0e937
commit 11d5b88
Showing
3 changed files
with
72 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import jax.numpy as jnp | ||
|
||
from pytensor.graph import FunctionGraph | ||
from pytensor.link.jax.dispatch import jax_funcify | ||
from pytensor.tensor.blockwise import Blockwise | ||
|
||
|
||
@jax_funcify.register(Blockwise) | ||
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): | ||
signature = op.signature | ||
core_node = op._create_dummy_core_node(node.inputs) | ||
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) | ||
tuple_core_fn = jax_funcify(core_fgraph) | ||
|
||
if len(node.outputs) == 1: | ||
|
||
def core_fn(*inputs): | ||
return tuple_core_fn(*inputs)[0] | ||
|
||
else: | ||
core_fn = tuple_core_fn | ||
|
||
vect_fn = jnp.vectorize(core_fn, signature=signature) | ||
|
||
def blockwise_fn(*inputs): | ||
op._check_runtime_broadcast(node, inputs) | ||
return vect_fn(*inputs) | ||
|
||
return blockwise_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pytensor import config | ||
from pytensor.graph import FunctionGraph | ||
from pytensor.tensor import tensor | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.math import Dot, matmul | ||
from tests.link.jax.test_basic import compare_jax_and_py | ||
from tests.tensor.test_blockwise import check_blockwise_runtime_broadcasting | ||
|
||
|
||
jax = pytest.importorskip("jax") | ||
|
||
|
||
def test_runtime_broadcasting(): | ||
check_blockwise_runtime_broadcasting("JAX") | ||
|
||
|
||
# Equivalent blockwise to matmul but with dumb signature | ||
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") | ||
|
||
|
||
@pytest.mark.parametrize("matmul_op", (matmul, odd_matmul)) | ||
def test_matmul(matmul_op): | ||
rng = np.random.default_rng(14) | ||
a = tensor("a", shape=(2, 3, 5)) | ||
b = tensor("b", shape=(2, 5, 3)) | ||
test_values = [ | ||
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b) | ||
] | ||
|
||
out = matmul_op(a, b) | ||
assert isinstance(out.owner.op, Blockwise) | ||
fg = FunctionGraph([a, b], [out]) | ||
fn, _ = compare_jax_and_py(fg, test_values) | ||
|
||
# Check we are not adding any unnecessary stuff | ||
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) | ||
jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul") | ||
expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values)) | ||
assert jaxpr == expected_jaxpr |