Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff implementation (experimental) #494

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
282 changes: 282 additions & 0 deletions funsor/autodiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from collections import defaultdict
from functools import reduce, singledispatch

import funsor.ops as ops
from funsor import Tensor
from funsor.adjoint import _alpha_unmangle
from funsor.cnf import Contraction
from funsor.domains import Array, Bint, Real, Reals
from funsor.interpretations import autodiff, trace
from funsor.interpreter import interpretation
from funsor.ops import AssociativeOp, LogOp
from funsor.terms import (
Binary,
Funsor,
Lambda,
Number,
Reduce,
Tuple,
Unary,
Variable,
eager,
lazy,
)


class JVP(Tuple):
"""
Tuple:(Primal, Tanget)
Semiring: (Add, Mul)
"""

sum_op = ops.add
prod_op = ops.mul
div_op = ops.safediv
zero = Number(0)
one = Number(1)

@property
def primal(self):
return self[0]

@property
def tangent(self):
return self[1]


class LJVP(Tuple):
"""
Tuple: (LogPrimal, LogTanget)
Semiring: (Logaddexp, Add)
"""

sum_op = ops.logaddexp
prod_op = ops.add
div_op = ops.safesub
zero = Number(-math.inf)
one = Number(0)

@property
def primal(self):
return self[0]

@property
def tangent(self):
return self[1]


@trace.register(Binary, AssociativeOp, Funsor, Funsor)
def trace_binary_associativeop(op, lhs, rhs):
with lazy:
result = Binary(op, lhs, rhs)
return result


@trace.register(Reduce, AssociativeOp, Funsor, frozenset)
def trace_binary_associativeop(op, arg, reduced_args):
with lazy:
result = Reduce(op, arg, reduced_args)
return result


def to_jvp(primal):
input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items())
output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output
tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)]
return JVP(primal, tangent_placeholder)


def to_ljvp(primal):
input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items())
output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output
tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)]
return LJVP(primal, tangent_placeholder)


def grad(expr, targets, out_tangent=None):
out_tangent = expr.one if out_tangent is None else out_tangent
in_tangents = set(target.tangent for target in targets)
transposes = transpose(
expr.tangent,
out_tangent,
in_tangents,
defaultdict(lambda: expr.zero),
type(expr),
)
result = {}
for target in targets:
result[target] = transposes[target.tangent]
return result


@singledispatch
def transpose(expr, out_tangent, in_tangents, result, semiring):
if expr in in_tangents:
result[expr] = semiring.sum_op(result[expr], out_tangent)
return result


@transpose.register(Binary)
def transpose_binary(expr, out_tangent, in_tangents, result, semiring):

op, lhs, rhs = expr.op, expr.lhs, expr.rhs
sum_op, prod_op = semiring.sum_op, semiring.prod_op

if expr in in_tangents:
result[expr] = sum_op(result[expr], out_tangent)
out_tangent = result[expr]

if op is sum_op:
lhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - lhs.input_vars)
rhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - rhs.input_vars)
elif op is prod_op:
lhs_adj = prod_op(rhs, out_tangent).reduce(
sum_op, out_tangent.input_vars - lhs.input_vars
)
rhs_adj = prod_op(lhs, out_tangent).reduce(
sum_op, out_tangent.input_vars - rhs.input_vars
)
else:
return result # is it always correct?
result = transpose(lhs, lhs_adj, in_tangents, result, semiring)
result = transpose(rhs, rhs_adj, in_tangents, result, semiring)
return result


@transpose.register(Reduce)
def transpose_reduce(expr, out_tangent, in_tangents, result, semiring):
# fix this in contraction as well
op, arg, reduced_vars = _alpha_unmangle(expr)
sum_op, prod_op = semiring.sum_op, semiring.prod_op

if expr in in_tangents:
result[expr] = sum_op(result[expr], out_tangent)
out_tangent = result[expr]

