Skip to content

Commit

Permalink
Add rough version of a autodiff refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed May 29, 2024
1 parent fc21336 commit ea52ead
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 1 deletion.
89 changes: 88 additions & 1 deletion pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.compile.ops import ViewOp
from pytensor.configdefaults import config
from pytensor.graph import utils
from pytensor.graph.basic import Apply, NominalVariable, Variable
from pytensor.graph.basic import Apply, NominalVariable, Variable, io_toposort
from pytensor.graph.null_type import NullType, null_type
from pytensor.graph.op import get_test_values
from pytensor.graph.type import Type
Expand Down Expand Up @@ -2292,3 +2292,90 @@ def grad_scale(x, multiplier):
0.416...
"""
return GradScale(multiplier)(x)


# ===========================================
# The following is more or less pseudocode...
# ===========================================

# Use transpose and forward mode autodiff to get reverse mode autodiff
# Ops that only define push_forward (Rop) could use this, which is nice
# because push_forward is usually easier to derive and think about.
def pull_back_through_transpose(outputs, inputs, output_cotangents):
tangents = [input.type() for input in inputs]
output_tangents = push_forward(outputs, inputs, tangents)
return linear_transpose(output_tangents, tangents, output_cotangents)


# Ops that only define pull_back (Lop) could use this to derive push_forward.
def push_forward_through_pull_back(outputs, inputs, tangents):
cotangents = [out.type("u") for out in outputs]
input_cotangents = pull_back(outputs, inputs, cotangents)
return pull_back(input_cotangents, cotangents, tangents)


def push_forward(outputs, inputs, input_tangents):
# Get the nodes in topological order and precompute
# a set of values that are used in the graph.
nodes = io_toposort(inputs, outputs)
used_values = set(outputs)
for node in reversed(nodes):
if any(output in used_values for output in node.outputs):
used_values.update(node.inputs)

# Maybe a lazy gradient op could use this during rewrite time?
recorded_rewrites = {}
known_tangents = dict(zip(inputs, input_tangents, strict=True))
for node in nodes:
tangents = [known_tangents.get(input, None) for input in node.inputs]
result_nums = [i for i in range(len(node.outputs)) if node.outputs[i] in used_values]
new_outputs, output_tangents = node.op.push_forward(node, tangents, result_nums)
if new_outputs is not None:
recorded_rewrites[node] = new_outputs

for i, tangent in zip(result_nums, output_tangents, strict=True):
known_tangents[node.outputs[i]] = tangent

return [known_tangents[output] for output in outputs]


def pull_back(outputs, inputs, output_cotangents):
known_cotangents = dict(zip(outputs, output_cotangents, strict=True))

nodes = io_toposort(inputs, outputs)
used_values = set(outputs)
for node in reversed(nodes):
if any(output in used_values for output in node.outputs):
used_values.update(node.inputs)

# Maybe a lazy gradient op could use this during rewrite time?
recorded_rewrites = {}
for node in reversed(nodes):
cotangents = [known_cotangents.get(output, None) for output in node.outputs]
argnums = [i for i in range(len(node.inputs)) if node.inputs[i] in used_values]
new_outputs, input_cotangents = node.op.pull_back(node, cotangents, argnums)
if new_outputs is not None:
recorded_rewrites[node] = new_outputs

for i, cotangent in zip(argnums, input_cotangents, strict=True):
input = node.inputs[i]
if input not in known_cotangents:
known_cotangents[input] = cotangent
else:
# TODO check that we are not broadcasting?
known_cotangents[input] += cotangent

return [known_cotangents[input] for input in inputs]

def pullback_grad(cost, wrt):
"""A new pt.grad that uses the pull_back function.
At some point we might want to replace pt.grad with this?
"""
# Error checking and allow non-list wrt...
return pull_back([cost], wrt, [1.])

def linear_transpose(outputs, inputs, transposed_inputs):
"""Given a linear function from inputs to outputs, return the transposed function."""
# some loop over inv_toposort...
# Should look similar to pull_back?
115 changes: 115 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
Protocol,
Tuple,
TypeVar,
cast,
)
Expand Down Expand Up @@ -323,6 +325,119 @@ def __ne__(self, other: Any) -> bool:
# just to self.add_tag_trace
add_tag_trace = staticmethod(add_tag_trace)

def linear_transpose(
self,
node: Apply,
transposed_inputs: Sequence[Variable],
linear_inputs: Sequence[int],
linear_outputs: Sequence[int],
) -> Sequence[Variable]:
"""Transpose a linear function.
The function f: [node.inputs[i] for i in linear_inputs] to [node.outputs[i] ofr i in linear_outputs]
given the remaining inputs as constants must be linear. This function can then
be implemented by an Op, and return f^*(transposed_inputs).
Parameters
----------
node: Apply
The point at which to do the transpose
transposed_inputs:
The inputs for the transposed function.
linear_inputs:
Indices of input arguments to consider.
linear_outputs:
Indices of output arguments to consider.
"""
raise NotImplementedError(f"Linear transpos of {self} is not defined or not implemented.")

def push_forward(
self,
node: Apply,
input_tangents: Sequence[Variable | None],
result_nums: Sequence[int],
) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
"""Compute the push_forward of tangent vectors at the specified point.
Parameters
----------
node: Apply
The point at which to compute the push_forward. (ie at x = node.inputs
and f(x) = node.outputs).
input_tangents:
The values of the tangent vectors that we wish to map. Values that
are set to None are assumed to be constants.
result_nums:
Compute only the output tangents of [node.outputs[i] for i in argnums].
Returns
-------
alternative_outputs:
Optionally a hint to the rewriter that the outputs of the op could
also be computed with the provided values, if the tangents are also
computed.
output_tangents:
The tangents of the outputs specified in argnums.
If the value is None, this indicates that the output did
not depend on the inputs that had tangents provided..
"""
from pytensor.gradient import DisconnectedType
from pytensor.graph.null_type import NullType
from pytensor.tensor.basic import zeros_like

tangents_filled = [
# TODO do the R_op methods also accept a disconnected_grad?
tangent if tangent is not None else zeros_like(input)
for tangent, input in zip(input_tangents, node.inputs, strict=True)
]
output_tangents = self.R_op(node.inputs, tangents_filled)
output_tangents = [output_tangents[i] for i in result_nums]

mapped_output_tangents = []
for argnum, tangent in zip(result_nums, output_tangents):
if isinstance(tangent.type, DisconnectedType):
mapped_output_tangents.append(None)
elif isinstance(tangent.type, NullType):
raise NotImplementedError(
f"The push_forward of argument {argnum} of op "
f"{self} is not implemented or not defined."
)
else:
mapped_output_tangents.append(tangent)
return (None, mapped_output_tangents)

def pull_back(
self,
node: Apply,
output_cotangents: Sequence[Variable | None],
argnums: Sequence[int],
) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
from pytensor.gradient import DisconnectedType
from pytensor.graph.null_type import NullType
from pytensor.tensor.basic import zeros_like

cotangents_filled = [
# TODO do the L_op methods also accept a disconnected_grad?
cotangent if cotangent is not None else zeros_like(input)
for cotangent, input in zip(output_cotangents, node.outputs, strict=True)
]

input_cotangents = self.L_op(node.inputs, node.outputs, cotangents_filled)
input_cotangents = [input_cotangents[i] for i in argnums]

mapped_input_cotangents = []
for argnum, cotangent in zip(argnums, input_cotangents):
if isinstance(cotangent.type, DisconnectedType):
mapped_input_cotangents.append(None)
elif isinstance(cotangent.type, NullType):
raise NotImplementedError(
f"The push_forward of argument {argnum} of op "
f"{self} is not implemented or not defined."
)
else:
mapped_input_cotangents.append(cotangent)
return (None, mapped_input_cotangents)

def grad(
self, inputs: Sequence[Variable], output_grads: Sequence[Variable]
) -> list[Variable]:
Expand Down

0 comments on commit ea52ead

Please sign in to comment.