Skip to content

Commit

Permalink
Simplify logic with variadic_add and variadic_mul helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 7, 2024
1 parent eaeb3da commit d6d55cc
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 59 deletions.
8 changes: 2 additions & 6 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor

Expand Down Expand Up @@ -1399,11 +1399,7 @@ def item_to_var(t):
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [add(*add_inputs)]
else:
rval = add_inputs
# print "RETURNING GEMM THING", rval
rval = [variadic_add(*add_inputs)]
return rval, old_dot22


Expand Down
36 changes: 24 additions & 12 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,18 +1430,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
else:
shp = cast(shp, "float64")

if axis is None:
axis = list(range(input.ndim))
elif isinstance(axis, int | np.integer):
axis = [axis]
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]

# This sequential division will possibly be optimized by PyTensor:
for i in axis:
s = true_div(s, shp[i])
reduced_dims = (
shp
if axis is None
else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)]
)
s /= variadic_mul(*reduced_dims)

# This can happen when axis is an empty list/tuple
if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
Expand Down Expand Up @@ -1597,6 +1591,15 @@ def add(a, *other_terms):
# see decorator for function body


def variadic_add(*args):
"""Add that accepts arbitrary number of inputs, including zero or one."""
if not args:
return 0
if len(args) == 1:
return args[0]
return add(*args)


@scalar_elemwise
def sub(a, b):
"""elementwise subtraction"""
Expand All @@ -1609,6 +1612,15 @@ def mul(a, *other_terms):
# see decorator for function body


def variadic_mul(*args):
"""Mul that accepts arbitrary number of inputs, including zero or one."""
if not args:
return 1
if len(args) == 1:
return args[0]
return mul(*args)


@scalar_elemwise
def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)"""
Expand Down
13 changes: 4 additions & 9 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq
from pytensor.tensor.math import Sum, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable
Expand Down Expand Up @@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node):
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
return

if len(elements) == 0:
element_sum = zeros(dtype=out_dtype, shape=())
elif len(elements) == 1:
element_sum = cast(elements[0], out_dtype)
else:
element_sum = cast(
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)
element_sum = cast(
variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)

return [element_sum]

Expand Down
15 changes: 10 additions & 5 deletions pytensor/tensor/rewriting/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
from pytensor.tensor.math import (
Dot,
_matrix_matrix_matmul,
add,
mul,
neg,
sub,
variadic_add,
)
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import (
DenseTensorType,
Expand Down Expand Up @@ -386,10 +394,7 @@ def item_to_var(t):
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [add(*add_inputs)]
else:
rval = add_inputs
rval = [variadic_add(*add_inputs)]
# print "RETURNING GEMM THING", rval
return rval, old_dot22

Expand Down
25 changes: 7 additions & 18 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
sub,
tri_gamma,
true_div,
variadic_add,
variadic_mul,
)
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import max as pt_max
Expand Down Expand Up @@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):

if not outer_terms:
return None
elif len(outer_terms) == 1:
[outer_term] = outer_terms
else:
outer_term = mul(*outer_terms)
outer_term = variadic_mul(*outer_terms)

if not inner_terms:
inner_term = None
elif len(inner_terms) == 1:
[inner_term] = inner_terms
else:
inner_term = mul(*inner_terms)
inner_term = variadic_mul(*inner_terms)

else: # true_div
# We only care about removing the denominator out of the reduction
Expand Down Expand Up @@ -2163,10 +2161,7 @@ def local_add_remove_zeros(fgraph, node):
assert cst.type.broadcastable == (True,) * ndim
return [alloc_like(cst, node_output, fgraph)]

if len(new_inputs) == 1:
ret = [alloc_like(new_inputs[0], node_output, fgraph)]
else:
ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)]

# The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0.
Expand Down Expand Up @@ -2277,10 +2272,7 @@ def local_log1p(fgraph, node):
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and np.allclose(np.sum(scalars), 1):
if nonconsts:
if len(nonconsts) > 1:
ninp = add(*nonconsts)
else:
ninp = nonconsts[0]
ninp = variadic_add(*nonconsts)
if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype)
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
Expand Down Expand Up @@ -3104,10 +3096,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
else:
new_num = mul(*new_num)
new_num = variadic_mul(*new_num)

if num_neg ^ denom_neg:
new_num = -new_num
Expand Down
15 changes: 6 additions & 9 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
maximum,
minimum,
or_,
variadic_add,
)
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -1241,15 +1242,11 @@ def movable(i):
new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs
]
if len(new_inputs) == 0:
new_add = new_inputs[0]
else:
new_add = add(*new_inputs)

# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)
new_add = variadic_add(*new_inputs)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)

# stack up the new incsubtensors
tip = new_add
Expand Down

0 comments on commit d6d55cc

Please sign in to comment.