Skip to content

Commit

Permalink
Merge pull request #365 from robertknight/try-slice-with
Browse files Browse the repository at this point in the history
Add fallible variants of `TensorBase::slice_with`
  • Loading branch information
robertknight authored Sep 19, 2024
2 parents 319b97e + f7a5523 commit aecc370
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 146 deletions.
72 changes: 64 additions & 8 deletions rten-tensor/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use std::error::Error;
use std::fmt::{Display, Formatter};

use crate::slice_range::SliceRange;

/// Error in a tensor operation if the dimension count is incorrect.
#[derive(Debug, PartialEq)]
pub struct DimensionError {}
Expand Down Expand Up @@ -47,28 +49,82 @@ impl Error for FromDataError {}
#[derive(Clone, Debug, PartialEq)]
pub enum SliceError {
/// The slice spec has more dimensions than the tensor being sliced.
TooManyDims,
TooManyDims {
/// Number of axes in the tensor.
ndim: usize,
/// Number of items in the slice spec.
range_ndim: usize,
},

/// An index in the slice spec is out of bounds for the corresponding tensor
/// dimension.
InvalidIndex,
InvalidIndex {
/// Axis that the error applies to.
axis: usize,
/// Index in the slice range.
index: isize,
/// Size of the dimension.
size: usize,
},

/// A range in the slice spec is out of bounds for the corresponding tensor
/// dimension.
InvalidRange,
InvalidRange {
/// Axis that the error applies to.
axis: usize,

/// The range item.
range: SliceRange,

/// Size of the dimension.
size: usize,
},

/// The step in a slice range is negative, in a context where this is not
/// supported.
InvalidStep,
InvalidStep {
/// Axis that the error applies to.
axis: usize,

/// Size of the dimension.
step: isize,
},

/// There is a mismatch between the actual and expected number of axes
/// in the output slice.
OutputDimsMismatch { actual: usize, expected: usize },
}

impl Display for SliceError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SliceError::TooManyDims => write!(f, "slice spec has too many dims"),
SliceError::InvalidIndex => write!(f, "slice index is invalid"),
SliceError::InvalidRange => write!(f, "slice range is invalid"),
SliceError::InvalidStep => write!(f, "slice step is invalid"),
SliceError::TooManyDims { ndim, range_ndim } => {
write!(
f,
"slice range has {} items but tensor has only {} dims",
range_ndim, ndim
)
}
SliceError::InvalidIndex { axis, index, size } => write!(
f,
"slice index {} is invalid for axis ({}) of size {}",
index, axis, size
),
SliceError::InvalidRange { axis, range, size } => write!(
f,
"slice range {:?} is invalid for axis ({}) of size {}",
range, axis, size
),
SliceError::InvalidStep { axis, step } => {
write!(f, "slice step {} is invalid for axis {}", step, axis)
}
SliceError::OutputDimsMismatch { actual, expected } => {
write!(
f,
"slice output dims {} does not match expected dims {}",
actual, expected
)
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ impl LaneRanges {
(0..end).into()
})
.collect();
let (_range, sliced) = layout.slice_dyn(&slice_starts);
let (_range, sliced) = layout.slice_dyn(&slice_starts).unwrap();
let offsets = Offsets::new(&sliced);
LaneRanges {
offsets,
Expand Down
Loading

0 comments on commit aecc370

Please sign in to comment.