From 573ded4cbb59ca4e3ad1a4163cedb89dfade7e65 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 2 Feb 2024 22:41:33 +0000 Subject: [PATCH] Allow negative indices in `Gather` op Per the ONNX spec [1]: > All index values are expected to be within bounds [-s, s-1] along axis of size s. The implementation already supported negative values, only the bounds check preceding it did not. [1] https://onnx.ai/onnx/operators/onnx__Gather.html Fixes https://github.com/robertknight/rten/issues/46 --- src/ops/gather.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ops/gather.rs b/src/ops/gather.rs index e4c449ab..08938c00 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -26,7 +26,8 @@ pub fn gather( let axis = resolve_axis(input.ndim(), axis)?; for index in indices.iter().copied() { - if index < 0 || index >= input.size(axis) as i32 { + let size = input.size(axis) as i32; + if index < -size || index >= size { return Err(OpError::InvalidValue("Entry in `indices` is out of range")); } } @@ -500,6 +501,13 @@ mod tests { let result = gather(input.view(), 1, indices.view()).unwrap(); expect_equal(&result, &expected)?; + // Negative index values. + let input = Tensor::from([1, 2, 3]); + let indices = Tensor::from([-1, -2, -3]); + let expected = Tensor::from([3, 2, 1]); + let result = gather(input.view(), 0, indices.view()).unwrap(); + assert_eq!(&result, &expected); + Ok(()) }