diff --git a/test/test_adjoint.py b/test/test_adjoint.py index d4fc2696..ad5a3828 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -259,3 +259,32 @@ def test_sequential_sum_product_adjoint( ) expected_bwd = expected_bwds[operand] assert (actual_bwd_t - expected_bwd).abs().data.max() < 5e-3 * num_steps + + +@pytest.mark.parametrize( + "test", + [ + None, + xfail_param("same"), + xfail_param("empty"), + xfail_param("other"), + xfail_param("reduce"), + ], +) +def test_identity_adjoint(test): + x = random_tensor(OrderedDict(i=Bint[3])) + + with AdjointTape() as tape: + y = 2 * x + if test == "same": + y = y(i="i") + elif test == "empty": + y = y() + elif test == "other": + y = y(j=0) + elif test == "reduce": + y = funsor.terms.Reduce(ops.add, y, frozenset()) + + # these identity tests return Number(4.0) + actual = tape.adjoint(ops.add, ops.mul, y, (x,))[x] + assert actual is funsor.Number(2.0)