diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index a0cd19cc..30520c09 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -154,11 +154,19 @@ pub trait AsView: Layout { /// Return a view with dimensions permuted in the order given by `dims`. fn permuted( &self, - dims: Self::Index<'_>, + order: Self::Index<'_>, ) -> TensorBase { - self.view().permuted(dims) + self.view().permuted(order) } + /// Return a view with a given shape, without copying any data. This + /// requires that the tensor is contiguous. + /// + /// The new shape must have the same number of elments as the current + /// shape. The result will have a static rank if `shape` is an array or + /// a dynamic rank if it is a slice. + /// + /// Panics if the tensor is not contiguous. fn reshaped( &self, shape: S, @@ -187,10 +195,10 @@ pub trait AsView: Layout { /// Slice this tensor and return a static-rank view with `M` dimensions. /// - /// Use [AsView::slice_dyn] instead the number of dimensions in the returned - /// view is unknown at compile time. + /// Use [AsView::slice_dyn] instead if the number of dimensions in the + /// returned view is unknown at compile time. /// - /// Panics if the dimension count is not `M`. + /// Panics if the dimension count of the result is not `M`. fn slice(&self, range: R) -> NdTensorView { self.view().slice(range) } @@ -221,7 +229,12 @@ pub trait AsView: Layout { where Self::Elem: Clone; - /// Return a vector with the same shape but with strides in contiguous order. + /// Return a tensor with the same shape as this tensor/view but with the + /// data contiguous in memory and arranged in the same order as the + /// logical/iteration order (used by `iter`). + /// + /// This will return a view if the data is already contiguous or copy + /// data into a new buffer otherwise. /// /// Certain operations require or are faster with contiguous tensors. fn to_contiguous(&self) -> TensorBase, Self::Layout> @@ -239,7 +252,7 @@ pub trait AsView: Layout { where Self::Elem: Clone; - /// Return clone of this tensor which uniquely owns its elements. + /// Return a copy of this tensor/view which uniquely owns its elements. fn to_tensor(&self) -> TensorBase, Self::Layout> where Self::Elem: Clone, @@ -488,6 +501,9 @@ impl<'a> IntoLayout for &'a [usize] { /// Trait which extends [MutLayout] with support for changing the number of /// dimensions in-place. +/// +/// This is only implemented for [DynLayout], since layouts that have a static +/// rank cannot change their dimension count at runtime. pub trait ResizeLayout: MutLayout { /// Insert a size-one axis at the given index in the shape. This will have /// the same stride as the dimension that follows it. @@ -503,7 +519,7 @@ impl ResizeLayout for DynLayout { /// Trait for converting types into indices for use with a given layout. /// /// Static-rank tensors can be indexed with `[usize; N]` arrays. Dynamic-rank -/// tensors can be indexed with any type that can be converted to a `&[usize]` +/// tensors can be indexed with any type that can be converted to an `&[usize]` /// slice. pub trait AsIndex { /// Convert `self` into an index for use the layout `L`. @@ -528,7 +544,9 @@ impl, L: MutLayout> TensorBase { let layout = L::from_shape(shape); assert!( data.as_ref().len() == layout.len(), - "data length does not match shape" + "data length {} does not match shape {:?}", + data.as_ref().len(), + layout.shape().as_ref(), ); TensorBase { data, @@ -543,7 +561,7 @@ impl, L: MutLayout> TensorBase { /// This will fail if the data length is incorrect for the shape and stride /// combination, or if the strides lead to overlap (see [OverlapPolicy]). /// See also [TensorBase::from_slice_with_strides] which is a similar method - /// for immutable views, that does allow overlapping strides. + /// for immutable views that does allow overlapping strides. pub fn from_data_with_strides( shape: L::Index<'_>, data: S, @@ -730,8 +748,8 @@ impl + AsMut<[T]>, L: MutLayout> TensorBase { /// Slice this tensor and return a static-rank view with `M` dimensions. /// - /// Use [AsView::slice_dyn] instead the number of dimensions in the returned - /// view is unknown at compile time. + /// Use [AsView::slice_dyn] instead if the number of dimensions in the + /// returned view is unknown at compile time. /// /// Panics if the dimension count is not `M`. pub fn slice_mut( @@ -891,6 +909,11 @@ impl TensorBase, L> { } /// Make the underlying data in this tensor contiguous. + /// + /// This means that after calling `make_contiguous`, the elements are + /// guaranteed to be stored in the same order as the logical order in + /// which `iter` yields elements. This method is cheap if the storage is + /// already contiguous. pub fn make_contiguous(&mut self) where T: Clone, @@ -1027,7 +1050,7 @@ impl<'a, T, L: Clone + MutLayout> TensorBase { InnerIter::new(self.view()) } - /// Return the scalar value in this tensor if it has 0 dimensions. + /// Return the scalar value in this tensor if it has one element. pub fn item(&self) -> Option<&'a T> { match self.ndim() { 0 => Some(&self.data[0]), @@ -1408,6 +1431,7 @@ impl From> for TensorBase, L> where [usize; 1]: AsIndex, { + /// Create a 1D tensor from a vector. fn from(vec: Vec) -> Self { Self::from_data([vec.len()].as_index(), vec) } @@ -1417,6 +1441,7 @@ impl<'a, T, L: Clone + MutLayout> From<&'a [T]> for TensorBase where [usize; 1]: AsIndex, { + /// Create a 1D view from a slice. fn from(slice: &'a [T]) -> Self { Self::from_data([slice.len()].as_index(), slice) } @@ -1426,6 +1451,7 @@ impl<'a, T, L: Clone + MutLayout, const N: usize> From<&'a [T; N]> for TensorBas where [usize; 1]: AsIndex, { + /// Create a 1D view from a slice of known length. fn from(slice: &'a [T; N]) -> Self { Self::from_data([slice.len()].as_index(), slice.as_slice()) } @@ -1458,7 +1484,7 @@ impl + AsMut<[T]>, const N: usize> TensorBase /// Store an array of `M` elements into successive entries of a tensor along /// the `dim` axis. /// - /// See [NdTensorBase::get_array] for more details. + /// See [TensorBase::get_array] for more details. #[inline] pub fn set_array(&mut self, base: [usize; N], dim: usize, values: [T; M]) where @@ -1588,7 +1614,7 @@ where { type Error = DimensionError; - /// Convert a dynamic-dimensional tensor or view into a static-dimensional one. + /// Convert a tensor or view with dynamic rank into a static rank one. /// /// Fails if `value` does not have `N` dimensions. fn try_from(value: TensorBase) -> Result { @@ -1658,7 +1684,7 @@ where /// /// This offers a middle-ground between regular indexing, which bounds-checks /// each index element, and unchecked indexing, which does no bounds-checking -/// at all. +/// at all and is thus unsafe. pub struct WeaklyCheckedView, L: MutLayout> { base: TensorBase, } @@ -1987,7 +2013,7 @@ mod tests { } #[test] - #[should_panic(expected = "data length does not match shape")] + #[should_panic(expected = "data length 4 does not match shape [2, 2, 2]")] fn test_from_data_shape_mismatch() { NdTensor::from_data([2, 2, 2], vec![1, 2, 3, 4]); } @@ -2219,11 +2245,15 @@ mod tests { fn test_item() { let tensor = NdTensor::from_data([], vec![5.]); assert_eq!(tensor.item(), Some(&5.)); + let tensor = NdTensor::from_data([1], vec![6.]); + assert_eq!(tensor.item(), Some(&6.)); let tensor = NdTensor::from_data([2], vec![2., 3.]); assert_eq!(tensor.item(), None); let tensor = Tensor::from_data(&[], vec![5.]); assert_eq!(tensor.item(), Some(&5.)); + let tensor = Tensor::from_data(&[1], vec![6.]); + assert_eq!(tensor.item(), Some(&6.)); let tensor = Tensor::from_data(&[2], vec![2., 3.]); assert_eq!(tensor.item(), None); } @@ -2236,6 +2266,11 @@ mod tests { tensor.iter().copied().collect::>(), &[1., 2., 3., 4.] ); + let transposed = tensor.transposed(); + assert_eq!( + transposed.iter().copied().collect::>(), + &[1., 3., 2., 4.] + ); let data = vec![1., 2., 3., 4.]; let tensor = Tensor::from_data(&[2, 2], data); @@ -2243,6 +2278,11 @@ mod tests { tensor.iter().copied().collect::>(), &[1., 2., 3., 4.] ); + let transposed = tensor.transposed(); + assert_eq!( + transposed.iter().copied().collect::>(), + &[1., 3., 2., 4.] + ); } #[test] @@ -2576,10 +2616,20 @@ mod tests { fn test_to_contiguous() { let data = vec![1., 2., 3., 4.]; let tensor = NdTensor::from_data([2, 2], data); - // TODO - Try both contiguous and non-contiguous tensors. - let tensor = tensor.to_contiguous(); - // TODO - Check the actual storage from to_contiguous + + // Tensor is already contiguous, so this is a no-op. + let mut tensor = tensor.to_contiguous(); assert_eq!(tensor.to_vec(), &[1., 2., 3., 4.]); + + // Swap strides to make tensor non-contiguous. + tensor.transpose(); + assert!(!tensor.is_contiguous()); + assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]); + + // Create a new contiguous copy. + let tensor = tensor.to_contiguous(); + assert!(tensor.is_contiguous()); + assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]); } #[test]