if op is sum_op:
arg_adj = out_tangent.expand(tuple(reduced_vars))
result = transpose(arg, arg_adj, in_tangents, result, semiring)
return result
elif op is prod_op:
# this is unnecessary
return result
else:
raise ValueError


@transpose.register(Contraction)
def transpose_contraction(expr, out_tangent, in_tangents, result):
breakpoint()
if expr in in_tangents:
result[expr] += out_tangent
out_tangent = result[expr]

if expr.red_op is ops.nullop:
for term in expr.terms:
if expr.bin_op is ops.add:
term_adj = out_tangent.reduce(
ops.add, out_tangent.input_vars - term.input_vars
)
elif expr.bin_op is ops.mul:
expr_div_term = reduce(
ops.mul, tuple(t for t in expr.terms if t is not term)
)
term_adj = (out_tangent * expr_div_term).reduce(
ops.add, out_tangent.input_vars - term.input_vars
)
else:
raise ValueError
result = transpose(term, term_adj, in_tangents, result)
elif expr.bin_op is ops.nullop:
for term in expr.terms: # only one term
if expr.red_op is ops.add:
term_adj = out_tangent.expand(tuple(expr.reduced_vars))
elif expr.red_op is ops.mul:
term_adj = ops.safediv(ops.mul(out_tangent, expr), term)
else:
raise ValueError
result = transpose(term, term_adj, in_tangents, result)
else:
raise ValueError
return result


@eager.register(Binary, AssociativeOp, JVP, JVP)
@eager.register(Binary, AssociativeOp, LJVP, LJVP)
@autodiff.register(Binary, AssociativeOp, JVP, JVP)
@autodiff.register(Binary, AssociativeOp, LJVP, LJVP)
def jvp_binary(op, lhs, rhs):
sum_op, prod_op = lhs.sum_op, lhs.prod_op
primal = Binary(op, lhs.primal, rhs.primal)
if op is sum_op:
tangent = sum_op(lhs.tangent, rhs.tangent)
elif op is prod_op:
tangent = sum_op(
prod_op(rhs.primal, lhs.tangent), prod_op(lhs.primal, rhs.tangent)
)
else:
raise NotImplementedError
return type(lhs)(primal, tangent)


@eager.register(Binary, AssociativeOp, JVP, (Number, Tensor))
@eager.register(Binary, AssociativeOp, LJVP, (Number, Tensor))
@autodiff.register(Binary, AssociativeOp, JVP, (Number, Tensor))
@autodiff.register(Binary, AssociativeOp, LJVP, (Number, Tensor))
def jvp_binary_jvp_funsor(op, lhs, rhs):
sum_op, prod_op = lhs.sum_op, lhs.prod_op
primal = Binary(op, lhs.primal, rhs)
if op is sum_op:
tangent = sum_op(lhs.tangent, rhs)
elif op is prod_op:
tangent = prod_op(lhs.tangent, rhs)
else:
raise NotImplementedError
return type(lhs)(primal, tangent)


@eager.register(Binary, AssociativeOp, (Number, Tensor), JVP)
@eager.register(Binary, AssociativeOp, (Number, Tensor), LJVP)
@autodiff.register(Binary, AssociativeOp, (Number, Tensor), JVP)
@autodiff.register(Binary, AssociativeOp, (Number, Tensor), LJVP)
def jvp_binary_jvp_funsor(op, lhs, rhs):
sum_op, prod_op = rhs.sum_op, rhs.prod_op
primal = Binary(op, lhs, rhs.primal)
if op is sum_op:
tangent = sum_op(lhs, rhs.tangent)
elif op is prod_op:
tangent = prod_op(lhs, rhs.tangent)
else:
raise NotImplementedError
return type(rhs)(primal, tangent)


