Skip to content

Commit

Permalink
Modify test for median
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 13, 2024
1 parent c1cb563 commit 9874a55
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 36 deletions.
43 changes: 32 additions & 11 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,33 +1580,54 @@ def median(input, axis=None):
Parameters
----------
input: TensorVariable
The input tensor.
axis: None or int or (list of int) (see `Sum`)
Compute the median along this axis of the tensor.
None means all axes (like numpy).
Notes
-----
This function uses the numpy implementation of median.
"""
from pytensor.ifelse import ifelse

input = as_tensor_variable(input)
input_ndim = input.type.ndim
if axis is None:
input = input.flatten()
axis = 0
axis = list(range(input_ndim))
elif isinstance(axis, int | np.integer):
axis = [axis]

Check warning on line 1596 in pytensor/tensor/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/math.py#L1596

Added line #L1596 was not covered by tests
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]

input = as_tensor_variable(input)
new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis)
input = input.dimshuffle(new_axes_order)
input_shape = input.shape

remaining_axis_size = shape(input)[: input.ndim - len(axis)]
flattened_axis_size = prod(shape(input)[input.ndim - len(axis) :])

input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]]))
axis = -1

# Sort the input tensor along the specified axis
sorted_input = input.sort(axis=axis)
shape = input.shape[axis]
k = extract_constant(shape) // 2
input_shape = input.shape[axis]
k = extract_constant(input_shape) // 2

indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis)
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
median_val_even = (ans1 + ans2) / 2.0

indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
median_val_odd = take_along_axis(sorted_input, indices, axis=axis)
median_val = ifelse(eq(mod(shape, 2), 0), median_val_even, median_val_odd)
median_val_odd = (
take_along_axis(sorted_input, indices, axis=axis) / 1.0
) # Divide by one so that the two dtypes passed in ifelse are compatible

median_val = ifelse(eq(mod(input_shape, 2), 0), median_val_even, median_val_odd)
median_val.name = "median"

return median_val.squeeze(axis=axis)


Expand Down
48 changes: 23 additions & 25 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3735,32 +3735,30 @@ def test_nan_to_num(nan, posinf, neginf):


@pytest.mark.parametrize(
"data, axis",
"ndim, axis",
[
# 1D array
([1, 7, 3, 6, 5, 2, 4], None),
([1, 7, 3, 6, 5, 2, 4], 0),
# 2D array
([[6, 2], [4, 3], [1, 5]], 0),
([[6, 2], [4, 3], [1, 5]], 1),
# 3D array
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], None),
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 0),
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 1),
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 2),
# 4D array
(
[
[[[3, 1], [4, 3]], [[0, 5], [6, 2]], [[7, 8], [9, 4]]],
[[[10, 11], [12, 13]], [[14, 15], [16, 17]], [[18, 19], [20, 21]]],
],
3,
),
(2, None),
(2, 1),
(2, (0, 1)),
(3, None),
(3, (1, 2)),
(4, (1, 3, 0)),
],
)
def test_median(data, axis):
x = tensor(shape=np.array(data).shape)
def test_median(ndim, axis):
# Generate random data with both odd and even lengths
shape_even = np.arange(1, ndim + 1) * 2
shape_odd = shape_even - 1

data_even = np.random.rand(*shape_even)
data_odd = np.random.rand(*shape_odd)

x = tensor(dtype="float64", shape=(None,) * ndim)
f = function([x], median(x, axis=axis))
result = f(data)
expected = np.median(data, axis=axis)
assert np.allclose(result, expected)
result_odd = f(data_odd)
result_even = f(data_even)
expected_odd = np.median(data_odd, axis=axis)
expected_even = np.median(data_even, axis=axis)

assert np.allclose(result_odd, expected_odd)
assert np.allclose(result_even, expected_even)

0 comments on commit 9874a55

Please sign in to comment.