diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 56ec438c9f..a977c6d4b2 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -2,6 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( + Cast, ScalarOp, ) @@ -38,3 +39,13 @@ def pytorch_func(*args): ) return pytorch_func + + +@pytorch_funcify.register(Cast) +def pytorch_funcify_Cast(op: Cast, node, **kwargs): + dtype = getattr(torch, op.o_type.dtype) + + def cast(x): + return x.to(dtype=dtype) + + return cast diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index afb62848cc..3b5eee956a 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -11,6 +11,9 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py +torch = pytest.importorskip("torch") + + def test_pytorch_Dimshuffle(): a_pt = matrix("a") @@ -143,3 +146,13 @@ def test_softmax_grad(axis): out = SoftmaxGrad(axis=axis)(dy, sm) fgraph = FunctionGraph([dy, sm], [out]) compare_pytorch_and_py(fgraph, [dy_value, sm_value]) + + +def test_cast(): + x = matrix("x", dtype="float32") + out = pt.cast(x, "int32") + fgraph = FunctionGraph([x], [out]) + _, [res] = compare_pytorch_and_py( + fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] + ) + assert res.dtype == torch.int32