Skip to content

Commit

Permalink
Merge pull request #47 from robertknight/fix-gather-negative-index
Browse files Browse the repository at this point in the history
Allow negative indices in `Gather` op
  • Loading branch information
robertknight authored Feb 2, 2024
2 parents 1881790 + 573ded4 commit f6184a4
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ pub fn gather<T: Copy + Default>(
let axis = resolve_axis(input.ndim(), axis)?;

for index in indices.iter().copied() {
if index < 0 || index >= input.size(axis) as i32 {
let size = input.size(axis) as i32;
if index < -size || index >= size {
return Err(OpError::InvalidValue("Entry in `indices` is out of range"));
}
}
Expand Down Expand Up @@ -500,6 +501,13 @@ mod tests {
let result = gather(input.view(), 1, indices.view()).unwrap();
expect_equal(&result, &expected)?;

// Negative index values.
let input = Tensor::from([1, 2, 3]);
let indices = Tensor::from([-1, -2, -3]);
let expected = Tensor::from([3, 2, 1]);
let result = gather(input.view(), 0, indices.view()).unwrap();
assert_eq!(&result, &expected);

Ok(())
}

Expand Down

0 comments on commit f6184a4

Please sign in to comment.