@eager.register(Reduce, AssociativeOp, JVP, frozenset)
@eager.register(Reduce, AssociativeOp, LJVP, frozenset)
@autodiff.register(Reduce, AssociativeOp, JVP, frozenset)
@autodiff.register(Reduce, AssociativeOp, LJVP, frozenset)
def jvp_reduce(op, arg, reduced_vars):
sum_op, prod_op, div_op = arg.sum_op, arg.prod_op, arg.div_op
primal = Reduce(op, arg.primal, reduced_vars)
if op is sum_op:
tangent = Reduce(sum_op, arg.tangent, reduced_vars)
elif op is prod_op:
tangent = Reduce(
sum_op, div_op(prod_op(arg.tangent, primal), arg.primal), reduced_vars
)
else:
raise NotImplementedError
return type(arg)(primal, tangent)


# @lazy.register(Unary, LogOp, JVP)
# @eager.register(Unary, LogOp, JVP)
# def jvp_log(op, arg):
# arg_primal, arg_tangent = arg
# primal = Unary(op, arg_primal)
# tangent = Binary(ops.truediv, arg_tangent, arg_primal)
# return JVP(primal, tangent)
6 changes: 5 additions & 1 deletion funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain):
return Array[dtype, shape]
elif isinstance(lhs_domain, ProductDomain):
# XXX should this return a Union?
return Real
raise NotImplementedError(
"Cannot statically infer domain from: " f"{lhs_domain}[{rhs_domain}]"
)
Expand Down Expand Up @@ -325,7 +326,10 @@ def _find_domain_associative_generic(op, *domains):
return Array[domains[0].dtype, ()]

lhs, rhs = domains
if lhs.dtype == "real" or rhs.dtype == "real":
# FIXME
if lhs is rhs:
return lhs
elif lhs.dtype == "real" or rhs.dtype == "real":
dtype = "real"
elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min):
dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1
Expand Down
14 changes: 14 additions & 0 deletions funsor/interpretations.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,20 @@ def reflect(cls, *args):
Eager exact naive interpretation wherever possible.
"""

trace_base = DispatchedInterpretation("trace")
trace = PrioritizedInterpretation(trace_base, eager_base, normalize_base, reflect)
"""
Constructs a trace (expression) in terms of primitive operations.
"""

autodiff_base = DispatchedInterpretation("autodiff")
autodiff = PrioritizedInterpretation(
autodiff_base, trace_base, eager_base, normalize_base, reflect
)
"""
Constructs a trace (expression) in terms of primitive operations.
"""

die = DispatchedInterpretation("die")
eager_or_die = PrioritizedInterpretation(eager_base, die, reflect)

Expand Down
25 changes: 24 additions & 1 deletion funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from . import ops
from .delta import Delta
from .domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain
from .ops import GetitemOp, MatmulOp, Op, ReshapeOp
from .ops import AssociativeOp, GetitemOp, MatmulOp, Op, ReshapeOp
from .terms import (
Binary,
Expand,
Finitary,
Funsor,
FunsorMeta,
Expand Down Expand Up @@ -682,6 +683,28 @@ def eager_scatter_tensor(op, subs, source, reduced_vars):
return Tensor(data, destin_inputs, output.dtype)


@eager.register(Expand, Number, tuple)
def eager_tensor_expand(arg, expanded_vars):
shape = tuple(var.output.size for var in expanded_vars)
inputs = OrderedDict([(var.name, var.output) for var in expanded_vars])
data = ops.new_full(
funsor.tensor.get_default_prototype(),
shape,
arg.data
)
return Tensor(data, inputs, arg.dtype)


@eager.register(Expand, Tensor, tuple)
def eager_tensor_expand(arg, expanded_vars):
expanded_shape = tuple(var.output.size for var in expanded_vars)
old_shape = (-1,) * (len(arg.inputs) + len(arg.output.shape))
new_shape = expanded_shape + old_shape
inputs = OrderedDict([(var.name, var.output) for var in expanded_vars])
inputs.update(arg.inputs)
return Tensor(ops.expand(arg.data, new_shape), inputs, arg.dtype)


@eager.register(Binary, Op, Tensor, Number)
def eager_binary_tensor_number(op, lhs, rhs):
dtype = find_domain(op, lhs.output, rhs.output).dtype
Expand Down
Loading