From f1f6fa3f288895dfa2e3a8f1e2d8bd455e7920c3 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 24 Feb 2021 17:42:51 -0500 Subject: [PATCH 01/16] Funsor as a function --- funsor/jvp.py | 58 ++++++++++++++++++++ test/test_jvp.py | 137 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 funsor/jvp.py create mode 100644 test/test_jvp.py diff --git a/funsor/jvp.py b/funsor/jvp.py new file mode 100644 index 000000000..7c616b39e --- /dev/null +++ b/funsor/jvp.py @@ -0,0 +1,58 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from functools import singledispatch + +import funsor.ops as ops +from funsor.interpreter import interpretation +from funsor.domains import Bint, Real, Reals +from funsor.ops import AssociativeOp, LogOp, Op +from funsor.terms import ( + Binary, + Funsor, + Number, + Reduce, + Tuple, + Unary, + Variable, + eager, + lazy, +) + + +@eager.register(Binary, AssociativeOp, Tuple, Tuple) +def jvp_binary(op, lhs, rhs): + lhs_primal, lhs_tangent = lhs + rhs_primal, rhs_tangent = rhs + primal = Binary(op, lhs_primal, rhs_primal) + with interpretation(lazy): + if op is ops.add: + tangent = lhs_tangent + rhs_tangent + elif op is ops.mul: + tangent = rhs_primal * lhs_tangent + lhs_primal * rhs_tangent + else: + raise NotImplementedError + return Tuple(primal, tangent) + + +@eager.register(Reduce, AssociativeOp, Tuple, frozenset) +def jvp_reduce(op, arg, reduced_vars): + arg_primal, arg_tangent = arg + primal = Reduce(op, arg_primal, reduced_vars) + with interpretation(lazy): + if op is ops.add: + tangent = Reduce(op, arg_tangent, reduced_vars) + elif op is ops.mul: + tangent = Reduce(ops.add, arg_tangent * primal / arg_primal, reduced_vars) + else: + raise NotImplementedError + return Tuple(primal, tangent) + + +@eager.register(Unary, LogOp, Tuple) +def jvp_log(op, arg): + arg_primal, arg_tangent = arg + primal = Unary(op, arg_primal) + tangent = Binary(ops.truediv, arg_tangent, arg_primal) + return Tuple(primal, tangent) diff --git a/test/test_jvp.py b/test/test_jvp.py new file mode 100644 index 000000000..a619cf8f9 --- /dev/null +++ b/test/test_jvp.py @@ -0,0 +1,137 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict + +import pytest + +import funsor.jvp +import funsor.ops as ops +from funsor.domains import Bint, Real, Reals +from funsor.interpreter import interpretation +from funsor.optimizer import apply_optimizer +from funsor.tensor import Tensor +from funsor.terms import Number, Tuple, Variable, lazy, reflect +from funsor.testing import assert_close, random_tensor + +try: + import torch +except ImportError: + pytest.skip() + + +def test_identity(): + x = random_tensor(OrderedDict(i=Bint[2])) + dx = random_tensor(OrderedDict(i=Bint[2])) + x_ = Tuple((x, dx)) + f, df = x_ + assert_close(f, x) + assert_close(df, dx) + + +def test_log(): + x = random_tensor(OrderedDict(i=Bint[2])) + dx = random_tensor(OrderedDict(i=Bint[2])) + x_ = Tuple((x, dx)) + f, df = x_.log() + assert_close(f, x.log()) + assert_close(df, dx / x) + + +def test_add(): + x = random_tensor(OrderedDict(i=Bint[2])) + y = random_tensor(OrderedDict(j=Bint[3])) + dx = random_tensor(OrderedDict(i=Bint[2])) + dy = random_tensor(OrderedDict(j=Bint[3])) + # dx = Number(1.0) # Variable("dx", Real) + # dy = Number(1.0) # Variable("dy", Real) + x_ = Tuple((x, dx)) + y_ = Tuple((y, dy)) + f_ = x_ + y_ + f, df = f_ + assert_close(f, x + y) + assert_close(df, dx + dy) + + +def test_linearize_add(): + x = random_tensor(OrderedDict(i=Bint[2])) + y = random_tensor(OrderedDict(j=Bint[3])) + dx = Variable("dx", Real) + dy = random_tensor(OrderedDict(j=Bint[3])) + # dy = Number(1.0) # Variable("dy", Real) + x_ = Tuple((x, dx)) + y_ = Tuple((y, dy)) + f_ = x_ + y_ + f, df = f_ + breakpoint() + assert_close(f, x + y) + assert_close(df, dx + dy) + + +def test_mul(): + x = random_tensor(OrderedDict(i=Bint[2])) + y = random_tensor(OrderedDict(j=Bint[3])) + dx = random_tensor(OrderedDict(i=Bint[2])) + dy = random_tensor(OrderedDict(j=Bint[3])) + # dx = Number(1.0) # Variable("dx", Real) + # dy = Number(0.0) # Variable("dy", Real) + x_ = Tuple((x, dx)) + y_ = Tuple((y, dy)) + f, df = x_ * y_ + assert_close(f, x * y) + assert_close(df, (x * dy + y * dx).align(tuple(df.inputs.keys()))) + + +def test_reduce_sum(): + x = random_tensor(OrderedDict(j=Bint[4])) + dx = random_tensor(OrderedDict(j=Bint[4])) + x_ = Tuple((x, dx)) + f, df = x_.reduce(ops.add, "j") + assert_close(f, x.reduce(ops.add, "j")) + assert_close(apply_optimizer(df), dx.reduce(ops.add, "j")) + + +def test_linearize_reduce_sum(): + x = random_tensor(OrderedDict(j=Bint[4])) + dx = random_tensor(OrderedDict(j=Bint[4])) + x_ = Tuple((x, dx)) + f, df = x_.reduce(ops.add, "j") + breakpoint() + assert_close(f, x.reduce(ops.add, "j")) + assert_close(df, dx.reduce(ops.add, "j")) + + +def test_reduce_prod(): + x = random_tensor(OrderedDict(j=Bint[4])) + dx = random_tensor(OrderedDict(j=Bint[4])) + x_ = Tuple((x, dx)) + f, df = x_.reduce(ops.mul, "j") + assert_close(f, x.reduce(ops.mul, "j")) + assert_close(df, (f * dx / x).reduce(ops.add, "j")) + + +def test_matmul_tensor(): + x = random_tensor(OrderedDict(j=Bint[4])) + y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + dx = random_tensor(OrderedDict(j=Bint[4])) + dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + x_ = Tuple((x, dx)) + y_ = Tuple((y, dy)) + xy_ = x_ * y_ + z, dz = xy_.reduce(ops.add, "j") + assert_close(z, (x * y).reduce(ops.add, "j")) + assert_close(dz, (y * dx + x * dy).reduce(ops.add, "j")) + + +def test_compose(): + x = random_tensor(OrderedDict(j=Bint[4])) + y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + dx = random_tensor(OrderedDict(j=Bint[4])) + dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + x_ = Tuple((x, dx)) + logx_ = x_.log() + y_ = Tuple((y, dy)) + logxy_ = logx_ * y_ + z, dz = logxy_.reduce(ops.add, "j") + assert_close(z, (x.log() * y).reduce(ops.add, "j")) + assert_close(dz, (y * dx / x + x.log() * dy).reduce(ops.add, "j")) From 6bc49a2d7f1c4be515d8ec9c90c9b9875c321193 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 24 Feb 2021 18:35:05 -0500 Subject: [PATCH 02/16] Tangent space --- funsor/jvp.py | 65 +++++++++++++++++++----------------------------- test/test_jvp.py | 64 +++++++++++------------------------------------ 2 files changed, 40 insertions(+), 89 deletions(-) diff --git a/funsor/jvp.py b/funsor/jvp.py index 7c616b39e..81f1d9b71 100644 --- a/funsor/jvp.py +++ b/funsor/jvp.py @@ -1,58 +1,45 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict -from functools import singledispatch - import funsor.ops as ops -from funsor.interpreter import interpretation -from funsor.domains import Bint, Real, Reals -from funsor.ops import AssociativeOp, LogOp, Op -from funsor.terms import ( - Binary, - Funsor, - Number, - Reduce, - Tuple, - Unary, - Variable, - eager, - lazy, -) - - -@eager.register(Binary, AssociativeOp, Tuple, Tuple) +from funsor.ops import AssociativeOp, LogOp +from funsor.terms import Binary, Reduce, Tuple, Unary, eager + + +class Tangent(Tuple): + pass + + +@eager.register(Binary, AssociativeOp, Tangent, Tangent) def jvp_binary(op, lhs, rhs): lhs_primal, lhs_tangent = lhs rhs_primal, rhs_tangent = rhs primal = Binary(op, lhs_primal, rhs_primal) - with interpretation(lazy): - if op is ops.add: - tangent = lhs_tangent + rhs_tangent - elif op is ops.mul: - tangent = rhs_primal * lhs_tangent + lhs_primal * rhs_tangent - else: - raise NotImplementedError - return Tuple(primal, tangent) + if op is ops.add: + tangent = lhs_tangent + rhs_tangent + elif op is ops.mul: + tangent = rhs_primal * lhs_tangent + lhs_primal * rhs_tangent + else: + raise NotImplementedError + return Tangent(primal, tangent) -@eager.register(Reduce, AssociativeOp, Tuple, frozenset) +@eager.register(Reduce, AssociativeOp, Tangent, frozenset) def jvp_reduce(op, arg, reduced_vars): arg_primal, arg_tangent = arg primal = Reduce(op, arg_primal, reduced_vars) - with interpretation(lazy): - if op is ops.add: - tangent = Reduce(op, arg_tangent, reduced_vars) - elif op is ops.mul: - tangent = Reduce(ops.add, arg_tangent * primal / arg_primal, reduced_vars) - else: - raise NotImplementedError - return Tuple(primal, tangent) + if op is ops.add: + tangent = Reduce(op, arg_tangent, reduced_vars) + elif op is ops.mul: + tangent = Reduce(ops.add, arg_tangent * primal / arg_primal, reduced_vars) + else: + raise NotImplementedError + return Tangent(primal, tangent) -@eager.register(Unary, LogOp, Tuple) +@eager.register(Unary, LogOp, Tangent) def jvp_log(op, arg): arg_primal, arg_tangent = arg primal = Unary(op, arg_primal) tangent = Binary(ops.truediv, arg_tangent, arg_primal) - return Tuple(primal, tangent) + return Tangent(primal, tangent) diff --git a/test/test_jvp.py b/test/test_jvp.py index a619cf8f9..6ed1e8130 100644 --- a/test/test_jvp.py +++ b/test/test_jvp.py @@ -3,27 +3,16 @@ from collections import OrderedDict -import pytest - -import funsor.jvp import funsor.ops as ops -from funsor.domains import Bint, Real, Reals -from funsor.interpreter import interpretation -from funsor.optimizer import apply_optimizer -from funsor.tensor import Tensor -from funsor.terms import Number, Tuple, Variable, lazy, reflect +from funsor.domains import Bint +from funsor.jvp import Tangent from funsor.testing import assert_close, random_tensor -try: - import torch -except ImportError: - pytest.skip() - def test_identity(): x = random_tensor(OrderedDict(i=Bint[2])) dx = random_tensor(OrderedDict(i=Bint[2])) - x_ = Tuple((x, dx)) + x_ = Tangent((x, dx)) f, df = x_ assert_close(f, x) assert_close(df, dx) @@ -32,7 +21,7 @@ def test_identity(): def test_log(): x = random_tensor(OrderedDict(i=Bint[2])) dx = random_tensor(OrderedDict(i=Bint[2])) - x_ = Tuple((x, dx)) + x_ = Tangent((x, dx)) f, df = x_.log() assert_close(f, x.log()) assert_close(df, dx / x) @@ -45,25 +34,10 @@ def test_add(): dy = random_tensor(OrderedDict(j=Bint[3])) # dx = Number(1.0) # Variable("dx", Real) # dy = Number(1.0) # Variable("dy", Real) - x_ = Tuple((x, dx)) - y_ = Tuple((y, dy)) - f_ = x_ + y_ - f, df = f_ - assert_close(f, x + y) - assert_close(df, dx + dy) - - -def test_linearize_add(): - x = random_tensor(OrderedDict(i=Bint[2])) - y = random_tensor(OrderedDict(j=Bint[3])) - dx = Variable("dx", Real) - dy = random_tensor(OrderedDict(j=Bint[3])) - # dy = Number(1.0) # Variable("dy", Real) - x_ = Tuple((x, dx)) - y_ = Tuple((y, dy)) + x_ = Tangent((x, dx)) + y_ = Tangent((y, dy)) f_ = x_ + y_ f, df = f_ - breakpoint() assert_close(f, x + y) assert_close(df, dx + dy) @@ -75,8 +49,8 @@ def test_mul(): dy = random_tensor(OrderedDict(j=Bint[3])) # dx = Number(1.0) # Variable("dx", Real) # dy = Number(0.0) # Variable("dy", Real) - x_ = Tuple((x, dx)) - y_ = Tuple((y, dy)) + x_ = Tangent((x, dx)) + y_ = Tangent((y, dy)) f, df = x_ * y_ assert_close(f, x * y) assert_close(df, (x * dy + y * dx).align(tuple(df.inputs.keys()))) @@ -85,18 +59,8 @@ def test_mul(): def test_reduce_sum(): x = random_tensor(OrderedDict(j=Bint[4])) dx = random_tensor(OrderedDict(j=Bint[4])) - x_ = Tuple((x, dx)) - f, df = x_.reduce(ops.add, "j") - assert_close(f, x.reduce(ops.add, "j")) - assert_close(apply_optimizer(df), dx.reduce(ops.add, "j")) - - -def test_linearize_reduce_sum(): - x = random_tensor(OrderedDict(j=Bint[4])) - dx = random_tensor(OrderedDict(j=Bint[4])) - x_ = Tuple((x, dx)) + x_ = Tangent((x, dx)) f, df = x_.reduce(ops.add, "j") - breakpoint() assert_close(f, x.reduce(ops.add, "j")) assert_close(df, dx.reduce(ops.add, "j")) @@ -104,7 +68,7 @@ def test_linearize_reduce_sum(): def test_reduce_prod(): x = random_tensor(OrderedDict(j=Bint[4])) dx = random_tensor(OrderedDict(j=Bint[4])) - x_ = Tuple((x, dx)) + x_ = Tangent((x, dx)) f, df = x_.reduce(ops.mul, "j") assert_close(f, x.reduce(ops.mul, "j")) assert_close(df, (f * dx / x).reduce(ops.add, "j")) @@ -115,8 +79,8 @@ def test_matmul_tensor(): y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - x_ = Tuple((x, dx)) - y_ = Tuple((y, dy)) + x_ = Tangent((x, dx)) + y_ = Tangent((y, dy)) xy_ = x_ * y_ z, dz = xy_.reduce(ops.add, "j") assert_close(z, (x * y).reduce(ops.add, "j")) @@ -128,9 +92,9 @@ def test_compose(): y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - x_ = Tuple((x, dx)) + x_ = Tangent((x, dx)) logx_ = x_.log() - y_ = Tuple((y, dy)) + y_ = Tangent((y, dy)) logxy_ = logx_ * y_ z, dz = logxy_.reduce(ops.add, "j") assert_close(z, (x.log() * y).reduce(ops.add, "j")) From fb2a7a017c598e3e071c9c15c830b502013e1d71 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 26 Feb 2021 18:43:42 -0500 Subject: [PATCH 03/16] JVP Functor --- funsor/jvp.py | 44 ++++++++++- test/test_jvp.py | 194 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 197 insertions(+), 41 deletions(-) diff --git a/funsor/jvp.py b/funsor/jvp.py index 81f1d9b71..4f39ce799 100644 --- a/funsor/jvp.py +++ b/funsor/jvp.py @@ -3,13 +3,53 @@ import funsor.ops as ops from funsor.ops import AssociativeOp, LogOp -from funsor.terms import Binary, Reduce, Tuple, Unary, eager +from funsor.terms import Binary, Reduce, Tuple, Unary, eager, lazy, Variable, Number +from funsor.interpreter import interpretation +from funsor.domains import Bint, Real +from collections import defaultdict class Tangent(Tuple): pass +class JVP: + def __init__(self, primal, tangent=defaultdict(lambda: Number(0.0))): + self.primal = primal + self.tangent = tangent.copy() + self.tangent[str(id(primal))] = Variable(str(id(primal)), Real) + + def __add__(self, other): + primal = self.primal + other.primal + tangent = defaultdict(lambda: Number(0.0)) + for key, value in self.tangent.items(): + tangent[key] += value + for key, value in other.tangent.items(): + tangent[key] += value + tangent[str(id(self.primal))] += other.primal - other.primal + tangent[str(id(other.primal))] += self.primal - self.primal + return JVP(primal, tangent) + + def __mul__(self, other): + primal = self.primal * other.primal + tangent = defaultdict(lambda: Number(0.0)) + for key, value in self.tangent.items(): + tangent[key] += value + for key, value in other.tangent.items(): + tangent[key] += value + tangent[str(id(self.primal))] *= other.primal + tangent[str(id(other.primal))] *= self.primal + return JVP(primal, tangent) + + def log(self): + primal = self.primal.log() + tangent = self.tangent + tangent[str(id(self.primal))] /= self.primal + return JVP(primal, tangent) + + + +@lazy.register(Binary, AssociativeOp, Tangent, Tangent) @eager.register(Binary, AssociativeOp, Tangent, Tangent) def jvp_binary(op, lhs, rhs): lhs_primal, lhs_tangent = lhs @@ -24,6 +64,7 @@ def jvp_binary(op, lhs, rhs): return Tangent(primal, tangent) +@lazy.register(Reduce, AssociativeOp, Tangent, frozenset) @eager.register(Reduce, AssociativeOp, Tangent, frozenset) def jvp_reduce(op, arg, reduced_vars): arg_primal, arg_tangent = arg @@ -37,6 +78,7 @@ def jvp_reduce(op, arg, reduced_vars): return Tangent(primal, tangent) +@lazy.register(Unary, LogOp, Tangent) @eager.register(Unary, LogOp, Tangent) def jvp_log(op, arg): arg_primal, arg_tangent = arg diff --git a/test/test_jvp.py b/test/test_jvp.py index 6ed1e8130..1c58f1b00 100644 --- a/test/test_jvp.py +++ b/test/test_jvp.py @@ -4,26 +4,41 @@ from collections import OrderedDict import funsor.ops as ops -from funsor.domains import Bint -from funsor.jvp import Tangent +from funsor.domains import Bint, Real +from funsor.jvp import Tangent, JVP +from funsor.vjp import transpose from funsor.testing import assert_close, random_tensor +from funsor.terms import Variable, Number, lazy +from funsor.tensor import Tensor +from funsor.optimizer import apply_optimizer +from funsor.interpreter import interpretation -def test_identity(): +import torch +import funsor +funsor.set_backend("torch") + + +def test_id(): x = random_tensor(OrderedDict(i=Bint[2])) dx = random_tensor(OrderedDict(i=Bint[2])) - x_ = Tangent((x, dx)) - f, df = x_ - assert_close(f, x) + x_ = JVP(x) + with lazy: + f = x_ + assert_close(f.primal, x) + df = f.tangent[str(id(x))](**{str(id(x)): dx}) assert_close(df, dx) def test_log(): - x = random_tensor(OrderedDict(i=Bint[2])) + x = Tensor(torch.tensor([1., 2.]), OrderedDict(i=Bint[2])) dx = random_tensor(OrderedDict(i=Bint[2])) - x_ = Tangent((x, dx)) - f, df = x_.log() - assert_close(f, x.log()) + x_ = JVP(x) + with lazy: + f = x_.log() + primal = apply_optimizer(f.primal) + assert_close(primal, x.log()) + df = f.tangent[str(id(x))](**{str(id(x)): dx}) assert_close(df, dx / x) @@ -32,14 +47,33 @@ def test_add(): y = random_tensor(OrderedDict(j=Bint[3])) dx = random_tensor(OrderedDict(i=Bint[2])) dy = random_tensor(OrderedDict(j=Bint[3])) - # dx = Number(1.0) # Variable("dx", Real) - # dy = Number(1.0) # Variable("dy", Real) - x_ = Tangent((x, dx)) - y_ = Tangent((y, dy)) - f_ = x_ + y_ - f, df = f_ - assert_close(f, x + y) - assert_close(df, dx + dy) + x_ = JVP(x) + y_ = JVP(y) + with lazy: + f = x_ + y_ + + primal = apply_optimizer(f.primal) + assert_close(primal, x + y) + + dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) + assert_close(dfdx, dx+y-y) + + dfdy = f.tangent[str(id(y))](**{str(id(y)): dy}) + assert_close(dfdy, dy+x-x) + + +def test_add_two(): + x = random_tensor(OrderedDict(i=Bint[2])) + dx = Tensor(torch.tensor([1, 1]), OrderedDict(i=Bint[2])) + x_ = JVP(x) + with lazy: + f = x_ + x_ + + primal = apply_optimizer(f.primal) + assert_close(primal, x + x) + + dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) + assert_close(dfdx, 2*dx) def test_mul(): @@ -47,22 +81,63 @@ def test_mul(): y = random_tensor(OrderedDict(j=Bint[3])) dx = random_tensor(OrderedDict(i=Bint[2])) dy = random_tensor(OrderedDict(j=Bint[3])) - # dx = Number(1.0) # Variable("dx", Real) - # dy = Number(0.0) # Variable("dy", Real) - x_ = Tangent((x, dx)) - y_ = Tangent((y, dy)) - f, df = x_ * y_ - assert_close(f, x * y) - assert_close(df, (x * dy + y * dx).align(tuple(df.inputs.keys()))) + x_ = JVP(x) + y_ = JVP(y) + with lazy: + f = x_ * y_ + + primal = apply_optimizer(f.primal) + assert_close(primal, x * y) + + dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) + assert_close(dfdx, dx*y) + + dfdy = f.tangent[str(id(y))](**{str(id(y)): dy}) + assert_close(dfdy, dy*x) + + # jacfwd + dx = Tensor(torch.eye(2), OrderedDict(i=Bint[2], l=Bint[2])) + jacdx = f.tangent[str(id(x))](**{str(id(x)): dx}) + assert_close(jacdx, dx*y) + + +def test_mul_add(): + x = random_tensor(OrderedDict(i=Bint[2])) + y = random_tensor(OrderedDict(j=Bint[3])) + z = random_tensor(OrderedDict(k=Bint[4])) + dx = random_tensor(OrderedDict(i=Bint[2])) + dy = random_tensor(OrderedDict(j=Bint[3])) + dz = random_tensor(OrderedDict(k=Bint[4])) + x_ = JVP(x) + y_ = JVP(y) + z_ = JVP(z) + with lazy: + f = x_ * y_ + z_ + + primal = apply_optimizer(f.primal) + assert_close(primal, x * y + z) + + dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) + assert_close(dfdx, dx*y) + + dfdy = f.tangent[str(id(y))](**{str(id(y)): dy}) + # assert_close(dfdy, dy*x+z-z) + + dfdz = f.tangent[str(id(z))](**{str(id(z)): dz}) + breakpoint() + assert_close(dfdz, dz+x*y-x*y) def test_reduce_sum(): x = random_tensor(OrderedDict(j=Bint[4])) dx = random_tensor(OrderedDict(j=Bint[4])) - x_ = Tangent((x, dx)) - f, df = x_.reduce(ops.add, "j") - assert_close(f, x.reduce(ops.add, "j")) - assert_close(df, dx.reduce(ops.add, "j")) + Tx = Variable("dx", Real) + x_ = Tangent((x, Tx)) + with lazy: + f, df = x_.reduce(ops.add, "j") + breakpoint() + assert_close(apply_optimizer(f), x.reduce(ops.add, "j")) + assert_close(df(dx=dx), dx.reduce(ops.add, "j")) def test_reduce_prod(): @@ -74,17 +149,46 @@ def test_reduce_prod(): assert_close(df, (f * dx / x).reduce(ops.add, "j")) +def test_reduce_jacfwd(): + x = random_tensor(OrderedDict(j=Bint[4])) + # dx = Tensor(torch.tensor([1, 0, 0, 0]), OrderedDict(j=Bint[4])) + dx = Tensor(torch.eye(4), OrderedDict(j=Bint[4], l=Bint[4])) + x_ = Tangent((x, dx)) + f, df = x_.reduce(ops.mul, "j") + assert_close(f, x.reduce(ops.mul, "j")) + assert_close(df, (f * dx / x).reduce(ops.add, "j")) + + def test_matmul_tensor(): x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + # Tx = Variable("dx", Real) + # Ty = Variable("dy", Real) x_ = Tangent((x, dx)) y_ = Tangent((y, dy)) - xy_ = x_ * y_ - z, dz = xy_.reduce(ops.add, "j") - assert_close(z, (x * y).reduce(ops.add, "j")) - assert_close(dz, (y * dx + x * dy).reduce(ops.add, "j")) + with lazy: + x @ y + xy_ = x_ * y_ + z, dz = xy_.reduce(ops.add, "j") + assert_close(apply_optimizer(z), (x * y).reduce(ops.add, "j")) + assert_close(apply_optimizer(dz), (y * dx + x * dy).reduce(ops.add, "j")) + + +def test_matmul_jacfwd(): + x = random_tensor(OrderedDict(j=Bint[4])) + y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + dx = Tensor(torch.eye(4), OrderedDict(j=Bint[4], l=Bint[4])) + dy = Number(0.0) + x_ = Tangent((x, dx)) + y_ = Tangent((y, dy)) + with lazy: + xy_ = x_ * y_ + z, dz = xy_.reduce(ops.add, "j") + assert_close(apply_optimizer(z), (x * y).reduce(ops.add, "j")) + assert_close(apply_optimizer(dz), (y * dx).reduce(ops.add, "j")) + assert_close(apply_optimizer(dz), y(j="l").align(tuple(dz.inputs.keys()))) def test_compose(): @@ -92,10 +196,20 @@ def test_compose(): y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - x_ = Tangent((x, dx)) - logx_ = x_.log() - y_ = Tangent((y, dy)) - logxy_ = logx_ * y_ - z, dz = logxy_.reduce(ops.add, "j") - assert_close(z, (x.log() * y).reduce(ops.add, "j")) - assert_close(dz, (y * dx / x + x.log() * dy).reduce(ops.add, "j")) + Tx = Variable("dx", Real) + Ty = Variable("dy", Real) + x_ = Tangent((x, Tx)) + y_ = Tangent((y, Ty)) + with lazy: + logx_ = x_.log() + logxy_ = logx_ * y_ + z, dz = logxy_.reduce(ops.add, "j") + + breakpoint() + actual_z = apply_optimizer(z) + expected_z = (x.log() * y).reduce(ops.add, "j") + assert_close(actual_z, expected_z) + actual_dz = apply_optimizer(dz(**{"dx": dx, "dy": dy})) + expected_dz = (y * dx / x + x.log() * dy).reduce(ops.add, "j") + breakpoint() + assert_close(actual_dz, expected_dz) From 16c37be38685e0bcbf61a93055b76cb95243d654 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 5 Mar 2021 23:28:47 -0500 Subject: [PATCH 04/16] GetitemOp, Lambda --- test/test_jvp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/test_jvp.py b/test/test_jvp.py index 1c58f1b00..fbe18f886 100644 --- a/test/test_jvp.py +++ b/test/test_jvp.py @@ -165,13 +165,16 @@ def test_matmul_tensor(): dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) # Tx = Variable("dx", Real) - # Ty = Variable("dy", Real) + Ty = Variable("dy", Reals[4, 5])["j", "k"] x_ = Tangent((x, dx)) - y_ = Tangent((y, dy)) + y_ = Tangent((y, Ty)) with lazy: - x @ y xy_ = x_ * y_ z, dz = xy_.reduce(ops.add, "j") + + dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + dy = funsor.terms.Lambda("k", funsor.terms.Lambda("j", dy)) + dz(dy=dy) assert_close(apply_optimizer(z), (x * y).reduce(ops.add, "j")) assert_close(apply_optimizer(dz), (y * dx + x * dy).reduce(ops.add, "j")) From b25ca21d5246af4a43f7022af17ca95d81506713 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 10 Mar 2021 04:31:50 -0500 Subject: [PATCH 05/16] save changes --- funsor/domains.py | 6 +++++- funsor/jvp.py | 4 ++-- test/test_jvp.py | 50 ++++++++++++++++++++++++++++++++++------------- 3 files changed, 43 insertions(+), 17 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 617e965c5..c447f1f0a 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -261,6 +261,7 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain): return Array[dtype, shape] elif isinstance(lhs_domain, ProductDomain): # XXX should this return a Union? + return Real raise NotImplementedError( "Cannot statically infer domain from: " f"{lhs_domain}[{rhs_domain}]" ) @@ -325,7 +326,10 @@ def _find_domain_associative_generic(op, *domains): return Array[domains[0].dtype, ()] lhs, rhs = domains - if lhs.dtype == "real" or rhs.dtype == "real": + # FIXME + if lhs is rhs: + return lhs + elif lhs.dtype == "real" or rhs.dtype == "real": dtype = "real" elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min): dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1 diff --git a/funsor/jvp.py b/funsor/jvp.py index 4f39ce799..617bf11d7 100644 --- a/funsor/jvp.py +++ b/funsor/jvp.py @@ -49,7 +49,7 @@ def log(self): -@lazy.register(Binary, AssociativeOp, Tangent, Tangent) +# @lazy.register(Binary, AssociativeOp, Tangent, Tangent) @eager.register(Binary, AssociativeOp, Tangent, Tangent) def jvp_binary(op, lhs, rhs): lhs_primal, lhs_tangent = lhs @@ -64,7 +64,7 @@ def jvp_binary(op, lhs, rhs): return Tangent(primal, tangent) -@lazy.register(Reduce, AssociativeOp, Tangent, frozenset) +# @lazy.register(Reduce, AssociativeOp, Tangent, frozenset) @eager.register(Reduce, AssociativeOp, Tangent, frozenset) def jvp_reduce(op, arg, reduced_vars): arg_primal, arg_tangent = arg diff --git a/test/test_jvp.py b/test/test_jvp.py index fbe18f886..8398d51cf 100644 --- a/test/test_jvp.py +++ b/test/test_jvp.py @@ -4,14 +4,14 @@ from collections import OrderedDict import funsor.ops as ops -from funsor.domains import Bint, Real +from funsor.domains import Bint, Real, Reals from funsor.jvp import Tangent, JVP -from funsor.vjp import transpose from funsor.testing import assert_close, random_tensor -from funsor.terms import Variable, Number, lazy +from funsor.terms import Variable, Number, lazy, Lambda, Binary, Funsor from funsor.tensor import Tensor from funsor.optimizer import apply_optimizer from funsor.interpreter import interpretation +from funsor.factory import make_funsor, Bound, Fresh, Has import torch @@ -159,24 +159,46 @@ def test_reduce_jacfwd(): assert_close(df, (f * dx / x).reduce(ops.add, "j")) +@make_funsor +def MatMul( + a: Has[{"i"}], + b: Has[{"i"}], + i: Bound + ) -> Fresh[lambda a: a]: + return Prod(a, b).reduce(ops.add, i) + +@make_funsor +def Prod( + x: Funsor, + y: Funsor + ) -> Fresh[lambda x: x]: + return x * y + + + def test_matmul_tensor(): x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - # Tx = Variable("dx", Real) - Ty = Variable("dy", Reals[4, 5])["j", "k"] - x_ = Tangent((x, dx)) + Tx = Variable("dx", Reals[4])["j"] + Ty = Variable("dy", Reals[4])["j"] + Dx = Lambda(Variable("j", Bint[4]), dx) + Dy = Lambda(Variable("j", Bint[4]), dy) + x_ = Tangent((x, Tx)) y_ = Tangent((y, Ty)) - with lazy: - xy_ = x_ * y_ - z, dz = xy_.reduce(ops.add, "j") - dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - dy = funsor.terms.Lambda("k", funsor.terms.Lambda("j", dy)) - dz(dy=dy) - assert_close(apply_optimizer(z), (x * y).reduce(ops.add, "j")) - assert_close(apply_optimizer(dz), (y * dx + x * dy).reduce(ops.add, "j")) + with funsor.terms.eager: + z, dz = MatMul(x_, y_, "j") + breakpoint() + + actual_z = apply_optimizer(z) + actual_dz = dz(dx=Dx, dy=Dy) + expected_z = (x * y).reduce(ops.add, "j") + expected_dz = (y * dx + x * dy).reduce(ops.add, "j") + + assert_close(actual_z, expected_z) + assert_close(actual_dz, expected_dz) def test_matmul_jacfwd(): From 40ba97d0c20f1f2fb61e19cbb45e88458caf6fbb Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 11 Mar 2021 13:37:05 -0500 Subject: [PATCH 06/16] linearize wip --- funsor/linearize.py | 140 +++++++++++++++++++++++++++++++++++++++++ test/test_linearize.py | 46 ++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 funsor/linearize.py create mode 100644 test/test_linearize.py diff --git a/funsor/linearize.py b/funsor/linearize.py new file mode 100644 index 000000000..a0df942f2 --- /dev/null +++ b/funsor/linearize.py @@ -0,0 +1,140 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import inspect +import typing +import warnings +from collections import OrderedDict +from functools import singledispatch + +import makefun + +from funsor.instrument import debug_logged +from funsor.terms import Funsor, FunsorMeta, Variable, eager, to_funsor + + + +def linearize(fn, primals): + """ + Decorator to dynamically create a subclass of + :class:`~funsor.terms.Funsor`, together with a single default eager + pattern. + + This infers inputs, outputs, fresh, and bound variables from type hints + follow the following convention: + + - Funsor inputs are typed :class:`~funsor.terms.Funsor`. + - Bound variable inputs (names) are typed :class:`Bound`. + - Fresh variable inputs (names) are typed :class:`Fresh` together with + lambda to compute the dependent domain. + - Ground value inputs (e.g. Python ints) are typed :class:`Value` together with + their actual data type, e.g. ``Value[int]``. + - The return value is typed :class:`Fresh` together with a lambda to + compute the dependent return domain. + + For example to unflatten a single coordinate into a pair of coordinates we + could define:: + + @make_funsor + def Unflatten( + x: Funsor, + i: Bound, + i_over_2: Fresh[lambda i: Bint[i.size // 2]], + i_mod_2: Fresh[lambda: Bint[2]], + ) -> Fresh[lambda x: x]: + assert i.output.size % 2 == 0 + return x(**{i.name: i_over_2 * Number(2, 3) + i_mod_2}) + + :param callable fn: A type annotated function of Funsors. + :rtype: subclas of :class:`~funsor.terms.Funsor` + """ + breakpoint() + input_types = typing.get_type_hints(fn) + for name, hint in input_types.items(): + if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has))): + raise TypeError(f"Invalid type hint {name}: {hint}") + output_type = input_types.pop("return") + hints = tuple(input_types.values()) + + class ResultMeta(FunsorMeta): + def __call__(cls, *args): + args = list(args) + + # Compute domains of bound variables. + for i, (name, arg) in enumerate(zip(cls._ast_fields, args)): + hint = input_types[name] + if hint is Funsor or isinstance(hint, Has): # TODO support domains + args[i] = to_funsor(arg) + elif hint is Bound: + for other in args: + if isinstance(other, Funsor): + domain = other.inputs.get(arg, None) + if domain is not None: + arg = to_funsor(arg, domain) + if not isinstance(arg, Variable): + raise ValueError(f"Cannot infer domain of {name}={arg}") + args[i] = arg + elif isinstance(hint, Value): + if not isinstance(arg, hint.value_type): + raise TypeError( + f"invalid dependent value type: {arg}: {hint.value_type}" + ) + args[i] = arg + + # Compute domains of fresh variables. + dependent_args = _get_dependent_args(cls._ast_fields, hints, args) + for i, (hint, arg) in enumerate(zip(hints, args)): + if isinstance(hint, Fresh): + domain = hint(**dependent_args) + args[i] = to_funsor(arg, domain) + return super().__call__(*args) + + @makefun.with_signature( + "__init__({})".format(", ".join(["self"] + list(input_types))) + ) + def __init__(self, **kwargs): + args = tuple(kwargs[k] for k in self._ast_fields) + dependent_args = _get_dependent_args(self._ast_fields, hints, args) + output = output_type(**dependent_args) + inputs = OrderedDict() + bound = {} + for hint, arg, arg_name in zip(hints, args, self._ast_fields): + if hint is Funsor: + assert isinstance(arg, Funsor) + inputs.update(arg.inputs) + elif isinstance(hint, Has): + assert isinstance(arg, Funsor) + inputs.update(arg.inputs) + for name in hint.bound: + if kwargs[name] not in arg.input_vars: + warnings.warn( + f"Argument {arg_name} is missing bound variable {kwargs[name]} from argument {name}." + f"Are you sure {name} will always appear in {arg_name}?", + SyntaxWarning, + ) + for hint, arg in zip(hints, args): + if hint is Bound: + bound[arg.name] = inputs.pop(arg.name) + for hint, arg in zip(hints, args): + if isinstance(hint, Fresh): + for k, d in arg.inputs.items(): + if k not in bound: + inputs[k] = d + fresh = frozenset() + Funsor.__init__(self, inputs, output, fresh, bound) + for name, arg in zip(self._ast_fields, args): + setattr(self, name, arg) + + def _alpha_convert(self, alpha_subs): + alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} + return Funsor._alpha_convert(self, alpha_subs) + + ResultMeta.__name__ = f"{fn.__name__}Meta" + Result = ResultMeta( + fn.__name__, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} + ) + pattern = (Result,) + tuple( + _hint_to_pattern(input_types[k]) for k in Result._ast_fields + ) + eager.register(*pattern)(_erase_types(fn)) + return Result diff --git a/test/test_linearize.py b/test/test_linearize.py new file mode 100644 index 000000000..e46d5dd29 --- /dev/null +++ b/test/test_linearize.py @@ -0,0 +1,46 @@ +from collections import OrderedDict +from funsor.factory import make_funsor, Fresh, Has +from funsor.terms import Funsor, Number +# from funsor.linearize import linearize +from funsor.testing import assert_close, random_tensor +from funsor.domains import Real, Reals, Bint +from funsor.jvp import Tangent +from makefun import create_function, with_signature, partial + +def test_mul(): + @make_funsor + def Mul( + x: Funsor, + y: Funsor + ) -> Fresh[lambda x: x]: + return x * y + x = random_tensor(OrderedDict(j=Bint[4])) + y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + dx = random_tensor(OrderedDict(j=Bint[4])) + dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + # x_ = Tangent((x, Tx)) + # y_ = Tangent((y, Ty)) + def linear_fn(x, y, Tx, Ty): + print("hello fellas!") + x_ = Tangent((x, Tx)) + y_ = Tangent((y, Ty)) + return Mul(x_, y_)[1] + + breakpoint() + LinearMul = partial(linear_fn, x=x, y=y) + + # def linearize(fn, (x, y)): + # def linear_fn(Tx, Ty): + # x_ = Tangent((x, Tx)) + # y_ = Tangent((y, Ty)) + # return Mul(x_, y_) + # return + + # Mul(x, y), LinearMul = linearize(Mul, (x, y)) + @make_funsor + def LinearMul( + dx: Funsor, + dy: Funsor + ) -> Fresh[lambda dx: dx]: + return x * dy + y * dx + assert Mul(x, y) == x * y From 596de4f0e5d84ce7207e64b204698ca5e01b63d4 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 13 Mar 2021 20:26:08 -0500 Subject: [PATCH 07/16] implement linearize --- funsor/jvp.py | 144 +++++++++++++++++++++++++++++------------------ test/test_jvp.py | 112 +++++++++++++++++++----------------- 2 files changed, 148 insertions(+), 108 deletions(-) diff --git a/funsor/jvp.py b/funsor/jvp.py index 617bf11d7..e8d83baeb 100644 --- a/funsor/jvp.py +++ b/funsor/jvp.py @@ -3,85 +3,117 @@ import funsor.ops as ops from funsor.ops import AssociativeOp, LogOp -from funsor.terms import Binary, Reduce, Tuple, Unary, eager, lazy, Variable, Number +from funsor.terms import Binary, Reduce, Tuple, Unary, eager, lazy, Variable, Number, Lambda from funsor.interpreter import interpretation -from funsor.domains import Bint, Real +from funsor.domains import Bint, Real, Array, Reals from collections import defaultdict +from functools import reduce, singledispatch +from funsor import Tensor -class Tangent(Tuple): +def to_var(x, name): + var = Variable(name, Array["real", x.data.shape])[tuple(x.inputs)] + return var + + +def to_arg(x): + input_vars = tuple(Variable(key, value) for key, value in x.inputs.items()) + arg = reduce(lambda a, b: Lambda(b, a), reversed(input_vars), x) + return arg + + +def fjit(cls, *args): + new_args = [] + for field, arg in zip(cls._ast_fields, args): + if isinstance(arg, (Number, Tensor)): + arg = to_var(arg, field) + new_args.append(arg) + new_args = tuple(new_args) + return cls(*new_args) + + +def linearize(cls, *args, log=True): + jvp = logJVP if log else JVP + new_args = [] + for arg_name, arg in zip(cls._ast_fields, args): + if isinstance(arg, (Number, Tensor)): + tangent_var = to_var(arg, "d" + arg_name) + arg = jvp(arg, tangent_var) + new_args.append(arg) + new_args = tuple(new_args) + return cls(*new_args) + + +def get_linear_terms(expr, linear_vars): + breakpoint() pass -class JVP: - def __init__(self, primal, tangent=defaultdict(lambda: Number(0.0))): - self.primal = primal - self.tangent = tangent.copy() - self.tangent[str(id(primal))] = Variable(str(id(primal)), Real) - - def __add__(self, other): - primal = self.primal + other.primal - tangent = defaultdict(lambda: Number(0.0)) - for key, value in self.tangent.items(): - tangent[key] += value - for key, value in other.tangent.items(): - tangent[key] += value - tangent[str(id(self.primal))] += other.primal - other.primal - tangent[str(id(other.primal))] += self.primal - self.primal - return JVP(primal, tangent) - - def __mul__(self, other): - primal = self.primal * other.primal - tangent = defaultdict(lambda: Number(0.0)) - for key, value in self.tangent.items(): - tangent[key] += value - for key, value in other.tangent.items(): - tangent[key] += value - tangent[str(id(self.primal))] *= other.primal - tangent[str(id(other.primal))] *= self.primal - return JVP(primal, tangent) - - def log(self): - primal = self.primal.log() - tangent = self.tangent - tangent[str(id(self.primal))] /= self.primal - return JVP(primal, tangent) - - - -# @lazy.register(Binary, AssociativeOp, Tangent, Tangent) -@eager.register(Binary, AssociativeOp, Tangent, Tangent) +@singledispatch +def transpose(expr, linear_vars): + get_linear_terms(expr, linear_vars) + out_shape = tuple(value.size for key, value in expr.inputs.items() if key not in linear_vars) + out_inputs = tuple(key for key in expr.inputs if key not in linear_vars) + out_tangent = Variable("dout", Array["real", out_shape])[out_inputs] + breakpoint() + pass + + +class JVP(Tuple): + """ + Tuple:(Primal, Tanget) + Semiring: (Add, Mul) + """ + sum_op = ops.add + prod_op = ops.mul + + +class logJVP(Tuple): + """ + Tuple: (LogPrimal, LogTanget) + Semiring: (Logaddexp, Add) + """ + sum_op = ops.logaddexp + prod_op = ops.add + + +@eager.register(Binary, AssociativeOp, JVP, JVP) +@eager.register(Binary, AssociativeOp, logJVP, logJVP) def jvp_binary(op, lhs, rhs): + sum_op = lhs.sum_op + prod_op = lhs.prod_op lhs_primal, lhs_tangent = lhs rhs_primal, rhs_tangent = rhs primal = Binary(op, lhs_primal, rhs_primal) - if op is ops.add: - tangent = lhs_tangent + rhs_tangent - elif op is ops.mul: - tangent = rhs_primal * lhs_tangent + lhs_primal * rhs_tangent + if op is sum_op: + tangent = sum_op(lhs_tangent, rhs_tangent) + elif op is prod_op: + tangent = sum_op(prod_op(rhs_primal, lhs_tangent), prod_op(lhs_primal, rhs_tangent)) else: raise NotImplementedError - return Tangent(primal, tangent) + return type(lhs)(primal, tangent) -# @lazy.register(Reduce, AssociativeOp, Tangent, frozenset) -@eager.register(Reduce, AssociativeOp, Tangent, frozenset) +@lazy.register(Reduce, AssociativeOp, JVP, frozenset) +@eager.register(Reduce, AssociativeOp, logJVP, frozenset) def jvp_reduce(op, arg, reduced_vars): + sum_op, prod_op = arg.sum_op, arg.prod_op arg_primal, arg_tangent = arg primal = Reduce(op, arg_primal, reduced_vars) - if op is ops.add: - tangent = Reduce(op, arg_tangent, reduced_vars) - elif op is ops.mul: - tangent = Reduce(ops.add, arg_tangent * primal / arg_primal, reduced_vars) + if op is sum_op: + tangent = Reduce(sum_op, arg_tangent, reduced_vars) + elif op is prod_op: + div_op = ops.SAFE_BINARY_INVERSES[prod_op] + tangent = Reduce(prod_op, div_op(prod_op(arg_tangent, primal), arg_primal), reduced_vars) else: raise NotImplementedError - return Tangent(primal, tangent) + return type(arg)(primal, tangent) -@lazy.register(Unary, LogOp, Tangent) -@eager.register(Unary, LogOp, Tangent) +@lazy.register(Unary, LogOp, JVP) +@eager.register(Unary, LogOp, JVP) def jvp_log(op, arg): arg_primal, arg_tangent = arg primal = Unary(op, arg_primal) tangent = Binary(ops.truediv, arg_tangent, arg_primal) - return Tangent(primal, tangent) + return JVP(primal, tangent) diff --git a/test/test_jvp.py b/test/test_jvp.py index 8398d51cf..5f2a1f249 100644 --- a/test/test_jvp.py +++ b/test/test_jvp.py @@ -5,13 +5,14 @@ import funsor.ops as ops from funsor.domains import Bint, Real, Reals -from funsor.jvp import Tangent, JVP +from funsor.jvp import JVP, to_var, to_arg, fjit, linearize, transpose from funsor.testing import assert_close, random_tensor from funsor.terms import Variable, Number, lazy, Lambda, Binary, Funsor from funsor.tensor import Tensor from funsor.optimizer import apply_optimizer from funsor.interpreter import interpretation from funsor.factory import make_funsor, Bound, Fresh, Has +from funsor.sum_product import MarkovProduct import torch @@ -132,7 +133,7 @@ def test_reduce_sum(): x = random_tensor(OrderedDict(j=Bint[4])) dx = random_tensor(OrderedDict(j=Bint[4])) Tx = Variable("dx", Real) - x_ = Tangent((x, Tx)) + x_ = JVP((x, Tx)) with lazy: f, df = x_.reduce(ops.add, "j") breakpoint() @@ -143,7 +144,7 @@ def test_reduce_sum(): def test_reduce_prod(): x = random_tensor(OrderedDict(j=Bint[4])) dx = random_tensor(OrderedDict(j=Bint[4])) - x_ = Tangent((x, dx)) + x_ = JVP((x, dx)) f, df = x_.reduce(ops.mul, "j") assert_close(f, x.reduce(ops.mul, "j")) assert_close(df, (f * dx / x).reduce(ops.add, "j")) @@ -153,7 +154,7 @@ def test_reduce_jacfwd(): x = random_tensor(OrderedDict(j=Bint[4])) # dx = Tensor(torch.tensor([1, 0, 0, 0]), OrderedDict(j=Bint[4])) dx = Tensor(torch.eye(4), OrderedDict(j=Bint[4], l=Bint[4])) - x_ = Tangent((x, dx)) + x_ = JVP((x, dx)) f, df = x_.reduce(ops.mul, "j") assert_close(f, x.reduce(ops.mul, "j")) assert_close(df, (f * dx / x).reduce(ops.add, "j")) @@ -175,66 +176,73 @@ def Prod( return x * y - -def test_matmul_tensor(): +def test_fjit(): + # Product x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - dx = random_tensor(OrderedDict(j=Bint[4])) - dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - Tx = Variable("dx", Reals[4])["j"] - Ty = Variable("dy", Reals[4])["j"] - Dx = Lambda(Variable("j", Bint[4]), dx) - Dy = Lambda(Variable("j", Bint[4]), dy) - x_ = Tangent((x, Tx)) - y_ = Tangent((y, Ty)) - - with funsor.terms.eager: - z, dz = MatMul(x_, y_, "j") - breakpoint() + cProd = fjit(Prod, x, y) + + x2 = random_tensor(OrderedDict(j=Bint[4])) + y2 = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + expected = Prod(x2, y2) + actual = cProd(x=to_arg(x2), y=to_arg(y2)) + assert_close(actual, expected) + + # MarkovProduct + trans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) + cMarkovProduct = fjit(MarkovProduct, ops.logaddexp, ops.add, trans, "time", {"prev": "curr"}) - actual_z = apply_optimizer(z) - actual_dz = dz(dx=Dx, dy=Dy) - expected_z = (x * y).reduce(ops.add, "j") - expected_dz = (y * dx + x * dy).reduce(ops.add, "j") - - assert_close(actual_z, expected_z) - assert_close(actual_dz, expected_dz) + trans2 = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) + expected = MarkovProduct(ops.logaddexp, ops.add, trans2, "time", {"prev": "curr"}) + actual = cMarkovProduct(trans=to_arg(trans2)) + assert_close(actual, expected) -def test_matmul_jacfwd(): +def test_linearize(): + # Add x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - dx = Tensor(torch.eye(4), OrderedDict(j=Bint[4], l=Bint[4])) - dy = Number(0.0) - x_ = Tangent((x, dx)) - y_ = Tangent((y, dy)) - with lazy: - xy_ = x_ * y_ - z, dz = xy_.reduce(ops.add, "j") - assert_close(apply_optimizer(z), (x * y).reduce(ops.add, "j")) - assert_close(apply_optimizer(dz), (y * dx).reduce(ops.add, "j")) - assert_close(apply_optimizer(dz), y(j="l").align(tuple(dz.inputs.keys()))) + z, linearAdd = linearize(Binary, ops.add, x, y, log=False) + dx = random_tensor(OrderedDict(j=Bint[4])) + dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + expected = dx + dy + actual = linearAdd(dlhs=to_arg(dx), drhs=to_arg(dy)) + assert_close(actual, expected) + assert_close(z, x + y) -def test_compose(): + # Add in a LogFunctor x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + z, linearAdd = linearize(Binary, ops.add, x, y, log=True) + dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - Tx = Variable("dx", Real) - Ty = Variable("dy", Real) - x_ = Tangent((x, Tx)) - y_ = Tangent((y, Ty)) - with lazy: - logx_ = x_.log() - logxy_ = logx_ * y_ - z, dz = logxy_.reduce(ops.add, "j") + expected = ops.logaddexp(ops.add(y, dx), ops.add(x, dy)) + actual = linearAdd(dlhs=to_arg(dx), drhs=to_arg(dy)) + assert_close(actual, expected) + # MarkovProduct in a LogFunctor + trans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) + z, linearMP = linearize(MarkovProduct, ops.logaddexp, ops.add, trans, "time", {"prev": "curr"}, log=True) + + dtrans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) + # expected = MarkovProduct(ops.logaddexp, ops.add, trans2, "time", {"prev": "curr"}) + actual = linearMP(dtrans=to_arg(dtrans)) breakpoint() - actual_z = apply_optimizer(z) - expected_z = (x.log() * y).reduce(ops.add, "j") - assert_close(actual_z, expected_z) - actual_dz = apply_optimizer(dz(**{"dx": dx, "dy": dy})) - expected_dz = (y * dx / x + x.log() * dy).reduce(ops.add, "j") - breakpoint() - assert_close(actual_dz, expected_dz) + assert_close(actual, expected) + + +def test_transpose(): + # Add + x = random_tensor(OrderedDict(j=Bint[4])) + y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + z, linearAdd = linearize(Binary, ops.add, x, y, log=False) + transpose(linearAdd, {"dlhs", "drhs"}) + + dx = random_tensor(OrderedDict(j=Bint[4])) + dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + expected = dx + dy + actual = linearAdd(dlhs=to_arg(dx), drhs=to_arg(dy)) + assert_close(actual, expected) + assert_close(z, x + y) From 84b008ab5391a5814ee3500978c917076c77e7c9 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 13 Mar 2021 20:27:42 -0500 Subject: [PATCH 08/16] rename to autodiff --- funsor/{jvp.py => autodiff.py} | 0 test/{test_jvp.py => test_autodiff.py} | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename funsor/{jvp.py => autodiff.py} (100%) rename test/{test_jvp.py => test_autodiff.py} (98%) diff --git a/funsor/jvp.py b/funsor/autodiff.py similarity index 100% rename from funsor/jvp.py rename to funsor/autodiff.py diff --git a/test/test_jvp.py b/test/test_autodiff.py similarity index 98% rename from test/test_jvp.py rename to test/test_autodiff.py index 5f2a1f249..31a217914 100644 --- a/test/test_jvp.py +++ b/test/test_autodiff.py @@ -5,7 +5,7 @@ import funsor.ops as ops from funsor.domains import Bint, Real, Reals -from funsor.jvp import JVP, to_var, to_arg, fjit, linearize, transpose +from funsor.autodiff import JVP, to_var, to_arg, fjit, linearize, transpose from funsor.testing import assert_close, random_tensor from funsor.terms import Variable, Number, lazy, Lambda, Binary, Funsor from funsor.tensor import Tensor From cd28eab43662fcdb4ae669ea6d6046cd67bc6d4f Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 13 Mar 2021 20:29:54 -0500 Subject: [PATCH 09/16] rm linearize files --- funsor/linearize.py | 140 ----------------------------------------- test/test_linearize.py | 46 -------------- 2 files changed, 186 deletions(-) delete mode 100644 funsor/linearize.py delete mode 100644 test/test_linearize.py diff --git a/funsor/linearize.py b/funsor/linearize.py deleted file mode 100644 index a0df942f2..000000000 --- a/funsor/linearize.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import inspect -import typing -import warnings -from collections import OrderedDict -from functools import singledispatch - -import makefun - -from funsor.instrument import debug_logged -from funsor.terms import Funsor, FunsorMeta, Variable, eager, to_funsor - - - -def linearize(fn, primals): - """ - Decorator to dynamically create a subclass of - :class:`~funsor.terms.Funsor`, together with a single default eager - pattern. - - This infers inputs, outputs, fresh, and bound variables from type hints - follow the following convention: - - - Funsor inputs are typed :class:`~funsor.terms.Funsor`. - - Bound variable inputs (names) are typed :class:`Bound`. - - Fresh variable inputs (names) are typed :class:`Fresh` together with - lambda to compute the dependent domain. - - Ground value inputs (e.g. Python ints) are typed :class:`Value` together with - their actual data type, e.g. ``Value[int]``. - - The return value is typed :class:`Fresh` together with a lambda to - compute the dependent return domain. - - For example to unflatten a single coordinate into a pair of coordinates we - could define:: - - @make_funsor - def Unflatten( - x: Funsor, - i: Bound, - i_over_2: Fresh[lambda i: Bint[i.size // 2]], - i_mod_2: Fresh[lambda: Bint[2]], - ) -> Fresh[lambda x: x]: - assert i.output.size % 2 == 0 - return x(**{i.name: i_over_2 * Number(2, 3) + i_mod_2}) - - :param callable fn: A type annotated function of Funsors. - :rtype: subclas of :class:`~funsor.terms.Funsor` - """ - breakpoint() - input_types = typing.get_type_hints(fn) - for name, hint in input_types.items(): - if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has))): - raise TypeError(f"Invalid type hint {name}: {hint}") - output_type = input_types.pop("return") - hints = tuple(input_types.values()) - - class ResultMeta(FunsorMeta): - def __call__(cls, *args): - args = list(args) - - # Compute domains of bound variables. - for i, (name, arg) in enumerate(zip(cls._ast_fields, args)): - hint = input_types[name] - if hint is Funsor or isinstance(hint, Has): # TODO support domains - args[i] = to_funsor(arg) - elif hint is Bound: - for other in args: - if isinstance(other, Funsor): - domain = other.inputs.get(arg, None) - if domain is not None: - arg = to_funsor(arg, domain) - if not isinstance(arg, Variable): - raise ValueError(f"Cannot infer domain of {name}={arg}") - args[i] = arg - elif isinstance(hint, Value): - if not isinstance(arg, hint.value_type): - raise TypeError( - f"invalid dependent value type: {arg}: {hint.value_type}" - ) - args[i] = arg - - # Compute domains of fresh variables. - dependent_args = _get_dependent_args(cls._ast_fields, hints, args) - for i, (hint, arg) in enumerate(zip(hints, args)): - if isinstance(hint, Fresh): - domain = hint(**dependent_args) - args[i] = to_funsor(arg, domain) - return super().__call__(*args) - - @makefun.with_signature( - "__init__({})".format(", ".join(["self"] + list(input_types))) - ) - def __init__(self, **kwargs): - args = tuple(kwargs[k] for k in self._ast_fields) - dependent_args = _get_dependent_args(self._ast_fields, hints, args) - output = output_type(**dependent_args) - inputs = OrderedDict() - bound = {} - for hint, arg, arg_name in zip(hints, args, self._ast_fields): - if hint is Funsor: - assert isinstance(arg, Funsor) - inputs.update(arg.inputs) - elif isinstance(hint, Has): - assert isinstance(arg, Funsor) - inputs.update(arg.inputs) - for name in hint.bound: - if kwargs[name] not in arg.input_vars: - warnings.warn( - f"Argument {arg_name} is missing bound variable {kwargs[name]} from argument {name}." - f"Are you sure {name} will always appear in {arg_name}?", - SyntaxWarning, - ) - for hint, arg in zip(hints, args): - if hint is Bound: - bound[arg.name] = inputs.pop(arg.name) - for hint, arg in zip(hints, args): - if isinstance(hint, Fresh): - for k, d in arg.inputs.items(): - if k not in bound: - inputs[k] = d - fresh = frozenset() - Funsor.__init__(self, inputs, output, fresh, bound) - for name, arg in zip(self._ast_fields, args): - setattr(self, name, arg) - - def _alpha_convert(self, alpha_subs): - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return Funsor._alpha_convert(self, alpha_subs) - - ResultMeta.__name__ = f"{fn.__name__}Meta" - Result = ResultMeta( - fn.__name__, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} - ) - pattern = (Result,) + tuple( - _hint_to_pattern(input_types[k]) for k in Result._ast_fields - ) - eager.register(*pattern)(_erase_types(fn)) - return Result diff --git a/test/test_linearize.py b/test/test_linearize.py deleted file mode 100644 index e46d5dd29..000000000 --- a/test/test_linearize.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections import OrderedDict -from funsor.factory import make_funsor, Fresh, Has -from funsor.terms import Funsor, Number -# from funsor.linearize import linearize -from funsor.testing import assert_close, random_tensor -from funsor.domains import Real, Reals, Bint -from funsor.jvp import Tangent -from makefun import create_function, with_signature, partial - -def test_mul(): - @make_funsor - def Mul( - x: Funsor, - y: Funsor - ) -> Fresh[lambda x: x]: - return x * y - x = random_tensor(OrderedDict(j=Bint[4])) - y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - dx = random_tensor(OrderedDict(j=Bint[4])) - dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - # x_ = Tangent((x, Tx)) - # y_ = Tangent((y, Ty)) - def linear_fn(x, y, Tx, Ty): - print("hello fellas!") - x_ = Tangent((x, Tx)) - y_ = Tangent((y, Ty)) - return Mul(x_, y_)[1] - - breakpoint() - LinearMul = partial(linear_fn, x=x, y=y) - - # def linearize(fn, (x, y)): - # def linear_fn(Tx, Ty): - # x_ = Tangent((x, Tx)) - # y_ = Tangent((y, Ty)) - # return Mul(x_, y_) - # return - - # Mul(x, y), LinearMul = linearize(Mul, (x, y)) - @make_funsor - def LinearMul( - dx: Funsor, - dy: Funsor - ) -> Fresh[lambda dx: dx]: - return x * dy + y * dx - assert Mul(x, y) == x * y From cb38dcfe3814d0b56f4a570e96fdc7c3c70c6d20 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 14 Mar 2021 17:49:55 -0400 Subject: [PATCH 10/16] save changes --- funsor/autodiff.py | 63 +++++++++++++++++++++++++++++++++---------- test/test_autodiff.py | 43 ++++++++++++++++++++++------- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/funsor/autodiff.py b/funsor/autodiff.py index e8d83baeb..271423ea7 100644 --- a/funsor/autodiff.py +++ b/funsor/autodiff.py @@ -9,6 +9,7 @@ from collections import defaultdict from functools import reduce, singledispatch from funsor import Tensor +from funsor.cnf import Contraction def to_var(x, name): @@ -24,37 +25,71 @@ def to_arg(x): def fjit(cls, *args): new_args = [] - for field, arg in zip(cls._ast_fields, args): + for arg_name, arg in zip(cls._ast_fields, args): if isinstance(arg, (Number, Tensor)): - arg = to_var(arg, field) + arg = to_var(arg, arg_name) new_args.append(arg) new_args = tuple(new_args) return cls(*new_args) -def linearize(cls, *args, log=True): +def grad(cls, *args, targets, log=True): + (out_primal, linear_fn), in_tangents = linearize(cls, *args, targets=targets, log=log) + linear_terms = get_linear_terms(linear_fn, set(in_tangents)) + + out_shape = tuple(value.size for key, value in linear_fn.inputs.items() if key not in in_tangents) + out_inputs = tuple(key for key in linear_fn.inputs if key not in in_tangents) + out_tangent = Variable("dout", Array["real", out_shape])[out_inputs] + + grad_dict = {} + for name, var in in_tangents.items(): + grad_dict[name] = transpose(linear_terms[name], var, out_tangent) + return grad_dict + + +def linearize(cls, *args, targets, log=True): jvp = logJVP if log else JVP new_args = [] + in_tangents = {} for arg_name, arg in zip(cls._ast_fields, args): - if isinstance(arg, (Number, Tensor)): - tangent_var = to_var(arg, "d" + arg_name) + # if isinstance(arg, (Number, Tensor)): + if arg in targets: + tangent_var = to_var(arg, arg_name) arg = jvp(arg, tangent_var) + in_tangents[arg_name] = tangent_var new_args.append(arg) new_args = tuple(new_args) - return cls(*new_args) + return cls(*new_args), in_tangents def get_linear_terms(expr, linear_vars): - breakpoint() - pass + if len(linear_vars) == 1: + return {next(iter(linear_vars)): expr} + assert isinstance(expr, Contraction) + assert expr.bin_op is ops.add or expr.bin_op is ops.logaddexp + assert expr.red_op is ops.nullop + terms = {} + for term in expr.terms: + if len(linear_vars.intersection(term.inputs)) == 1: + var = next(iter(linear_vars.intersection(term.inputs))) + terms[var] = term + else: + result = get_linear_terms(term, linear_vars) + terms.update(result) + return terms @singledispatch -def transpose(expr, linear_vars): - get_linear_terms(expr, linear_vars) - out_shape = tuple(value.size for key, value in expr.inputs.items() if key not in linear_vars) - out_inputs = tuple(key for key in expr.inputs if key not in linear_vars) - out_tangent = Variable("dout", Array["real", out_shape])[out_inputs] +def transpose(expr, target, out_tangent): + if expr is target: + return out_tangent + raise ValueError + + +@transpose.register(Binary) +def transpose_binary(expr, target, out_tangent): + if expr is target: + return out_tangent breakpoint() pass @@ -94,7 +129,7 @@ def jvp_binary(op, lhs, rhs): return type(lhs)(primal, tangent) -@lazy.register(Reduce, AssociativeOp, JVP, frozenset) +@eager.register(Reduce, AssociativeOp, JVP, frozenset) @eager.register(Reduce, AssociativeOp, logJVP, frozenset) def jvp_reduce(op, arg, reduced_vars): sum_op, prod_op = arg.sum_op, arg.prod_op diff --git a/test/test_autodiff.py b/test/test_autodiff.py index 31a217914..fcbbf14bb 100644 --- a/test/test_autodiff.py +++ b/test/test_autodiff.py @@ -5,7 +5,7 @@ import funsor.ops as ops from funsor.domains import Bint, Real, Reals -from funsor.autodiff import JVP, to_var, to_arg, fjit, linearize, transpose +from funsor.autodiff import JVP, to_var, to_arg, fjit, linearize, grad from funsor.testing import assert_close, random_tensor from funsor.terms import Variable, Number, lazy, Lambda, Binary, Funsor from funsor.tensor import Tensor @@ -198,47 +198,70 @@ def test_fjit(): assert_close(actual, expected) +def test_grad(): + # Add + x = random_tensor(OrderedDict(j=Bint[4])) + y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + result = grad(Binary, ops.add, x, y, log=False) + breakpoint() + + dx = random_tensor(OrderedDict(j=Bint[4])) + dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + expected = dx + dy + actual = linearAdd(lhs=to_arg(dx), rhs=to_arg(dy)) + assert_close(actual, expected) + assert_close(z, x + y) + + def test_linearize(): # Add x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z, linearAdd = linearize(Binary, ops.add, x, y, log=False) + (z, linearAdd), linear_vars = linearize(Binary, ops.add, x, y, log=False) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) expected = dx + dy - actual = linearAdd(dlhs=to_arg(dx), drhs=to_arg(dy)) + actual = linearAdd(lhs=to_arg(dx), rhs=to_arg(dy)) assert_close(actual, expected) assert_close(z, x + y) # Add in a LogFunctor x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z, linearAdd = linearize(Binary, ops.add, x, y, log=True) + with funsor.terms.lazy: + z, linearAdd = linearize(Binary, ops.add, x, y, log=True) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) expected = ops.logaddexp(ops.add(y, dx), ops.add(x, dy)) - actual = linearAdd(dlhs=to_arg(dx), drhs=to_arg(dy)) + breakpoint() + actual = linearAdd(lhs=to_arg(dx), rhs=to_arg(dy)) assert_close(actual, expected) # MarkovProduct in a LogFunctor trans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) - z, linearMP = linearize(MarkovProduct, ops.logaddexp, ops.add, trans, "time", {"prev": "curr"}, log=True) + with funsor.terms.lazy: + z, linearMP = linearize(MarkovProduct, ops.logaddexp, ops.add, trans, "time", {"prev": "curr"}, log=True) dtrans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) # expected = MarkovProduct(ops.logaddexp, ops.add, trans2, "time", {"prev": "curr"}) - actual = linearMP(dtrans=to_arg(dtrans)) - breakpoint() - assert_close(actual, expected) + actual = linearMP(trans=to_arg(dtrans)) + # assert_close(actual, expected) def test_transpose(): + # Mul + x = random_tensor(OrderedDict(j=Bint[4])) + y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + z, linearAdd = linearize(Binary, ops.mul, x, y, log=False) + linear_transpose(linearAdd, {"lhs", "rhs"}, log=False) + # Add x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) z, linearAdd = linearize(Binary, ops.add, x, y, log=False) - transpose(linearAdd, {"dlhs", "drhs"}) + linear_transpose(linearAdd, {"lhs", "rhs"}, log=False) dx = random_tensor(OrderedDict(j=Bint[4])) dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) From 78246504ea35b0451de5822b1cd648f4f186f2d7 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 15 Mar 2021 16:52:14 -0400 Subject: [PATCH 11/16] save changes --- funsor/autodiff.py | 159 ++++++++++++++++++++++-------------------- test/test_autodiff.py | 15 ++-- 2 files changed, 95 insertions(+), 79 deletions(-) diff --git a/funsor/autodiff.py b/funsor/autodiff.py index 271423ea7..2e9ba84a6 100644 --- a/funsor/autodiff.py +++ b/funsor/autodiff.py @@ -1,9 +1,10 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import math import funsor.ops as ops from funsor.ops import AssociativeOp, LogOp -from funsor.terms import Binary, Reduce, Tuple, Unary, eager, lazy, Variable, Number, Lambda +from funsor.terms import Binary, Reduce, Tuple, Unary, eager, lazy, Variable, Number, Lambda, Funsor from funsor.interpreter import interpretation from funsor.domains import Bint, Real, Array, Reals from collections import defaultdict @@ -12,6 +13,35 @@ from funsor.cnf import Contraction +class JVP(Tuple): + """ + Tuple:(Primal, Tanget) + Semiring: (Add, Mul) + """ + sum_op = ops.add + prod_op = ops.mul + div_op = ops.safediv + zero = Number(0) + one = Number(1) + + +class logJVP(Tuple): + """ + Tuple: (LogPrimal, LogTanget) + Semiring: (Logaddexp, Add) + """ + sum_op = ops.logaddexp + prod_op = ops.add + div_op = ops.safesub + zero = Number(-math.inf) + one = Number(0) + + +def requires_grad(primal): + tangent = Variable(str(id(primal)), Array["real", primal.data.shape])[tuple(primal.inputs)] + return JVP(primal, tangent) + + def to_var(x, name): var = Variable(name, Array["real", x.data.shape])[tuple(x.inputs)] return var @@ -33,90 +63,57 @@ def fjit(cls, *args): return cls(*new_args) -def grad(cls, *args, targets, log=True): - (out_primal, linear_fn), in_tangents = linearize(cls, *args, targets=targets, log=log) - linear_terms = get_linear_terms(linear_fn, set(in_tangents)) - - out_shape = tuple(value.size for key, value in linear_fn.inputs.items() if key not in in_tangents) - out_inputs = tuple(key for key in linear_fn.inputs if key not in in_tangents) - out_tangent = Variable("dout", Array["real", out_shape])[out_inputs] - - grad_dict = {} - for name, var in in_tangents.items(): - grad_dict[name] = transpose(linear_terms[name], var, out_tangent) - return grad_dict - +def grad(expr, targets, out_adj=None): + out_primal, out_tangent = expr + # in_primals = Tuple(tuple(primal for primal, _ in targets)) + in_tangents = set(tangent for _, tangent in targets) + out_adj = Number(1) if out_adj is None else out_adj + transposes = transpose(out_tangent, Number(1), in_tangents) + result = {} + for target in targets: + result[target] = transposes[target[1]] -def linearize(cls, *args, targets, log=True): - jvp = logJVP if log else JVP - new_args = [] - in_tangents = {} - for arg_name, arg in zip(cls._ast_fields, args): - # if isinstance(arg, (Number, Tensor)): - if arg in targets: - tangent_var = to_var(arg, arg_name) - arg = jvp(arg, tangent_var) - in_tangents[arg_name] = tangent_var - new_args.append(arg) - new_args = tuple(new_args) - return cls(*new_args), in_tangents - - -def get_linear_terms(expr, linear_vars): - if len(linear_vars) == 1: - return {next(iter(linear_vars)): expr} - assert isinstance(expr, Contraction) - assert expr.bin_op is ops.add or expr.bin_op is ops.logaddexp - assert expr.red_op is ops.nullop - terms = {} - for term in expr.terms: - if len(linear_vars.intersection(term.inputs)) == 1: - var = next(iter(linear_vars.intersection(term.inputs))) - terms[var] = term - else: - result = get_linear_terms(term, linear_vars) - terms.update(result) - return terms + # out_shape = tuple(value.size for key, value in out_tangent.inputs.items() if key not in in_tangents.inputs) + # out_inputs = tuple(key for key in out_tangent.inputs if key not in in_tangents.inputs) + # out_tangent = Variable("dout", Array["real", out_shape])[out_inputs] + # out_tangent = Number(1.0) + return result @singledispatch -def transpose(expr, target, out_tangent): - if expr is target: - return out_tangent - raise ValueError - - -@transpose.register(Binary) -def transpose_binary(expr, target, out_tangent): - if expr is target: - return out_tangent +def transpose(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): breakpoint() - pass - + if expr in targets: + result[expr] += out_adj + return result + else: + raise ValueError -class JVP(Tuple): - """ - Tuple:(Primal, Tanget) - Semiring: (Add, Mul) - """ - sum_op = ops.add - prod_op = ops.mul +@transpose.register(Contraction) +def transpose_contraction(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): + # assert expr.bin_op is ops.add or expr.bin_op is ops.logaddexp + breakpoint() + assert expr.red_op is ops.nullop + if expr in targets: + result[expr] += out_adj + out_adj = result[expr] -class logJVP(Tuple): - """ - Tuple: (LogPrimal, LogTanget) - Semiring: (Logaddexp, Add) - """ - sum_op = ops.logaddexp - prod_op = ops.add + for term in expr.terms: + if expr.bin_op is ops.add: + term_adj = out_adj + elif expr.bin_op is ops.mul: + term_adj = out_adj * expr / term + else: + raise ValueError + result = transpose(term, term_adj, targets, result) + return result @eager.register(Binary, AssociativeOp, JVP, JVP) @eager.register(Binary, AssociativeOp, logJVP, logJVP) def jvp_binary(op, lhs, rhs): - sum_op = lhs.sum_op - prod_op = lhs.prod_op + sum_op, prod_op = lhs.sum_op, lhs.prod_op lhs_primal, lhs_tangent = lhs rhs_primal, rhs_tangent = rhs primal = Binary(op, lhs_primal, rhs_primal) @@ -129,16 +126,30 @@ def jvp_binary(op, lhs, rhs): return type(lhs)(primal, tangent) +@eager.register(Binary, AssociativeOp, JVP, Tensor) +@eager.register(Binary, AssociativeOp, logJVP, Tensor) +def jvp_binary_jvp_funsor(op, lhs, rhs): + sum_op, prod_op = lhs.sum_op, lhs.prod_op + lhs_primal, lhs_tangent = lhs + primal = Binary(op, lhs_primal, rhs) + if op is sum_op: + tangent = sum_op(lhs_tangent, rhs) + elif op is prod_op: + tangent = prod_op(lhs_tangent, rhs) + else: + raise NotImplementedError + return type(lhs)(primal, tangent) + + @eager.register(Reduce, AssociativeOp, JVP, frozenset) @eager.register(Reduce, AssociativeOp, logJVP, frozenset) def jvp_reduce(op, arg, reduced_vars): - sum_op, prod_op = arg.sum_op, arg.prod_op + sum_op, prod_op, div_op = arg.sum_op, arg.prod_op, arg.div_op arg_primal, arg_tangent = arg primal = Reduce(op, arg_primal, reduced_vars) if op is sum_op: tangent = Reduce(sum_op, arg_tangent, reduced_vars) elif op is prod_op: - div_op = ops.SAFE_BINARY_INVERSES[prod_op] tangent = Reduce(prod_op, div_op(prod_op(arg_tangent, primal), arg_primal), reduced_vars) else: raise NotImplementedError diff --git a/test/test_autodiff.py b/test/test_autodiff.py index fcbbf14bb..52dd526d1 100644 --- a/test/test_autodiff.py +++ b/test/test_autodiff.py @@ -5,9 +5,9 @@ import funsor.ops as ops from funsor.domains import Bint, Real, Reals -from funsor.autodiff import JVP, to_var, to_arg, fjit, linearize, grad +from funsor.autodiff import JVP, to_var, to_arg, fjit, grad, requires_grad, transpose from funsor.testing import assert_close, random_tensor -from funsor.terms import Variable, Number, lazy, Lambda, Binary, Funsor +from funsor.terms import Variable, Number, lazy, Lambda, Binary, Funsor, Tuple from funsor.tensor import Tensor from funsor.optimizer import apply_optimizer from funsor.interpreter import interpretation @@ -200,9 +200,14 @@ def test_fjit(): def test_grad(): # Add - x = random_tensor(OrderedDict(j=Bint[4])) - y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - result = grad(Binary, ops.add, x, y, log=False) + x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + A = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + # A = random_tensor(OrderedDict(j=Bint[4])) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + + z = x * A + result = grad(z, (x,), out_adj) breakpoint() dx = random_tensor(OrderedDict(j=Bint[4])) From 45ab9a79d5c571b32220176a5b7db9bccb160ca2 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Mar 2021 05:22:12 -0400 Subject: [PATCH 12/16] Expand; autodiff and trace interpretations --- funsor/autodiff.py | 120 +++++++++++---- funsor/interpretations.py | 12 ++ funsor/tensor.py | 13 +- funsor/terms.py | 50 ++++++ test/test_autodiff.py | 317 +++++++++++++------------------------- 5 files changed, 270 insertions(+), 242 deletions(-) diff --git a/funsor/autodiff.py b/funsor/autodiff.py index 2e9ba84a6..35767cae0 100644 --- a/funsor/autodiff.py +++ b/funsor/autodiff.py @@ -11,6 +11,21 @@ from functools import reduce, singledispatch from funsor import Tensor from funsor.cnf import Contraction +from funsor.interpretations import trace, autodiff + + +@trace.register(Binary, AssociativeOp, Funsor, Funsor) +def trace_binary_associativeop(op, lhs, rhs): + with lazy: + result = Binary(op, lhs, rhs) + return result + + +@trace.register(Reduce, AssociativeOp, Funsor, frozenset) +def trace_binary_associativeop(op, arg, reduced_args): + with lazy: + result = Reduce(op, arg, reduced_args) + return result class JVP(Tuple): @@ -68,7 +83,7 @@ def grad(expr, targets, out_adj=None): # in_primals = Tuple(tuple(primal for primal, _ in targets)) in_tangents = set(tangent for _, tangent in targets) out_adj = Number(1) if out_adj is None else out_adj - transposes = transpose(out_tangent, Number(1), in_tangents) + transposes = transpose(out_tangent, out_adj, in_tangents) result = {} for target in targets: result[target] = transposes[target[1]] @@ -85,33 +100,84 @@ def transpose(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): breakpoint() if expr in targets: result[expr] += out_adj - return result + return result + + +@transpose.register(Binary) +def transpose_binary(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): + breakpoint() + if expr in targets: + result[expr] += out_adj + out_adj = result[expr] + + lhs, rhs, op = expr.lhs, expr.rhs, expr.op + + if op is ops.add: + lhs_adj = out_adj.reduce(ops.add, out_adj.input_vars - lhs.input_vars) + rhs_adj = out_adj.reduce(ops.add, out_adj.input_vars - rhs.input_vars) + elif op is ops.mul: + lhs_adj = (out_adj * rhs).reduce(ops.add, out_adj.input_vars - lhs.input_vars) + rhs_adj = (out_adj * lhs).reduce(ops.add, out_adj.input_vars - rhs.input_vars) + else: + return result # is it always correct? + result = transpose(lhs, lhs_adj, targets, result) + result = transpose(rhs, rhs_adj, targets, result) + return result + + +@transpose.register(Reduce) +def transpose_reduce(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): + breakpoint() + if expr in targets: + result[expr] += out_adj + out_adj = result[expr] + + op, arg, reduced_vars = expr.op, expr.arg, expr.reduced_vars + + if op is ops.add: + arg_adj = out_adj.expand(ops.add, tuple(reduced_vars)) + elif op is ops.mul: + arg_adj = ops.safediv(ops.mul(out_adj, expr), arg) else: raise ValueError + result = transpose(arg, arg_adj, targets, result) + return result @transpose.register(Contraction) def transpose_contraction(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): # assert expr.bin_op is ops.add or expr.bin_op is ops.logaddexp breakpoint() - assert expr.red_op is ops.nullop if expr in targets: result[expr] += out_adj out_adj = result[expr] - for term in expr.terms: - if expr.bin_op is ops.add: - term_adj = out_adj - elif expr.bin_op is ops.mul: - term_adj = out_adj * expr / term - else: - raise ValueError - result = transpose(term, term_adj, targets, result) + if expr.red_op is ops.nullop: + for term in expr.terms: + if expr.bin_op is ops.add: + term_adj = out_adj.reduce(ops.add, out_adj.input_vars - term.input_vars) + elif expr.bin_op is ops.mul: + expr_div_term = reduce(ops.mul, tuple(t for t in expr.terms if t is not term)) + term_adj = (out_adj * expr_div_term).reduce(ops.add, out_adj.input_vars - term.input_vars) + else: + raise ValueError + result = transpose(term, term_adj, targets, result) + elif expr.bin_op is ops.nullop: + for term in expr.terms: # only one term + if expr.red_op is ops.add: + term_adj = out_adj.expand(ops.add, tuple(expr.reduced_vars)) + elif expr.red_op is ops.mul: + term_adj = ops.safediv(ops.mul(out_adj, expr), term) + else: + raise ValueError + result = transpose(term, term_adj, targets, result) + else: + raise ValueError return result -@eager.register(Binary, AssociativeOp, JVP, JVP) -@eager.register(Binary, AssociativeOp, logJVP, logJVP) +@autodiff.register(Binary, AssociativeOp, JVP, JVP) +@autodiff.register(Binary, AssociativeOp, logJVP, logJVP) def jvp_binary(op, lhs, rhs): sum_op, prod_op = lhs.sum_op, lhs.prod_op lhs_primal, lhs_tangent = lhs @@ -126,8 +192,8 @@ def jvp_binary(op, lhs, rhs): return type(lhs)(primal, tangent) -@eager.register(Binary, AssociativeOp, JVP, Tensor) -@eager.register(Binary, AssociativeOp, logJVP, Tensor) +@autodiff.register(Binary, AssociativeOp, JVP, Tensor) +@autodiff.register(Binary, AssociativeOp, logJVP, Tensor) def jvp_binary_jvp_funsor(op, lhs, rhs): sum_op, prod_op = lhs.sum_op, lhs.prod_op lhs_primal, lhs_tangent = lhs @@ -141,25 +207,25 @@ def jvp_binary_jvp_funsor(op, lhs, rhs): return type(lhs)(primal, tangent) -@eager.register(Reduce, AssociativeOp, JVP, frozenset) -@eager.register(Reduce, AssociativeOp, logJVP, frozenset) +@autodiff.register(Reduce, AssociativeOp, JVP, frozenset) +@autodiff.register(Reduce, AssociativeOp, logJVP, frozenset) def jvp_reduce(op, arg, reduced_vars): sum_op, prod_op, div_op = arg.sum_op, arg.prod_op, arg.div_op arg_primal, arg_tangent = arg - primal = Reduce(op, arg_primal, reduced_vars) + out_primal = Reduce(op, arg_primal, reduced_vars) if op is sum_op: tangent = Reduce(sum_op, arg_tangent, reduced_vars) elif op is prod_op: - tangent = Reduce(prod_op, div_op(prod_op(arg_tangent, primal), arg_primal), reduced_vars) + tangent = Reduce(prod_op, div_op(prod_op(arg_tangent, out_primal), arg_primal), reduced_vars) else: raise NotImplementedError - return type(arg)(primal, tangent) + return type(arg)(out_primal, tangent) -@lazy.register(Unary, LogOp, JVP) -@eager.register(Unary, LogOp, JVP) -def jvp_log(op, arg): - arg_primal, arg_tangent = arg - primal = Unary(op, arg_primal) - tangent = Binary(ops.truediv, arg_tangent, arg_primal) - return JVP(primal, tangent) +# @lazy.register(Unary, LogOp, JVP) +# @eager.register(Unary, LogOp, JVP) +# def jvp_log(op, arg): +# arg_primal, arg_tangent = arg +# primal = Unary(op, arg_primal) +# tangent = Binary(ops.truediv, arg_tangent, arg_primal) +# return JVP(primal, tangent) diff --git a/funsor/interpretations.py b/funsor/interpretations.py index 48f042de1..178817c0c 100644 --- a/funsor/interpretations.py +++ b/funsor/interpretations.py @@ -329,6 +329,18 @@ def reflect(cls, *args): Eager exact naive interpretation wherever possible. """ +trace_base = DispatchedInterpretation("trace") +trace = PrioritizedInterpretation(trace_base, eager_base, normalize_base, reflect) +""" +Constructs a trace (expression) in terms of primitive operations. +""" + +autodiff_base = DispatchedInterpretation("autodiff") +autodiff = PrioritizedInterpretation(autodiff_base, trace_base, eager_base, normalize_base, reflect) +""" +Constructs a trace (expression) in terms of primitive operations. +""" + die = DispatchedInterpretation("die") eager_or_die = PrioritizedInterpretation(eager_base, die, reflect) diff --git a/funsor/tensor.py b/funsor/tensor.py index 63f5ffb1e..5ee4e8012 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -18,9 +18,10 @@ from . import ops from .delta import Delta from .domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain -from .ops import GetitemOp, MatmulOp, Op, ReshapeOp +from .ops import AssociativeOp, GetitemOp, MatmulOp, Op, ReshapeOp from .terms import ( Binary, + Expand, Finitary, Funsor, FunsorMeta, @@ -682,6 +683,16 @@ def eager_scatter_tensor(op, subs, source, reduced_vars): return Tensor(data, destin_inputs, output.dtype) +@eager.register(Expand, AssociativeOp, (Number, Tensor), tuple) +def eager_tensor_expand(op, arg, expanded_vars): + expanded_shape = tuple(var.output.size for var in expanded_vars) + old_shape = (-1,) * (len(arg.inputs) + len(arg.output.shape)) + new_shape = expanded_shape + old_shape + inputs = OrderedDict([(var.name, var.output) for var in expanded_vars]) + inputs.update(arg.inputs) + return Tensor(ops.expand(arg.data, new_shape), inputs, arg.dtype) + + @eager.register(Binary, Op, Tensor, Number) def eager_binary_tensor_number(op, lhs, rhs): dtype = find_domain(op, lhs.output, rhs.output).dtype diff --git a/funsor/terms.py b/funsor/terms.py index 9423e0266..578937048 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -337,6 +337,14 @@ def item(self): def requires_grad(self): return False + def expand(self, op, expanded_vars): + assert isinstance(op, AssociativeOp) + # Eagerly convert reduced_vars to appropriate things. + assert isinstance(expanded_vars, tuple) + if not expanded_vars: + return self + return Expand(op, self, expanded_vars) + def reduce(self, op, reduced_vars=None): """ Reduce along all or a subset of inputs. @@ -994,6 +1002,48 @@ def die_binary(op, lhs, rhs): raise NotImplementedError(f"Missing pattern for {repr(expr)}") +class Expand(Funsor): + """ + Lazy expand operation over multiple variables. + + The user-facing interface is the :meth:`Funsor.expand` method. + + :param op: An associative operator. + :type op: ~funsor.ops.AssociativeOp + :param funsor arg: An argument to be reduced. + :param frozenset reduced_vars: A set of variables over which to reduce. + """ + + def __init__(self, op, arg, expanded_vars): + assert isinstance(op, AssociativeOp) + assert isinstance(arg, Funsor) + assert isinstance(expanded_vars, tuple) + assert all(isinstance(v, Variable) for v in expanded_vars) + inputs = OrderedDict([(var.name, var.output) for var in expanded_vars]) + inputs.update(arg.inputs) + output = arg.output + fresh = frozenset() + bound = {} + super().__init__(inputs, output, fresh, bound) + self.op = op + self.arg = arg + self.expanded_vars = expanded_vars + + def __repr__(self): + assert self.expanded_vars + rvars = [repr(v) for v in self.expanded_vars] + return "{}.expand({}, {{{}}})".format( + repr(self.arg), self.op.__name__, ", ".join(rvars) + ) + + def __str__(self): + assert self.expanded_vars + rvars = [repr(v) for v in self.expanded_vars] + return "{}.expand({}, {{{}}})".format( + repr(self.arg), self.op.__name__, ", ".join(rvars) + ) + + class Reduce(Funsor): """ Lazy reduction over multiple variables. diff --git a/test/test_autodiff.py b/test/test_autodiff.py index 52dd526d1..52515a712 100644 --- a/test/test_autodiff.py +++ b/test/test_autodiff.py @@ -13,6 +13,7 @@ from funsor.interpreter import interpretation from funsor.factory import make_funsor, Bound, Fresh, Has from funsor.sum_product import MarkovProduct +from funsor.interpretations import trace, autodiff import torch @@ -20,257 +21,145 @@ funsor.set_backend("torch") -def test_id(): - x = random_tensor(OrderedDict(i=Bint[2])) - dx = random_tensor(OrderedDict(i=Bint[2])) - x_ = JVP(x) - with lazy: - f = x_ - assert_close(f.primal, x) - df = f.tangent[str(id(x))](**{str(id(x)): dx}) - assert_close(df, dx) +def test_mul_x_y(): + with autodiff: + # Mul + x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + z = x * y + result = grad(z, (x, y), out_adj) -def test_log(): - x = Tensor(torch.tensor([1., 2.]), OrderedDict(i=Bint[2])) - dx = random_tensor(OrderedDict(i=Bint[2])) - x_ = JVP(x) - with lazy: - f = x_.log() - primal = apply_optimizer(f.primal) - assert_close(primal, x.log()) - df = f.tangent[str(id(x))](**{str(id(x)): dx}) - assert_close(df, dx / x) - - -def test_add(): - x = random_tensor(OrderedDict(i=Bint[2])) - y = random_tensor(OrderedDict(j=Bint[3])) - dx = random_tensor(OrderedDict(i=Bint[2])) - dy = random_tensor(OrderedDict(j=Bint[3])) - x_ = JVP(x) - y_ = JVP(y) - with lazy: - f = x_ + y_ + expected_x = (out_adj * y[0]).reduce(ops.add, "k") + expected_y = out_adj * x[0] - primal = apply_optimizer(f.primal) - assert_close(primal, x + y) + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) - dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) - assert_close(dfdx, dx+y-y) + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) - dfdy = f.tangent[str(id(y))](**{str(id(y)): dy}) - assert_close(dfdy, dy+x-x) +def test_mul_x_x(): + with autodiff: + # Mul + x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) -def test_add_two(): - x = random_tensor(OrderedDict(i=Bint[2])) - dx = Tensor(torch.tensor([1, 1]), OrderedDict(i=Bint[2])) - x_ = JVP(x) - with lazy: - f = x_ + x_ + z = x * x + result = grad(z, (x,), out_adj) - primal = apply_optimizer(f.primal) - assert_close(primal, x + x) + expected_x = 2 * out_adj * x[0] + actual_x = apply_optimizer(result[x]) + assert_close(actual_x, expected_x) - dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) - assert_close(dfdx, 2*dx) +def test_add_x_x(): + with autodiff: + # Add + x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) -def test_mul(): - x = random_tensor(OrderedDict(i=Bint[2])) - y = random_tensor(OrderedDict(j=Bint[3])) - dx = random_tensor(OrderedDict(i=Bint[2])) - dy = random_tensor(OrderedDict(j=Bint[3])) - x_ = JVP(x) - y_ = JVP(y) - with lazy: - f = x_ * y_ + z = x + x + result = grad(z, (x,), out_adj) - primal = apply_optimizer(f.primal) - assert_close(primal, x * y) + expected_x = 2 * out_adj + actual_x = apply_optimizer(result[x]) + assert_close(actual_x, expected_x) - dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) - assert_close(dfdx, dx*y) - dfdy = f.tangent[str(id(y))](**{str(id(y)): dy}) - assert_close(dfdy, dy*x) +def test_add_x_y(): + with autodiff: + # Add + x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - # jacfwd - dx = Tensor(torch.eye(2), OrderedDict(i=Bint[2], l=Bint[2])) - jacdx = f.tangent[str(id(x))](**{str(id(x)): dx}) - assert_close(jacdx, dx*y) + z = x + y + result = grad(z, (x, y), out_adj) + expected_x = out_adj.reduce(ops.add, "k") + expected_y = out_adj -def test_mul_add(): - x = random_tensor(OrderedDict(i=Bint[2])) - y = random_tensor(OrderedDict(j=Bint[3])) - z = random_tensor(OrderedDict(k=Bint[4])) - dx = random_tensor(OrderedDict(i=Bint[2])) - dy = random_tensor(OrderedDict(j=Bint[3])) - dz = random_tensor(OrderedDict(k=Bint[4])) - x_ = JVP(x) - y_ = JVP(y) - z_ = JVP(z) - with lazy: - f = x_ * y_ + z_ + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) - primal = apply_optimizer(f.primal) - assert_close(primal, x * y + z) + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) - dfdx = f.tangent[str(id(x))](**{str(id(x)): dx}) - assert_close(dfdx, dx*y) - dfdy = f.tangent[str(id(y))](**{str(id(y)): dy}) - # assert_close(dfdy, dy*x+z-z) +def test_mul_add_x_x_y(): + with autodiff: + # Add Mul + x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - dfdz = f.tangent[str(id(z))](**{str(id(z)): dz}) - breakpoint() - assert_close(dfdz, dz+x*y-x*y) + z = x * x + y + result = grad(z, (x, y), out_adj) + expected_x = 2 * x[0] * out_adj.reduce(ops.add, "k") + expected_y = out_adj -def test_reduce_sum(): - x = random_tensor(OrderedDict(j=Bint[4])) - dx = random_tensor(OrderedDict(j=Bint[4])) - Tx = Variable("dx", Real) - x_ = JVP((x, Tx)) - with lazy: - f, df = x_.reduce(ops.add, "j") - breakpoint() - assert_close(apply_optimizer(f), x.reduce(ops.add, "j")) - assert_close(df(dx=dx), dx.reduce(ops.add, "j")) + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) -def test_reduce_prod(): - x = random_tensor(OrderedDict(j=Bint[4])) - dx = random_tensor(OrderedDict(j=Bint[4])) - x_ = JVP((x, dx)) - f, df = x_.reduce(ops.mul, "j") - assert_close(f, x.reduce(ops.mul, "j")) - assert_close(df, (f * dx / x).reduce(ops.add, "j")) +def test_mul_add_xx_yy(): + with autodiff: + # Add Mul + x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) -def test_reduce_jacfwd(): - x = random_tensor(OrderedDict(j=Bint[4])) - # dx = Tensor(torch.tensor([1, 0, 0, 0]), OrderedDict(j=Bint[4])) - dx = Tensor(torch.eye(4), OrderedDict(j=Bint[4], l=Bint[4])) - x_ = JVP((x, dx)) - f, df = x_.reduce(ops.mul, "j") - assert_close(f, x.reduce(ops.mul, "j")) - assert_close(df, (f * dx / x).reduce(ops.add, "j")) - - -@make_funsor -def MatMul( - a: Has[{"i"}], - b: Has[{"i"}], - i: Bound - ) -> Fresh[lambda a: a]: - return Prod(a, b).reduce(ops.add, i) + z = x * x + y + y + result = grad(z, (x, y), out_adj) -@make_funsor -def Prod( - x: Funsor, - y: Funsor - ) -> Fresh[lambda x: x]: - return x * y + expected_x = 2 * x[0] * out_adj.reduce(ops.add, "k") + expected_y = 2 * out_adj + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) -def test_fjit(): - # Product - x = random_tensor(OrderedDict(j=Bint[4])) - y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - cProd = fjit(Prod, x, y) - - x2 = random_tensor(OrderedDict(j=Bint[4])) - y2 = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - expected = Prod(x2, y2) - actual = cProd(x=to_arg(x2), y=to_arg(y2)) - assert_close(actual, expected) - - # MarkovProduct - trans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) - cMarkovProduct = fjit(MarkovProduct, ops.logaddexp, ops.add, trans, "time", {"prev": "curr"}) - - trans2 = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) - expected = MarkovProduct(ops.logaddexp, ops.add, trans2, "time", {"prev": "curr"}) - actual = cMarkovProduct(trans=to_arg(trans2)) - assert_close(actual, expected) - - -def test_grad(): - # Add - x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) - y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) - A = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - # A = random_tensor(OrderedDict(j=Bint[4])) - out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - - z = x * A - result = grad(z, (x,), out_adj) - breakpoint() - - dx = random_tensor(OrderedDict(j=Bint[4])) - dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - expected = dx + dy - actual = linearAdd(lhs=to_arg(dx), rhs=to_arg(dy)) - assert_close(actual, expected) - assert_close(z, x + y) - - -def test_linearize(): - # Add - x = random_tensor(OrderedDict(j=Bint[4])) - y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - (z, linearAdd), linear_vars = linearize(Binary, ops.add, x, y, log=False) + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) - dx = random_tensor(OrderedDict(j=Bint[4])) - dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - expected = dx + dy - actual = linearAdd(lhs=to_arg(dx), rhs=to_arg(dy)) - assert_close(actual, expected) - assert_close(z, x + y) - # Add in a LogFunctor - x = random_tensor(OrderedDict(j=Bint[4])) - y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - with funsor.terms.lazy: - z, linearAdd = linearize(Binary, ops.add, x, y, log=True) +def test_reduce_x(): + with autodiff: + # Reduce + y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) - dx = random_tensor(OrderedDict(j=Bint[4])) - dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - expected = ops.logaddexp(ops.add(y, dx), ops.add(x, dy)) - breakpoint() - actual = linearAdd(lhs=to_arg(dx), rhs=to_arg(dy)) - assert_close(actual, expected) + z = y.reduce(ops.add, "k") + result = grad(z, (y,), out_adj) - # MarkovProduct in a LogFunctor - trans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) - with funsor.terms.lazy: - z, linearMP = linearize(MarkovProduct, ops.logaddexp, ops.add, trans, "time", {"prev": "curr"}, log=True) + expected_y = out_adj.expand(ops.add, (Variable("k", Bint[5]),)).align(tuple(y[0].inputs.keys())) + actual_y = apply_optimizer(result[y]) + assert_close(actual_y, expected_y) - dtrans = random_tensor(OrderedDict(time=Bint[5], prev=Bint[3], curr=Bint[3])) - # expected = MarkovProduct(ops.logaddexp, ops.add, trans2, "time", {"prev": "curr"}) - actual = linearMP(trans=to_arg(dtrans)) - # assert_close(actual, expected) +def test_trace(): + @make_funsor + def Matmul( + x: Has[{"i"}], + y: Has[{"i"}], + i: Bound + ) -> Fresh[lambda x: x]: + return (x * y).reduce(ops.add, i) -def test_transpose(): - # Mul x = random_tensor(OrderedDict(j=Bint[4])) y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z, linearAdd = linearize(Binary, ops.mul, x, y, log=False) - linear_transpose(linearAdd, {"lhs", "rhs"}, log=False) + eager_z = Matmul(x, y, "j") + with lazy: + lazy_z = Matmul(x, y, "j") - # Add - x = random_tensor(OrderedDict(j=Bint[4])) - y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z, linearAdd = linearize(Binary, ops.add, x, y, log=False) - linear_transpose(linearAdd, {"lhs", "rhs"}, log=False) - - dx = random_tensor(OrderedDict(j=Bint[4])) - dy = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - expected = dx + dy - actual = linearAdd(dlhs=to_arg(dx), drhs=to_arg(dy)) - assert_close(actual, expected) - assert_close(z, x + y) + with trace: + trace_z = Matmul(x, y, "j") + + assert_close(eager_z, apply_optimizer(lazy_z)) + assert_close(eager_z, apply_optimizer(trace_z)) From ab2157d3d78f57cd7375982363c59c138dc9195d Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Mar 2021 20:53:22 -0400 Subject: [PATCH 13/16] clean up code --- funsor/autodiff.py | 221 ++++++++++++++++++++++++------------------ test/test_autodiff.py | 100 +++++++++++-------- 2 files changed, 186 insertions(+), 135 deletions(-) diff --git a/funsor/autodiff.py b/funsor/autodiff.py index 35767cae0..0ec5563b8 100644 --- a/funsor/autodiff.py +++ b/funsor/autodiff.py @@ -2,30 +2,29 @@ # SPDX-License-Identifier: Apache-2.0 import math -import funsor.ops as ops -from funsor.ops import AssociativeOp, LogOp -from funsor.terms import Binary, Reduce, Tuple, Unary, eager, lazy, Variable, Number, Lambda, Funsor -from funsor.interpreter import interpretation -from funsor.domains import Bint, Real, Array, Reals from collections import defaultdict from functools import reduce, singledispatch + +import funsor.ops as ops from funsor import Tensor +from funsor.adjoint import _alpha_unmangle from funsor.cnf import Contraction -from funsor.interpretations import trace, autodiff - - -@trace.register(Binary, AssociativeOp, Funsor, Funsor) -def trace_binary_associativeop(op, lhs, rhs): - with lazy: - result = Binary(op, lhs, rhs) - return result - - -@trace.register(Reduce, AssociativeOp, Funsor, frozenset) -def trace_binary_associativeop(op, arg, reduced_args): - with lazy: - result = Reduce(op, arg, reduced_args) - return result +from funsor.domains import Array, Bint, Real, Reals +from funsor.interpretations import autodiff, trace +from funsor.interpreter import interpretation +from funsor.ops import AssociativeOp, LogOp +from funsor.terms import ( + Binary, + Funsor, + Lambda, + Number, + Reduce, + Tuple, + Unary, + Variable, + eager, + lazy, +) class JVP(Tuple): @@ -33,151 +32,175 @@ class JVP(Tuple): Tuple:(Primal, Tanget) Semiring: (Add, Mul) """ + sum_op = ops.add prod_op = ops.mul div_op = ops.safediv zero = Number(0) one = Number(1) + @property + def primal(self): + return self[0] + + @property + def tangent(self): + return self[1] + -class logJVP(Tuple): +class LJVP(Tuple): """ Tuple: (LogPrimal, LogTanget) Semiring: (Logaddexp, Add) """ + sum_op = ops.logaddexp prod_op = ops.add div_op = ops.safesub zero = Number(-math.inf) one = Number(0) + @property + def primal(self): + return self[0] + + @property + def tangent(self): + return self[1] -def requires_grad(primal): - tangent = Variable(str(id(primal)), Array["real", primal.data.shape])[tuple(primal.inputs)] - return JVP(primal, tangent) + +@trace.register(Binary, AssociativeOp, Funsor, Funsor) +def trace_binary_associativeop(op, lhs, rhs): + with lazy: + result = Binary(op, lhs, rhs) + return result -def to_var(x, name): - var = Variable(name, Array["real", x.data.shape])[tuple(x.inputs)] - return var +@trace.register(Reduce, AssociativeOp, Funsor, frozenset) +def trace_binary_associativeop(op, arg, reduced_args): + with lazy: + result = Reduce(op, arg, reduced_args) + return result -def to_arg(x): - input_vars = tuple(Variable(key, value) for key, value in x.inputs.items()) - arg = reduce(lambda a, b: Lambda(b, a), reversed(input_vars), x) - return arg +def to_jvp(primal): + input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items()) + output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output + tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)] + return JVP(primal, tangent_placeholder) -def fjit(cls, *args): - new_args = [] - for arg_name, arg in zip(cls._ast_fields, args): - if isinstance(arg, (Number, Tensor)): - arg = to_var(arg, arg_name) - new_args.append(arg) - new_args = tuple(new_args) - return cls(*new_args) +def to_ljvp(primal): + input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items()) + output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output + tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)] + return LJVP(primal, tangent_placeholder) -def grad(expr, targets, out_adj=None): - out_primal, out_tangent = expr - # in_primals = Tuple(tuple(primal for primal, _ in targets)) - in_tangents = set(tangent for _, tangent in targets) - out_adj = Number(1) if out_adj is None else out_adj - transposes = transpose(out_tangent, out_adj, in_tangents) +def grad(expr, targets, out_tangent=None): + out_tangent = expr.one if out_tangent is None else out_tangent + in_tangents = set(target.tangent for target in targets) + transposes = transpose( + expr.tangent, out_tangent, in_tangents, defaultdict(lambda: expr.zero) + ) result = {} for target in targets: - result[target] = transposes[target[1]] - - # out_shape = tuple(value.size for key, value in out_tangent.inputs.items() if key not in in_tangents.inputs) - # out_inputs = tuple(key for key in out_tangent.inputs if key not in in_tangents.inputs) - # out_tangent = Variable("dout", Array["real", out_shape])[out_inputs] - # out_tangent = Number(1.0) + result[target] = transposes[target.tangent] return result @singledispatch -def transpose(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): - breakpoint() - if expr in targets: - result[expr] += out_adj +def transpose(expr, out_tangent, in_tangents, result): + if expr in in_tangents: + result[expr] += out_tangent return result @transpose.register(Binary) -def transpose_binary(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): - breakpoint() - if expr in targets: - result[expr] += out_adj - out_adj = result[expr] +def transpose_binary(expr, out_tangent, in_tangents, result): + if expr in in_tangents: + result[expr] += out_tangent + out_tangent = result[expr] - lhs, rhs, op = expr.lhs, expr.rhs, expr.op + op, lhs, rhs = expr.op, expr.lhs, expr.rhs if op is ops.add: - lhs_adj = out_adj.reduce(ops.add, out_adj.input_vars - lhs.input_vars) - rhs_adj = out_adj.reduce(ops.add, out_adj.input_vars - rhs.input_vars) + lhs_adj = out_tangent.reduce(ops.add, out_tangent.input_vars - lhs.input_vars) + rhs_adj = out_tangent.reduce(ops.add, out_tangent.input_vars - rhs.input_vars) elif op is ops.mul: - lhs_adj = (out_adj * rhs).reduce(ops.add, out_adj.input_vars - lhs.input_vars) - rhs_adj = (out_adj * lhs).reduce(ops.add, out_adj.input_vars - rhs.input_vars) + lhs_adj = (out_tangent * rhs).reduce( + ops.add, out_tangent.input_vars - lhs.input_vars + ) + rhs_adj = (out_tangent * lhs).reduce( + ops.add, out_tangent.input_vars - rhs.input_vars + ) else: - return result # is it always correct? - result = transpose(lhs, lhs_adj, targets, result) - result = transpose(rhs, rhs_adj, targets, result) + return result # is it always correct? + result = transpose(lhs, lhs_adj, in_tangents, result) + result = transpose(rhs, rhs_adj, in_tangents, result) return result @transpose.register(Reduce) -def transpose_reduce(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): - breakpoint() - if expr in targets: - result[expr] += out_adj - out_adj = result[expr] +def transpose_reduce(expr, out_tangent, in_tangents, result): + if expr in in_tangents: + result[expr] += out_tangent + out_tangent = result[expr] - op, arg, reduced_vars = expr.op, expr.arg, expr.reduced_vars + # fix this in contraction as well + op, arg, reduced_vars = _alpha_unmangle(expr) if op is ops.add: - arg_adj = out_adj.expand(ops.add, tuple(reduced_vars)) + arg_adj = out_tangent.expand(ops.add, tuple(reduced_vars)) elif op is ops.mul: - arg_adj = ops.safediv(ops.mul(out_adj, expr), arg) + arg_adj = ops.safediv(ops.mul(out_tangent, expr), arg) else: raise ValueError - result = transpose(arg, arg_adj, targets, result) + result = transpose(arg, arg_adj, in_tangents, result) return result @transpose.register(Contraction) -def transpose_contraction(expr, out_adj, targets, result=defaultdict(lambda: Number(0))): - # assert expr.bin_op is ops.add or expr.bin_op is ops.logaddexp +def transpose_contraction(expr, out_tangent, in_tangents, result): breakpoint() - if expr in targets: - result[expr] += out_adj - out_adj = result[expr] + if expr in in_tangents: + result[expr] += out_tangent + out_tangent = result[expr] if expr.red_op is ops.nullop: for term in expr.terms: if expr.bin_op is ops.add: - term_adj = out_adj.reduce(ops.add, out_adj.input_vars - term.input_vars) + term_adj = out_tangent.reduce( + ops.add, out_tangent.input_vars - term.input_vars + ) elif expr.bin_op is ops.mul: - expr_div_term = reduce(ops.mul, tuple(t for t in expr.terms if t is not term)) - term_adj = (out_adj * expr_div_term).reduce(ops.add, out_adj.input_vars - term.input_vars) + expr_div_term = reduce( + ops.mul, tuple(t for t in expr.terms if t is not term) + ) + term_adj = (out_tangent * expr_div_term).reduce( + ops.add, out_tangent.input_vars - term.input_vars + ) else: raise ValueError - result = transpose(term, term_adj, targets, result) + result = transpose(term, term_adj, in_tangents, result) elif expr.bin_op is ops.nullop: - for term in expr.terms: # only one term + for term in expr.terms: # only one term if expr.red_op is ops.add: - term_adj = out_adj.expand(ops.add, tuple(expr.reduced_vars)) + term_adj = out_tangent.expand(ops.add, tuple(expr.reduced_vars)) elif expr.red_op is ops.mul: - term_adj = ops.safediv(ops.mul(out_adj, expr), term) + term_adj = ops.safediv(ops.mul(out_tangent, expr), term) else: raise ValueError - result = transpose(term, term_adj, targets, result) + result = transpose(term, term_adj, in_tangents, result) else: raise ValueError return result +@eager.register(Binary, AssociativeOp, JVP, JVP) +@eager.register(Binary, AssociativeOp, LJVP, LJVP) @autodiff.register(Binary, AssociativeOp, JVP, JVP) -@autodiff.register(Binary, AssociativeOp, logJVP, logJVP) +@autodiff.register(Binary, AssociativeOp, LJVP, LJVP) def jvp_binary(op, lhs, rhs): sum_op, prod_op = lhs.sum_op, lhs.prod_op lhs_primal, lhs_tangent = lhs @@ -186,14 +209,18 @@ def jvp_binary(op, lhs, rhs): if op is sum_op: tangent = sum_op(lhs_tangent, rhs_tangent) elif op is prod_op: - tangent = sum_op(prod_op(rhs_primal, lhs_tangent), prod_op(lhs_primal, rhs_tangent)) + tangent = sum_op( + prod_op(rhs_primal, lhs_tangent), prod_op(lhs_primal, rhs_tangent) + ) else: raise NotImplementedError return type(lhs)(primal, tangent) +@eager.register(Binary, AssociativeOp, JVP, Tensor) +@eager.register(Binary, AssociativeOp, LJVP, Tensor) @autodiff.register(Binary, AssociativeOp, JVP, Tensor) -@autodiff.register(Binary, AssociativeOp, logJVP, Tensor) +@autodiff.register(Binary, AssociativeOp, LJVP, Tensor) def jvp_binary_jvp_funsor(op, lhs, rhs): sum_op, prod_op = lhs.sum_op, lhs.prod_op lhs_primal, lhs_tangent = lhs @@ -207,8 +234,10 @@ def jvp_binary_jvp_funsor(op, lhs, rhs): return type(lhs)(primal, tangent) +@eager.register(Reduce, AssociativeOp, JVP, frozenset) +@eager.register(Reduce, AssociativeOp, LJVP, frozenset) @autodiff.register(Reduce, AssociativeOp, JVP, frozenset) -@autodiff.register(Reduce, AssociativeOp, logJVP, frozenset) +@autodiff.register(Reduce, AssociativeOp, LJVP, frozenset) def jvp_reduce(op, arg, reduced_vars): sum_op, prod_op, div_op = arg.sum_op, arg.prod_op, arg.div_op arg_primal, arg_tangent = arg @@ -216,7 +245,9 @@ def jvp_reduce(op, arg, reduced_vars): if op is sum_op: tangent = Reduce(sum_op, arg_tangent, reduced_vars) elif op is prod_op: - tangent = Reduce(prod_op, div_op(prod_op(arg_tangent, out_primal), arg_primal), reduced_vars) + tangent = Reduce( + prod_op, div_op(prod_op(arg_tangent, out_primal), arg_primal), reduced_vars + ) else: raise NotImplementedError return type(arg)(out_primal, tangent) diff --git a/test/test_autodiff.py b/test/test_autodiff.py index 52515a712..de2885c15 100644 --- a/test/test_autodiff.py +++ b/test/test_autodiff.py @@ -3,29 +3,29 @@ from collections import OrderedDict +import torch + +import funsor import funsor.ops as ops +from funsor.autodiff import JVP, grad, to_jvp, to_ljvp from funsor.domains import Bint, Real, Reals -from funsor.autodiff import JVP, to_var, to_arg, fjit, grad, requires_grad, transpose -from funsor.testing import assert_close, random_tensor -from funsor.terms import Variable, Number, lazy, Lambda, Binary, Funsor, Tuple -from funsor.tensor import Tensor -from funsor.optimizer import apply_optimizer +from funsor.factory import Bound, Fresh, Has, make_funsor +from funsor.interpretations import autodiff, trace from funsor.interpreter import interpretation -from funsor.factory import make_funsor, Bound, Fresh, Has +from funsor.optimizer import apply_optimizer from funsor.sum_product import MarkovProduct -from funsor.interpretations import trace, autodiff - +from funsor.tensor import Tensor +from funsor.terms import Binary, Funsor, Lambda, Number, Tuple, Variable, lazy +from funsor.testing import assert_close, random_tensor -import torch -import funsor funsor.set_backend("torch") def test_mul_x_y(): with autodiff: # Mul - x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) - y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) + y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) z = x * y @@ -44,7 +44,7 @@ def test_mul_x_y(): def test_mul_x_x(): with autodiff: # Mul - x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) out_adj = random_tensor(OrderedDict(j=Bint[4])) z = x * x @@ -55,10 +55,10 @@ def test_mul_x_x(): assert_close(actual_x, expected_x) -def test_add_x_x(): +def test_add_x_x(): with autodiff: # Add - x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) + x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) out_adj = random_tensor(OrderedDict(j=Bint[4])) z = x + x @@ -72,8 +72,8 @@ def test_add_x_x(): def test_add_x_y(): with autodiff: # Add - x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) - y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) + y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) z = x + y @@ -92,8 +92,8 @@ def test_add_x_y(): def test_mul_add_x_x_y(): with autodiff: # Add Mul - x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) - y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) + y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) z = x * x + y @@ -112,8 +112,8 @@ def test_mul_add_x_x_y(): def test_mul_add_xx_yy(): with autodiff: # Add Mul - x = requires_grad(random_tensor(OrderedDict(j=Bint[4]))) - y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) + y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) z = x * x + y + y @@ -132,34 +132,54 @@ def test_mul_add_xx_yy(): def test_reduce_x(): with autodiff: # Reduce - y = requires_grad(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4])) z = y.reduce(ops.add, "k") result = grad(z, (y,), out_adj) - expected_y = out_adj.expand(ops.add, (Variable("k", Bint[5]),)).align(tuple(y[0].inputs.keys())) + expected_y = out_adj.expand(ops.add, (Variable("k", Bint[5]),)) actual_y = apply_optimizer(result[y]) assert_close(actual_y, expected_y) -def test_trace(): - @make_funsor - def Matmul( - x: Has[{"i"}], - y: Has[{"i"}], - i: Bound - ) -> Fresh[lambda x: x]: - return (x * y).reduce(ops.add, i) +def test_mul_reduce_x_y(): + with autodiff: + # Reduce + x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) + y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(k=Bint[5])) + + z = (x * y).reduce(ops.add, "j") + result = grad(z, (x, y), out_adj) + + expected_x = (y[0] * out_adj).reduce(ops.add, "k") + expected_y = x[0] * out_adj.expand(ops.add, (Variable("j", Bint[4]),)) - x = random_tensor(OrderedDict(j=Bint[4])) - y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - eager_z = Matmul(x, y, "j") - with lazy: - lazy_z = Matmul(x, y, "j") + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) + + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) - with trace: - trace_z = Matmul(x, y, "j") - assert_close(eager_z, apply_optimizer(lazy_z)) - assert_close(eager_z, apply_optimizer(trace_z)) +# def test_trace(): +# @make_funsor +# def Matmul( +# x: Has[{"i"}], +# y: Has[{"i"}], +# i: Bound +# ) -> Fresh[lambda x: x]: +# return (x * y).reduce(ops.add, i) +# +# x = random_tensor(OrderedDict(j=Bint[4])) +# y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) +# eager_z = Matmul(x, y, "j") +# with lazy: +# lazy_z = Matmul(x, y, "j") +# +# with trace: +# trace_z = Matmul(x, y, "j") +# +# assert_close(eager_z, apply_optimizer(lazy_z)) +# assert_close(eager_z, apply_optimizer(trace_z)) From ccef22d7493d13eec3ff4a4578d6a8d7b1a024cc Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Mar 2021 22:02:00 -0400 Subject: [PATCH 14/16] tranpose patterns for log jvp --- funsor/autodiff.py | 66 +++++++++-------- funsor/interpretations.py | 4 +- test/test_autodiff.py | 148 ++++++++++++++++++++++++++------------ 3 files changed, 144 insertions(+), 74 deletions(-) diff --git a/funsor/autodiff.py b/funsor/autodiff.py index 0ec5563b8..b5e7f8bd5 100644 --- a/funsor/autodiff.py +++ b/funsor/autodiff.py @@ -101,7 +101,11 @@ def grad(expr, targets, out_tangent=None): out_tangent = expr.one if out_tangent is None else out_tangent in_tangents = set(target.tangent for target in targets) transposes = transpose( - expr.tangent, out_tangent, in_tangents, defaultdict(lambda: expr.zero) + expr.tangent, + out_tangent, + in_tangents, + defaultdict(lambda: expr.zero), + type(expr), ) result = {} for target in targets: @@ -110,54 +114,58 @@ def grad(expr, targets, out_tangent=None): @singledispatch -def transpose(expr, out_tangent, in_tangents, result): +def transpose(expr, out_tangent, in_tangents, result, semiring): if expr in in_tangents: - result[expr] += out_tangent + result[expr] = semiring.sum_op(result[expr], out_tangent) return result @transpose.register(Binary) -def transpose_binary(expr, out_tangent, in_tangents, result): - if expr in in_tangents: - result[expr] += out_tangent - out_tangent = result[expr] +def transpose_binary(expr, out_tangent, in_tangents, result, semiring): op, lhs, rhs = expr.op, expr.lhs, expr.rhs + sum_op, prod_op = semiring.sum_op, semiring.prod_op + + if expr in in_tangents: + result[expr] = sum_op(result[expr], out_tangent) + out_tangent = result[expr] - if op is ops.add: - lhs_adj = out_tangent.reduce(ops.add, out_tangent.input_vars - lhs.input_vars) - rhs_adj = out_tangent.reduce(ops.add, out_tangent.input_vars - rhs.input_vars) - elif op is ops.mul: - lhs_adj = (out_tangent * rhs).reduce( - ops.add, out_tangent.input_vars - lhs.input_vars + if op is sum_op: + lhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - lhs.input_vars) + rhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - rhs.input_vars) + elif op is prod_op: + lhs_adj = prod_op(rhs, out_tangent).reduce( + sum_op, out_tangent.input_vars - lhs.input_vars ) - rhs_adj = (out_tangent * lhs).reduce( - ops.add, out_tangent.input_vars - rhs.input_vars + rhs_adj = prod_op(lhs, out_tangent).reduce( + sum_op, out_tangent.input_vars - rhs.input_vars ) else: return result # is it always correct? - result = transpose(lhs, lhs_adj, in_tangents, result) - result = transpose(rhs, rhs_adj, in_tangents, result) + result = transpose(lhs, lhs_adj, in_tangents, result, semiring) + result = transpose(rhs, rhs_adj, in_tangents, result, semiring) return result @transpose.register(Reduce) -def transpose_reduce(expr, out_tangent, in_tangents, result): - if expr in in_tangents: - result[expr] += out_tangent - out_tangent = result[expr] - +def transpose_reduce(expr, out_tangent, in_tangents, result, semiring): # fix this in contraction as well op, arg, reduced_vars = _alpha_unmangle(expr) + sum_op, prod_op = semiring.sum_op, semiring.prod_op - if op is ops.add: - arg_adj = out_tangent.expand(ops.add, tuple(reduced_vars)) - elif op is ops.mul: - arg_adj = ops.safediv(ops.mul(out_tangent, expr), arg) + if expr in in_tangents: + result[expr] = sum_op(result[expr], out_tangent) + out_tangent = result[expr] + + if op is sum_op: + arg_adj = out_tangent.expand(sum_op, tuple(reduced_vars)) + result = transpose(arg, arg_adj, in_tangents, result, semiring) + return result + elif op is prod_op: + # this is unnecessary + return result else: raise ValueError - result = transpose(arg, arg_adj, in_tangents, result) - return result @transpose.register(Contraction) @@ -246,7 +254,7 @@ def jvp_reduce(op, arg, reduced_vars): tangent = Reduce(sum_op, arg_tangent, reduced_vars) elif op is prod_op: tangent = Reduce( - prod_op, div_op(prod_op(arg_tangent, out_primal), arg_primal), reduced_vars + sum_op, div_op(prod_op(arg_tangent, out_primal), arg_primal), reduced_vars ) else: raise NotImplementedError diff --git a/funsor/interpretations.py b/funsor/interpretations.py index 178817c0c..2e38d6b9b 100644 --- a/funsor/interpretations.py +++ b/funsor/interpretations.py @@ -336,7 +336,9 @@ def reflect(cls, *args): """ autodiff_base = DispatchedInterpretation("autodiff") -autodiff = PrioritizedInterpretation(autodiff_base, trace_base, eager_base, normalize_base, reflect) +autodiff = PrioritizedInterpretation( + autodiff_base, trace_base, eager_base, normalize_base, reflect +) """ Constructs a trace (expression) in terms of primitive operations. """ diff --git a/test/test_autodiff.py b/test/test_autodiff.py index de2885c15..7e28e917a 100644 --- a/test/test_autodiff.py +++ b/test/test_autodiff.py @@ -1,8 +1,11 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import math from collections import OrderedDict +from functools import reduce +import pytest import torch import funsor @@ -21,65 +24,83 @@ funsor.set_backend("torch") -def test_mul_x_y(): +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_x_y(sum_op, prod_op, tojvp): with autodiff: # Mul - x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) - y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z = x * y + z = prod_op(x, y) result = grad(z, (x, y), out_adj) - expected_x = (out_adj * y[0]).reduce(ops.add, "k") - expected_y = out_adj * x[0] + expected_x = prod_op(out_adj, y.primal).reduce(sum_op, "k") + expected_y = prod_op(out_adj, x.primal) actual_x = apply_optimizer(result[x]) actual_y = apply_optimizer(result[y]) - assert_close(actual_x, expected_x) - assert_close(actual_y, expected_y) + assert_close(actual_x, expected_x, rtol=1e-5) + assert_close(actual_y, expected_y, rtol=1e-5) -def test_mul_x_x(): +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_x_x(sum_op, prod_op, tojvp): with autodiff: # Mul - x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) out_adj = random_tensor(OrderedDict(j=Bint[4])) - z = x * x + z = prod_op(x, x) result = grad(z, (x,), out_adj) - expected_x = 2 * out_adj * x[0] + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = reduce(prod_op, (two, out_adj, x.primal)) actual_x = apply_optimizer(result[x]) assert_close(actual_x, expected_x) -def test_add_x_x(): +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_add_x_x(sum_op, prod_op, tojvp): with autodiff: # Add - x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) out_adj = random_tensor(OrderedDict(j=Bint[4])) - z = x + x + z = sum_op(x, x) result = grad(z, (x,), out_adj) - expected_x = 2 * out_adj + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = prod_op(two, out_adj) actual_x = apply_optimizer(result[x]) assert_close(actual_x, expected_x) -def test_add_x_y(): +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_add_x_y(sum_op, prod_op, tojvp): with autodiff: # Add - x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) - y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z = x + y + z = sum_op(x, y) result = grad(z, (x, y), out_adj) - expected_x = out_adj.reduce(ops.add, "k") + expected_x = out_adj.reduce(sum_op, "k") expected_y = out_adj actual_x = apply_optimizer(result[x]) @@ -89,17 +110,22 @@ def test_add_x_y(): assert_close(actual_y, expected_y) -def test_mul_add_x_x_y(): +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_add_x_x_y(sum_op, prod_op, tojvp): with autodiff: # Add Mul - x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) - y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z = x * x + y + z = sum_op(prod_op(x, x), y) result = grad(z, (x, y), out_adj) - expected_x = 2 * x[0] * out_adj.reduce(ops.add, "k") + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = reduce(prod_op, (two, x.primal, out_adj.reduce(sum_op, "k"))) expected_y = out_adj actual_x = apply_optimizer(result[x]) @@ -109,18 +135,23 @@ def test_mul_add_x_x_y(): assert_close(actual_y, expected_y) -def test_mul_add_xx_yy(): +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_add_xx_yy(sum_op, prod_op, tojvp): with autodiff: # Add Mul - x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) - y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) - z = x * x + y + y + z = reduce(sum_op, (prod_op(x, x), y, y)) result = grad(z, (x, y), out_adj) - expected_x = 2 * x[0] * out_adj.reduce(ops.add, "k") - expected_y = 2 * out_adj + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = reduce(prod_op, (two, x.primal, out_adj.reduce(sum_op, "k"))) + expected_y = prod_op(two, out_adj) actual_x = apply_optimizer(result[x]) actual_y = apply_optimizer(result[y]) @@ -129,38 +160,67 @@ def test_mul_add_xx_yy(): assert_close(actual_y, expected_y) -def test_reduce_x(): +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_reduce_add_x(sum_op, prod_op, tojvp): with autodiff: # Reduce - y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(j=Bint[4])) - z = y.reduce(ops.add, "k") + z = y.reduce(sum_op, "k") result = grad(z, (y,), out_adj) expected_y = out_adj.expand(ops.add, (Variable("k", Bint[5]),)) actual_y = apply_optimizer(result[y]) - assert_close(actual_y, expected_y) + assert_close(actual_y, expected_y, rtol=1e-5) + + +@pytest.mark.parametrize( + "sum_op,prod_op,div_op,tojvp", + [ + (ops.add, ops.mul, ops.safediv, to_jvp), + (ops.logaddexp, ops.add, ops.safesub, to_ljvp), + ], +) +def test_reduce_mul_x(sum_op, prod_op, div_op, tojvp): + with autodiff: + # Reduce + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) + + z = y.reduce(prod_op, "k") + result = grad(z, (y,), out_adj) + actual_y = apply_optimizer(result[y]) + expected_y = div_op(prod_op(out_adj, z.primal), y.primal) + assert_close(actual_y, apply_optimizer(expected_y), rtol=1e-5) -def test_mul_reduce_x_y(): + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_reduce_x_y(sum_op, prod_op, tojvp): with autodiff: # Reduce - x = to_jvp(random_tensor(OrderedDict(j=Bint[4]))) - y = to_jvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) out_adj = random_tensor(OrderedDict(k=Bint[5])) - z = (x * y).reduce(ops.add, "j") + z = prod_op(x, y).reduce(sum_op, "j") result = grad(z, (x, y), out_adj) - expected_x = (y[0] * out_adj).reduce(ops.add, "k") - expected_y = x[0] * out_adj.expand(ops.add, (Variable("j", Bint[4]),)) + expected_x = prod_op(y.primal, out_adj).reduce(sum_op, "k") + expected_y = prod_op(x.primal, out_adj.expand(ops.add, (Variable("j", Bint[4]),))) actual_x = apply_optimizer(result[x]) actual_y = apply_optimizer(result[y]) - assert_close(actual_x, expected_x) - assert_close(actual_y, expected_y) + assert_close(actual_x, expected_x, rtol=1e-5) + assert_close(actual_y, expected_y, rtol=1e-5) # def test_trace(): From 3b30d2e41b3364d5c2fe0326d0e4bc8de3ee39b3 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Mar 2021 22:06:40 -0400 Subject: [PATCH 15/16] remove op from Expand --- funsor/autodiff.py | 4 ++-- funsor/tensor.py | 4 ++-- funsor/terms.py | 17 +++++++---------- test/test_autodiff.py | 4 ++-- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/funsor/autodiff.py b/funsor/autodiff.py index b5e7f8bd5..34db0429b 100644 --- a/funsor/autodiff.py +++ b/funsor/autodiff.py @@ -158,7 +158,7 @@ def transpose_reduce(expr, out_tangent, in_tangents, result, semiring): out_tangent = result[expr] if op is sum_op: - arg_adj = out_tangent.expand(sum_op, tuple(reduced_vars)) + arg_adj = out_tangent.expand(tuple(reduced_vars)) result = transpose(arg, arg_adj, in_tangents, result, semiring) return result elif op is prod_op: @@ -194,7 +194,7 @@ def transpose_contraction(expr, out_tangent, in_tangents, result): elif expr.bin_op is ops.nullop: for term in expr.terms: # only one term if expr.red_op is ops.add: - term_adj = out_tangent.expand(ops.add, tuple(expr.reduced_vars)) + term_adj = out_tangent.expand(tuple(expr.reduced_vars)) elif expr.red_op is ops.mul: term_adj = ops.safediv(ops.mul(out_tangent, expr), term) else: diff --git a/funsor/tensor.py b/funsor/tensor.py index 5ee4e8012..354dbccb3 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -683,8 +683,8 @@ def eager_scatter_tensor(op, subs, source, reduced_vars): return Tensor(data, destin_inputs, output.dtype) -@eager.register(Expand, AssociativeOp, (Number, Tensor), tuple) -def eager_tensor_expand(op, arg, expanded_vars): +@eager.register(Expand, (Number, Tensor), tuple) +def eager_tensor_expand(arg, expanded_vars): expanded_shape = tuple(var.output.size for var in expanded_vars) old_shape = (-1,) * (len(arg.inputs) + len(arg.output.shape)) new_shape = expanded_shape + old_shape diff --git a/funsor/terms.py b/funsor/terms.py index 578937048..0768b764e 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -337,13 +337,12 @@ def item(self): def requires_grad(self): return False - def expand(self, op, expanded_vars): - assert isinstance(op, AssociativeOp) + def expand(self, expanded_vars): # Eagerly convert reduced_vars to appropriate things. assert isinstance(expanded_vars, tuple) if not expanded_vars: return self - return Expand(op, self, expanded_vars) + return Expand(self, expanded_vars) def reduce(self, op, reduced_vars=None): """ @@ -1014,8 +1013,7 @@ class Expand(Funsor): :param frozenset reduced_vars: A set of variables over which to reduce. """ - def __init__(self, op, arg, expanded_vars): - assert isinstance(op, AssociativeOp) + def __init__(self, arg, expanded_vars): assert isinstance(arg, Funsor) assert isinstance(expanded_vars, tuple) assert all(isinstance(v, Variable) for v in expanded_vars) @@ -1025,22 +1023,21 @@ def __init__(self, op, arg, expanded_vars): fresh = frozenset() bound = {} super().__init__(inputs, output, fresh, bound) - self.op = op self.arg = arg self.expanded_vars = expanded_vars def __repr__(self): assert self.expanded_vars rvars = [repr(v) for v in self.expanded_vars] - return "{}.expand({}, {{{}}})".format( - repr(self.arg), self.op.__name__, ", ".join(rvars) + return "{}.expand({{{}}})".format( + repr(self.arg), ", ".join(rvars) ) def __str__(self): assert self.expanded_vars rvars = [repr(v) for v in self.expanded_vars] - return "{}.expand({}, {{{}}})".format( - repr(self.arg), self.op.__name__, ", ".join(rvars) + return "{}.expand({{{}}})".format( + repr(self.arg), ", ".join(rvars) ) diff --git a/test/test_autodiff.py b/test/test_autodiff.py index 7e28e917a..1f6291d85 100644 --- a/test/test_autodiff.py +++ b/test/test_autodiff.py @@ -173,7 +173,7 @@ def test_reduce_add_x(sum_op, prod_op, tojvp): z = y.reduce(sum_op, "k") result = grad(z, (y,), out_adj) - expected_y = out_adj.expand(ops.add, (Variable("k", Bint[5]),)) + expected_y = out_adj.expand((Variable("k", Bint[5]),)) actual_y = apply_optimizer(result[y]) assert_close(actual_y, expected_y, rtol=1e-5) @@ -214,7 +214,7 @@ def test_mul_reduce_x_y(sum_op, prod_op, tojvp): result = grad(z, (x, y), out_adj) expected_x = prod_op(y.primal, out_adj).reduce(sum_op, "k") - expected_y = prod_op(x.primal, out_adj.expand(ops.add, (Variable("j", Bint[4]),))) + expected_y = prod_op(x.primal, out_adj.expand((Variable("j", Bint[4]),))) actual_x = apply_optimizer(result[x]) actual_y = apply_optimizer(result[y]) From 3fd3cbde4ec1482dd279bc0055d0ca0dd01f81e2 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Mar 2021 22:58:45 -0400 Subject: [PATCH 16/16] expand Number --- funsor/autodiff.py | 48 +++++++++++++++++++++++++++++----------------- funsor/tensor.py | 14 +++++++++++++- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/funsor/autodiff.py b/funsor/autodiff.py index 34db0429b..f11ac2fd9 100644 --- a/funsor/autodiff.py +++ b/funsor/autodiff.py @@ -211,54 +211,66 @@ def transpose_contraction(expr, out_tangent, in_tangents, result): @autodiff.register(Binary, AssociativeOp, LJVP, LJVP) def jvp_binary(op, lhs, rhs): sum_op, prod_op = lhs.sum_op, lhs.prod_op - lhs_primal, lhs_tangent = lhs - rhs_primal, rhs_tangent = rhs - primal = Binary(op, lhs_primal, rhs_primal) + primal = Binary(op, lhs.primal, rhs.primal) if op is sum_op: - tangent = sum_op(lhs_tangent, rhs_tangent) + tangent = sum_op(lhs.tangent, rhs.tangent) elif op is prod_op: tangent = sum_op( - prod_op(rhs_primal, lhs_tangent), prod_op(lhs_primal, rhs_tangent) + prod_op(rhs.primal, lhs.tangent), prod_op(lhs.primal, rhs.tangent) ) else: raise NotImplementedError return type(lhs)(primal, tangent) -@eager.register(Binary, AssociativeOp, JVP, Tensor) -@eager.register(Binary, AssociativeOp, LJVP, Tensor) -@autodiff.register(Binary, AssociativeOp, JVP, Tensor) -@autodiff.register(Binary, AssociativeOp, LJVP, Tensor) +@eager.register(Binary, AssociativeOp, JVP, (Number, Tensor)) +@eager.register(Binary, AssociativeOp, LJVP, (Number, Tensor)) +@autodiff.register(Binary, AssociativeOp, JVP, (Number, Tensor)) +@autodiff.register(Binary, AssociativeOp, LJVP, (Number, Tensor)) def jvp_binary_jvp_funsor(op, lhs, rhs): sum_op, prod_op = lhs.sum_op, lhs.prod_op - lhs_primal, lhs_tangent = lhs - primal = Binary(op, lhs_primal, rhs) + primal = Binary(op, lhs.primal, rhs) if op is sum_op: - tangent = sum_op(lhs_tangent, rhs) + tangent = sum_op(lhs.tangent, rhs) elif op is prod_op: - tangent = prod_op(lhs_tangent, rhs) + tangent = prod_op(lhs.tangent, rhs) else: raise NotImplementedError return type(lhs)(primal, tangent) +@eager.register(Binary, AssociativeOp, (Number, Tensor), JVP) +@eager.register(Binary, AssociativeOp, (Number, Tensor), LJVP) +@autodiff.register(Binary, AssociativeOp, (Number, Tensor), JVP) +@autodiff.register(Binary, AssociativeOp, (Number, Tensor), LJVP) +def jvp_binary_jvp_funsor(op, lhs, rhs): + sum_op, prod_op = rhs.sum_op, rhs.prod_op + primal = Binary(op, lhs, rhs.primal) + if op is sum_op: + tangent = sum_op(lhs, rhs.tangent) + elif op is prod_op: + tangent = prod_op(lhs, rhs.tangent) + else: + raise NotImplementedError + return type(rhs)(primal, tangent) + + @eager.register(Reduce, AssociativeOp, JVP, frozenset) @eager.register(Reduce, AssociativeOp, LJVP, frozenset) @autodiff.register(Reduce, AssociativeOp, JVP, frozenset) @autodiff.register(Reduce, AssociativeOp, LJVP, frozenset) def jvp_reduce(op, arg, reduced_vars): sum_op, prod_op, div_op = arg.sum_op, arg.prod_op, arg.div_op - arg_primal, arg_tangent = arg - out_primal = Reduce(op, arg_primal, reduced_vars) + primal = Reduce(op, arg.primal, reduced_vars) if op is sum_op: - tangent = Reduce(sum_op, arg_tangent, reduced_vars) + tangent = Reduce(sum_op, arg.tangent, reduced_vars) elif op is prod_op: tangent = Reduce( - sum_op, div_op(prod_op(arg_tangent, out_primal), arg_primal), reduced_vars + sum_op, div_op(prod_op(arg.tangent, primal), arg.primal), reduced_vars ) else: raise NotImplementedError - return type(arg)(out_primal, tangent) + return type(arg)(primal, tangent) # @lazy.register(Unary, LogOp, JVP) diff --git a/funsor/tensor.py b/funsor/tensor.py index 354dbccb3..77cfe207e 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -683,7 +683,19 @@ def eager_scatter_tensor(op, subs, source, reduced_vars): return Tensor(data, destin_inputs, output.dtype) -@eager.register(Expand, (Number, Tensor), tuple) +@eager.register(Expand, Number, tuple) +def eager_tensor_expand(arg, expanded_vars): + shape = tuple(var.output.size for var in expanded_vars) + inputs = OrderedDict([(var.name, var.output) for var in expanded_vars]) + data = ops.new_full( + funsor.tensor.get_default_prototype(), + shape, + arg.data + ) + return Tensor(data, inputs, arg.dtype) + + +@eager.register(Expand, Tensor, tuple) def eager_tensor_expand(arg, expanded_vars): expanded_shape = tuple(var.output.size for var in expanded_vars) old_shape = (-1,) * (len(arg.inputs) + len(arg.output.shape))