diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index 1d8ae33104..f4098416b8 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -6,6 +6,7 @@ import pytensor.link.jax.dispatch.blockwise import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.extra_ops +import pytensor.link.jax.dispatch.pad import pytensor.link.jax.dispatch.math import pytensor.link.jax.dispatch.nlinalg import pytensor.link.jax.dispatch.random diff --git a/pytensor/link/jax/dispatch/pad.py b/pytensor/link/jax/dispatch/pad.py new file mode 100644 index 0000000000..6d40d20cc1 --- /dev/null +++ b/pytensor/link/jax/dispatch/pad.py @@ -0,0 +1,53 @@ +import jax.numpy as jnp +import numpy as np + +from pytensor.link.jax.dispatch import jax_funcify +from pytensor.tensor.pad import Pad + + +@jax_funcify.register(Pad) +def jax_funcify_pad(op, **kwargs): + pad_mode = op.pad_mode + reflect_type = op.reflect_type + has_stat_length = op.has_stat_length + + if pad_mode == "constant": + + def constant_pad(x, pad_width, constant_values): + return jnp.pad(x, pad_width, mode=pad_mode, constant_values=constant_values) + + return constant_pad + + elif pad_mode == "linear_ramp": + + def lr_pad(x, pad_width, end_values): + # JAX does not allow a dynamic input if end_values is non-scalar + if not isinstance(end_values, int | float): + end_values = tuple(np.array(end_values)) + return jnp.pad(x, pad_width, mode=pad_mode, end_values=end_values) + + return lr_pad + + elif pad_mode in ["maximum", "minimum", "mean"] and has_stat_length: + + def stat_pad(x, pad_width, stat_length): + # JAX does not allow a dynamic input here, need to cast to tuple + return jnp.pad( + x, pad_width, mode=pad_mode, stat_length=tuple(np.array(stat_length)) + ) + + return stat_pad + + elif pad_mode in ["reflect", "symmetric"]: + + def loop_pad(x, pad_width): + return jnp.pad(x, pad_width, mode=pad_mode, reflect_type=reflect_type) + + return loop_pad + + else: + + def pad(x, pad_width): + return jnp.pad(x, pad_width, mode=pad_mode) + + return pad diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 3dfa1b4b7a..81cabfa6bd 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -130,6 +130,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: from pytensor.tensor.extra_ops import * from pytensor.tensor.io import * from pytensor.tensor.math import * +from pytensor.tensor.pad import pad from pytensor.tensor.shape import ( reshape, shape, diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index b1eaf4f001..cf809a55ef 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Collection, Iterable import numpy as np @@ -20,14 +21,24 @@ from pytensor.raise_op import Assert from pytensor.scalar import int32 as int_t from pytensor.scalar import upcast -from pytensor.tensor import as_tensor_variable +from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb from pytensor.tensor.basic import alloc, second from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import eq as pt_eq -from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch +from pytensor.tensor.math import ( + ge, + gt, + log, + lt, + maximum, + minimum, + prod, + sign, + switch, +) from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.shape import specify_broadcastable @@ -1584,27 +1595,346 @@ def broadcast_shape_iter( return tuple(result_dims) -def geomspace(start, end, steps, base=10.0): - from pytensor.tensor.math import log +def _check_deprecated_inputs(stop, end, num, steps): + if end is not None: + warnings.warn( + "The 'end' parameter is deprecated and will be removed in a future version. Use 'stop' instead.", + DeprecationWarning, + ) + stop = end + if steps is not None: + warnings.warn( + "The 'steps' parameter is deprecated and will be removed in a future version. Use 'num' instead.", + DeprecationWarning, + ) + num = steps + + return stop, num + + +def _linspace_core( + start: TensorVariable, + stop: TensorVariable, + num: int, + endpoint=True, + retstep=False, + axis=0, +) -> TensorVariable | tuple[TensorVariable, TensorVariable]: + div = (num - 1) if endpoint else num + delta = stop - start + samples = ptb.shape_padright(ptb.arange(0, num), delta.ndim) + + step = delta / div + samples = switch(gt(div, 0), samples * delta / div + start, samples * delta + start) + if endpoint: + samples = switch(gt(num, 1), set_subtensor(samples[-1, ...], stop), samples) + + if axis != 0: + samples = ptb.moveaxis(samples, 0, axis) + + if retstep: + return samples, step + + return samples + + +def _broadcast_base_with_inputs(start, stop, base, axis): + """ + Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it + may change how the axis argument is interpreted in the final output. + + Parameters + ---------- + start: TensorVariable + The start value(s) of the sequence(s). + stop: TensorVariable + The end value(s) of the sequence(s) + base: TensorVariable + The log base value(s) of the sequence(s) + axis: int + The axis along which to generate samples. + + Returns + ------- + start: TensorVariable + The start value(s) of the sequence(s), broadcast with the base tensor if necessary. + stop: TensorVariable + The end value(s) of the sequence(s), broadcast with the base tensor if necessary. + base: TensorVariable + The log base value(s) of the sequence(s), broadcast with the start and stop tensors if necessary. + """ + base = ptb.as_tensor_variable(base) + if base.ndim > 0: + ndmax = len(broadcast_shape(start, stop, base)) + start, stop, base = ( + ptb.shape_padleft(a, ndmax - a.ndim) for a in (start, stop, base) + ) + base = ptb.expand_dims(base, axis=(axis,)) + + return start, stop, base + + +def linspace( + start: TensorLike, + stop: TensorLike, + num: TensorLike = 50, + endpoint: bool = True, + retstep: bool = False, + dtype: str | None = None, + axis: int = 0, + end: TensorLike | None = None, + steps: TensorLike | None = None, +) -> TensorVariable | tuple[TensorVariable, TensorVariable]: + """ + Return evenly spaced numbers over a specified interval. + + Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. + + The endpoint of the interval can optionally be excluded. + + Parameters + ---------- + start: int, float, or TensorVariable + The starting value of the sequence. + + stop: int, float or TensorVariable + The end value of the sequence, unless `endpoint` is set to False. + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded. + + num: int + Number of samples to generate. Must be non-negative. - start = ptb.as_tensor_variable(start) - end = ptb.as_tensor_variable(end) - return base ** linspace(log(start) / log(base), log(end) / log(base), steps) + endpoint: bool + Whether to include the endpoint in the range. + retstep: bool + If true, returns both the samples and an array of steps between samples. -def logspace(start, end, steps, base=10.0): - start = ptb.as_tensor_variable(start) - end = ptb.as_tensor_variable(end) - return base ** linspace(start, end, steps) + dtype: str, optional + dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start` + and `end` arguments. + axis: int + Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0 + will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1. + + end: int, float or TensorVariable + .. warning:: + The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead. + The end value of the sequence, unless `endpoint` is set to False. + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is + excluded. + + steps: float, int, or TensorVariable + .. warning:: + The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead. + + Number of samples to generate. Must be non-negative + + Returns + ------- + samples: TensorVariable + Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True. + + step: TensorVariable + Tensor containing the spacing between samples. Only returned if `retstep` is True. + """ + if dtype is None: + dtype = pytensor.config.floatX + end, num = _check_deprecated_inputs(stop, end, num, steps) + start, stop = broadcast_arrays(start, stop) + + ls = _linspace_core( + start=start, + stop=stop, + num=num, + endpoint=endpoint, + retstep=retstep, + axis=axis, + ) + + return ls.astype(dtype) + + +def geomspace( + start: TensorLike, + stop: TensorLike, + num: int = 50, + base: float = 10.0, + endpoint: bool = True, + dtype: str | None = None, + axis: int = 0, + end: TensorLike | None = None, + steps: TensorLike | None = None, +) -> TensorVariable: + """ + Return numbers spaced evenly on a log scale (a geometric progression). + + This is similar to logspace, but with endpoints specified directly. Each output sample is a constant multiple of + the previous. + + Parameters + ---------- + Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. + + The endpoint of the interval can optionally be excluded. + + Parameters + ---------- + start: int, float, or TensorVariable + The starting value of the sequence. + + stop: int, float or TensorVariable + The end value of the sequence, unless `endpoint` is set to False. + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded. + + num: int + Number of samples to generate. Must be non-negative. + + base: float + The base of the log space. + + endpoint: bool + Whether to include the endpoint in the range. + + dtype: str, optional + dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start` + and `end` arguments. + + axis: int + Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0 + will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1. + + end: int, float or TensorVariable + .. warning:: + The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead. + The end value of the sequence, unless `endpoint` is set to False. + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is + excluded. + + steps: float, int, or TensorVariable + .. warning:: + The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead. + + Number of samples to generate. Must be non-negative + + Returns + ------- + samples: TensorVariable + Tensor containing `num` evenly-spaced (in log space) values between [start, stop]. The range is inclusive if + `endpoint` is True. + """ + if dtype is None: + dtype = pytensor.config.floatX + stop, num = _check_deprecated_inputs(stop, end, num, steps) + start, stop = broadcast_arrays(start, stop) + start, stop, base = _broadcast_base_with_inputs(start, stop, base, axis) + + out_sign = sign(start) + log_start, log_stop = ( + log(start * out_sign) / log(base), + log(stop * out_sign) / log(base), + ) + result = _linspace_core( + start=log_start, + stop=log_stop, + num=num, + endpoint=endpoint, + axis=0, + retstep=False, + ) + result = base**result + + result = switch(gt(num, 0), set_subtensor(result[0, ...], start), result) + if endpoint: + result = switch(gt(num, 1), set_subtensor(result[-1, ...], stop), result) + + result = result * out_sign + + if axis != 0: + result = ptb.moveaxis(result, 0, axis) + + return result.astype(dtype) + + +def logspace( + start: TensorLike, + stop: TensorLike, + num: int = 50, + base: float = 10.0, + endpoint: bool = True, + dtype: str | None = None, + axis: int = 0, + end: TensorLike | None = None, + steps: TensorLike | None = None, +) -> TensorVariable: + """ + Return numbers spaced evenly on a log scale. + + In linear space, the sequence starts at ``base ** start`` (base to the power of start) and ends with ``base ** stop`` + (see ``endpoint`` below). + + Parameters + ---------- + start: int, float, or TensorVariable + ``base ** start`` is the starting value of the sequence + + stop: int, float or TensorVariable + ``base ** stop`` is the endpoint of the sequence, unless ``endopoint`` is set to False. + In that case, ``num + 1`` values are spaced over the interval in log-space, and the first ``num`` are returned. + + num: int, default = 50 + Number of samples to generate. + + base: float, default = 10.0 + The base of the log space. The step size between the elements in ``log(samples) / log(base)`` + (or ``log_base(samples)`` is uniform. + + endpoint: bool + Whether to include the endpoint in the range. + + dtype: str, optional + dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start` + and `stop` arguments. + + axis: int + Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0 + will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1. + + end: int float or TensorVariable + .. warning:: + The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead. + The end value of the sequence, unless `endpoint` is set to False. + In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is + excluded. + + steps: int or TensorVariable + .. warning:: + The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead. + Number of samples to generate. Must be non-negative + + Returns + ------- + samples: TensorVariable + Tensor containing `num` evenly-spaced (in log-pace) values between [start, stop]. The range is inclusive if + `endpoint` is True. + """ + if dtype is None: + dtype = pytensor.config.floatX + stop, num = _check_deprecated_inputs(stop, end, num, steps) + start, stop = broadcast_arrays(start, stop) + start, stop, base = _broadcast_base_with_inputs(start, stop, base, axis) + + ls = _linspace_core( + start=start, + stop=stop, + num=num, + endpoint=endpoint, + axis=axis, + retstep=False, + ) -def linspace(start, end, steps): - start = ptb.as_tensor_variable(start) - end = ptb.as_tensor_variable(end) - arr = ptb.arange(steps) - arr = ptb.shape_padright(arr, max(start.ndim, end.ndim)) - multiplier = (end - start) / (steps - 1) - return start + arr * multiplier + return (base**ls).astype(dtype) def broadcast_to( diff --git a/pytensor/tensor/pad.py b/pytensor/tensor/pad.py new file mode 100644 index 0000000000..91aef44004 --- /dev/null +++ b/pytensor/tensor/pad.py @@ -0,0 +1,690 @@ +from collections.abc import Callable +from functools import partial +from typing import Literal, cast + +from pytensor.compile.builders import OpFromGraph +from pytensor.ifelse import ifelse +from pytensor.scan import scan +from pytensor.tensor import TensorLike +from pytensor.tensor.basic import ( + TensorVariable, + as_tensor, + concatenate, + expand_dims, + moveaxis, + switch, + zeros, +) +from pytensor.tensor.extra_ops import broadcast_to, linspace +from pytensor.tensor.math import divmod as pt_divmod +from pytensor.tensor.math import eq, gt, mean, minimum +from pytensor.tensor.math import max as pt_max +from pytensor.tensor.math import min as pt_min +from pytensor.tensor.shape import specify_broadcastable +from pytensor.tensor.subtensor import flip, set_subtensor, slice_at_axis + + +PadMode = Literal[ + "constant", + "edge", + "linear_ramp", + "maximum", + "minimum", + "mean", + "median", + "wrap", + "symmetric", + "reflect", +] +stat_funcs = {"maximum": pt_max, "minimum": pt_min, "mean": mean} + +allowed_kwargs = { + "edge": [], + "wrap": [], + "constant": ["constant_values"], + "linear_ramp": ["end_values"], + "maximum": ["stat_length"], + "mean": ["stat_length"], + "median": ["stat_length"], + "minimum": ["stat_length"], + "reflect": ["reflect_type"], + "symmetric": ["reflect_type"], +} + + +def _get_edges( + padded: TensorVariable, axis: int, width_pair: tuple[TensorVariable, TensorVariable] +) -> tuple[TensorVariable, TensorVariable]: + """ + Retrieve edge values from empty-padded array in given dimension. + + Copied from numpy.lib.arraypad._get_edges + https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L154 + + Parameters + ---------- + padded : TensorVariable + Empty-padded array. + axis : int + Dimension in which the edges are considered. + width_pair : (TensorVariable, TensorVariable) + Pair of widths that mark the pad area on both sides in the given + dimension. + + Returns + ------- + left_edge, right_edge : TensorVariable + Edge values of the valid area in `padded` in the given dimension. Its + shape will always match `padded` except for the dimension given by + `axis` which will have a length of 1. + """ + left_index = width_pair[0] + left_slice = slice_at_axis(slice(left_index, left_index + 1), axis) + left_edge = padded[left_slice] + + right_index = padded.shape[axis] - width_pair[1] + right_slice = slice_at_axis(slice(right_index - 1, right_index), axis) + right_edge = padded[right_slice] + + return left_edge, right_edge + + +def _symbolic_pad( + x: TensorVariable, pad_width: TensorVariable +) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]: + pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2))) + new_shape = as_tensor( + [pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)] + ) + original_area_slice = tuple( + slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape) + ) + padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x) + return padded, original_area_slice, pad_width + + +def _get_padding_slices( + dim_shape: TensorVariable, + width_pair: tuple[TensorVariable, TensorVariable], + axis: int, +) -> tuple[tuple[slice, ...], tuple[slice, ...]]: + left_slice = slice_at_axis(slice(None, width_pair[0]), axis) + right_slice = slice_at_axis(slice(dim_shape - width_pair[1], None), axis) + + return left_slice, right_slice + + +def _constant_pad( + x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable +) -> TensorVariable: + padded, area_slice, pad_width = _symbolic_pad(x, pad_width) + values = broadcast_to(constant_values, as_tensor((padded.ndim, 2))) + + for axis in range(padded.ndim): + width_pair = pad_width[axis] + value_pair = values[axis] + dim_shape = padded.shape[axis] + + left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) + padded = set_subtensor(padded[left_slice], value_pair[0]) + padded = set_subtensor(padded[right_slice], value_pair[1]) + + return padded + + +def _edge_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable: + padded, area_slice, pad_width = _symbolic_pad(x, pad_width) + for axis in range(padded.ndim): + width_pair = pad_width[axis] + dim_shape = padded.shape[axis] + + left_edge, right_edge = _get_edges(padded, axis, width_pair) + left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) + + padded = set_subtensor(padded[left_slice], left_edge) + padded = set_subtensor(padded[right_slice], right_edge) + + return padded + + +def _get_stats( + padded: TensorVariable, + axis: int, + width_pair: TensorVariable, + length_pair: tuple[TensorVariable, TensorVariable] | tuple[None, None], + stat_func: Callable, +): + """ + Calculate statistic for the empty-padded array in given dimension. + + Copied from numpy.lib.arraypad._get_stats + https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L230 + + Parameters + ---------- + padded : TensorVariable + Empty-padded array. + axis : int + Dimension in which the statistic is calculated. + width_pair : (TensorVariable, TensorVariable) + Pair of widths that mark the pad area on both sides in the given dimension. + length_pair : 2-element sequence of None or TensorVariable + Gives the number of values in valid area from each side that is taken into account when calculating the + statistic. If None the entire valid area in `padded` is considered. + stat_func : function + Function to compute statistic. The expected signature is + ``stat_func(x: TensorVariable, axis: int, keepdims: bool) -> TensorVariable``. + + Returns + ------- + left_stat, right_stat : TensorVariable + Calculated statistic for both sides of `padded`. + """ + # Calculate indices of the edges of the area with original values + left_index = width_pair[0] + right_index = padded.shape[axis] - width_pair[1] + # as well as its length + max_length = right_index - left_index + + # Limit stat_lengths to max_length + left_length, right_length = length_pair + + # Calculate statistic for the left side + left_length = ( + minimum(left_length, max_length) if left_length is not None else max_length + ) + left_slice = slice_at_axis(slice(left_index, left_index + left_length), axis) + left_chunk = padded[left_slice] + left_stat = stat_func(left_chunk, axis=axis, keepdims=True) + if left_length is None and right_length is None: + # We could also return early in the more general case of left_length == right_length, but we don't necessarily + # know these shapes. + # TODO: Add rewrite to simplify in this case + return left_stat, left_stat + + # Calculate statistic for the right side + right_length = ( + minimum(right_length, max_length) if right_length is not None else max_length + ) + right_slice = slice_at_axis(slice(right_index - right_length, right_index), axis) + right_chunk = padded[right_slice] + right_stat = stat_func(right_chunk, axis=axis, keepdims=True) + + return left_stat, right_stat + + +def _stat_pad( + x: TensorVariable, + pad_width: TensorVariable, + stat_func: Callable, + stat_length: TensorVariable | None, +): + padded, area_slice, pad_width = _symbolic_pad(x, pad_width) + if stat_length is None: + stat_length = [[None, None]] * padded.ndim # type: ignore + else: + stat_length = broadcast_to(stat_length, as_tensor((padded.ndim, 2))) + + for axis in range(padded.ndim): + width_pair = pad_width[axis] + length_pair = stat_length[axis] # type: ignore + dim_shape = padded.shape[axis] + + left_stat, right_stat = _get_stats( + padded, axis, width_pair, length_pair, stat_func + ) + left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) + padded = set_subtensor(padded[left_slice], left_stat) + padded = set_subtensor(padded[right_slice], right_stat) + + return padded + + +def _linear_ramp_pad( + x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable | int = 0 +) -> TensorVariable: + padded, area_slice, pad_width = _symbolic_pad(x, pad_width) + end_values = as_tensor(end_values) + end_values = broadcast_to(end_values, as_tensor((padded.ndim, 2))) + + for axis in range(padded.ndim): + width_pair = pad_width[axis] + end_value_pair = end_values[axis] + edge_pair = _get_edges(padded, axis, width_pair) + dim_shape = padded.shape[axis] + left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) + + left_ramp, right_ramp = ( + linspace( + start=end_value, + stop=specify_broadcastable(edge, axis).squeeze(axis), + num=width, + endpoint=False, + dtype=padded.dtype, + axis=axis, + ) + for end_value, edge, width in zip(end_value_pair, edge_pair, width_pair) + ) + + # Reverse the direction of the ramp for the "right" side + right_ramp = right_ramp[slice_at_axis(slice(None, None, -1), axis)] # type: ignore + + padded = set_subtensor(padded[left_slice], left_ramp) + padded = set_subtensor(padded[right_slice], right_ramp) + + return padded + + +def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable: + pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2))) + + for axis in range(x.ndim): + size = x.shape[axis] + + # Compute how many complete copies of the input will be padded on this dimension, along with the amount of + # overflow on the final copy + repeats, (left_remainder, right_remainder) = pt_divmod(pad_width[axis], size) + + # In the next step we will generate extra copies of the input, and then trim them down to the correct size. + left_trim = size - left_remainder + right_trim = size - right_remainder + + # The total number of copies needed is always the sum of the number of complete copies to add, plus the original + # input itself, plus the two edge copies that will be trimmed down. + total_repeats = repeats.sum() + 3 + + # Create a batch dimension and clone the input the required number of times + parts = expand_dims(x, (0,)).repeat(total_repeats, axis=0) + + # Move the batch dimension to the active dimension + parts = moveaxis(parts, 0, axis) + + # Ravel the active dimension while preserving the shapes of the inactive dimensions. This will expand the + # active dimension to have the correctly padded shape, plus excess to be trimmed + new_shape = [-1 if i == axis else x.shape[i] for i in range(x.ndim)] + x = parts.reshape(new_shape) + + # Trim the excess on the active dimension + trim_slice = slice_at_axis(slice(left_trim, -right_trim), axis) + x = x[trim_slice] + + return x + + +def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis): + [_, parts], _ = scan( + inner_func, + non_sequences=[array, array_flipped], + outputs_info=[0, None], + n_steps=repeats, + ) + + parts = moveaxis(parts, 0, axis) + new_shape = [-1 if i == axis else array.shape[i] for i in range(array.ndim)] + padding = parts.reshape(new_shape) + + return padding + + +def _symmetric_pad(x, pad_width): + def _symmetric_inner(i, x, x_flipped, padding_left): + return i + 1, ifelse(eq(i % 2, int(padding_left)), x_flipped, x) + + pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2))) + + for axis in range(x.ndim): + x_flipped = flip(x, axis=axis) + original_size = x.shape[axis] + + repeats, remainders = pt_divmod(pad_width[axis], original_size) + has_remainder = gt(remainders, 0) + repeats = repeats + has_remainder + + left_padding = _build_padding_one_direction( + x, + x_flipped, + repeats[0], + axis=axis, + inner_func=partial(_symmetric_inner, padding_left=True), + ) + right_padding = _build_padding_one_direction( + x, + x_flipped, + repeats[1], + axis=axis, + inner_func=partial(_symmetric_inner, padding_left=False), + ) + + x = concatenate([flip(left_padding, axis), x, right_padding], axis=axis) + + (left_trim, right_trim) = switch( + has_remainder, original_size - remainders, remainders + ) + right_trim = x.shape[axis] - right_trim + + trim_slice = slice_at_axis(slice(left_trim, right_trim), axis) + x = x[trim_slice] + + return x + + +def _reflect_pad(x, pad_width): + def _reflect_inner(i, x, x_flipped, padding_left): + return i + 1, ifelse(eq(i % 2, int(padding_left)), x_flipped, x) + + pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2))) + for axis in range(x.ndim): + trimmed_size = x.shape[axis] - 1 + + trim_slice = slice_at_axis(slice(None, -1), axis) + x_trimmed = x[trim_slice] + x_flipped = flip(x, axis=axis)[trim_slice] + + repeats, remainders = pt_divmod(pad_width[axis], trimmed_size) + repeats = repeats + 1 + + left_padding = _build_padding_one_direction( + x_trimmed, + x_flipped, + repeats[0], + axis=axis, + inner_func=partial(_reflect_inner, padding_left=True), + ) + right_padding = _build_padding_one_direction( + x_trimmed, + x_flipped, + repeats[1], + axis=axis, + inner_func=partial(_reflect_inner, padding_left=False), + ) + + left_trim = slice_at_axis(slice(trimmed_size - remainders[0] - 1, -1), axis) + right_trim = slice_at_axis( + slice(1, right_padding.shape[axis] - trimmed_size + remainders[1] + 1), axis + ) + + x = concatenate( + [flip(left_padding, axis)[left_trim], x, right_padding[right_trim]], + axis=axis, + ) + return x + + +class Pad(OpFromGraph): + """ + Wrapper Op for Pad graphs + """ + + def __init__( + self, inputs, outputs, pad_mode, reflect_type=None, has_stat_length=False + ): + self.pad_mode = pad_mode + self.reflect_type = reflect_type + self.has_stat_length = has_stat_length + + super().__init__(inputs=inputs, outputs=outputs) + + +def pad( + x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs +) -> TensorVariable: + """ + Pad an array. + + Parameters + ---------- + array : array_like of rank N + The array to pad. + + pad_width : sequence, array_like, or int + Number of values padded to the edges of each axis. + ``((before_1, after_1), ... (before_N, after_N))`` unique pad widths + for each axis. + ``(before, after)`` or ``((before, after),)`` yields same before + and after pad for each axis. + ``(pad,)`` or ``int`` is a shortcut for before = after = pad width + for all axes. + + mode : str or function, optional + One of the following string values or a user supplied function. + + 'constant' (default) + Pads with a constant value. + 'edge' + Pads with the edge values of array. + 'linear_ramp' + Pads with the linear ramp between end_value and the + array edge value. + 'maximum' + Pads with the maximum value of all or part of the + vector along each axis. + 'mean' + Pads with the mean value of all or part of the + vector along each axis. + 'minimum' + Pads with the minimum value of all or part of the + vector along each axis. + 'reflect' + Pads with the reflection of the vector mirrored on + the first and last values of the vector along each + axis. + 'symmetric' + Pads with the reflection of the vector mirrored + along the edge of the array. + 'wrap' + Pads with the wrap of the vector along the axis. + The first values are used to pad the end and the + end values are used to pad the beginning. + + stat_length : sequence or int, optional + Used in 'maximum', 'mean', and 'minimum'. Number of + values at edge of each axis used to calculate the statistic value. + + ``((before_1, after_1), ... (before_N, after_N))`` unique statistic + lengths for each axis. + + ``(before, after)`` or ``((before, after),)`` yields same before + and after statistic lengths for each axis. + + ``(stat_length,)`` or ``int`` is a shortcut for + ``before = after = statistic`` length for all axes. + + Default is ``None``, to use the entire axis. + + constant_values : sequence or scalar, optional + Used in 'constant'. The values to set the padded values for each + axis. + + ``((before_1, after_1), ... (before_N, after_N))`` unique pad constants + for each axis. + + ``(before, after)`` or ``((before, after),)`` yields same before + and after constants for each axis. + + ``(constant,)`` or ``constant`` is a shortcut for + ``before = after = constant`` for all axes. + + Default is 0. + + end_values : sequence or scalar, optional + Used in 'linear_ramp'. The values used for the ending value of the + linear_ramp and that will form the edge of the padded array. + + ``((before_1, after_1), ... (before_N, after_N))`` unique end values + for each axis. + + ``(before, after)`` or ``((before, after),)`` yields same before + and after end values for each axis. + + ``(constant,)`` or ``constant`` is a shortcut for + ``before = after = constant`` for all axes. + + Default is 0. + + reflect_type : str, optional + Only 'even' is currently accepted. Used in 'reflect', and 'symmetric'. The 'even' style is the + default with an unaltered reflection around the edge value. + + Returns + ------- + pad : ndarray + Padded array of rank equal to `array` with shape increased + according to `pad_width`. + + Examples + -------- + + .. testcode:: + + import pytensor.tensor as pt + a = [1, 2, 3, 4, 5] + print(pt.pad(a, (2, 3), 'constant', constant_values=(4, 6)).eval()) + + .. testoutput:: + + [4. 4. 1. 2. 3. 4. 5. 6. 6. 6.] + + .. testcode:: + + print(pt.pad(a, (2, 3), 'edge').eval()) + + .. testoutput:: + + [1. 1. 1. 2. 3. 4. 5. 5. 5. 5.] + + .. testcode:: + + print(pt.pad(a, (2, 3), 'linear_ramp', end_values=(5, -4)).eval()) + + .. testoutput:: + + [ 5. 3. 1. 2. 3. 4. 5. 2. -1. -4.] + + .. testcode:: + + print(pt.pad(a, (2,), 'maximum').eval()) + + .. testoutput:: + + [5. 5. 1. 2. 3. 4. 5. 5. 5.] + + .. testcode:: + + print(pt.pad(a, (2,), 'mean').eval()) + + .. testoutput:: + + [3. 3. 1. 2. 3. 4. 5. 3. 3.] + + .. testcode:: + + a = [[1, 2], [3, 4]] + print(pt.pad(a, ((3, 2), (2, 3)), 'minimum').eval()) + + .. testoutput:: + + [[1. 1. 1. 2. 1. 1. 1.] + [1. 1. 1. 2. 1. 1. 1.] + [1. 1. 1. 2. 1. 1. 1.] + [1. 1. 1. 2. 1. 1. 1.] + [3. 3. 3. 4. 3. 3. 3.] + [1. 1. 1. 2. 1. 1. 1.] + [1. 1. 1. 2. 1. 1. 1.]] + + .. testcode:: + + a = [1, 2, 3, 4, 5] + print(pt.pad(a, (2, 3), 'reflect').eval()) + + .. testoutput:: + + [3 2 1 2 3 4 5 4 3 2] + + .. testcode:: + + print(pt.pad(a, (2, 3), 'symmetric').eval()) + + .. testoutput:: + + [2 1 1 2 3 4 5 5 4 3] + + .. testcode:: + + print(pt.pad(a, (2, 3), 'wrap').eval()) + + .. testoutput:: + + [4 5 1 2 3 4 5 1 2 3] + + """ + if any(value not in allowed_kwargs[mode] for value in kwargs.keys()): + raise ValueError( + f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}" + ) + x = as_tensor(x, name="x") + pad_width = as_tensor(pad_width, name="pad_width") + inputs = [x, pad_width] + attrs = {} + + if mode == "constant": + constant_values = as_tensor( + kwargs.pop("constant_values", 0), name="constant_values" + ) + inputs += [constant_values] + outputs = _constant_pad(x, pad_width, constant_values) + + elif mode == "edge": + outputs = _edge_pad(x, pad_width) + + elif mode in ["maximum", "minimum", "mean", "median"]: + if mode == "median": + # TODO: Revisit this after we implement a quantile function. + # See https://github.com/pymc-devs/pytensor/issues/53 + raise NotImplementedError("Median padding not implemented") + stat_func = cast(Callable, stat_funcs[mode]) + stat_length = kwargs.get("stat_length") + if stat_length is not None: + attrs.update({"has_stat_length": True}) + stat_length = as_tensor(stat_length, name="stat_length") + inputs += [stat_length] + + outputs = _stat_pad(x, pad_width, stat_func, stat_length) + + elif mode == "linear_ramp": + end_values = kwargs.pop("end_values", 0) + end_values = as_tensor(end_values) + + inputs += [end_values] + outputs = _linear_ramp_pad(x, pad_width, end_values) + + elif mode == "wrap": + outputs = _wrap_pad(x, pad_width) + + elif mode == "symmetric": + reflect_type = kwargs.pop("reflect_type", "even") + if reflect_type == "odd": + raise NotImplementedError( + "Odd reflection not implemented. If you need this feature, please open an " + "issue at https://github.com/pymc-devs/pytensor/issues" + ) + attrs.update({"reflect_type": reflect_type}) + outputs = _symmetric_pad(x, pad_width) + + elif mode == "reflect": + reflect_type = kwargs.pop("reflect_type", "even") + if reflect_type == "odd": + raise NotImplementedError( + "Odd reflection not implemented. If you need this feature, please open an " + "issue at https://github.com/pymc-devs/pytensor/issues" + ) + attrs.update({"reflect_type": reflect_type}) + outputs = _reflect_pad(x, pad_width) + + else: + raise ValueError(f"Invalid mode: {mode}") + + op = Pad(inputs=inputs, outputs=[outputs], pad_mode=mode, **attrs)(*inputs) + return cast(TensorVariable, op) + + +__all__ = ["pad", "flip"] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index a21f2d7dcc..41b4c6bd5a 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3013,8 +3013,123 @@ def _get_vector_length_Subtensor(op, var): raise ValueError(f"Length of {var} cannot be determined") +def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]: + """ + Construct tuple of slices to slice an array in the given dimension. + + Copied from numpy.lib.arraypad._slice_at_axis + https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33 + + Parameters + ---------- + sl : slice + The slice for the given dimension. + axis : int + The axis to which `sl` is applied. All other dimensions are left + "unsliced". + + Returns + ------- + sl : tuple of slices + A tuple with slices matching `shape` in length. + + Examples + -------- + + .. testcode:: + + import pytensor.tensor as pt + + s = pt.slice_at_axis(slice(None, 1), 1) + print(s) + + .. testoutput:: + + (slice(None, None, None), slice(None, 1, None), Ellipsis) + + .. testcode:: + + x = pt.tensor('x', shape=(None, None, None)) + x_sliced = x[s] + + f = pytensor.function([x], x_sliced) + x = np.arange(27).reshape(3, 3, 3) + print(f(x)) + + .. testoutput:: + [[[ 0. 1. 2.]] + + [[ 9. 10. 11.]] + + [[18. 19. 20.]]] + + """ + if axis >= 0: + return (slice(None),) * axis + (sl,) + (...,) # type: ignore + else: + # If axis = -1 we want zero right padding (and so on), so subtract one + axis = abs(axis) - 1 + return (...,) + (sl,) + (slice(None),) * axis # type: ignore + + +def flip( + arr: TensorVariable, axis: int | tuple[int] | TensorVariable | None = None +) -> TensorVariable: + """ + Reverse the order of elements in an tensor along the given axis. + + Parameters + ---------- + arr: TensorVariable + Input tensor. + + axis: int | tuple[int] | TensorVariable, optional + Axis or axes along which to flip over. The default is to flip over all of the axes of the input tensor. + + Returns + ------- + arr: TensorVariable + A view of `arr` with the entries of axis reversed. + + Examples + -------- + + .. testcode:: + + import pytensor + import pytensor.tensor as pt + + x = pt.tensor('x', shape=(None, None)) + x_flipped = pt.flip(x, axis=0) + + f = pytensor.function([x], x_flipped) + x = [[1, 2], [3, 4]] + print(f(x)) + + .. testoutput:: + [[3. 4.] + [1. 2.]] + + """ + if axis is None: + index = ((slice(None, None, -1)),) * arr.ndim + else: + if isinstance(axis, int): + axis = (axis,) + index = tuple( + [ + slice(None, None, -1) if i in axis else slice(None, None, None) + for i in range(arr.ndim) + ] + ) + + return cast(TensorVariable, arr[index]) + + __all__ = [ "take", + "flip", + "slice_at_axis", "inc_subtensor", "set_subtensor", ] diff --git a/tests/link/jax/test_pad.py b/tests/link/jax/test_pad.py new file mode 100644 index 0000000000..2321645741 --- /dev/null +++ b/tests/link/jax/test_pad.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import config +from pytensor.graph import FunctionGraph +from pytensor.tensor.pad import PadMode +from tests.link.jax.test_basic import compare_jax_and_py + + +jax = pytest.importorskip("jax") +floatX = config.floatX +RTOL = ATOL = 1e-6 if floatX.endswith("64") else 1e-3 + + +@pytest.mark.parametrize( + "mode, kwargs", + [ + ("constant", {"constant_values": 0}), + ("constant", {"constant_values": (1, 2)}), + ("edge", {}), + ("linear_ramp", {"end_values": 0}), + ("linear_ramp", {"end_values": (1, 2)}), + ("reflect", {"reflect_type": "even"}), + ("wrap", {}), + ("symmetric", {"reflect_type": "even"}), + ("mean", {"stat_length": None}), + ("mean", {"stat_length": (10, 2)}), + ("maximum", {"stat_length": None}), + ("maximum", {"stat_length": (10, 2)}), + ("minimum", {"stat_length": None}), + ("minimum", {"stat_length": (10, 2)}), + ], + ids=[ + "constant_default", + "constant_tuple", + "edge", + "linear_ramp_default", + "linear_ramp_tuple", + "reflect", + "wrap", + "symmetric", + "mean_default", + "mean_tuple", + "maximum_default", + "maximum_tuple", + "minimum_default", + "minimum_tuple", + ], +) +def test_jax_pad(mode: PadMode, kwargs): + x_pt = pt.tensor("x", shape=(3, 3)) + x = np.random.normal(size=(3, 3)) + + res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) + res_fg = FunctionGraph([x_pt], [res]) + + compare_jax_and_py( + res_fg, + [x], + assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), + py_mode="FAST_RUN", + ) diff --git a/tests/link/numba/test_pad.py b/tests/link/numba/test_pad.py new file mode 100644 index 0000000000..11877594d7 --- /dev/null +++ b/tests/link/numba/test_pad.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import config +from pytensor.graph import FunctionGraph +from pytensor.tensor.pad import PadMode +from tests.link.numba.test_basic import compare_numba_and_py + + +floatX = config.floatX +RTOL = ATOL = 1e-6 if floatX.endswith("64") else 1e-3 + + +@pytest.mark.parametrize( + "mode, kwargs", + [ + ("constant", {"constant_values": 0}), + ("constant", {"constant_values": (1, 2)}), + pytest.param( + "edge", + {}, + marks=pytest.mark.skip( + "This is causing a segfault in NUMBA mode, but I have no idea why" + ), + ), + ("linear_ramp", {"end_values": 0}), + ("linear_ramp", {"end_values": (1, 2)}), + ("reflect", {"reflect_type": "even"}), + ("wrap", {}), + ("symmetric", {"reflect_type": "even"}), + ("mean", {"stat_length": None}), + ("mean", {"stat_length": (10, 2)}), + ("maximum", {"stat_length": None}), + ("maximum", {"stat_length": (10, 2)}), + ("minimum", {"stat_length": None}), + ("minimum", {"stat_length": (10, 2)}), + ], + ids=[ + "constant_default", + "constant_tuple", + "edge", + "linear_ramp_default", + "linear_ramp_tuple", + "reflect", + "wrap", + "symmetric", + "mean_default", + "mean_tuple", + "maximum_default", + "maximum_tuple", + "minimum_default", + "minimum_tuple", + ], +) +def test_numba_pad(mode: PadMode, kwargs): + x_pt = pt.tensor("x", shape=(3, 3)) + x = np.random.normal(size=(3, 3)) + + res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) + res_fg = FunctionGraph([x_pt], [res]) + + compare_numba_and_py( + res_fg, + [x], + assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), + py_mode="FAST_RUN", + ) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 4376ab1d32..3b3cc5ec7f 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -35,9 +35,6 @@ diff, fill_diagonal, fill_diagonal_offset, - geomspace, - linspace, - logspace, ravel_multi_index, repeat, searchsorted, @@ -1281,25 +1278,37 @@ def test_broadcast_arrays(): @pytest.mark.parametrize( - "start, stop, num_samples", + "op", + ["linspace", "logspace", "geomspace"], + ids=["linspace", "logspace", "geomspace"], +) +@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"]) +@pytest.mark.parametrize( + "start, stop, num_samples, endpoint, axis", [ - (1, 10, 50), - (np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25), - (1, np.array([5, 6]), 30), + (1, 10, 50, True, 0), + (1, 10, 1, True, 0), + (np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 0), + (np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, True, 1), + (np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25, False, -1), + (1, np.array([5, 6]), 30, True, 0), + (1, np.array([5, 6]), 30, False, -1), ], ) -def test_space_ops(start, stop, num_samples): - z = linspace(start, stop, num_samples) - pytensor_res = function(inputs=[], outputs=z)() - numpy_res = np.linspace(start, stop, num=num_samples) - assert np.allclose(pytensor_res, numpy_res) - - z = logspace(start, stop, num_samples) - pytensor_res = function(inputs=[], outputs=z)() - numpy_res = np.logspace(start, stop, num=num_samples) - assert np.allclose(pytensor_res, numpy_res) - - z = geomspace(start, stop, num_samples) - pytensor_res = function(inputs=[], outputs=z)() - numpy_res = np.geomspace(start, stop, num=num_samples) - assert np.allclose(pytensor_res, numpy_res) +def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis): + pt_func = getattr(pt, op) + np_func = getattr(np, op) + dtype = dtype + config.floatX[-2:] if dtype is not None else dtype + z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype) + + numpy_res = np_func( + start, stop, num=num_samples, endpoint=endpoint, dtype=dtype, axis=axis + ) + pytensor_res = function(inputs=[], outputs=z, mode="FAST_COMPILE")() + + np.testing.assert_allclose( + pytensor_res, + numpy_res, + atol=1e-6 if config.floatX.endswith("64") else 1e-4, + rtol=1e-6 if config.floatX.endswith("64") else 1e-4, + ) diff --git a/tests/tensor/test_pad.py b/tests/tensor/test_pad.py new file mode 100644 index 0000000000..54df4a12e1 --- /dev/null +++ b/tests/tensor/test_pad.py @@ -0,0 +1,224 @@ +from typing import Literal + +import numpy as np +import pytest + +import pytensor +from pytensor.tensor.pad import PadMode, pad + + +floatX = pytensor.config.floatX +RTOL = ATOL = 1e-8 if floatX.endswith("64") else 1e-4 + + +def test_unknown_mode_raises(): + x = np.random.normal(size=(3, 3)).astype(floatX) + with pytest.raises(ValueError, match="Invalid mode: unknown"): + pad(x, 1, mode="unknown") + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 3, 3)], ids=["1d", "2d square", "3d square"] +) +@pytest.mark.parametrize("constant", [0, 0.0], ids=["int", "float"]) +@pytest.mark.parametrize( + "pad_width", + [10, (10, 0), (0, 10)], + ids=["symmetrical", "asymmetrical_left", "asymmetric_right"], +) +def test_constant_pad( + size: tuple, constant: int | float, pad_width: int | tuple[int, ...] +): + x = np.random.normal(size=size).astype(floatX) + expected = np.pad(x, pad_width, mode="constant", constant_values=constant) + z = pad(x, pad_width, mode="constant", constant_values=constant) + assert z.owner.op.pad_mode == "constant" + + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +@pytest.mark.parametrize( + "pad_width", + [10, (10, 0), (0, 10)], + ids=["symmetrical", "asymmetrical_left", "asymmetric_right"], +) +def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]): + x = np.random.normal(size=size).astype(floatX) + expected = np.pad(x, pad_width, mode="edge") + z = pad(x, pad_width, mode="edge") + assert z.owner.op.pad_mode == "edge" + + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +@pytest.mark.parametrize( + "pad_width", + [10, (10, 0), (0, 10)], + ids=["symmetrical", "asymmetrical_left", "asymmetric_right"], +) +@pytest.mark.parametrize("end_values", [0, -1], ids=["0", "-1"]) +def test_linear_ramp_pad( + size: tuple, + pad_width: int | tuple[int, ...], + end_values: int | float | tuple[int | float, ...], +): + x = np.random.normal(size=size).astype(floatX) + expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values) + z = pad(x, pad_width, mode="linear_ramp", end_values=end_values) + assert z.owner.op.pad_mode == "linear_ramp" + + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +@pytest.mark.parametrize( + "pad_width", + [10, (10, 0), (0, 10)], + ids=["symmetrical", "asymmetrical_left", "asymmetric_right"], +) +@pytest.mark.parametrize("stat", ["mean", "minimum", "maximum"]) +@pytest.mark.parametrize("stat_length", [None, 2]) +def test_stat_pad( + size: tuple, + pad_width: int | tuple[int, ...], + stat: PadMode, + stat_length: int | None, +): + x = np.random.normal(size=size).astype(floatX) + expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length) + z = pad(x, pad_width, mode=stat, stat_length=stat_length) + assert z.owner.op.pad_mode == stat + + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +@pytest.mark.parametrize( + "pad_width", + [10, (10, 0), (0, 10)], + ids=["symmetrical", "asymmetrical_left", "asymmetric_right"], +) +def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]): + x = np.random.normal(size=size).astype(floatX) + expected = np.pad(x, pad_width, mode="wrap") + z = pad(x, pad_width, mode="wrap") + assert z.owner.op.pad_mode == "wrap" + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +@pytest.mark.parametrize( + "pad_width", + [10, (10, 0), (0, 10)], + ids=["symmetrical", "asymmetrical_left", "asymmetric_right"], +) +@pytest.mark.parametrize( + "reflect_type", + ["even", pytest.param("odd", marks=pytest.mark.xfail(raises=NotImplementedError))], + ids=["even", "odd"], +) +def test_symmetric_pad( + size, + pad_width, + reflect_type: Literal["even", "odd"], +): + x = np.random.normal(size=size).astype(floatX) + expected = np.pad(x, pad_width, mode="symmetric", reflect_type=reflect_type) + z = pad(x, pad_width, mode="symmetric", reflect_type=reflect_type) + assert z.owner.op.pad_mode == "symmetric" + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +@pytest.mark.parametrize( + "pad_width", + [10, (10, 0), (0, 10)], + ids=["symmetrical", "asymmetrical_left", "asymmetric_right"], +) +@pytest.mark.parametrize( + "reflect_type", + ["even", pytest.param("odd", marks=pytest.mark.xfail(raises=NotImplementedError))], + ids=["even", "odd"], +) +def test_reflect_pad( + size, + pad_width, + reflect_type: Literal["even", "odd"], +): + x = np.random.normal(size=size).astype(floatX) + expected = np.pad(x, pad_width, mode="reflect", reflect_type=reflect_type) + z = pad(x, pad_width, mode="reflect", reflect_type=reflect_type) + assert z.owner.op.pad_mode == "reflect" + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "mode", + [ + "constant", + "edge", + "linear_ramp", + "wrap", + "symmetric", + "reflect", + "mean", + "maximum", + "minimum", + ], +) +@pytest.mark.parametrize("padding", ["symmetric", "asymmetric"]) +def test_nd_padding(mode, padding): + rng = np.random.default_rng() + n = rng.integers(3, 5) + if padding == "symmetric": + pad_width = [(i, i) for i in rng.integers(1, 5, size=n)] + stat_length = [(i, i) for i in rng.integers(1, 5, size=n)] + else: + pad_width = rng.integers(1, 5, size=(n, 2)).tolist() + stat_length = rng.integers(1, 5, size=(n, 2)).tolist() + + test_kwargs = { + "constant": {"constant_values": 0}, + "linear_ramp": {"end_values": 0}, + "maximum": {"stat_length": stat_length}, + "mean": {"stat_length": stat_length}, + "minimum": {"stat_length": stat_length}, + "reflect": {"reflect_type": "even"}, + "symmetric": {"reflect_type": "even"}, + } + + x = np.random.normal(size=(2,) * n).astype(floatX) + kwargs = test_kwargs.get(mode, {}) + expected = np.pad(x, pad_width, mode=mode, **kwargs) + z = pad(x, pad_width, mode=mode, **kwargs) + f = pytensor.function([], z, mode="FAST_COMPILE") + + np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 427287dcfd..d02880f543 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -37,11 +37,13 @@ advanced_subtensor1, as_index_literal, basic_shape, + flip, get_canonical_form_slice, inc_subtensor, index_vars_to_types, indexed_result_shape, set_subtensor, + slice_at_axis, take, ) from pytensor.tensor.type import ( @@ -2902,3 +2904,39 @@ def test_vectorize_adv_subtensor( vectorize_pt(x_test, idx_test), vectorize_np(x_test, idx_test), ) + + +def test_slice_at_axis(): + x = ptb.tensor("x", shape=(3, 4, 5)) + x_sliced = x[slice_at_axis(slice(None, 1), axis=0)] + assert x_sliced.type.shape == (1, 4, 5) + + # Negative axis + x_sliced = x[slice_at_axis(slice(None, 1), axis=-2)] + assert x_sliced.type.shape == (3, 1, 5) + + +@pytest.mark.parametrize( + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] +) +def test_flip(size: tuple[int]): + from itertools import combinations + + ATOL = RTOL = 1e-8 if config.floatX == "float64" else 1e-4 + + x = np.random.normal(size=size).astype(config.floatX) + x_pt = pytensor.tensor.tensor(shape=size, name="x") + expected = np.flip(x, axis=None) + z = flip(x_pt, axis=None) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) + + # Test all combinations of axes + flip_options = [ + axes for i in range(1, x.ndim + 1) for axes in combinations(range(x.ndim), r=i) + ] + for axes in flip_options: + expected = np.flip(x, axis=list(axes)) + z = flip(x_pt, axis=list(axes)) + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)