Skip to content

Commit

Permalink
Add a few additional tests and doc comments in tensor.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Jan 22, 2024
1 parent 5575881 commit e3964f6
Showing 1 changed file with 70 additions and 20 deletions.
90 changes: 70 additions & 20 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,19 @@ pub trait AsView: Layout {
/// Return a view with dimensions permuted in the order given by `dims`.
fn permuted(
&self,
dims: Self::Index<'_>,
order: Self::Index<'_>,
) -> TensorBase<Self::Elem, &[Self::Elem], Self::Layout> {
self.view().permuted(dims)
self.view().permuted(order)
}

/// Return a view with a given shape, without copying any data. This
/// requires that the tensor is contiguous.
///
/// The new shape must have the same number of elments as the current
/// shape. The result will have a static rank if `shape` is an array or
/// a dynamic rank if it is a slice.
///
/// Panics if the tensor is not contiguous.
fn reshaped<S: IntoLayout>(
&self,
shape: S,
Expand Down Expand Up @@ -187,10 +195,10 @@ pub trait AsView: Layout {

/// Slice this tensor and return a static-rank view with `M` dimensions.
///
/// Use [AsView::slice_dyn] instead the number of dimensions in the returned
/// view is unknown at compile time.
/// Use [AsView::slice_dyn] instead if the number of dimensions in the
/// returned view is unknown at compile time.
///
/// Panics if the dimension count is not `M`.
/// Panics if the dimension count of the result is not `M`.
fn slice<const M: usize, R: IntoSliceItems>(&self, range: R) -> NdTensorView<Self::Elem, M> {
self.view().slice(range)
}
Expand Down Expand Up @@ -221,7 +229,12 @@ pub trait AsView: Layout {
where
Self::Elem: Clone;

/// Return a vector with the same shape but with strides in contiguous order.
/// Return a tensor with the same shape as this tensor/view but with the
/// data contiguous in memory and arranged in the same order as the
/// logical/iteration order (used by `iter`).
///
/// This will return a view if the data is already contiguous or copy
/// data into a new buffer otherwise.
///
/// Certain operations require or are faster with contiguous tensors.
fn to_contiguous(&self) -> TensorBase<Self::Elem, Cow<[Self::Elem]>, Self::Layout>
Expand All @@ -239,7 +252,7 @@ pub trait AsView: Layout {
where
Self::Elem: Clone;

/// Return clone of this tensor which uniquely owns its elements.
/// Return a copy of this tensor/view which uniquely owns its elements.
fn to_tensor(&self) -> TensorBase<Self::Elem, Vec<Self::Elem>, Self::Layout>
where
Self::Elem: Clone,
Expand Down Expand Up @@ -488,6 +501,9 @@ impl<'a> IntoLayout for &'a [usize] {

/// Trait which extends [MutLayout] with support for changing the number of
/// dimensions in-place.
///
/// This is only implemented for [DynLayout], since layouts that have a static
/// rank cannot change their dimension count at runtime.
pub trait ResizeLayout: MutLayout {
/// Insert a size-one axis at the given index in the shape. This will have
/// the same stride as the dimension that follows it.
Expand All @@ -503,7 +519,7 @@ impl ResizeLayout for DynLayout {
/// Trait for converting types into indices for use with a given layout.
///
/// Static-rank tensors can be indexed with `[usize; N]` arrays. Dynamic-rank
/// tensors can be indexed with any type that can be converted to a `&[usize]`
/// tensors can be indexed with any type that can be converted to an `&[usize]`
/// slice.
pub trait AsIndex<L: Layout> {
/// Convert `self` into an index for use the layout `L`.
Expand All @@ -528,7 +544,9 @@ impl<T, S: AsRef<[T]>, L: MutLayout> TensorBase<T, S, L> {
let layout = L::from_shape(shape);
assert!(
data.as_ref().len() == layout.len(),
"data length does not match shape"
"data length {} does not match shape {:?}",
data.as_ref().len(),
layout.shape().as_ref(),
);
TensorBase {
data,
Expand All @@ -543,7 +561,7 @@ impl<T, S: AsRef<[T]>, L: MutLayout> TensorBase<T, S, L> {
/// This will fail if the data length is incorrect for the shape and stride
/// combination, or if the strides lead to overlap (see [OverlapPolicy]).
/// See also [TensorBase::from_slice_with_strides] which is a similar method
/// for immutable views, that does allow overlapping strides.
/// for immutable views that does allow overlapping strides.
pub fn from_data_with_strides(
shape: L::Index<'_>,
data: S,
Expand Down Expand Up @@ -730,8 +748,8 @@ impl<T, S: AsRef<[T]> + AsMut<[T]>, L: MutLayout> TensorBase<T, S, L> {

/// Slice this tensor and return a static-rank view with `M` dimensions.
///
/// Use [AsView::slice_dyn] instead the number of dimensions in the returned
/// view is unknown at compile time.
/// Use [AsView::slice_dyn] instead if the number of dimensions in the
/// returned view is unknown at compile time.
///
/// Panics if the dimension count is not `M`.
pub fn slice_mut<const M: usize, R: IntoSliceItems>(
Expand Down Expand Up @@ -891,6 +909,11 @@ impl<T, L: Clone + MutLayout> TensorBase<T, Vec<T>, L> {
}

/// Make the underlying data in this tensor contiguous.
///
/// This means that after calling `make_contiguous`, the elements are
/// guaranteed to be stored in the same order as the logical order in
/// which `iter` yields elements. This method is cheap if the storage is
/// already contiguous.
pub fn make_contiguous(&mut self)
where
T: Clone,
Expand Down Expand Up @@ -1027,7 +1050,7 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<T, &'a [T], L> {
InnerIter::new(self.view())
}

/// Return the scalar value in this tensor if it has 0 dimensions.
/// Return the scalar value in this tensor if it has one element.
pub fn item(&self) -> Option<&'a T> {
match self.ndim() {
0 => Some(&self.data[0]),
Expand Down Expand Up @@ -1408,6 +1431,7 @@ impl<T, L: Clone + MutLayout> From<Vec<T>> for TensorBase<T, Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
{
/// Create a 1D tensor from a vector.
fn from(vec: Vec<T>) -> Self {
Self::from_data([vec.len()].as_index(), vec)
}
Expand All @@ -1417,6 +1441,7 @@ impl<'a, T, L: Clone + MutLayout> From<&'a [T]> for TensorBase<T, &'a [T], L>
where
[usize; 1]: AsIndex<L>,
{
/// Create a 1D view from a slice.
fn from(slice: &'a [T]) -> Self {
Self::from_data([slice.len()].as_index(), slice)
}
Expand All @@ -1426,6 +1451,7 @@ impl<'a, T, L: Clone + MutLayout, const N: usize> From<&'a [T; N]> for TensorBas
where
[usize; 1]: AsIndex<L>,
{
/// Create a 1D view from a slice of known length.
fn from(slice: &'a [T; N]) -> Self {
Self::from_data([slice.len()].as_index(), slice.as_slice())
}
Expand Down Expand Up @@ -1458,7 +1484,7 @@ impl<T, S: AsRef<[T]> + AsMut<[T]>, const N: usize> TensorBase<T, S, NdLayout<N>
/// Store an array of `M` elements into successive entries of a tensor along
/// the `dim` axis.
///
/// See [NdTensorBase::get_array] for more details.
/// See [TensorBase::get_array] for more details.
#[inline]
pub fn set_array<const M: usize>(&mut self, base: [usize; N], dim: usize, values: [T; M])
where
Expand Down Expand Up @@ -1588,7 +1614,7 @@ where
{
type Error = DimensionError;

/// Convert a dynamic-dimensional tensor or view into a static-dimensional one.
/// Convert a tensor or view with dynamic rank into a static rank one.
///
/// Fails if `value` does not have `N` dimensions.
fn try_from(value: TensorBase<T, S1, DynLayout>) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -1658,7 +1684,7 @@ where
///
/// This offers a middle-ground between regular indexing, which bounds-checks
/// each index element, and unchecked indexing, which does no bounds-checking
/// at all.
/// at all and is thus unsafe.
pub struct WeaklyCheckedView<T, S: AsRef<[T]>, L: MutLayout> {
base: TensorBase<T, S, L>,
}
Expand Down Expand Up @@ -1987,7 +2013,7 @@ mod tests {
}

#[test]
#[should_panic(expected = "data length does not match shape")]
#[should_panic(expected = "data length 4 does not match shape [2, 2, 2]")]
fn test_from_data_shape_mismatch() {
NdTensor::from_data([2, 2, 2], vec![1, 2, 3, 4]);
}
Expand Down Expand Up @@ -2219,11 +2245,15 @@ mod tests {
fn test_item() {
let tensor = NdTensor::from_data([], vec![5.]);
assert_eq!(tensor.item(), Some(&5.));
let tensor = NdTensor::from_data([1], vec![6.]);
assert_eq!(tensor.item(), Some(&6.));
let tensor = NdTensor::from_data([2], vec![2., 3.]);
assert_eq!(tensor.item(), None);

let tensor = Tensor::from_data(&[], vec![5.]);
assert_eq!(tensor.item(), Some(&5.));
let tensor = Tensor::from_data(&[1], vec![6.]);
assert_eq!(tensor.item(), Some(&6.));
let tensor = Tensor::from_data(&[2], vec![2., 3.]);
assert_eq!(tensor.item(), None);
}
Expand All @@ -2236,13 +2266,23 @@ mod tests {
tensor.iter().copied().collect::<Vec<_>>(),
&[1., 2., 3., 4.]
);
let transposed = tensor.transposed();
assert_eq!(
transposed.iter().copied().collect::<Vec<_>>(),
&[1., 3., 2., 4.]
);

let data = vec![1., 2., 3., 4.];
let tensor = Tensor::from_data(&[2, 2], data);
assert_eq!(
tensor.iter().copied().collect::<Vec<_>>(),
&[1., 2., 3., 4.]
);
let transposed = tensor.transposed();
assert_eq!(
transposed.iter().copied().collect::<Vec<_>>(),
&[1., 3., 2., 4.]
);
}

#[test]
Expand Down Expand Up @@ -2576,10 +2616,20 @@ mod tests {
fn test_to_contiguous() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);
// TODO - Try both contiguous and non-contiguous tensors.
let tensor = tensor.to_contiguous();
// TODO - Check the actual storage from to_contiguous

// Tensor is already contiguous, so this is a no-op.
let mut tensor = tensor.to_contiguous();
assert_eq!(tensor.to_vec(), &[1., 2., 3., 4.]);

// Swap strides to make tensor non-contiguous.
tensor.transpose();
assert!(!tensor.is_contiguous());
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);

// Create a new contiguous copy.
let tensor = tensor.to_contiguous();
assert!(tensor.is_contiguous());
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);
}

#[test]
Expand Down

0 comments on commit e3964f6

Please sign in to comment.