From 7464c83d5816fc7fdfa596efa5dbfbb3ceeeb790 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 12 Jan 2024 21:52:54 +0000 Subject: [PATCH 01/12] Initial implementation of unified static/dynamic rank tensor Add a new tensor base type in `rten_tensor::unified_tensor` which is generic over the layout and can represent either a static or dynamic rank tensor. Unifying the implementation will help to avoid unintended API differences between the two. The API is largely compatible with the existing tensor types and the same set of aliases are provided (`NdTensor*` for static rank, `Tensor*` for dynamic rank). --- rten-tensor/src/lib.rs | 3 + rten-tensor/src/unified_tensor.rs | 2718 +++++++++++++++++++ rten-tensor/src/unified_tensor/iterators.rs | 247 ++ 3 files changed, 2968 insertions(+) create mode 100644 rten-tensor/src/unified_tensor.rs create mode 100644 rten-tensor/src/unified_tensor/iterators.rs diff --git a/rten-tensor/src/lib.rs b/rten-tensor/src/lib.rs index 1bb248b9..56fad6f8 100644 --- a/rten-tensor/src/lib.rs +++ b/rten-tensor/src/lib.rs @@ -44,6 +44,9 @@ mod overlap; mod range; mod tensor; +#[cfg(test)] +mod unified_tensor; + /// Trait for sources of random data for tensors, for use with [Tensor::rand]. pub trait RandomSource { /// Generate the next random value. diff --git a/rten-tensor/src/unified_tensor.rs b/rten-tensor/src/unified_tensor.rs new file mode 100644 index 00000000..6ef41f18 --- /dev/null +++ b/rten-tensor/src/unified_tensor.rs @@ -0,0 +1,2718 @@ +use std::borrow::Cow; +use std::marker::PhantomData; +use std::ops::{Index, IndexMut, Range}; + +use crate::errors::{DimensionError, FromDataError, SliceError}; +use crate::iterators::{BroadcastIter, Iter, IterMut, Lanes, LanesMut, MutViewRef, ViewRef}; +use crate::layout::{DynLayout, Layout, MatrixLayout, NdLayout, OverlapPolicy}; +use crate::{IntoSliceItems, RandomSource, SliceItem}; + +pub mod iterators; +use iterators::{AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut}; + +/// The base type for multi-dimensional arrays. This consists of storage for +/// elements, plus a _layout_ which maps from a multi-dimensional array index +/// to a storage offset. This base type is not normally used directly but +/// instead through a type alias which selects the storage type and layout. +/// +/// The storage can be owned (like a `Vec`), borrowed (like `&[T]`) or +/// mutably borrowed (like `&mut [T]`). The layout can have a dimension count +/// that is determined statically (ie. forms part of the tensor's type), see +/// [NdLayout] or is only known at runtime, see [DynLayout]. +#[derive(Debug)] +pub struct TensorBase, L: MutLayout> { + data: S, + layout: L, + element_type: PhantomData, +} + +/// Trait implemented by all variants of [TensorBase], which provides a +/// `view` method to get an immutable view of the tensor, plus methods which +/// forward to such a view. +/// +/// The purpose of this trait is to allow methods to be specialized for +/// immutable views by preserving the lifetime of the underlying data in +/// return types (eg. `iter` returns `&[T]` in the trait, but `&'a [T]` in +/// the view). This allows for chaining operations on views together (eg. +/// `tensor.slice(...).transpose()`) without needing to separate each step +/// into separate statements. +/// +/// This trait is conceptually similar to the way [std::ops::Deref] in the Rust +/// standard library allows a `Vec` to have all the methods of an `&[T]`. +/// +/// If stable Rust gains support for specialization or a `Deref` trait that can +/// return non-references (see. https://github.com/rust-lang/rfcs/issues/997) +/// this will become unnecessary. +pub trait AsView: Layout { + /// Type of element stored in this tensor. + type Elem; + + /// The underlying layout of this tensor. It must have the same index + /// type (eg. `[usize; N]` or `&[usize]`) as this view. + type Layout: for<'a> MutLayout = Self::Index<'a>>; + + /// Return a borrowed view of this tensor. + fn view(&self) -> TensorBase; + + /// Return the layout of this tensor. + fn layout(&self) -> &Self::Layout; + + /// Return a view of this tensor with a dynamic rank. + fn as_dyn(&self) -> TensorBase + where + Self::Layout: Into, + { + self.view().as_dyn() + } + + /// Return an iterator over slices of this tensor along a given axis. + fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks { + self.view().axis_chunks(dim, chunk_size) + } + + /// Return an iterator over slices of this tensor along a given axis. + fn axis_iter(&self, dim: usize) -> AxisIter { + self.view().axis_iter(dim) + } + + /// Broadcast this view to another shape. + /// + /// If `shape` is an array (`[usize; N]`), the result will have a + /// static-rank layout with `N` dims. If `shape` is a slice, the result will + /// have a dynamic-rank layout. + fn broadcast(&self, shape: S) -> TensorBase + where + Self::Layout: BroadcastLayout, + { + self.view().broadcast(shape) + } + + /// Return an iterator over elements of this tensor, broadcast to `shape`. + /// + /// This is equivalent to `self.broadcast(shape).iter()` but has some + /// additional optimizations. + fn broadcast_iter(&self, shape: &[usize]) -> BroadcastIter { + self.view().broadcast_iter(shape) + } + + /// Return the layout of this tensor as a slice, if it is contiguous. + fn data(&self) -> Option<&[Self::Elem]>; + + /// Return a reference to the element at a given index, or `None` if the + /// index is invalid. + fn get>(&self, index: I) -> Option<&Self::Elem>; + + /// Return an iterator over the innermost N dimensions. + fn inner_iter(&self) -> InnerIter { + self.view().inner_iter() + } + + /// Insert a size-1 axis at the given index. + fn insert_axis(&mut self, index: usize) + where + Self::Layout: ResizeLayout; + + /// Return the scalar value in this tensor if it has 0 dimensions. + fn item(&self) -> Option<&Self::Elem> { + self.view().item() + } + + /// Return an iterator over elements in this tensor in their logical order. + fn iter(&self) -> Iter; + + /// Return an iterator over 1D slices of this tensor along a given axis. + fn lanes(&self, dim: usize) -> Lanes { + self.view().lanes(dim) + } + + /// Return a new tensor with the same shape, formed by applying `f` to each + /// element in this tensor. + fn map(&self, f: F) -> TensorBase, Self::Layout> + where + F: Fn(&Self::Elem) -> U, + { + self.view().map(f) + } + + /// Re-order the axes of this tensor to move the axis at index `from` to + /// `to`. + /// + /// Panics if `from` or `to` is >= `self.ndim()`. + fn move_axis(&mut self, from: usize, to: usize); + + /// Convert this tensor to one with the same shape but a static dimension + /// count. + /// + /// Panics if `self.ndim() != N`. + fn nd_view(&self) -> TensorBase> { + self.view().nd_view() + } + + /// Permute the dimensions of this tensor. + fn permute(&mut self, order: Self::Index<'_>); + + /// Return a view with dimensions permuted in the order given by `dims`. + fn permuted( + &self, + dims: Self::Index<'_>, + ) -> TensorBase { + self.view().permuted(dims) + } + + fn reshaped( + &self, + shape: S, + ) -> TensorBase { + self.view().reshaped(shape) + } + + /// Reverse the order of dimensions in this tensor. + fn transpose(&mut self); + + /// Return a view with the order of dimensions reversed. + fn transposed(&self) -> TensorBase { + self.view().transposed() + } + + /// Slice this tensor and return a dynamic-rank view. + /// + /// Fails if the range has more dimensions than the view or is out of bounds + /// for any dimension. + fn try_slice(&self, range: R) -> Result, SliceError> { + self.view().try_slice(range) + } + + /// Slice this tensor and return a static-rank view with `M` dimensions. + /// + /// Use [AsView::slice_dyn] instead the number of dimensions in the returned + /// view is unknown at compile time. + /// + /// Panics if the dimension count is not `M`. + fn slice(&self, range: R) -> NdTensorView { + self.view().slice(range) + } + + /// Slice this tensor and return a dynamic-rank view. + fn slice_dyn(&self, range: R) -> TensorView { + self.view().slice_dyn(range) + } + + /// Return an iterator over a slice of this tensor. + /// + /// This is similar to `self.slice(range).iter()` except that it + /// returns an iterator directly instead of creating an intermediate view. + /// Also slicing with this method is more flexible as negative steps are + /// supported for items in `range`. + fn slice_iter(&self, range: &[SliceItem]) -> Iter { + self.view().slice_iter(range) + } + + /// Return a view of this tensor with all dimensions of size 1 removed. + fn squeezed(&self) -> TensorView { + self.view().squeezed() + } + + /// Return a vector containing the elements of this tensor in their logical + /// order, ie. as if the tensor were flattened into one dimension. + fn to_vec(&self) -> Vec + where + Self::Elem: Clone; + + /// Return a vector with the same shape but with strides in contiguous order. + /// + /// Certain operations require or are faster with contiguous tensors. + fn to_contiguous(&self) -> TensorBase, Self::Layout> + where + Self::Elem: Clone, + { + self.view().to_contiguous() + } + + /// Return a copy of this tensor with a given shape. + fn to_shape( + &self, + shape: S, + ) -> TensorBase, S::Layout> + where + Self::Elem: Clone; + + /// Return clone of this tensor which uniquely owns its elements. + fn to_tensor(&self) -> TensorBase, Self::Layout> + where + Self::Elem: Clone, + { + let data = self.to_vec(); + TensorBase::from_data(self.layout().shape(), data) + } + + /// Return a view which performs "weak" checking when indexing via + /// `view[]`. See [WeaklyCheckedView] for an explanation. + fn weakly_checked_view(&self) -> WeaklyCheckedView { + self.view().weakly_checked_view() + } +} + +/// MutLayout extends [Layout] with methods for creating, modifying and +/// transforming layouts. +pub trait MutLayout: Layout + Clone { + /// Create a new contiguous layout with a given shape. + fn from_shape(shape: Self::Index<'_>) -> Self; + + /// Create a layout with custom strides. + /// + /// The strides specify the offset gap between successive entries along a + /// given axis. `overlap` controls whether the layout is allowed to map + /// multiple indices to the same element. This can be true for immutable + /// views, but must be false for tensors or views that are mutable. + fn from_shape_and_strides( + shape: Self::Index<'_>, + strides: Self::Index<'_>, + overlap: OverlapPolicy, + ) -> Result; + + /// Move the axis at position `from` to `to` by swapping their strides. + fn move_axis(&mut self, from: usize, to: usize); + + /// Return a layout with the axes permuted according to the given order. + fn permuted(&self, order: Self::Index<'_>) -> Self; + + /// Combine or split dimensions by reshaping the layout to a given shape. + /// + /// This will fail if the layout is not contiguous. + fn reshaped(&self, shape: S) -> S::Layout { + assert!( + self.is_contiguous(), + "tried to reshape non-contiguous layout" + ); + shape.into_layout() + } + + // Modify the size of a dimension. This does not alter the strides. + fn resize_dim(&mut self, dim: usize, size: usize); + + /// Reverse the order of dimensions. This is equivalent to + /// `self.permuted([N-1, N-2, ... 0])`. + fn transposed(&self) -> Self; + + /// Slice the layout and return a static rank layout with `M` dimensions. + fn slice(&self, range: &[SliceItem]) -> (Range, NdLayout); + + /// Slice the layout and return a dynamic rank layout. + fn slice_dyn(&self, range: &[SliceItem]) -> (Range, DynLayout); + + /// Return a layout with all size-one dimensions removed. + fn squeezed(&self) -> DynLayout; + + /// Attempt to slice the layout or return an error if the range is invalid + /// for the layout's shape. + fn try_slice( + &self, + range: R, + ) -> Result<(Range, DynLayout), SliceError>; +} + +/// Trait for broadcasting a layout from one shape to another. +pub trait BroadcastLayout { + /// Broadcast the `self` layout to a given shape. + fn broadcast>(&self, shape: S) -> L; +} + +impl BroadcastLayout> for NdLayout { + fn broadcast>>(&self, shape: S) -> NdLayout { + let shape: [usize; M] = shape.as_ref().try_into().unwrap(); + self.broadcast(shape) + } +} + +impl BroadcastLayout for NdLayout { + fn broadcast>(&self, shape: S) -> DynLayout { + let dyn_layout: DynLayout = self.into(); + dyn_layout.broadcast(shape.as_ref()) + } +} + +impl BroadcastLayout for DynLayout { + fn broadcast>(&self, shape: S) -> DynLayout { + self.broadcast(shape.as_ref()) + } +} + +impl BroadcastLayout> for DynLayout { + fn broadcast>>(&self, shape: S) -> NdLayout { + let dyn_broadcast = self.broadcast(shape.as_ref()); + (&dyn_broadcast).try_into().unwrap() + } +} + +impl MutLayout for NdLayout { + fn from_shape(shape: [usize; N]) -> Self { + Self::from_shape(shape) + } + + fn from_shape_and_strides( + shape: Self::Index<'_>, + strides: Self::Index<'_>, + overlap: OverlapPolicy, + ) -> Result { + Self::try_from_shape_and_strides(shape, strides, overlap) + } + + fn move_axis(&mut self, from: usize, to: usize) { + assert!(from < N && to < N); + let mut dyn_layout = self.as_dyn(); + dyn_layout.move_axis(from, to); + *self = NdLayout::try_from(&dyn_layout).unwrap(); + } + + fn permuted(&self, order: [usize; N]) -> NdLayout { + self.permuted(order) + } + + fn resize_dim(&mut self, dim: usize, size: usize) { + self.resize_dim(dim, size) + } + + fn transposed(&self) -> NdLayout { + self.transposed() + } + + fn slice(&self, range: &[SliceItem]) -> (Range, NdLayout) { + self.slice(range) + } + + fn slice_dyn(&self, range: &[SliceItem]) -> (Range, DynLayout) { + self.as_dyn().slice(range) + } + + fn squeezed(&self) -> DynLayout { + self.as_dyn().squeezed() + } + + fn try_slice( + &self, + range: R, + ) -> Result<(Range, DynLayout), SliceError> { + let items = range.into_slice_items(); + self.as_dyn().try_slice(items.as_ref()) + } +} + +impl MutLayout for DynLayout { + fn from_shape(shape: &[usize]) -> Self { + Self::from_shape(shape) + } + + fn from_shape_and_strides( + shape: &[usize], + strides: &[usize], + overlap: OverlapPolicy, + ) -> Result { + Self::try_from_shape_and_strides(shape, strides, overlap) + } + + fn move_axis(&mut self, from: usize, to: usize) { + self.move_axis(from, to) + } + + fn permuted(&self, order: &[usize]) -> DynLayout { + self.permuted(order) + } + + fn resize_dim(&mut self, dim: usize, size: usize) { + self.resize_dim(dim, size) + } + + fn transposed(&self) -> DynLayout { + self.transposed() + } + + fn slice(&self, range: &[SliceItem]) -> (Range, NdLayout) { + let (offset_range, dyn_layout) = self.slice(range); + let nd_layout = NdLayout::try_from(&dyn_layout).unwrap_or_else(|_| { + panic!( + "expected sliced tensor to have {} dims but it has {}", + M, + dyn_layout.ndim() + ); + }); + (offset_range, nd_layout) + } + + fn slice_dyn(&self, range: &[SliceItem]) -> (Range, DynLayout) { + self.slice(range) + } + + fn squeezed(&self) -> DynLayout { + self.squeezed() + } + + fn try_slice( + &self, + range: R, + ) -> Result<(Range, DynLayout), SliceError> { + let items = range.into_slice_items(); + self.try_slice(items.as_ref()) + } +} + +/// Trait for shapes which can be used to create a contiguous layout. +/// +/// This is implemented for `[usize; N]` for creating static-rank layouts from +/// arrays, and `&[usize]` for creating dynamic-rank layouts from slices. +pub trait IntoLayout: AsRef<[usize]> { + /// The type of layout produced from this shape. + type Layout: MutLayout; + + /// Convert this shape into a contiguous layout. + fn into_layout(self) -> Self::Layout; +} + +impl IntoLayout for [usize; N] { + type Layout = NdLayout; + + fn into_layout(self) -> NdLayout { + NdLayout::from_shape(self) + } +} + +impl<'a> IntoLayout for &'a [usize] { + type Layout = DynLayout; + + fn into_layout(self) -> DynLayout { + DynLayout::from_shape(self) + } +} + +/// Trait which extends [MutLayout] with support for changing the number of +/// dimensions in-place. +pub trait ResizeLayout: MutLayout { + /// Insert a size-one axis at the given index in the shape. This will have + /// the same stride as the dimension that follows it. + fn insert_axis(&mut self, index: usize); +} + +impl ResizeLayout for DynLayout { + fn insert_axis(&mut self, index: usize) { + self.insert_dim(index) + } +} + +/// Trait for converting types into indices for use with a given layout. +/// +/// Static-rank tensors can be indexed with `[usize; N]` arrays. Dynamic-rank +/// tensors can be indexed with any type that can be converted to a `&[usize]` +/// slice. +pub trait AsIndex { + /// Convert `self` into an index for use the layout `L`. + fn as_index(&self) -> L::Index<'_>; +} + +impl> AsIndex for T { + fn as_index(&self) -> &[usize] { + self.as_ref() + } +} + +impl AsIndex> for [usize; N] { + fn as_index(&self) -> [usize; N] { + *self + } +} + +impl, L: MutLayout> TensorBase { + /// Construct a new tensor from a given shape and storage. + pub fn from_data(shape: L::Index<'_>, data: S) -> TensorBase { + let layout = L::from_shape(shape); + assert!( + data.as_ref().len() == layout.len(), + "data length does not match shape" + ); + TensorBase { + data, + layout, + element_type: PhantomData, + } + } + + /// Construct a new tensor from a given shape and storage, and custom + /// strides. + /// + /// This will fail if the data length is incorrect for the shape and stride + /// combination, or if the strides lead to overlap (see [OverlapPolicy]). + /// See also [TensorBase::from_slice_with_strides] which is a similar method + /// for immutable views, that does allow overlapping strides. + pub fn from_data_with_strides( + shape: L::Index<'_>, + data: S, + strides: L::Index<'_>, + ) -> Result, FromDataError> { + let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::DisallowOverlap)?; + if layout.min_data_len() > data.as_ref().len() { + return Err(FromDataError::StorageTooShort); + } + Ok(TensorBase { + data, + layout, + element_type: PhantomData, + }) + } + + /// Convert the current tensor into a dynamic rank tensor without copying + /// any data. + pub fn into_dyn(self) -> TensorBase + where + L: Into, + { + TensorBase { + data: self.data, + layout: self.layout.into(), + element_type: PhantomData, + } + } + + /// Attempt to convert this tensor's layout to a static-rank layout with `N` + /// dimensions. + fn nd_layout(&self) -> Option> { + if self.ndim() != N { + return None; + } + let shape: [usize; N] = std::array::from_fn(|i| self.size(i)); + let strides: [usize; N] = std::array::from_fn(|i| self.stride(i)); + let layout = + NdLayout::try_from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap) + .expect("invalid layout"); + Some(layout) + } +} + +impl + AsMut<[T]>, L: MutLayout> TensorBase { + /// Return an iterator over mutable slices of this tensor along a given + /// axis. Each view yielded has one dimension fewer than the current layout. + pub fn axis_iter_mut(&mut self, dim: usize) -> AxisIterMut { + AxisIterMut::new(self.view_mut(), dim) + } + + /// Return an iterator over mutable slices of this tensor along a given + /// axis. Each view yielded has the same rank as this tensor, but the + /// dimension `dim` will only have `chunk_size` entries. + pub fn axis_chunks_mut(&mut self, dim: usize, chunk_size: usize) -> AxisChunksMut { + AxisChunksMut::new(self.view_mut(), dim, chunk_size) + } + + /// Replace each element in this tensor with the result of applying `f` to + /// the element. + pub fn apply T>(&mut self, f: F) { + if self.is_contiguous() { + self.data.as_mut().iter_mut().for_each(|x| *x = f(x)); + } else { + self.iter_mut().for_each(|x| *x = f(x)); + } + } + + /// Return a mutable view of this tensor with a dynamic dimension count. + pub fn as_dyn_mut(&mut self) -> TensorBase + where + L: Clone + Into, + { + TensorBase { + layout: self.layout.clone().into(), + data: self.data.as_mut(), + element_type: PhantomData, + } + } + + /// Copy elements from another tensor into this tensor. + /// + /// This tensor and `other` must have the same shape. + pub fn copy_from>(&mut self, other: &TensorBase) + where + T: Clone, + L: Clone, + { + assert!(self.shape() == other.shape()); + for (out, x) in self.iter_mut().zip(other.iter()) { + *out = x.clone(); + } + } + + /// Return the data in this tensor as a slice if it is contiguous. + pub fn data_mut(&mut self) -> Option<&mut [T]> { + self.layout.is_contiguous().then_some(self.data.as_mut()) + } + + /// Replace all elements of this tensor with `value`. + pub fn fill(&mut self, value: T) + where + T: Clone, + { + self.apply(|_| value.clone()) + } + + /// Return a mutable reference to the element at `index`, or `None` if the + /// index is invalid. + pub fn get_mut>(&mut self, index: I) -> Option<&mut T> { + self.try_offset(index.as_index()) + .map(|offset| &mut self.data.as_mut()[offset]) + } + + /// Return the element at a given index, without performing any bounds- + /// checking. + /// + /// # Safety + /// + /// The caller must ensure that the index is valid for the tensor's shape. + pub unsafe fn get_unchecked_mut>(&mut self, index: I) -> &mut T { + self.data + .as_mut() + .get_unchecked_mut(self.layout.offset_unchecked(index.as_index())) + } + + pub(crate) fn mut_view_ref(&mut self) -> MutViewRef { + MutViewRef::new(self.data.as_mut(), &self.layout) + } + + /// Return a mutable iterator over the N innermost dimensions of this tensor. + pub fn inner_iter_mut(&mut self) -> InnerIterMut { + InnerIterMut::new(self.view_mut()) + } + + /// Return a mutable iterator over the elements of this tensor, in their + /// logical order. + pub fn iter_mut(&mut self) -> IterMut { + IterMut::new(self.mut_view_ref()) + } + + /// Return an iterator over mutable 1D slices of this tensor along a given + /// dimension. + pub fn lanes_mut(&mut self, dim: usize) -> LanesMut { + LanesMut::new(self.mut_view_ref(), dim) + } + + /// Return a view of this tensor with a static dimension count. + /// + /// Panics if `self.ndim() != N`. + pub fn nd_view_mut(&mut self) -> TensorBase> { + assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N); + TensorBase { + layout: self.nd_layout().unwrap(), + data: self.data.as_mut(), + element_type: PhantomData, + } + } + + /// Permute the order of dimensions according to the given order. + /// + /// See [AsView::permuted]. + pub fn permuted_mut(&mut self, order: L::Index<'_>) -> TensorBase { + TensorBase { + layout: self.layout.permuted(order), + data: self.data.as_mut(), + element_type: PhantomData, + } + } + + /// Change the layout of the tensor without moving any data. + /// + /// See [AsView::reshaped]. + pub fn reshaped_mut( + &mut self, + shape: SH, + ) -> TensorBase { + TensorBase { + layout: self.layout.reshaped(shape), + data: self.data.as_mut(), + element_type: PhantomData, + } + } + + /// Slice this tensor and return a static-rank view with `M` dimensions. + /// + /// Use [AsView::slice_dyn] instead the number of dimensions in the returned + /// view is unknown at compile time. + /// + /// Panics if the dimension count is not `M`. + pub fn slice_mut( + &mut self, + range: R, + ) -> NdTensorViewMut { + let range = range.into_slice_items(); + let (offset_range, sliced_layout) = self.layout.slice(range.as_ref()); + NdTensorViewMut { + data: &mut self.data.as_mut()[offset_range], + layout: sliced_layout, + element_type: PhantomData, + } + } + + /// Slice this tensor and return a dynamic-rank view. + pub fn slice_mut_dyn(&mut self, range: R) -> TensorViewMut { + let range = range.into_slice_items(); + let (offset_range, sliced_layout) = self.layout.slice_dyn(range.as_ref()); + TensorViewMut { + data: &mut self.data.as_mut()[offset_range], + layout: sliced_layout, + element_type: PhantomData, + } + } + + /// Slice this tensor and return a dynamic-rank view. + /// + /// Fails if the range has more dimensions than the view or is out of bounds + /// for any dimension. + pub fn try_slice_mut( + &mut self, + range: R, + ) -> Result, SliceError> { + let (offset_range, layout) = self.layout.try_slice(range)?; + Ok(TensorBase { + data: &mut self.data.as_mut()[offset_range], + layout, + element_type: PhantomData, + }) + } + + /// Return a mutable view of this tensor. + pub fn view_mut(&mut self) -> TensorBase + where + L: Clone, + { + TensorBase { + data: self.data.as_mut(), + layout: self.layout.clone(), + element_type: PhantomData, + } + } + + /// Return a mutable view that performs only "weak" checking when indexing, + /// this is faster but can hide bugs. See [WeaklyCheckedView]. + pub fn weakly_checked_view_mut(&mut self) -> WeaklyCheckedView { + WeaklyCheckedView { + base: self.view_mut(), + } + } +} + +impl TensorBase, L> { + /// Create a new 1D tensor filled with an arithmetic sequence of values + /// in the range `[start, end)` separated by `step`. If `step` is omitted, + /// it defaults to 1. + pub fn arange(start: T, end: T, step: Option) -> TensorBase, L> + where + T: Copy + PartialOrd + From + std::ops::Add, + [usize; 1]: AsIndex, + { + let step = step.unwrap_or((true).into()); + let mut data = Vec::new(); + let mut curr = start; + while curr < end { + data.push(curr); + curr = curr + step; + } + TensorBase::from_data([data.len()].as_index(), data) + } + + /// Create a new 1D tensor from a `Vec`. + pub fn from_vec(vec: Vec) -> TensorBase, L> + where + [usize; 1]: AsIndex, + { + TensorBase::from_data([vec.len()].as_index(), vec) + } + + /// Clip dimension `dim` to `[range.start, range.end)`. The new size for + /// the dimension must be <= the old size. + /// + /// This currently requires `T: Copy` to support efficiently moving data + /// from the new start offset to the beginning of the element buffer. + pub fn clip_dim(&mut self, dim: usize, range: Range) + where + T: Copy, + { + let (start, end) = (range.start, range.end); + + assert!(start <= end, "start must be <= end"); + assert!(end <= self.size(dim), "end must be <= dim size"); + + let start_offset = self.layout.stride(dim) * start; + self.layout.resize_dim(dim, end - start); + + let range = start_offset..start_offset + self.layout.min_data_len(); + self.data.copy_within(range.clone(), 0); + self.data.truncate(range.end - range.start); + } + + /// Consume self and return the underlying data as a contiguous tensor. + /// + /// See also [TensorBase::to_vec]. + pub fn into_data(self) -> Vec + where + T: Clone, + { + if self.is_contiguous() { + self.data + } else { + self.to_vec() + } + } + + /// Consume self and return a new contiguous tensor with the given shape. + /// + /// This avoids copying the data if it is already contiguous. + pub fn into_shape(self, shape: S) -> TensorBase, S::Layout> + where + T: Clone, + { + TensorBase { + data: self.into_data(), + layout: shape.into_layout(), + element_type: PhantomData, + } + } + + /// Create a new 0D tensor from a scalar value. + pub fn from_scalar(value: T) -> TensorBase, L> + where + [usize; 0]: AsIndex, + { + TensorBase::from_data([].as_index(), vec![value]) + } + + /// Create a new tensor with a given shape and all elements set to `value`. + pub fn full(shape: L::Index<'_>, value: T) -> TensorBase, L> + where + T: Clone, + { + let n_elts = shape.as_ref().iter().product(); + let data = vec![value; n_elts]; + TensorBase::from_data(shape, data) + } + + /// Make the underlying data in this tensor contiguous. + pub fn make_contiguous(&mut self) + where + T: Clone, + { + if self.is_contiguous() { + return; + } + self.data = self.to_vec(); + self.layout = L::from_shape(self.layout.shape()); + } + + /// Create a new tensor with a given shape and elements populated using + /// numbers generated by `rand_src`. + pub fn rand>( + shape: L::Index<'_>, + rand_src: &mut R, + ) -> TensorBase, L> { + let data: Vec<_> = std::iter::from_fn(|| Some(rand_src.next())) + .take(shape.as_ref().iter().product()) + .collect(); + TensorBase::from_data(shape, data) + } + + /// Create a new tensor with a given shape, with all elements set to their + /// default value (ie. zero for numeric types). + pub fn zeros(shape: L::Index<'_>) -> TensorBase, L> + where + T: Clone + Default, + { + Self::full(shape, T::default()) + } +} + +impl<'a, T, L: Clone + MutLayout> TensorBase { + pub fn axis_iter(&self, dim: usize) -> AxisIter<'a, T, L> { + AxisIter::new(self, dim) + } + + pub fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'a, T, L> { + AxisChunks::new(self, dim, chunk_size) + } + + /// Return a view of this tensor with a dynamic dimension count. + /// + /// See [AsView::as_dyn]. + pub fn as_dyn(&self) -> TensorBase + where + L: Clone + Into, + { + TensorBase { + data: self.data, + layout: self.layout.clone().into(), + element_type: PhantomData, + } + } + + /// Broadcast this view to another shape. + /// + /// See [AsView::broadcast]. + pub fn broadcast(&self, shape: S) -> TensorBase + where + L: BroadcastLayout, + { + TensorBase { + layout: self.layout.broadcast(shape), + data: self.data, + element_type: PhantomData, + } + } + + /// Return an iterator over elements as if this tensor was broadcast to + /// another shape. + /// + /// See [AsView::broadcast_iter]. + pub fn broadcast_iter(&self, shape: &[usize]) -> BroadcastIter<'a, T> { + BroadcastIter::new(self.view_ref(), shape) + } + + /// Return the data in this tensor as a slice if it is contiguous, ie. + /// the order of elements in the slice is the same as the logical order + /// yielded by `iter`, and there are no gaps. + pub fn data(&self) -> Option<&'a [T]> { + self.layout.is_contiguous().then_some(self.data) + } + + /// Return this view's underlying data as a slice. + /// + /// Unlike the `data` method, this method does not check if the storage + /// is contiguous in memory (ie. elements are stored in the same order as + /// returned by `iter`, with no gaps). + /// + /// Note there is no safe equivalent of this method for mutable views + /// because this could lead to overlapping mutable slices. + pub fn non_contiguous_data(&self) -> &'a [T] { + self.data + } + + /// Create a new view with a given shape and data slice, and custom strides. + /// + /// If you do not need to specify custom strides, use [TensorBase::from_data] + /// instead. This method is similar to [TensorBase::from_data_with_strides], + /// but allows strides that lead to internal overlap (see [OverlapPolicy]). + pub fn from_slice_with_strides( + shape: L::Index<'_>, + data: &'a [T], + strides: L::Index<'_>, + ) -> Result, FromDataError> { + let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)?; + if layout.min_data_len() > data.as_ref().len() { + return Err(FromDataError::StorageTooShort); + } + Ok(TensorBase { + data, + layout, + element_type: PhantomData, + }) + } + + /// Return the element at a given index, without performing any bounds- + /// checking. + /// + /// # Safety + /// + /// The caller must ensure that the index is valid for the tensor's shape. + pub unsafe fn get_unchecked>(&self, index: I) -> &'a T { + self.data + .get_unchecked(self.layout.offset_unchecked(index.as_index())) + } + + /// Return an iterator over the inner `N` dimensions of this tensor. + /// + /// See [AsView::inner_iter]. + pub fn inner_iter(&self) -> InnerIter<'a, T, L, N> { + InnerIter::new(self.view()) + } + + /// Return the scalar value in this tensor if it has 0 dimensions. + pub fn item(&self) -> Option<&'a T> { + match self.ndim() { + 0 => Some(&self.data[0]), + _ if self.len() == 1 => self.iter().next(), + _ => None, + } + } + + /// Return an iterator over elements of this tensor in their logical order. + /// + /// See [AsView::iter]. + pub fn iter(&self) -> Iter<'a, T> { + Iter::new(self.view_ref()) + } + + /// Return an iterator over 1D slices of this tensor along a given dimension. + /// + /// See [AsView::lanes]. + pub fn lanes(&self, dim: usize) -> Lanes<'a, T> { + Lanes::new(self.view_ref(), dim) + } + + /// Return a view of this tensor with a static dimension count. + /// + /// Panics if `self.ndim() != N`. + pub fn nd_view(&self) -> TensorBase> { + assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N); + TensorBase { + data: self.data, + layout: self.nd_layout().unwrap(), + element_type: PhantomData, + } + } + + /// Permute the axes of this tensor according to `order`. + /// + /// See [AsView::permuted]. + pub fn permuted(&self, order: L::Index<'_>) -> TensorBase { + TensorBase { + data: self.data, + layout: self.layout.permuted(order), + element_type: PhantomData, + } + } + + /// Change the shape of this tensor without copying data. + /// + /// See [AsView::reshaped]. + pub fn reshaped(&self, shape: S) -> TensorBase { + TensorBase { + data: self.data, + layout: self.layout.reshaped(shape), + element_type: PhantomData, + } + } + + /// Slice this tensor and return a static-rank view. See [AsView::slice]. + pub fn slice(&self, range: R) -> NdTensorView<'a, T, M> { + let range = range.into_slice_items(); + let (offset_range, sliced_layout) = self.layout.slice(range.as_ref()); + NdTensorView { + data: &self.data[offset_range], + layout: sliced_layout, + element_type: PhantomData, + } + } + + /// Slice this tensor and return a dynamic-rank view. See [AsView::slice_dyn]. + pub fn slice_dyn(&self, range: R) -> TensorView<'a, T> { + let range = range.into_slice_items(); + let (offset_range, sliced_layout) = self.layout.slice_dyn(range.as_ref()); + TensorView { + data: &self.data[offset_range], + layout: sliced_layout, + element_type: PhantomData, + } + } + + /// See [AsView::slice_iter]. + pub fn slice_iter(&self, range: &[SliceItem]) -> Iter<'a, T> { + Iter::slice(self.view_ref(), range) + } + + /// Remove all size-one dimensions from this tensor. + /// + /// See [AsView::squeezed]. + pub fn squeezed(&self) -> TensorView<'a, T> { + TensorBase { + data: self.data, + layout: self.layout.squeezed(), + element_type: PhantomData, + } + } + + /// Return a view of this tensor with elements stored in contiguous order. + /// + /// If the data is already contiguous, no copy is made, otherwise the + /// elements are copied into a new buffer in contiguous order. + pub fn to_contiguous(&self) -> TensorBase, L> + where + T: Clone, + { + if self.is_contiguous() { + TensorBase { + data: Cow::Borrowed(self.data), + layout: self.layout.clone(), + element_type: PhantomData, + } + } else { + let data = self.to_vec(); + TensorBase { + data: Cow::Owned(data), + layout: L::from_shape(self.layout.shape()), + element_type: PhantomData, + } + } + } + + /// Reverse the order of dimensions in this tensor. See [AsView::transposed]. + pub fn transposed(&self) -> TensorBase { + TensorBase { + data: self.data, + layout: self.layout.transposed(), + element_type: PhantomData, + } + } + + pub fn try_slice(&self, range: R) -> Result, SliceError> { + let (offset_range, layout) = self.layout.try_slice(range)?; + Ok(TensorBase { + data: &self.data[offset_range], + layout, + element_type: PhantomData, + }) + } + + /// Return a read-only view of this tensor. See [AsView::view]. + pub fn view(&self) -> TensorBase { + TensorBase { + data: self.data, + layout: self.layout.clone(), + element_type: PhantomData, + } + } + + pub(crate) fn view_ref(&self) -> ViewRef<'a, '_, T, L> { + ViewRef::new(self.data, &self.layout) + } + + pub fn weakly_checked_view(&self) -> WeaklyCheckedView { + WeaklyCheckedView { base: self.view() } + } +} + +impl, L: MutLayout> Layout for TensorBase { + type Index<'a> = L::Index<'a>; + type Indices = L::Indices; + + fn ndim(&self) -> usize { + self.layout.ndim() + } + + fn len(&self) -> usize { + self.layout.len() + } + + fn is_empty(&self) -> bool { + self.layout.is_empty() + } + + fn shape(&self) -> Self::Index<'_> { + self.layout.shape() + } + + fn size(&self, dim: usize) -> usize { + self.layout.size(dim) + } + + fn strides(&self) -> Self::Index<'_> { + self.layout.strides() + } + + fn stride(&self, dim: usize) -> usize { + self.layout.stride(dim) + } + + fn indices(&self) -> Self::Indices { + self.layout.indices() + } + + fn try_offset(&self, index: Self::Index<'_>) -> Option { + self.layout.try_offset(index) + } +} + +impl, L: MutLayout + MatrixLayout> MatrixLayout for TensorBase { + fn rows(&self) -> usize { + self.layout.rows() + } + + fn cols(&self) -> usize { + self.layout.cols() + } + + fn row_stride(&self) -> usize { + self.layout.row_stride() + } + + fn col_stride(&self) -> usize { + self.layout.col_stride() + } +} + +impl, L: MutLayout + Clone> AsView for TensorBase { + type Elem = T; + type Layout = L; + + fn iter(&self) -> Iter { + self.view().iter() + } + + fn data(&self) -> Option<&[Self::Elem]> { + self.view().data() + } + + fn insert_axis(&mut self, index: usize) + where + L: ResizeLayout, + { + self.layout.insert_axis(index) + } + + fn layout(&self) -> &L { + &self.layout + } + + fn map(&self, f: F) -> TensorBase, L> + where + F: Fn(&Self::Elem) -> U, + { + let data: Vec<_> = self.iter().map(f).collect(); + TensorBase::from_data(self.shape(), data) + } + + fn move_axis(&mut self, from: usize, to: usize) { + self.layout.move_axis(from, to); + } + + fn view(&self) -> TensorBase { + TensorBase { + data: self.data.as_ref(), + layout: self.layout.clone(), + element_type: PhantomData, + } + } + + fn get>(&self, index: I) -> Option<&Self::Elem> { + self.try_offset(index.as_index()) + .map(|offset| &self.data.as_ref()[offset]) + } + + fn permute(&mut self, order: Self::Index<'_>) { + self.layout = self.layout.permuted(order); + } + + fn to_vec(&self) -> Vec + where + T: Clone, + { + if let Some(data) = self.data() { + data.to_vec() + } else { + // TODO - Add fast path for low rank that doesn't use iterators. + self.view().iter().cloned().collect() + } + } + + fn to_shape( + &self, + shape: SH, + ) -> TensorBase, SH::Layout> + where + T: Clone, + { + TensorBase { + data: self.to_vec(), + layout: shape.into_layout(), + element_type: PhantomData, + } + } + + fn transpose(&mut self) { + self.layout = self.layout.transposed(); + } +} + +impl, const N: usize> TensorBase> { + /// Load an array of `M` elements from successive entries of a tensor along + /// the `dim` axis. + /// + /// eg. If `base` is `[0, 1, 2]`, dim=0 and `M` = 4 this will return an + /// array with values from indices `[0, 1, 2]`, `[1, 1, 2]` ... `[3, 1, 2]`. + /// + /// Panics if any of the array indices are out of bounds. + #[inline] + pub fn get_array(&self, base: [usize; N], dim: usize) -> [T; M] + where + T: Copy + Default, + { + let offsets: [usize; M] = array_offsets(&self.layout, base, dim); + let data = self.data.as_ref(); + let mut result = [T::default(); M]; + for i in 0..M { + // Safety: `array_offsets` returns valid offsets + result[i] = unsafe { *data.get_unchecked(offsets[i]) }; + } + result + } +} + +impl TensorBase, DynLayout> { + /// Reshape this tensor in place. This is cheap if the tensor is contiguous, + /// as only the layout will be changed, but requires copying data otherwise. + pub fn reshape(&mut self, shape: &[usize]) + where + T: Clone, + { + if !self.is_contiguous() { + self.data = self.to_vec(); + } + self.layout = DynLayout::from_shape(shape); + } +} + +impl<'a, T> TensorBase { + /// Reshape this view. + /// + /// Panics if the view is not contiguous. + pub fn reshape(&mut self, shape: &[usize]) + where + T: Clone, + { + assert!(self.is_contiguous(), "can only reshape contiguous views"); + self.layout = DynLayout::from_shape(shape); + } +} + +impl<'a, T> TensorBase { + /// Reshape this view. + /// + /// Panics if the view is not contiguous. + pub fn reshape(&mut self, shape: &[usize]) + where + T: Clone, + { + assert!(self.is_contiguous(), "can only reshape contiguous views"); + self.layout = DynLayout::from_shape(shape); + } +} + +impl FromIterator for TensorBase, L> +where + [usize; 1]: AsIndex, +{ + /// Create a new 1D tensor filled with an arithmetic sequence of values + /// in the range `[start, end)` separated by `step`. If `step` is omitted, + /// it defaults to 1. + fn from_iter>(iter: I) -> TensorBase, L> { + let data: Vec = iter.into_iter().collect(); + TensorBase::from_data([data.len()].as_index(), data) + } +} + +impl From> for TensorBase, L> +where + [usize; 1]: AsIndex, +{ + fn from(vec: Vec) -> Self { + Self::from_data([vec.len()].as_index(), vec) + } +} + +impl<'a, T, L: Clone + MutLayout> From<&'a [T]> for TensorBase +where + [usize; 1]: AsIndex, +{ + fn from(slice: &'a [T]) -> Self { + Self::from_data([slice.len()].as_index(), slice) + } +} + +impl<'a, T, L: Clone + MutLayout, const N: usize> From<&'a [T; N]> for TensorBase +where + [usize; 1]: AsIndex, +{ + fn from(slice: &'a [T; N]) -> Self { + Self::from_data([slice.len()].as_index(), slice.as_slice()) + } +} + +/// Return the offsets of `M` successive elements along the `dim` axis, starting +/// at index `base`. +/// +/// Panics if any of the M element indices are out of bounds. +fn array_offsets( + layout: &NdLayout, + base: [usize; N], + dim: usize, +) -> [usize; M] { + assert!( + base[dim] < usize::MAX - M && layout.size(dim) >= base[dim] + M, + "array indices invalid" + ); + + let offset = layout.offset(base); + let stride = layout.stride(dim); + let mut offsets = [0; M]; + for i in 0..M { + offsets[i] = offset + i * stride; + } + offsets +} + +impl + AsMut<[T]>, const N: usize> TensorBase> { + /// Store an array of `M` elements into successive entries of a tensor along + /// the `dim` axis. + /// + /// See [NdTensorBase::get_array] for more details. + #[inline] + pub fn set_array(&mut self, base: [usize; N], dim: usize, values: [T; M]) + where + T: Copy, + { + let offsets: [usize; M] = array_offsets(&self.layout, base, dim); + let data = self.data.as_mut(); + + for i in 0..M { + // Safety: `array_offsets` returns valid offsets. + unsafe { *data.get_unchecked_mut(offsets[i]) = values[i] }; + } + } +} + +impl> TensorBase> { + /// Convert this vector to a static array of length `M`. + /// + /// Panics if the length of this vector is not M. + #[inline] + pub fn to_array(&self) -> [T; M] + where + T: Copy + Default, + { + self.get_array([0], 0) + } +} + +impl + AsMut<[T]>> TensorBase> { + /// Fill this vector with values from a static array of length `M`. + /// + /// Panics if the length of this vector is not M. + #[inline] + pub fn assign_array(&mut self, values: [T; M]) + where + T: Copy + Default, + { + self.set_array([0], 0, values) + } +} + +/// View of a slice of a tensor with a static dimension count. +pub type NdTensorView<'a, T, const N: usize> = TensorBase>; + +/// Tensor with a static dimension count. +pub type NdTensor = TensorBase, NdLayout>; + +/// Mutable view of a slice of a tensor with a static dimension count. +pub type NdTensorViewMut<'a, T, const N: usize> = TensorBase>; + +/// View of a slice as a matrix. +pub type Matrix<'a, T = f32> = NdTensorView<'a, T, 2>; + +/// Mutable view of a slice as a matrix. +pub type MatrixMut<'a, T = f32> = NdTensorViewMut<'a, T, 2>; + +/// Tensor with a dynamic dimension count. +pub type Tensor = TensorBase, DynLayout>; + +/// View of a slice of a tensor with a dynamic dimension count. +pub type TensorView<'a, T = f32> = TensorBase; + +/// Mutable view of a slice of a tensor with a dynamic dimension count. +pub type TensorViewMut<'a, T = f32> = TensorBase; + +impl, L: MutLayout, I: AsIndex> Index for TensorBase { + type Output = T; + + /// Return the element at a given index. + /// + /// Panics if the index is out of bounds along any dimension. + fn index(&self, index: I) -> &Self::Output { + let offset = self.layout.offset(index.as_index()); + &self.data.as_ref()[offset] + } +} + +impl + AsMut<[T]>, L: MutLayout, I: AsIndex> IndexMut + for TensorBase +{ + /// Return the element at a given index. + /// + /// Panics if the index is out of bounds along any dimension. + fn index_mut(&mut self, index: I) -> &mut Self::Output { + let offset = self.layout.offset(index.as_index()); + &mut self.data.as_mut()[offset] + } +} + +impl + Clone, L: MutLayout + Clone> Clone for TensorBase { + fn clone(&self) -> TensorBase { + let data = self.data.clone(); + TensorBase { + data, + layout: self.layout.clone(), + element_type: PhantomData, + } + } +} + +impl + Copy, L: MutLayout + Copy> Copy for TensorBase {} + +impl, L: MutLayout, V: AsView> PartialEq + for TensorBase +{ + fn eq(&self, other: &V) -> bool { + self.shape().as_ref() == other.shape().as_ref() && self.iter().eq(other.iter()) + } +} + +impl, const N: usize> From>> + for TensorBase +{ + fn from(tensor: TensorBase>) -> Self { + Self { + data: tensor.data, + layout: tensor.layout.into(), + element_type: PhantomData, + } + } +} + +impl, S2: AsRef<[T]>, const N: usize> TryFrom> + for TensorBase> +where + S1: Into, +{ + type Error = DimensionError; + + /// Convert a dynamic-dimensional tensor or view into a static-dimensional one. + /// + /// Fails if `value` does not have `N` dimensions. + fn try_from(value: TensorBase) -> Result { + let layout: NdLayout = value.layout().try_into()?; + Ok(TensorBase { + data: value.data.into(), + layout, + element_type: PhantomData, + }) + } +} + +// Trait for scalar (ie. non-array) values. +// +// This is used as a bound in contexts where we don't want a generic type +// `T` to be inferred as an array type. +pub trait Scalar {} + +impl Scalar for i32 {} +impl Scalar for f32 {} + +// The `T: Scalar` bound avoids ambiguity when choosing a `Tensor::from` +// impl for a nested array literal, as it prevents `T` from matching an array +// type. + +impl From<[T; D0]> for TensorBase, L> +where + [usize; 1]: AsIndex, +{ + /// Construct a 1D tensor from a 1D array. + fn from(value: [T; D0]) -> Self { + Self::from_data([D0].as_index(), value.iter().cloned().collect()) + } +} + +impl From<[[T; D1]; D0]> + for TensorBase, L> +where + [usize; 2]: AsIndex, +{ + /// Construct a 2D tensor from a nested array. + fn from(value: [[T; D1]; D0]) -> Self { + let data: Vec<_> = value.iter().flat_map(|y| y.iter()).cloned().collect(); + Self::from_data([D0, D1].as_index(), data) + } +} + +impl + From<[[[T; D2]; D1]; D0]> for TensorBase, L> +where + [usize; 3]: AsIndex, +{ + /// Construct a 3D tensor from a nested array. + fn from(value: [[[T; D2]; D1]; D0]) -> Self { + let data: Vec<_> = value + .iter() + .flat_map(|y| y.iter().flat_map(|z| z.iter())) + .cloned() + .collect(); + Self::from_data([D0, D1, D2].as_index(), data) + } +} + +/// A view of a tensor which does "weak" checking when indexing via +/// `view[]`. This means that it does not bounds-check individual +/// dimensions, but does bounds-check the computed offset. +/// +/// This offers a middle-ground between regular indexing, which bounds-checks +/// each index element, and unchecked indexing, which does no bounds-checking +/// at all. +pub struct WeaklyCheckedView, L: MutLayout> { + base: TensorBase, +} + +impl, L: MutLayout> Layout for WeaklyCheckedView { + type Index<'a> = L::Index<'a>; + type Indices = L::Indices; + + fn ndim(&self) -> usize { + self.base.ndim() + } + + fn try_offset(&self, index: Self::Index<'_>) -> Option { + self.base.try_offset(index) + } + + fn len(&self) -> usize { + self.base.len() + } + + fn shape(&self) -> Self::Index<'_> { + self.base.shape() + } + + fn strides(&self) -> Self::Index<'_> { + self.base.strides() + } + + fn indices(&self) -> Self::Indices { + self.base.indices() + } +} + +impl, L: MutLayout, I: AsIndex> Index for WeaklyCheckedView { + type Output = T; + fn index(&self, index: I) -> &Self::Output { + &self.base.data.as_ref()[self.base.layout.offset_unchecked(index.as_index())] + } +} + +impl + AsMut<[T]>, L: MutLayout, I: AsIndex> IndexMut + for WeaklyCheckedView +{ + fn index_mut(&mut self, index: I) -> &mut Self::Output { + let offset = self.base.layout.offset_unchecked(index.as_index()); + &mut self.base.data.as_mut()[offset] + } +} + +#[cfg(test)] +mod tests { + use super::{AsView, NdTensor, NdTensorView, Tensor}; + use crate::errors::FromDataError; + use crate::layout::MatrixLayout; + use crate::prelude::*; + use crate::rng::XorShiftRng; + use crate::SliceItem; + + #[test] + fn test_apply() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = NdTensor::from_data([2, 2], data); + tensor.apply(|x| *x * 2.); + assert_eq!(tensor.to_vec(), &[2., 4., 6., 8.]); + } + + #[test] + fn test_arange() { + let x = Tensor::arange(2, 6, None); + let y = NdTensor::arange(2, 6, None); + assert_eq!(x.data(), Some([2, 3, 4, 5].as_slice())); + assert_eq!(y.data(), Some([2, 3, 4, 5].as_slice())); + } + + #[test] + fn test_as_dyn() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + let dyn_view = tensor.as_dyn(); + assert_eq!(dyn_view.shape(), tensor.shape().as_ref()); + assert_eq!(dyn_view.to_vec(), tensor.to_vec()); + } + + #[test] + fn test_as_dyn_mut() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = NdTensor::from_data([2, 2], data); + let mut dyn_view = tensor.as_dyn_mut(); + + dyn_view[[0, 0]] = 9.; + + assert_eq!(tensor[[0, 0]], 9.); + } + + #[test] + fn test_assign_array() { + let mut tensor = NdTensor::zeros([2, 2]); + let mut transposed = tensor.view_mut(); + + transposed.permute([1, 0]); + transposed.slice_mut(0).assign_array([1, 2]); + transposed.slice_mut(1).assign_array([3, 4]); + + assert_eq!(tensor.iter().copied().collect::>(), [1, 3, 2, 4]); + } + + #[test] + fn test_axis_chunks() { + let tensor = NdTensor::arange(0, 8, None).into_shape([4, 2]); + let mut row_chunks = tensor.axis_chunks(0, 2); + + let chunk = row_chunks.next().unwrap(); + assert_eq!(chunk.shape(), &[2, 2]); + assert_eq!(chunk.to_vec(), &[0, 1, 2, 3]); + + let chunk = row_chunks.next().unwrap(); + assert_eq!(chunk.shape(), &[2, 2]); + assert_eq!(chunk.to_vec(), &[4, 5, 6, 7]); + + assert!(row_chunks.next().is_none()); + } + + #[test] + fn test_axis_chunks_mut() { + let mut tensor = NdTensor::arange(1, 9, None).into_shape([4, 2]); + let mut row_chunks = tensor.axis_chunks_mut(0, 2); + + let mut chunk = row_chunks.next().unwrap(); + chunk.apply(|x| x * 2); + + let mut chunk = row_chunks.next().unwrap(); + chunk.apply(|x| x * -2); + + assert!(row_chunks.next().is_none()); + assert_eq!(tensor.to_vec(), [2, 4, 6, 8, -10, -12, -14, -16]); + } + + #[test] + fn test_axis_iter() { + let tensor = NdTensor::arange(0, 4, None).into_shape([2, 2]); + let mut rows = tensor.axis_iter(0); + + let row = rows.next().unwrap(); + assert_eq!(row.shape(), &[2]); + assert_eq!(row.to_vec(), &[0, 1]); + + let row = rows.next().unwrap(); + assert_eq!(row.shape(), &[2]); + assert_eq!(row.to_vec(), &[2, 3]); + + assert!(rows.next().is_none()); + } + + #[test] + fn test_axis_iter_mut() { + let mut tensor = NdTensor::arange(1, 5, None).into_shape([2, 2]); + let mut rows = tensor.axis_iter_mut(0); + + let mut row = rows.next().unwrap(); + row.apply(|x| x * 2); + + let mut row = rows.next().unwrap(); + row.apply(|x| x * -2); + + assert!(rows.next().is_none()); + assert_eq!(tensor.to_vec(), [2, 4, -6, -8]); + } + + #[test] + fn test_broadcast() { + let data = vec![1., 2., 3., 4.]; + let dest_shape = [3, 1, 2, 2]; + let expected_data: Vec<_> = data.iter().copied().cycle().take(data.len() * 3).collect(); + let ndtensor = NdTensor::from_data([2, 2], data); + + // Broadcast static -> static. + let view = ndtensor.broadcast(dest_shape); + assert_eq!(view.shape(), dest_shape); + assert_eq!(view.to_vec(), expected_data); + + // Broadcast static -> dynamic. + let view = ndtensor.broadcast(dest_shape.as_slice()); + assert_eq!(view.shape(), dest_shape); + assert_eq!(view.to_vec(), expected_data); + + // Broadcast dynamic -> static. + let tensor = ndtensor.as_dyn(); + let view = tensor.broadcast(dest_shape); + assert_eq!(view.shape(), dest_shape); + assert_eq!(view.to_vec(), expected_data); + + // Broadcast dynamic -> dynamic. + let view = tensor.broadcast(dest_shape.as_slice()); + assert_eq!(view.shape(), dest_shape); + assert_eq!(view.to_vec(), expected_data); + } + + #[test] + fn test_broadcast_iter() { + let tensor = NdTensor::from_data([1], vec![3]); + let elems: Vec<_> = tensor.broadcast_iter(&[2, 2]).copied().collect(); + assert_eq!(elems, &[3, 3, 3, 3]); + } + + #[test] + fn test_clip_dim() { + let mut tensor = NdTensor::arange(0, 10, None).into_shape([3, 3]); + tensor.clip_dim(0, 0..3); // No-op + assert_eq!(tensor.shape(), [3, 3]); + + tensor.clip_dim(0, 1..2); // Remove first and last rows + assert_eq!(tensor.shape(), [1, 3]); + assert_eq!(tensor.data(), Some([3, 4, 5].as_slice())); + } + + #[test] + fn test_clone() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + let cloned = tensor.clone(); + assert_eq!(tensor.shape(), cloned.shape()); + assert_eq!(tensor.to_vec(), cloned.to_vec()); + } + + #[test] + fn test_copy_view() { + let data = vec![1., 2., 3., 4.]; + let view = NdTensorView::from_data([2, 2], &data); + + // Verify that views are copyable, if their layout is. + let view2 = view; + + assert_eq!(view.shape(), view2.shape()); + } + + #[test] + fn test_copy_from() { + let mut dest = Tensor::zeros(&[2, 2]); + let src = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]); + dest.copy_from(&src); + assert_eq!(dest.to_vec(), &[1., 2., 3., 4.]); + } + + #[test] + fn test_data() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let tensor = NdTensorView::from_data([2, 3], &data); + assert_eq!(tensor.data(), Some(data.as_slice())); + + let permuted = tensor.permuted([1, 0]); + assert_eq!(permuted.shape(), [3, 2]); + assert_eq!(permuted.data(), None); + } + + #[test] + fn test_data_mut() { + let mut data = vec![1., 2., 3., 4., 5., 6.]; + let mut tensor = NdTensor::from_data([2, 3], data.clone()); + assert_eq!(tensor.data_mut(), Some(data.as_mut_slice())); + + let mut permuted = tensor.permuted_mut([1, 0]); + assert_eq!(permuted.shape(), [3, 2]); + assert_eq!(permuted.data_mut(), None); + } + + #[test] + fn test_fill() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = NdTensor::from_data([2, 2], data); + tensor.fill(9.); + assert_eq!(tensor.to_vec(), &[9., 9., 9., 9.]); + } + + #[test] + fn test_from_nested_array() { + let x = NdTensor::from([1, 2, 3]); + assert_eq!(x.shape(), [3]); + assert_eq!(x.data(), Some([1, 2, 3].as_slice())); + + let x = NdTensor::from([[1, 2], [3, 4]]); + assert_eq!(x.shape(), [2, 2]); + assert_eq!(x.data(), Some([1, 2, 3, 4].as_slice())); + + let x = NdTensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + assert_eq!(x.shape(), [2, 2, 2]); + assert_eq!(x.data(), Some([1, 2, 3, 4, 5, 6, 7, 8].as_slice())); + } + + #[test] + fn test_from_vec_or_slice() { + let x = NdTensor::from(vec![1, 2, 3, 4]); + assert_eq!(x.shape(), [4]); + assert_eq!(x.data(), Some([1, 2, 3, 4].as_slice())); + + let x = NdTensorView::from(&[1, 2, 3]); + assert_eq!(x.shape(), [3]); + assert_eq!(x.data(), Some([1, 2, 3].as_slice())); + } + + #[test] + fn test_dyn_tensor_from_nd_tensor() { + let x = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let y: Tensor = x.into(); + assert_eq!(y.data(), Some([1, 2, 3, 4].as_slice())); + assert_eq!(y.shape(), &[2, 2]); + } + + #[test] + fn test_nd_tensor_from_dyn_tensor() { + let x = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]); + let y: NdTensor = x.try_into().unwrap(); + assert_eq!(y.data(), Some([1, 2, 3, 4].as_slice())); + assert_eq!(y.shape(), [2, 2]); + + let x = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]); + let y: Result, _> = x.try_into(); + assert!(y.is_err()); + } + + #[test] + fn test_from_data() { + let x = NdTensor::from_data([1, 2, 2], vec![1, 2, 3, 4]); + assert_eq!(x.shape(), [1, 2, 2]); + assert_eq!(x.strides(), [4, 2, 1]); + assert_eq!(x.to_vec(), [1, 2, 3, 4]); + } + + #[test] + #[should_panic(expected = "data length does not match shape")] + fn test_from_data_shape_mismatch() { + NdTensor::from_data([2, 2, 2], vec![1, 2, 3, 4]); + } + + #[test] + fn test_from_data_with_strides() { + let x = NdTensor::from_data_with_strides([2, 2, 1], vec![1, 2, 3, 4], [1, 2, 4]).unwrap(); + assert_eq!(x.shape(), [2, 2, 1]); + assert_eq!(x.strides(), [1, 2, 4]); + assert_eq!(x.to_vec(), [1, 3, 2, 4]); + + // Invalid (wrong storage length) + let x = NdTensor::from_data_with_strides([2, 2, 2], vec![1, 2, 3, 4], [1, 2, 4]); + assert_eq!(x, Err(FromDataError::StorageTooShort)); + + // Invalid strides (overlapping) + let x = NdTensor::from_data_with_strides([2, 2], vec![1, 2], [0, 1]); + assert_eq!(x, Err(FromDataError::MayOverlap)); + } + + #[test] + fn test_from_slice_with_strides() { + // The strides here are overlapping, but `from_slice_with_strides` + // allows this since it is a read-only view. + let data = [1, 2]; + let x = NdTensorView::from_slice_with_strides([2, 2], &data, [0, 1]).unwrap(); + assert_eq!(x.to_vec(), [1, 2, 1, 2]); + } + + #[test] + fn test_from_iter() { + let x: Tensor = [1., 2., 3., 4.].into_iter().collect(); + assert_eq!(x.shape(), &[4]); + assert_eq!(x.data(), Some([1., 2., 3., 4.].as_slice())); + + let y: NdTensor<_, 1> = [1., 2., 3., 4.].into_iter().collect(); + assert_eq!(y.shape(), [4]); + assert_eq!(y.data(), Some([1., 2., 3., 4.].as_slice())); + } + + #[test] + fn test_from_scalar() { + let x = Tensor::from_scalar(5.); + let y = NdTensor::from_scalar(6.); + assert_eq!(x.item(), Some(&5.)); + assert_eq!(y.item(), Some(&6.)); + } + + #[test] + fn test_from_vec() { + let x = NdTensor::from_vec(vec![1, 2, 3, 4]); + assert_eq!(x.shape(), [4]); + assert_eq!(x.data(), Some([1, 2, 3, 4].as_slice())); + } + + #[test] + fn test_full() { + let tensor = NdTensor::full([2, 2], 2.); + assert_eq!(tensor.shape(), [2, 2]); + assert_eq!(tensor.data(), Some([2., 2., 2., 2.].as_slice())); + } + + #[test] + fn test_get() { + // NdLayout + let data = vec![1., 2., 3., 4.]; + let tensor: NdTensor = NdTensor::from_data([2, 2], data); + assert_eq!(tensor.get([1, 1]), Some(&4.)); + assert_eq!(tensor.get([2, 1]), None); + + // DynLayout + let data = vec![1., 2., 3., 4.]; + let tensor: Tensor = Tensor::from_data(&[2, 2], data); + assert_eq!(tensor.get([1, 1]), Some(&4.)); + assert_eq!(tensor.get([2, 1]), None); // Invalid index + assert_eq!(tensor.get([1, 2, 3]), None); // Incorrect dim count + } + + #[test] + fn test_get_array() { + let tensor = NdTensor::arange(1, 17, None).into_shape([4, 2, 2]); + + // First dim, zero base. + let values: [i32; 4] = tensor.get_array([0, 0, 0], 0); + assert_eq!(values, [1, 5, 9, 13]); + + // First dim, different base. + let values: [i32; 4] = tensor.get_array([0, 1, 1], 0); + assert_eq!(values, [4, 8, 12, 16]); + + // Last dim, zero base. + let values: [i32; 2] = tensor.get_array([0, 0, 0], 2); + assert_eq!(values, [1, 2]); + } + + #[test] + fn test_get_mut() { + let data = vec![1., 2., 3., 4.]; + let mut tensor: NdTensor = NdTensor::from_data([2, 2], data); + if let Some(elem) = tensor.get_mut([1, 1]) { + *elem = 9.; + } + assert_eq!(tensor[[1, 1]], 9.); + assert_eq!(tensor.get_mut([2, 1]), None); + } + + #[test] + fn test_get_unchecked() { + let ndtensor = NdTensor::arange(1, 5, None); + for i in 0..ndtensor.size(0) { + assert_eq!( + unsafe { ndtensor.view().get_unchecked([i]) }, + &ndtensor[[i]] + ); + } + + let tensor = Tensor::arange(1, 5, None); + for i in 0..tensor.size(0) { + assert_eq!(unsafe { tensor.view().get_unchecked([i]) }, &ndtensor[[i]]); + } + } + + #[test] + fn test_get_unchecked_mut() { + let mut ndtensor = NdTensor::arange(1, 5, None); + for i in 0..ndtensor.size(0) { + unsafe { *ndtensor.get_unchecked_mut([i]) += 1 } + } + assert_eq!(ndtensor.to_vec(), &[2, 3, 4, 5]); + + let mut tensor = Tensor::arange(1, 5, None); + for i in 0..tensor.size(0) { + unsafe { *tensor.get_unchecked_mut([i]) += 1 } + } + assert_eq!(tensor.to_vec(), &[2, 3, 4, 5]); + } + + #[test] + fn test_index_and_index_mut() { + // NdLayout + let data = vec![1., 2., 3., 4.]; + let mut tensor: NdTensor = NdTensor::from_data([2, 2], data); + assert_eq!(tensor[[1, 1]], 4.); + tensor[[1, 1]] = 9.; + assert_eq!(tensor[[1, 1]], 9.); + + // DynLayout + let data = vec![1., 2., 3., 4.]; + let mut tensor: Tensor = Tensor::from_data(&[2, 2], data); + assert_eq!(tensor[[1, 1]], 4.); + tensor[&[1, 1]] = 9.; + assert_eq!(tensor[[1, 1]], 9.); + } + + #[test] + fn test_into_data() { + let tensor = NdTensor::from_data([2], vec![2., 3.]); + assert_eq!(tensor.into_data(), vec![2., 3.]); + } + + #[test] + fn test_into_dyn() { + let tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]); + let dyn_tensor = tensor.into_dyn(); + assert_eq!(dyn_tensor.shape(), &[2, 2]); + assert_eq!(dyn_tensor.data(), Some([1., 2., 3., 4.].as_slice())); + } + + #[test] + fn test_into_shape() { + // Contiguous tensor. + let tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]); + let reshaped = tensor.into_shape([4]); + assert_eq!(reshaped.shape(), [4]); + assert_eq!(reshaped.data(), Some([1., 2., 3., 4.].as_slice())); + + // Non-contiguous tensor. + let mut tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]); + tensor.transpose(); + let reshaped = tensor.into_shape([4]); + assert_eq!(reshaped.shape(), [4]); + assert_eq!(reshaped.data(), Some([1., 3., 2., 4.].as_slice())); + } + + #[test] + fn test_inner_iter() { + let tensor = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]); + let mut rows = tensor.inner_iter::<1>(); + + let row = rows.next().unwrap(); + assert_eq!(row.shape(), [2]); + assert_eq!(row.to_vec(), &[1, 2]); + + let row = rows.next().unwrap(); + assert_eq!(row.shape(), [2]); + assert_eq!(row.to_vec(), &[3, 4]); + + assert_eq!(rows.next(), None); + } + + #[test] + fn test_inner_iter_mut() { + let mut tensor = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]); + let mut rows = tensor.inner_iter_mut::<1>(); + + let mut row = rows.next().unwrap(); + assert_eq!(row.shape(), [2]); + row.apply(|x| x * 2); + + let mut row = rows.next().unwrap(); + assert_eq!(row.shape(), [2]); + row.apply(|x| x * 2); + + assert_eq!(rows.next(), None); + + assert_eq!(tensor.to_vec(), &[2, 4, 6, 8]); + } + + #[test] + fn test_insert_axis() { + let mut tensor = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]); + tensor.insert_axis(0); + assert_eq!(tensor.shape(), &[1, 2, 2]); + tensor.insert_axis(3); + assert_eq!(tensor.shape(), &[1, 2, 2, 1]); + } + + #[test] + fn test_item() { + let tensor = NdTensor::from_data([], vec![5.]); + assert_eq!(tensor.item(), Some(&5.)); + let tensor = NdTensor::from_data([2], vec![2., 3.]); + assert_eq!(tensor.item(), None); + + let tensor = Tensor::from_data(&[], vec![5.]); + assert_eq!(tensor.item(), Some(&5.)); + let tensor = Tensor::from_data(&[2], vec![2., 3.]); + assert_eq!(tensor.item(), None); + } + + #[test] + fn test_iter() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + assert_eq!( + tensor.iter().copied().collect::>(), + &[1., 2., 3., 4.] + ); + + let data = vec![1., 2., 3., 4.]; + let tensor = Tensor::from_data(&[2, 2], data); + assert_eq!( + tensor.iter().copied().collect::>(), + &[1., 2., 3., 4.] + ); + } + + #[test] + fn test_iter_mut() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = NdTensor::from_data([2, 2], data); + tensor.iter_mut().for_each(|x| *x *= 2.); + assert_eq!(tensor.to_vec(), &[2., 4., 6., 8.]); + } + + #[test] + fn test_lanes() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + let mut lanes = tensor.lanes(1); + assert_eq!( + lanes.next().unwrap().copied().collect::>(), + &[1., 2.] + ); + assert_eq!( + lanes.next().unwrap().copied().collect::>(), + &[3., 4.] + ); + } + + #[test] + fn test_lanes_mut() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = NdTensor::from_data([2, 2], data); + let mut lanes = tensor.lanes_mut(1); + assert_eq!(lanes.next().unwrap().collect::>(), &[&1., &2.]); + assert_eq!(lanes.next().unwrap().collect::>(), &[&3., &4.]); + } + + #[test] + fn test_make_contiguous() { + let mut tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]); + assert!(tensor.is_contiguous()); + + // No-op, since tensor is already contiguous. + tensor.make_contiguous(); + assert!(tensor.is_contiguous()); + + // On a non-contiguous tensor, the data should be shuffled. + tensor.transpose(); + assert!(!tensor.is_contiguous()); + tensor.make_contiguous(); + assert!(tensor.is_contiguous()); + assert_eq!(tensor.data(), Some([1., 3., 2., 4.].as_slice())); + } + + #[test] + fn test_map() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + let doubled = tensor.map(|x| x * 2.); + assert_eq!(doubled.to_vec(), &[2., 4., 6., 8.]); + } + + #[test] + fn test_matrix_layout() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let tensor = NdTensorView::from_data([2, 3], &data); + assert_eq!(tensor.rows(), 2); + assert_eq!(tensor.row_stride(), 3); + assert_eq!(tensor.cols(), 3); + assert_eq!(tensor.col_stride(), 1); + } + + #[test] + fn test_move_axis() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let mut tensor = NdTensorView::from_data([2, 3], &data); + + tensor.move_axis(1, 0); + assert_eq!(tensor.shape(), [3, 2]); + assert_eq!(tensor.to_vec(), &[1., 4., 2., 5., 3., 6.]); + + tensor.move_axis(0, 1); + assert_eq!(tensor.shape(), [2, 3]); + assert_eq!(tensor.to_vec(), &[1., 2., 3., 4., 5., 6.]); + } + + #[test] + fn test_nd_view() { + let tensor: Tensor = Tensor::zeros(&[1, 4, 5]); + + // Dynamic -> static rank conversion. + let nd_view = tensor.nd_view::<3>(); + assert_eq!(nd_view.shape(), [1, 4, 5]); + assert_eq!(nd_view.strides().as_ref(), tensor.strides()); + + // Static -> static rank conversion. Pointless, but it should compile. + let nd_view_2 = nd_view.nd_view::<3>(); + assert_eq!(nd_view_2.shape(), nd_view.shape()); + } + + #[test] + fn test_nd_view_mut() { + let mut tensor: Tensor = Tensor::zeros(&[1, 4, 5]); + let mut nd_view = tensor.nd_view_mut::<3>(); + assert_eq!(nd_view.shape(), [1, 4, 5]); + + nd_view[[0, 0, 0]] = 9.; + + assert_eq!(tensor[[0, 0, 0]], 9.); + } + + #[test] + fn test_non_contiguous_data() { + let mut tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + assert_eq!(tensor.data(), Some(tensor.view().non_contiguous_data())); + + tensor.transpose(); + + assert!(tensor.data().is_none()); + assert_eq!(tensor.view().non_contiguous_data(), [1, 2, 3, 4]); + } + + #[test] + fn test_rand() { + let mut rng = XorShiftRng::new(1234); + let tensor = NdTensor::rand([2, 2], &mut rng); + assert_eq!(tensor.shape(), [2, 2]); + for &x in tensor.iter() { + assert!(x >= 0. && x <= 1.); + } + } + + #[test] + fn test_permute() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let mut tensor = NdTensorView::from_data([2, 3], &data); + + tensor.permute([1, 0]); + + assert_eq!(tensor.shape(), [3, 2]); + assert_eq!(tensor.to_vec(), &[1., 4., 2., 5., 3., 6.]); + } + + #[test] + fn test_permuted() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let tensor = NdTensorView::from_data([2, 3], &data); + + let permuted = tensor.permuted([1, 0]); + + assert_eq!(permuted.shape(), [3, 2]); + assert_eq!(permuted.to_vec(), &[1., 4., 2., 5., 3., 6.]); + } + + #[test] + fn test_permuted_mut() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let mut tensor = NdTensor::from_data([2, 3], data); + + let mut permuted = tensor.permuted_mut([1, 0]); + permuted[[2, 1]] = 8.; + + assert_eq!(permuted.shape(), [3, 2]); + assert_eq!(permuted.to_vec(), &[1., 4., 2., 5., 3., 8.]); + } + + #[test] + fn test_reshape() { + // Owned tensor + let mut tensor = Tensor::::from_data(&[2, 2], vec![1., 2., 3., 4.]); + tensor.transpose(); + tensor.reshape(&[4]); + assert_eq!(tensor.shape(), &[4]); + assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]); + + // View + let mut view = tensor.view(); + view.reshape(&[2, 2]); + assert_eq!(view.shape(), &[2, 2]); + + // Mut view + let mut view_mut = tensor.view_mut(); + view_mut.reshape(&[2, 2]); + assert_eq!(view_mut.shape(), &[2, 2]); + } + + #[test] + fn test_reshaped() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let tensor = NdTensorView::from_data([1, 1, 2, 1, 3], &data); + + // Reshape to static dim count + let reshaped = tensor.reshaped([6]); + assert_eq!(reshaped.shape(), [6]); + + // Reshape to dynamic dim count + let reshaped = tensor.reshaped([6].as_slice()); + assert_eq!(reshaped.shape(), &[6]); + } + + #[test] + fn test_reshaped_mut() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let mut tensor = NdTensor::from_data([1, 1, 2, 1, 3], data); + + let mut reshaped = tensor.reshaped_mut([6]); + reshaped[[0]] = 0.; + reshaped[[5]] = 0.; + + assert_eq!(tensor.data(), Some([0., 2., 3., 4., 5., 0.].as_slice())); + } + + #[test] + fn test_set_array() { + let mut tensor = NdTensor::arange(1, 17, None).into_shape([4, 2, 2]); + tensor.set_array([0, 0, 0], 0, [-1, -2, -3, -4]); + assert_eq!( + tensor.iter().copied().collect::>(), + &[-1, 2, 3, 4, -2, 6, 7, 8, -3, 10, 11, 12, -4, 14, 15, 16] + ); + } + + #[test] + fn test_slice_with_ndlayout() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + + let row_one = tensor.slice(0); + assert_eq!(row_one[[0]], 1.); + assert_eq!(row_one[[1]], 2.); + + let row_two = tensor.slice(1); + assert_eq!(row_two[[0]], 3.); + assert_eq!(row_two[[1]], 4.); + } + + #[test] + fn test_slice_dyn_with_ndlayout() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + + let row_one = tensor.slice_dyn(0); + assert_eq!(row_one[[0]], 1.); + assert_eq!(row_one[[1]], 2.); + + let row_two = tensor.slice_dyn(1); + assert_eq!(row_two[[0]], 3.); + assert_eq!(row_two[[1]], 4.); + } + + #[test] + fn test_slice_with_dynlayout() { + let data = vec![1., 2., 3., 4.]; + let tensor = Tensor::from_data(&[2, 2], data); + + let row_one = tensor.slice(0); + assert_eq!(row_one[[0]], 1.); + assert_eq!(row_one[[1]], 2.); + + let row_two = tensor.slice(1); + assert_eq!(row_two[[0]], 3.); + assert_eq!(row_two[[1]], 4.); + } + + #[test] + fn test_slice_dyn_with_dynlayout() { + let data = vec![1., 2., 3., 4.]; + let tensor = Tensor::from_data(&[2, 2], data); + + let row_one = tensor.slice_dyn(0); + assert_eq!(row_one[[0]], 1.); + assert_eq!(row_one[[1]], 2.); + + let row_two = tensor.slice_dyn(1); + assert_eq!(row_two[[0]], 3.); + assert_eq!(row_two[[1]], 4.); + } + + #[test] + fn test_slice_iter() { + let data = vec![1., 2., 3., 4.]; + let tensor = Tensor::from_data(&[2, 2], data); + let row_one: Vec<_> = tensor + .slice_iter(&[SliceItem::Index(0), SliceItem::full_range()]) + .copied() + .collect(); + assert_eq!(row_one, &[1., 2.]); + } + + #[test] + fn test_slice_mut() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = NdTensor::from_data([2, 2], data); + + let mut row = tensor.slice_mut(1); + row[[0]] = 8.; + row[[1]] = 9.; + + assert_eq!(tensor.to_vec(), &[1., 2., 8., 9.]); + } + + #[test] + fn test_slice_mut_dyn() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = NdTensor::from_data([2, 2], data); + + let mut row = tensor.slice_mut_dyn(1); + row[[0]] = 8.; + row[[1]] = 9.; + + assert_eq!(tensor.to_vec(), &[1., 2., 8., 9.]); + } + + #[test] + fn test_squeezed() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let tensor = NdTensorView::from_data([1, 1, 2, 1, 3], &data); + + let squeezed = tensor.squeezed(); + + assert_eq!(squeezed.shape(), &[2, 3]); + } + + #[test] + fn test_to_array() { + let tensor = NdTensor::arange(1., 5., None).into_shape([2, 2]); + let col0: [f32; 2] = tensor.view().transposed().slice::<1, _>(0).to_array(); + let col1: [f32; 2] = tensor.view().transposed().slice::<1, _>(1).to_array(); + assert_eq!(col0, [1., 3.]); + assert_eq!(col1, [2., 4.]); + } + + #[test] + fn test_to_contiguous() { + let data = vec![1., 2., 3., 4.]; + let tensor = NdTensor::from_data([2, 2], data); + // TODO - Try both contiguous and non-contiguous tensors. + let tensor = tensor.to_contiguous(); + // TODO - Check the actual storage from to_contiguous + assert_eq!(tensor.to_vec(), &[1., 2., 3., 4.]); + } + + #[test] + fn test_to_shape() { + let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let flat = tensor.to_shape([4]); + assert_eq!(flat.shape(), [4]); + assert_eq!(flat.data(), Some([1, 2, 3, 4].as_slice())); + } + + #[test] + fn test_to_vec() { + // Contiguous case + let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + assert_eq!(tensor.to_vec(), &[1, 2, 3, 4]); + + // Non-contiguous case + let mut tensor = tensor.clone(); + tensor.transpose(); + assert_eq!(tensor.to_vec(), &[1, 3, 2, 4]); + } + + #[test] + fn test_to_tensor() { + let data = vec![1., 2., 3., 4.]; + let view = NdTensorView::from_data([2, 2], &data); + let tensor = view.to_tensor(); + assert_eq!(tensor.shape(), view.shape()); + assert_eq!(tensor.to_vec(), view.to_vec()); + } + + #[test] + fn test_transpose() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let mut tensor = NdTensorView::from_data([2, 3], &data); + + tensor.transpose(); + + assert_eq!(tensor.shape(), [3, 2]); + assert_eq!(tensor.to_vec(), &[1., 4., 2., 5., 3., 6.]); + } + + #[test] + fn test_transposed() { + let data = vec![1., 2., 3., 4., 5., 6.]; + let tensor = NdTensorView::from_data([2, 3], &data); + + let permuted = tensor.transposed(); + + assert_eq!(permuted.shape(), [3, 2]); + assert_eq!(permuted.to_vec(), &[1., 4., 2., 5., 3., 6.]); + } + + #[test] + fn test_try_slice() { + let data = vec![1., 2., 3., 4.]; + let tensor = Tensor::from_data(&[2, 2], data); + + let row = tensor.try_slice(0); + assert!(row.is_ok()); + assert_eq!(row.unwrap().data(), Some([1., 2.].as_slice())); + + let row = tensor.try_slice(1); + assert!(row.is_ok()); + + let row = tensor.try_slice(2); + assert!(row.is_err()); + } + + #[test] + fn test_try_slice_mut() { + let data = vec![1., 2., 3., 4.]; + let mut tensor = Tensor::from_data(&[2, 2], data); + + let mut row = tensor.try_slice_mut(0).unwrap(); + row[[0]] += 1.; + row[[1]] += 1.; + assert_eq!(row.data(), Some([2., 3.].as_slice())); + + let row = tensor.try_slice_mut(1); + assert!(row.is_ok()); + + let row = tensor.try_slice(2); + assert!(row.is_err()); + } + + #[test] + fn test_view() { + let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let view = tensor.view(); + assert_eq!(view.data(), Some([1, 2, 3, 4].as_slice())); + } + + #[test] + fn test_view_mut() { + let mut tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let mut view = tensor.view_mut(); + view[[0, 0]] = 0; + view[[1, 1]] = 0; + assert_eq!(tensor.data(), Some([0, 2, 3, 0].as_slice())); + } + + #[test] + fn test_weakly_checked_view() { + let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let view = tensor.weakly_checked_view(); + + // Valid indexing should work the same as a normal view. + for y in 0..tensor.size(0) { + for x in 0..tensor.size(1) { + assert_eq!(view[[y, x]], tensor[[y, x]]); + } + } + + // Indexes that are invalid, but lead to an in-bounds offset, won't + // trigger a panic, unlike a normal view. + assert_eq!(view[[0, 2]], 3); + } + + #[test] + fn test_weakly_checked_view_mut() { + let mut tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let mut view = tensor.weakly_checked_view_mut(); + + // Valid indices + view[[0, 0]] = 5; + view[[1, 1]] = 6; + + // Indices that are invalid, but lead to an in-bounds offset, won't + // trigger a panic, unlike a normal view. + view[[0, 2]] = 7; + + assert_eq!(tensor.data(), Some([5, 2, 7, 6].as_slice())); + } + + #[test] + fn test_zeros() { + let tensor = NdTensor::zeros([2, 2]); + assert_eq!(tensor.shape(), [2, 2]); + assert_eq!(tensor.data(), Some([0, 0, 0, 0].as_slice())); + } +} diff --git a/rten-tensor/src/unified_tensor/iterators.rs b/rten-tensor/src/unified_tensor/iterators.rs new file mode 100644 index 00000000..f0e23e57 --- /dev/null +++ b/rten-tensor/src/unified_tensor/iterators.rs @@ -0,0 +1,247 @@ +use std::ops::Add; + +use crate::index_iterator::DynIndices; +use crate::layout::Layout; +use crate::range::to_slice_items; + +use super::{ + AsView, MutLayout, NdTensorView, NdTensorViewMut, TensorBase, TensorView, TensorViewMut, +}; + +/// Iterator over views of the N innermost dimensions of a tensor with element +/// type `T` and layout `L`. +pub struct InnerIter<'a, T, L: MutLayout, const N: usize> { + outer_indices: DynIndices, + view: TensorBase, +} + +impl<'a, T, L: MutLayout, const N: usize> InnerIter<'a, T, L, N> { + pub fn new(view: TensorBase) -> Self { + assert!(view.ndim() >= N); + let outer_dims = view.ndim() - N; + let outer_indices = DynIndices::from_shape(&view.shape().as_ref()[..outer_dims]); + InnerIter { + outer_indices, + view, + } + } +} + +impl<'a, T, L: MutLayout, const N: usize> Iterator for InnerIter<'a, T, L, N> { + type Item = NdTensorView<'a, T, N>; + + fn next(&mut self) -> Option { + self.outer_indices.next().map(|idx| { + let slice_items = to_slice_items(&idx); + self.view.slice(slice_items.as_slice()) + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.outer_indices.size_hint() + } +} + +impl<'a, T, L: MutLayout, const N: usize> ExactSizeIterator for InnerIter<'a, T, L, N> {} + +/// Iterator over mutable views of the N innermost dimensions of a tensor. +pub struct InnerIterMut<'a, T, L: MutLayout, const N: usize> { + outer_indices: DynIndices, + view: TensorBase, +} + +impl<'a, T, L: MutLayout, const N: usize> InnerIterMut<'a, T, L, N> { + pub fn new(view: TensorBase) -> Self { + assert!(view.ndim() >= N); + let outer_dims = view.ndim() - N; + let outer_indices = DynIndices::from_shape(&view.shape().as_ref()[..outer_dims]); + InnerIterMut { + outer_indices, + view, + } + } +} + +impl<'a, T, L: MutLayout, const N: usize> Iterator for InnerIterMut<'a, T, L, N> { + type Item = NdTensorViewMut<'a, T, N>; + + fn next(&mut self) -> Option { + self.outer_indices.next().map(|idx| { + let slice_items = to_slice_items(&idx); + let view: NdTensorViewMut<'_, T, N> = self.view.slice_mut(slice_items.as_slice()); + unsafe { + // Safety: Outer view is non-broadcasting, and we increment the + // outer index each time, so returned views will not overlap. + std::mem::transmute::, NdTensorViewMut<'a, T, N>>(view) + } + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.outer_indices.size_hint() + } +} + +impl<'a, T, L: MutLayout, const N: usize> ExactSizeIterator for InnerIterMut<'a, T, L, N> {} + +/// Iterator over slices of a tensor along an axis. See [TensorView::axis_iter]. +pub struct AxisIter<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, +} + +impl<'a, T, L: MutLayout> AxisIter<'a, T, L> { + pub fn new(view: &TensorBase, dim: usize) -> AxisIter<'a, T, L> { + let mut permuted = view.clone(); + permuted.move_axis(dim, 0); + AxisIter { + view: permuted, + index: 0, + } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisIter<'a, T, L> { + type Item = TensorView<'a, T>; + + fn next(&mut self) -> Option { + if self.index >= self.view.size(0) { + None + } else { + let view = self.view.slice_dyn([self.index]); + self.index += 1; + Some(view) + } + } +} + +/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_iter_mut]. +pub struct AxisIterMut<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, +} + +impl<'a, T, L: MutLayout> AxisIterMut<'a, T, L> { + pub fn new(mut view: TensorBase, dim: usize) -> AxisIterMut<'a, T, L> { + // See notes in `Layout` about internal overlap. + assert!( + !view.layout().is_broadcast(), + "Cannot mutably iterate over broadcasting view" + ); + view.move_axis(dim, 0); + AxisIterMut { view, index: 0 } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisIterMut<'a, T, L> { + type Item = TensorViewMut<'a, T>; + + fn next(&mut self) -> Option { + if self.index >= self.view.size(0) { + None + } else { + let index = self.index; + self.index += 1; + + // Safety: This is non-broadcasting view, and we increment the index + // each time, so returned views will not overlap. + let view = unsafe { + let view = self.view.slice_mut_dyn([index]); + std::mem::transmute::, TensorViewMut<'a, T>>(view) + }; + Some(view) + } + } +} + +/// Iterator over slices of a tensor along an axis. See [TensorView::axis_chunks]. +pub struct AxisChunks<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, + chunk_size: usize, +} + +impl<'a, T, L: MutLayout> AxisChunks<'a, T, L> { + pub fn new( + view: &TensorBase, + dim: usize, + chunk_size: usize, + ) -> AxisChunks<'a, T, L> { + let mut permuted = view.clone(); + permuted.move_axis(dim, 0); + AxisChunks { + view: permuted, + index: 0, + chunk_size, + } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> { + type Item = TensorView<'a, T>; + + fn next(&mut self) -> Option { + let size = self.view.size(0); + if self.index >= self.view.size(0) { + None + } else { + let view = self + .view + .slice_dyn(self.index..self.index.add(self.chunk_size).min(size)); + self.index += self.chunk_size; + Some(view) + } + } +} + +/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_chunks_mut]. +pub struct AxisChunksMut<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, + chunk_size: usize, +} + +impl<'a, T, L: MutLayout> AxisChunksMut<'a, T, L> { + pub fn new( + mut view: TensorBase, + dim: usize, + chunk_size: usize, + ) -> AxisChunksMut<'a, T, L> { + // See notes in `Layout` about internal overlap. + assert!( + !view.layout().is_broadcast(), + "Cannot mutably iterate over broadcasting view" + ); + view.move_axis(dim, 0); + AxisChunksMut { + view, + chunk_size, + index: 0, + } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> { + type Item = TensorViewMut<'a, T>; + + fn next(&mut self) -> Option { + let size = self.view.size(0); + + if self.index >= size { + None + } else { + let index = self.index; + self.index += self.chunk_size; + + // Safety: This is non-broadcasting view, and we increment the index + // each time, so returned views will not overlap. + let view = unsafe { + let view = self + .view + .slice_mut_dyn(index..index.add(self.chunk_size).min(size)); + std::mem::transmute::, TensorViewMut<'a, T>>(view) + }; + Some(view) + } + } +} From 578fcf2390fdfe0ecbf87c29ba5f8b647f156bd5 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 21 Jan 2024 08:21:41 +0000 Subject: [PATCH 02/12] Switch `rten-tensor` over to unified tensor - Change the library entry point to export the unified tensor types / traits and aliases, under the same names as the legacy tensor types - Add missing trait import in iterator tests - Remove old `AxisIter*`, `InnerIter*` types from rten-tensor iterators --- rten-tensor/src/iterators.rs | 244 +---------------------------------- rten-tensor/src/lib.rs | 29 +++-- 2 files changed, 20 insertions(+), 253 deletions(-) diff --git a/rten-tensor/src/iterators.rs b/rten-tensor/src/iterators.rs index b74bd48b..96288fac 100644 --- a/rten-tensor/src/iterators.rs +++ b/rten-tensor/src/iterators.rs @@ -1,11 +1,9 @@ use std::iter::{repeat, zip, Cycle, FusedIterator, StepBy, Take}; -use std::ops::{Add, Range}; +use std::ops::Range; use std::slice; use super::range::{SliceItem, SliceRange}; -use crate::{ - to_slice_items, DynIndices, Layout, NdTensorView, NdTensorViewMut, TensorView, TensorViewMut, -}; +use crate::Layout; /// Borrowed reference to a tensor's data and layout. This differs from /// [TensorView] in that it borrows the layout rather than having its own. @@ -643,164 +641,6 @@ impl<'a, T> ExactSizeIterator for BroadcastIter<'a, T> {} impl<'a, T> FusedIterator for BroadcastIter<'a, T> {} -/// Iterator over slices of a tensor along an axis. See [TensorView::axis_iter]. -pub struct AxisIter<'a, T> { - view: TensorView<'a, T>, - index: usize, -} - -impl<'a, T> AxisIter<'a, T> { - pub fn new(view: &TensorView<'a, T>, dim: usize) -> AxisIter<'a, T> { - let mut permuted = view.clone(); - permuted.move_axis(dim, 0); - AxisIter { - view: permuted, - index: 0, - } - } -} - -impl<'a, T> Iterator for AxisIter<'a, T> { - type Item = TensorView<'a, T>; - - fn next(&mut self) -> Option { - if self.index >= self.view.size(0) { - None - } else { - let view = self.view.slice([self.index]); - self.index += 1; - Some(view) - } - } -} - -/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_iter_mut]. -pub struct AxisIterMut<'a, T> { - view: TensorViewMut<'a, T>, - index: usize, -} - -impl<'a, T> AxisIterMut<'a, T> { - pub fn new(mut view: TensorViewMut<'a, T>, dim: usize) -> AxisIterMut<'a, T> { - // See notes in `Layout` about internal overlap. - assert!( - !view.layout().is_broadcast(), - "Cannot mutably iterate over broadcasting view" - ); - view.move_axis(dim, 0); - AxisIterMut { view, index: 0 } - } -} - -impl<'a, T> Iterator for AxisIterMut<'a, T> { - type Item = TensorViewMut<'a, T>; - - fn next(&mut self) -> Option { - if self.index >= self.view.size(0) { - None - } else { - let index = self.index; - self.index += 1; - - // Safety: This is non-broadcasting view, and we increment the index - // each time, so returned views will not overlap. - let view = unsafe { - let view = self.view.slice_mut([index]); - std::mem::transmute::, TensorViewMut<'a, T>>(view) - }; - Some(view) - } - } -} - -/// Iterator over slices of a tensor along an axis. See [TensorView::axis_chunks]. -pub struct AxisChunks<'a, T> { - view: TensorView<'a, T>, - index: usize, - chunk_size: usize, -} - -impl<'a, T> AxisChunks<'a, T> { - pub fn new(view: &TensorView<'a, T>, dim: usize, chunk_size: usize) -> AxisChunks<'a, T> { - let mut permuted = view.clone(); - permuted.move_axis(dim, 0); - AxisChunks { - view: permuted, - index: 0, - chunk_size, - } - } -} - -impl<'a, T> Iterator for AxisChunks<'a, T> { - type Item = TensorView<'a, T>; - - fn next(&mut self) -> Option { - let size = self.view.size(0); - if self.index >= self.view.size(0) { - None - } else { - let view = self - .view - .slice(self.index..self.index.add(self.chunk_size).min(size)); - self.index += self.chunk_size; - Some(view) - } - } -} - -/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_chunks_mut]. -pub struct AxisChunksMut<'a, T> { - view: TensorViewMut<'a, T>, - index: usize, - chunk_size: usize, -} - -impl<'a, T> AxisChunksMut<'a, T> { - pub fn new( - mut view: TensorViewMut<'a, T>, - dim: usize, - chunk_size: usize, - ) -> AxisChunksMut<'a, T> { - // See notes in `Layout` about internal overlap. - assert!( - !view.layout().is_broadcast(), - "Cannot mutably iterate over broadcasting view" - ); - view.move_axis(dim, 0); - AxisChunksMut { - view, - chunk_size, - index: 0, - } - } -} - -impl<'a, T> Iterator for AxisChunksMut<'a, T> { - type Item = TensorViewMut<'a, T>; - - fn next(&mut self) -> Option { - let size = self.view.size(0); - - if self.index >= size { - None - } else { - let index = self.index; - self.index += self.chunk_size; - - // Safety: This is non-broadcasting view, and we increment the index - // each time, so returned views will not overlap. - let view = unsafe { - let view = self - .view - .slice_mut(index..index.add(self.chunk_size).min(size)); - std::mem::transmute::, TensorViewMut<'a, T>>(view) - }; - Some(view) - } - } -} - /// Iterator over the ranges of a tensor's data that correspond to 1D lanes /// along a particular dimension. struct LaneRanges { @@ -964,89 +804,11 @@ impl<'a, T> Iterator for LanesMut<'a, T> { } } -/// Iterator over views of the N innermost dimensions of a tensor. -pub struct InnerIter<'a, T, const N: usize> { - outer_indices: DynIndices, - view: TensorView<'a, T>, -} - -impl<'a, T, const N: usize> InnerIter<'a, T, N> { - pub fn new(view: TensorView<'a, T>) -> Self { - assert!(view.ndim() >= N); - let outer_dims = view.ndim() - N; - InnerIter { - outer_indices: DynIndices::from_shape(&view.shape()[..outer_dims]), - view, - } - } -} - -impl<'a, T, const N: usize> Iterator for InnerIter<'a, T, N> { - type Item = NdTensorView<'a, T, N>; - - fn next(&mut self) -> Option { - self.outer_indices.next().map(|idx| { - let slice_items = to_slice_items(&idx); - self.view.slice(slice_items.as_slice()).try_into().unwrap() - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.outer_indices.size_hint() - } -} - -impl<'a, T, const N: usize> ExactSizeIterator for InnerIter<'a, T, N> {} - -/// Iterator over mutable views of the N innermost dimensions of a tensor. -pub struct InnerIterMut<'a, T, const N: usize> { - outer_indices: DynIndices, - view: TensorViewMut<'a, T>, -} - -impl<'a, T, const N: usize> InnerIterMut<'a, T, N> { - pub fn new(view: TensorViewMut<'a, T>) -> Self { - assert!(view.ndim() >= N); - let outer_dims = view.ndim() - N; - InnerIterMut { - outer_indices: DynIndices::from_shape(&view.shape()[..outer_dims]), - view, - } - } -} - -impl<'a, T, const N: usize> Iterator for InnerIterMut<'a, T, N> { - type Item = NdTensorViewMut<'a, T, N>; - - fn next(&mut self) -> Option { - self.outer_indices.next().map(|idx| { - let slice_items = to_slice_items(&idx); - let view: NdTensorViewMut<'_, T, N> = self - .view - .slice_mut(slice_items.as_slice()) - .try_into() - .unwrap(); - - unsafe { - // Safety: Outer view is non-broadcasting, and we increment the - // outer index each time, so returned views will not overlap. - std::mem::transmute::, NdTensorViewMut<'a, T, N>>(view) - } - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.outer_indices.size_hint() - } -} - -impl<'a, T, const N: usize> ExactSizeIterator for InnerIterMut<'a, T, N> {} - // Tests for iterator internals. Most tests of iterators are currently done via // tests on tensor methods. #[cfg(test)] mod tests { - use crate::{Lanes, LanesMut, Tensor}; + use crate::{AsView, Lanes, LanesMut, Tensor}; #[test] fn test_lanes_empty() { diff --git a/rten-tensor/src/lib.rs b/rten-tensor/src/lib.rs index 56fad6f8..5817e7d8 100644 --- a/rten-tensor/src/lib.rs +++ b/rten-tensor/src/lib.rs @@ -39,12 +39,9 @@ mod index_iterator; mod iterators; mod layout; mod macros; -mod ndtensor; mod overlap; mod range; -mod tensor; -#[cfg(test)] mod unified_tensor; /// Trait for sources of random data for tensors, for use with [Tensor::rand]. @@ -54,21 +51,29 @@ pub trait RandomSource { } pub use index_iterator::{DynIndices, Indices, NdIndices}; -pub use iterators::{ - AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterMut, Iter, - IterMut, Lanes, LanesMut, Offsets, -}; +pub use iterators::{BroadcastIter, Iter, IterMut, Lanes, LanesMut, Offsets}; pub use layout::{is_valid_permutation, DynLayout, Layout, MatrixLayout, NdLayout}; -pub use ndtensor::{ - Matrix, MatrixMut, NdTensor, NdTensorBase, NdTensorView, NdTensorViewMut, NdView, -}; pub use range::{to_slice_items, DynSliceItems, IntoSliceItems, SliceItem, SliceRange}; -pub use tensor::{Tensor, TensorBase, TensorView, TensorViewMut, View}; + +pub use unified_tensor::{ + AsView, Matrix, MatrixMut, MutLayout, NdTensor, NdTensorView, NdTensorViewMut, Tensor, + TensorBase, TensorView, TensorViewMut, +}; + +// For backwards compatibility. +pub type NdTensorBase = TensorBase>; + +pub use unified_tensor::iterators::{ + AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut, +}; + +// For backwards compatibility. +pub use unified_tensor::{AsView as View, AsView as NdView}; /// This module provides a convenient way to import the most common traits /// from this library via a glob import. pub mod prelude { - pub use super::{Layout, NdView, View}; + pub use super::{AsView, Layout, NdView, View}; } // These modules are public for use by other crates in this repo, but From fb82289890d4679eb272f4b708b4515c8881dd27 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 21 Jan 2024 08:23:00 +0000 Subject: [PATCH 03/12] Adapt rten to changes in unified tensor API Adapt rten to the API changes arising from unifying `TensorBase` and `NdTensorBase`. --- src/ctc.rs | 1 + src/gemm/packing.rs | 4 +-- src/ops/conv.rs | 28 ++++++++-------- src/ops/gather.rs | 12 +++---- src/ops/layout.rs | 14 ++++---- src/ops/matmul.rs | 4 +-- src/ops/norm.rs | 4 +-- src/ops/operators.rs | 6 ++-- src/ops/pad.rs | 2 +- src/ops/pooling.rs | 12 ++++--- src/ops/reduce.rs | 7 ++-- src/ops/resize.rs | 10 +++--- src/ops/rnn.rs | 58 +++++++++++++++------------------ src/ops/slice.rs | 2 +- src/ops/split.rs | 5 +-- src/ops/trilu.rs | 7 ++-- src/ops/variadic_elementwise.rs | 1 + 17 files changed, 91 insertions(+), 86 deletions(-) diff --git a/src/ctc.rs b/src/ctc.rs index 0760f789..c1482b29 100644 --- a/src/ctc.rs +++ b/src/ctc.rs @@ -389,6 +389,7 @@ impl Default for CtcDecoder { #[cfg(test)] mod tests { + use rten_tensor::prelude::*; use rten_tensor::NdTensor; use super::{log_sum_exp, CtcDecoder, CtcHypothesis}; diff --git a/src/gemm/packing.rs b/src/gemm/packing.rs index 2ae46361..8b32448c 100644 --- a/src/gemm/packing.rs +++ b/src/gemm/packing.rs @@ -16,7 +16,7 @@ pub fn pack_a_block(out: &mut [f32], a: Matrix, rows: Range, c let a_cols = cols.len(); // Safety: Loops below must only access valid offsets in `a_data`. - let a_data = unsafe { a.data_unchecked() }; + let a_data = a.non_contiguous_data(); let row_stride = a.row_stride(); let col_stride = a.col_stride(); @@ -86,7 +86,7 @@ pub fn pack_b_block(out: &mut [f32], b: Matrix, rows: Range, c let b_col_stride = b.col_stride(); // Safety: Loops below must only access valid offsets in `b_data`. - let b_data = unsafe { b.data_unchecked() }; + let b_data = b.non_contiguous_data(); let n_panels = round_up(b_cols, K::NR) / K::NR; for panel in 0..n_panels { diff --git a/src/ops/conv.rs b/src/ops/conv.rs index 06078968..c5359e74 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -202,7 +202,7 @@ fn conv_2d_pointwise( let gemm = GemmExecutor::new(); for n in 0..batch { - let mut out_item = output.slice_mut([n]); + let mut out_item = output.slice_mut::<2, _>([n]); let out_row_stride = out_item.stride(0); let in_mat = input.slice::<3, _>([n]).reshaped([in_c, in_h * in_w]); @@ -277,7 +277,7 @@ fn conv_2d_depthwise( for n in 0..batch { for c in 0..in_c { - let kernel_view = kernel.slice([c, 0]).unchecked(); + let kernel_view = kernel.slice([c, 0]).weakly_checked_view(); // For efficiency, use manual slicing in the inner loops to extract // input/output rows. @@ -483,18 +483,17 @@ pub fn conv( let out_chans = out_chan_start..out_chan_start + out_channels_per_group; let kernel_mat = kernel - .slice([out_chans.clone()]) - .reshaped(&[out_channels_per_group, in_channels_per_group * k_h * k_w]) - .nd_view(); + .slice::<4, _>([out_chans.clone()]) + .reshaped([out_channels_per_group, in_channels_per_group * k_h * k_w]); let prepacked_kernel = gemm.prepack_a(kernel_mat); - let in_group = input.slice((.., in_chan_start..in_chan_end)); - let mut out_group = output.slice_mut((.., out_chans.clone())); + let in_group = input.slice_dyn((.., in_chan_start..in_chan_end)); + let mut out_group = output.slice_mut_dyn((.., out_chans.clone())); zip(out_group.axis_iter_mut(0), in_group.axis_iter(0)) .par_bridge() .for_each(|(mut out_item, in_item)| { - let mut out_mat = out_item.reshaped_mut(&[out_channels_per_group, out_h * out_w]); + let mut out_mat = out_item.reshaped_mut([out_channels_per_group, out_h * out_w]); let out_row_stride = out_mat.stride(0); let im2col = VirtualIm2Col::new( @@ -578,7 +577,7 @@ fn col2im( let columns_shape = columns.shape(); let mut col_data_iter = columns.data().unwrap().iter(); - let mut out_view = output.unchecked_mut(); + let mut out_view = output.weakly_checked_view_mut(); // Loop order must match dim order of `columns`. for y in 0..columns_shape[0] { @@ -633,18 +632,21 @@ pub fn conv_transpose( let kernel = kernel.to_contiguous(); let mut col2im_mat = Tensor::zeros(&[in_h * in_w, out_c * k_h * k_w]); - let kernel_mat = kernel.reshaped(&[k_in_c, out_c * k_h * k_w]); + let kernel_mat = kernel.reshaped([k_in_c, out_c * k_h * k_w]); // The implementation here is the inverse of the im2col-based convolution. for n in 0..batch { - let input_mat = input.slice([n]).reshaped(&[in_c, in_h * in_w]).transposed(); + let input_mat = input + .slice::<3, _>([n]) + .reshaped([in_c, in_h * in_w]) + .transposed(); let col2im_row_stride = col2im_mat.stride(0); gemm( col2im_mat.data_mut().unwrap(), col2im_row_stride, - input_mat.nd_view(), - kernel_mat.nd_view(), + input_mat, + kernel_mat, 1., /* alpha */ 0., /* beta */ ); diff --git a/src/ops/gather.rs b/src/ops/gather.rs index 9ada9141..4b730c23 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -39,7 +39,7 @@ pub fn gather( if let (0, Some(index)) = (indices.ndim(), indices.item()) { let mut slice_range = full_range(input.ndim()); slice_range[axis] = SliceItem::Index(*index as isize); - let output = input.slice(slice_range.as_slice()).to_tensor(); + let output = input.slice_dyn(slice_range.as_slice()).to_tensor(); return Ok(output); } @@ -60,8 +60,8 @@ pub fn gather( out_range[axis + i] = SliceItem::Index(index_val as isize); } - let in_slice = input.slice(in_range.as_slice()); - let mut out_slice = output.slice_mut(out_range.as_slice()); + let in_slice = input.slice_dyn(in_range.as_slice()); + let mut out_slice = output.slice_mut_dyn(out_range.as_slice()); out_slice.copy_from(&in_slice); } @@ -155,7 +155,7 @@ fn gather_elements_4d( /// Expand a tensor to 4 dims by inserting `n` axes at the front. fn unsqueeze_n(mut view: TensorView, n: usize) -> TensorView { for _ in 0..n { - view.insert_dim(0); + view.insert_axis(0); } view } @@ -180,7 +180,7 @@ pub fn gather_elements( let pad = FAST_PATH_NDIM - input.ndim(); let mut output = output.view_mut(); for _ in 0..pad { - output.insert_dim(0); + output.insert_axis(0); } gather_elements_4d( output.view_mut(), @@ -397,7 +397,7 @@ pub fn scatter_nd< let mut output = data.to_tensor(); for index in DynIndices::from_shape(update_indices) { let update_idx = to_slice_items(&index); - let update_slice = updates.slice(update_idx.as_slice()); + let update_slice = updates.slice_dyn(update_idx.as_slice()); let output_idx: DynSliceItems = indices .try_slice(update_idx.as_slice()) diff --git a/src/ops/layout.rs b/src/ops/layout.rs index 0e2b7c9c..d1c08b3b 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -607,7 +607,7 @@ mod tests { // Reshape with an unspecified (-1) dim and nonzero-length input let input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]); let shape = ndtensor!([1, -1, 2]); - let expected = input.to_shape(&[1, 2, 2]); + let expected = input.to_shape([1, 2, 2].as_slice()); let result = reshape(input.view(), &shape.view(), false /* allow_zero */).unwrap(); expect_equal(&result, &expected)?; @@ -620,7 +620,7 @@ mod tests { false, /* allow_zero */ ) .unwrap(); - let expected = zero_sized_input.to_shape(&[100, 0]); + let expected = zero_sized_input.to_shape([100, 0].as_slice()); expect_equal(&result, &expected)?; Ok(()) @@ -632,14 +632,14 @@ mod tests { // size should be copied. let input = Tensor::from_data(&[1, 1, 4], vec![-0.5, 0.5, 3.0, -5.5]); let shape = ndtensor!([-1, 0]); - let expected = input.to_shape(&[4, 1]); + let expected = input.to_shape([4, 1].as_slice()); let result = reshape(input.view(), &shape.view(), false /* allow_zero */).unwrap(); expect_equal(&result, &expected)?; // Case where copied input dim is also zero. let input = Tensor::::from_data(&[0], vec![]); let shape = ndtensor!([0]); - let expected = input.to_shape(&[0]); + let expected = input.to_shape([0].as_slice()); let result = reshape(input.view(), &shape.view(), false /* allow_zero */).unwrap(); expect_equal(&result, &expected)?; @@ -658,7 +658,7 @@ mod tests { let input = Tensor::::from_data(&[0, 0, 10], vec![]); let shape = ndtensor!([10, 0, 0]); let result = reshape(input.view(), &shape.view(), true /* allow_zero */).unwrap(); - let expected = input.to_shape(&[10, 0, 0]); + let expected = input.to_shape([10, 0, 0].as_slice()); expect_equal(&result, &expected)?; Ok(()) @@ -698,7 +698,7 @@ mod tests { fn test_reshape_in_place() { let mut input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]); let shape = ndtensor!([4]); - let expected = input.to_shape(&[4]); + let expected = input.to_shape([4].as_slice()); reshape_in_place(&mut input, &shape.view(), false /* allow_zero */).unwrap(); assert_eq!(&input, &expected); } @@ -707,7 +707,7 @@ mod tests { fn test_reshape_op() -> Result<(), Box> { let input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]); let shape = Tensor::from_data(&[1], vec![4]); - let expected = input.to_shape(&[4]); + let expected = input.to_shape([4].as_slice()); let op = Reshape { allow_zero: false }; let result = op diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index a4dae0b8..74b19a60 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -119,8 +119,8 @@ pub fn matmul(a: TensorView, b: TensorView) -> Result { let a_broadcast_shape = [out_prefix.as_slice(), &[a_rows, a_cols]].concat(); let b_broadcast_shape = [out_prefix.as_slice(), &[b_rows, b_cols]].concat(); - let a_broadcast = a.broadcast(&a_broadcast_shape); - let b_broadcast = b.broadcast(&b_broadcast_shape); + let a_broadcast = a.broadcast(a_broadcast_shape.as_slice()); + let b_broadcast = b.broadcast(b_broadcast_shape.as_slice()); let out_row_stride = output.stride(output.ndim() - 2); let out_batches = output diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 65a7e6a0..84e9eb75 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -30,7 +30,7 @@ pub fn batch_norm_in_place( let chan_bias = bias[[c]]; let mut out_view = input.slice_mut([n, c]); - let mut out_view = out_view.unchecked_mut(); + let mut out_view = out_view.weakly_checked_view_mut(); // The batch norm formula, from the ONNX spec, is: // @@ -160,7 +160,7 @@ pub fn instance_normalization_in_place( for n in 0..batch { for c in 0..chans { - let mut slice = input.slice_mut([n, c]); + let mut slice = input.slice_mut_dyn([n, c]); let chan_scale = scale[[c]]; let chan_bias = bias[[c]]; let chan_mean = slice_sum(slice.data().unwrap()) / slice.len() as f32; diff --git a/src/ops/operators.rs b/src/ops/operators.rs index 534eb802..ff80e973 100644 --- a/src/ops/operators.rs +++ b/src/ops/operators.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; use rten_tensor::prelude::*; -use rten_tensor::{NdTensorBase, NdTensorView, Tensor, TensorBase, TensorView}; +use rten_tensor::{DynLayout, NdTensorBase, NdTensorView, Tensor, TensorBase, TensorView}; use crate::number::{Identities, IsInt}; use crate::ops::OpError; @@ -69,7 +69,7 @@ pub trait FloatOperators { fn softmax(&self, axis: isize) -> Result; } -impl> Operators for TensorBase { +impl> Operators for TensorBase { type Elem = T; fn arg_max(&self, axis: isize, keep_dims: bool) -> Result, OpError> @@ -171,7 +171,7 @@ impl, const N: usize> Operators for NdTensorBase { } } -impl> FloatOperators for TensorBase { +impl> FloatOperators for TensorBase { fn matmul(&self, other: TensorView) -> Result { matmul(self.view(), other) } diff --git a/src/ops/pad.rs b/src/ops/pad.rs index b6a10083..a1b1f38c 100644 --- a/src/ops/pad.rs +++ b/src/ops/pad.rs @@ -44,7 +44,7 @@ pub fn pad( let mut output = Tensor::from_data(&out_shape, vec![const_val; out_len]); for (out, in_) in zip( - output.slice_mut(non_pad_region.as_slice()).iter_mut(), + output.slice_mut_dyn(non_pad_region.as_slice()).iter_mut(), input.iter(), ) { *out = *in_; diff --git a/src/ops/pooling.rs b/src/ops/pooling.rs index 249770af..4612697c 100644 --- a/src/ops/pooling.rs +++ b/src/ops/pooling.rs @@ -110,8 +110,8 @@ pub fn average_pool( for n in 0..batch { for chan in 0..in_c { let mut out_view = output.slice_mut([n, chan]); - let mut out_view = out_view.unchecked_mut(); - let in_view = input.slice([n, chan]).unchecked(); + let mut out_view = out_view.weakly_checked_view_mut(); + let in_view = input.slice([n, chan]).weakly_checked_view(); for out_y in 0..out_h { for out_x in 0..out_w { @@ -177,8 +177,10 @@ pub fn global_average_pool(input: TensorView) -> Result { const N: usize = 4; for (chan_group, mut out_group) in zip( - input.slice(n).axis_chunks(0, N), - output.slice_mut((n, .., 0, 0)).axis_chunks_mut(0, N), + input.slice::<3, _>(n).axis_chunks(0, N), + output + .slice_mut::<1, _>((n, .., 0, 0)) + .axis_chunks_mut(0, N), ) { if chan_group.size(0) == N { // Compute average over batch of N channels in parallel. @@ -200,7 +202,7 @@ pub fn global_average_pool(input: TensorView) -> Result { } else { // Compute average over remaining channels. for i in 0..chan_group.size(0) { - let sum: f32 = chan_group.slice([i]).iter().sum(); + let sum: f32 = chan_group.slice::<2, _>([i]).iter().sum(); out_group[[i]] = sum / (in_h * in_w) as f32; } } diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index df34d9ec..87136931 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -3,7 +3,7 @@ use std::iter::zip; use rten_tensor; use rten_tensor::prelude::*; -use rten_tensor::{DynIndices, NdTensor, SliceItem, Tensor, TensorView}; +use rten_tensor::{DynIndices, NdTensor, NdTensorView, SliceItem, Tensor, TensorView}; use crate::number::Identities; use crate::ops::layout::squeeze_in_place; @@ -46,7 +46,8 @@ fn select_max_index std::cmp::Ordering>( if !keep_dims { let axes = &[resolved_axis as i32]; - squeeze_in_place(&mut reduced, Some(axes.into())).expect("Invalid axis"); + let axes = NdTensorView::from_data([1], axes); + squeeze_in_place(&mut reduced, Some(axes)).expect("Invalid axis"); } Ok(reduced) @@ -849,7 +850,7 @@ mod tests { expect_equal(&result, &expected)?; let result = reduce_l2(input.view(), Some(&[2]), true /* keep_dims */).unwrap(); - let expected = expected.to_shape(&[3, 2, 1]); + let expected = expected.to_shape([3, 2, 1].as_slice()); expect_equal(&result, &expected)?; Ok(()) diff --git a/src/ops/resize.rs b/src/ops/resize.rs index ba25d338..c29cc565 100644 --- a/src/ops/resize.rs +++ b/src/ops/resize.rs @@ -181,7 +181,7 @@ pub fn resize_image(input: TensorView, size: [usize; 2]) -> Result([n]); + let mut out_image = output.slice_mut::<3, _>([n]); out_image .axis_chunks_mut(0, CHAN_GROUP_SIZE) @@ -424,7 +424,7 @@ mod tests { for case in cases { let result = resize( case.image.view(), - ResizeTarget::Scales((&case.scales).into()), + ResizeTarget::Scales(case.scales.as_slice().into()), ResizeMode::Nearest, CoordTransformMode::HalfPixel, NearestMode::RoundPreferFloor, @@ -582,7 +582,7 @@ mod tests { for case in cases { let result = resize( case.image.view(), - ResizeTarget::Scales((&case.scales).into()), + ResizeTarget::Scales(case.scales.as_slice().into()), ResizeMode::Linear, CoordTransformMode::HalfPixel, NearestMode::Floor, diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index 75670638..fa783077 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -166,12 +166,10 @@ fn extract_matrix(tensor: TensorView, dir: usize, num_gates: usize, gate_index: let hidden_total = tensor.size(1); assert!(hidden_total % num_gates == 0); let hidden_size = hidden_total / num_gates; - tensor - .slice(( - dir, - (gate_index * hidden_size..(gate_index + 1) * hidden_size), - )) - .nd_view() + tensor.slice::<2, _>(( + dir, + (gate_index * hidden_size..(gate_index + 1) * hidden_size), + )) } /// Extract weights and biases for a specific RNN gate/output from a tensor that @@ -203,10 +201,8 @@ fn extract_weights_and_bias<'a>( let rec_weight = extract_matrix(recurrent_weights, dir, num_gates, gate_index).transposed(); let bias = bias.map(|bias| { let nth_gate = |gate_index| (gate_index * hidden_size)..((gate_index + 1) * hidden_size); - let input_bias = bias.slice((dir, nth_gate(gate_index))).nd_view(); - let hidden_bias = bias - .slice((dir, nth_gate(gate_index + num_gates))) - .nd_view(); + let input_bias = bias.slice::<1, _>((dir, nth_gate(gate_index))); + let hidden_bias = bias.slice::<1, _>((dir, nth_gate(gate_index + num_gates))); (input_bias, hidden_bias) }); (weight, rec_weight, bias) @@ -311,8 +307,8 @@ pub fn gru( extract_gru_weights_and_bias(dir, HIDDEN_GATE); for seq in sequence_for_dir(direction, dir, seq_len) { - let in_item = input.slice([seq]); - let hidden_item = hidden.slice([dir]); + let in_item = input.slice_dyn([seq]); + let hidden_item = hidden.slice_dyn([dir]); // From the ONNX spec, the intermediate values are computed as: // @@ -406,7 +402,7 @@ pub fn gru( } // Compute next hidden state - let mut hidden_item = hidden.slice_mut([dir]); + let mut hidden_item = hidden.slice_mut_dyn([dir]); for (hidden, update, hidden_gate) in zip3( hidden_item.iter_mut(), update_gate.iter(), @@ -416,8 +412,8 @@ pub fn gru( } hidden_seq - .slice_mut([seq, dir]) - .copy_from(&hidden_item.view()); + .slice_mut_dyn([seq, dir]) + .copy_from(&hidden_item.as_dyn()); } } @@ -577,8 +573,8 @@ pub fn lstm( // supported. // - `f`, `g` and `h` are activations. `f`=sigmoid, `g` and `h` // are tanh. - let in_item = input.slice([seq]); - let hidden_item = hidden.slice([dir]); + let in_item = input.slice_dyn([seq]); + let hidden_item = hidden.slice_dyn([dir]); // Compute outputs for input, forget, cell and output gates. compute_rnn_gate( @@ -626,7 +622,7 @@ pub fn lstm( ); // Compute new values of cell and hidden state - let mut cell_item = cell.slice_mut([dir]); + let mut cell_item = cell.slice_mut_dyn([dir]); for (cell, forget_gate, input_gate, cell_gate) in zip4( cell_item.iter_mut(), @@ -637,7 +633,7 @@ pub fn lstm( *cell = forget_gate * *cell + input_gate * cell_gate; } - let mut hidden_item = hidden.slice_mut([dir]); + let mut hidden_item = hidden.slice_mut_dyn([dir]); for (hidden, out_gate, cell) in zip3(hidden_item.iter_mut(), out_gate.iter(), cell_item.iter()) { @@ -645,7 +641,7 @@ pub fn lstm( } hidden_seq - .slice_mut([seq, dir]) + .slice_mut_dyn([seq, dir]) .copy_from(&hidden_item.view()); } } @@ -961,7 +957,7 @@ mod tests { let is_bidirectional = params.get("weight_ih_l0_reverse").is_some(); let mut input = read_tensor(&case["input"]).expect("failed to read input"); - input.insert_dim(1); // Add batch dim + input.insert_axis(1); // Add batch dim let mut expected = read_tensor(&case["output"]).expect("failed to read output"); @@ -970,9 +966,9 @@ mod tests { let es = expected.shape(); expected.reshape(&[es[0], 2, es[1] / 2]); } else { - expected.insert_dim(1); + expected.insert_axis(1); } - expected.insert_dim(2); // Add batch dim + expected.insert_axis(2); // Add batch dim let read_param = |name| match op { Op::Lstm => reorder_ifco_to_iofc( @@ -986,45 +982,45 @@ mod tests { }; let mut weights = read_param("weight_ih_l0"); - weights.insert_dim(0); // Add directions dim + weights.insert_axis(0); // Add directions dim let mut hidden_weights = read_param("weight_hh_l0"); - hidden_weights.insert_dim(0); // Add directions dim + hidden_weights.insert_axis(0); // Add directions dim let input_bias = read_param("bias_ih_l0"); let hidden_bias = read_param("bias_hh_l0"); let mut bias = concat(&[input_bias.view(), hidden_bias.view()], 0).unwrap(); - bias.insert_dim(0); // Add directions dim + bias.insert_axis(0); // Add directions dim // If this is a bidirectional RNN, there will be `_reverse`-suffixed // versions of the bias and weight params. Extract these and concatenate // with the forwards direction values. if is_bidirectional { let mut rev_weights = read_param("weight_ih_l0_reverse"); - rev_weights.insert_dim(0); // Add directions dim + rev_weights.insert_axis(0); // Add directions dim weights = concat(&[weights.view(), rev_weights.view()], 0).unwrap(); let mut rev_hidden_weights = read_param("weight_hh_l0_reverse"); - rev_hidden_weights.insert_dim(0); // Add directions dim + rev_hidden_weights.insert_axis(0); // Add directions dim hidden_weights = concat(&[hidden_weights.view(), rev_hidden_weights.view()], 0).unwrap(); let rev_input_bias = read_param("bias_ih_l0_reverse"); let rev_hidden_bias = read_param("bias_hh_l0_reverse"); let mut rev_bias = concat(&[rev_input_bias.view(), rev_hidden_bias.view()], 0).unwrap(); - rev_bias.insert_dim(0); // Add directions dim + rev_bias.insert_axis(0); // Add directions dim bias = concat(&[bias.view(), rev_bias.view()], 0).unwrap(); } let initial_hidden = case.get("initial_hidden").map(|param| { let mut init = read_tensor(param).expect("failed to read initial hidden state"); - init.insert_dim(1); // Add batch dim + init.insert_axis(1); // Add batch dim init }); let initial_cell = case.get("initial_cell").map(|param| { let mut init = read_tensor(param).expect("failed to read initial cell state"); - init.insert_dim(1); // Add batch dim + init.insert_axis(1); // Add batch dim init }); diff --git a/src/ops/slice.rs b/src/ops/slice.rs index 8f47fba0..fcaceb36 100644 --- a/src/ops/slice.rs +++ b/src/ops/slice.rs @@ -495,7 +495,7 @@ mod tests { .unwrap(); assert_eq!( sliced, - Tensor::from_data(case.expected_shape, case.expected_elements) + Tensor::from_data(case.expected_shape, case.expected_elements.to_vec()) ); } } diff --git a/src/ops/split.rs b/src/ops/split.rs index 337bacd2..e7d92629 100644 --- a/src/ops/split.rs +++ b/src/ops/split.rs @@ -38,7 +38,7 @@ pub fn split( split_start += split_size; - input.view().slice(slice_range.as_slice()).to_tensor() + input.slice_dyn(slice_range.as_slice()).to_tensor() }) .collect(); @@ -67,7 +67,8 @@ impl Operator for Split { #[cfg(test)] mod tests { - use rten_tensor::{tensor, View}; + use rten_tensor::prelude::*; + use rten_tensor::tensor; use crate::ops::{split, OpError}; diff --git a/src/ops/trilu.rs b/src/ops/trilu.rs index 086726fc..15199253 100644 --- a/src/ops/trilu.rs +++ b/src/ops/trilu.rs @@ -55,6 +55,7 @@ impl Operator for Trilu { #[cfg(test)] mod tests { use crate::ops::{trilu, OpError}; + use rten_tensor::prelude::*; use rten_tensor::{tensor, Tensor}; #[test] @@ -66,7 +67,7 @@ mod tests { k: i32, } - let in_3x3 = Tensor::arange(1, 10, None).into_shape(&[3, 3]); + let in_3x3 = Tensor::arange(1, 10, None).into_shape([3, 3].as_slice()); let cases = [ // k = 0, upper = true @@ -166,7 +167,7 @@ mod tests { }, // Non-square (wide) matrix Case { - input: Tensor::arange(1, 16, None).into_shape(&[3, 5]), + input: Tensor::arange(1, 16, None).into_shape([3, 5].as_slice()), expected: Tensor::from([ [1, 2, 3, 4, 5], // [0, 7, 8, 9, 10], // @@ -177,7 +178,7 @@ mod tests { }, // Non-square (tall) matrix Case { - input: Tensor::arange(1, 16, None).into_shape(&[5, 3]), + input: Tensor::arange(1, 16, None).into_shape([5, 3].as_slice()), expected: Tensor::from([ [1, 2, 3], // [0, 5, 6], // diff --git a/src/ops/variadic_elementwise.rs b/src/ops/variadic_elementwise.rs index abe2fedb..9b8a13ab 100644 --- a/src/ops/variadic_elementwise.rs +++ b/src/ops/variadic_elementwise.rs @@ -175,6 +175,7 @@ impl Operator for Sum { #[cfg(test)] mod tests { + use rten_tensor::prelude::*; use rten_tensor::test_util::eq_with_nans; use rten_tensor::{tensor, Tensor}; From f6e4922de72f9cd2c9dde1d4407749f53ad7e128 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 21 Jan 2024 08:45:44 +0000 Subject: [PATCH 04/12] Import tensor traits in rten-imageproc tests --- rten-imageproc/src/contours.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rten-imageproc/src/contours.rs b/rten-imageproc/src/contours.rs index 536ab509..e4ed9053 100644 --- a/rten-imageproc/src/contours.rs +++ b/rten-imageproc/src/contours.rs @@ -217,6 +217,7 @@ pub fn find_contours(mask: NdTensorView, mode: RetrievalMode) -> Polygon mod tests { use std::iter::zip; + use rten_tensor::prelude::*; use rten_tensor::NdTensor; use crate::tests::border_points; From 665a6c19f8bcae5a1c0b4138f69749354895d603 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 21 Jan 2024 10:09:44 +0000 Subject: [PATCH 05/12] Adapt rten-examples to unified tensor API changes --- rten-examples/src/bert_qa.rs | 4 ++-- rten-examples/src/deeplab.rs | 4 ++-- rten-examples/src/detr.rs | 2 +- rten-examples/src/jina_similarity.rs | 10 +++++----- rten-examples/src/wav2vec2.rs | 2 +- rten-examples/src/yolo.rs | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/rten-examples/src/bert_qa.rs b/rten-examples/src/bert_qa.rs index 09d85b43..de523f35 100644 --- a/rten-examples/src/bert_qa.rs +++ b/rten-examples/src/bert_qa.rs @@ -91,7 +91,7 @@ fn extract_nbest_answers<'a>( .iter() .map(|tid| *tid as i32) .collect::>() - .into_shape(&[1, query_context.token_ids().len()]); + .into_shape([1, query_context.token_ids().len()].as_slice()); let attention_mask = Tensor::full(&[batch, input_ids.len()], 1i32); let input_ids_id = model.node_id("input_ids")?; @@ -112,7 +112,7 @@ fn extract_nbest_answers<'a>( .token_type_ids() .map(|tid| tid as i32) .collect::>() - .into_shape(&[1, query_context.token_ids().len()]); + .into_shape([1, query_context.token_ids().len()].as_slice()); inputs.push((type_ids_id, type_ids.view().into())); } diff --git a/rten-examples/src/deeplab.rs b/rten-examples/src/deeplab.rs index e1a38d41..744ef5b1 100644 --- a/rten-examples/src/deeplab.rs +++ b/rten-examples/src/deeplab.rs @@ -112,7 +112,7 @@ fn main() -> Result<(), Box> { let mut image: Tensor = read_image(&args.image)?.into(); normalize_image(image.nd_view_mut()); - image.insert_dim(0); // Add batch dim + image.insert_axis(0); // Add batch dim // Resize image according to metadata in the model. let input_shape = model @@ -132,7 +132,7 @@ fn main() -> Result<(), Box> { output.permute(&[0, 2, 3, 1]); // (N,class,H,W) => (N,H,W,class) let seg_classes: NdTensor = output - .slice(0) + .slice_dyn(0) .arg_max(-1, false /* keep_dims */)? .try_into()?; let [out_height, out_width] = seg_classes.shape(); diff --git a/rten-examples/src/detr.rs b/rten-examples/src/detr.rs index 460d4932..81548077 100644 --- a/rten-examples/src/detr.rs +++ b/rten-examples/src/detr.rs @@ -292,7 +292,7 @@ fn main() -> Result<(), Box> { let [_, image_height, image_width] = image.shape(); let mut image = image.as_dyn().to_tensor(); - image.insert_dim(0); // Add batch dim + image.insert_axis(0); // Add batch dim // Resize input image according to min/max side length constraints. // diff --git a/rten-examples/src/jina_similarity.rs b/rten-examples/src/jina_similarity.rs index db6391d5..9c9ccd50 100644 --- a/rten-examples/src/jina_similarity.rs +++ b/rten-examples/src/jina_similarity.rs @@ -103,7 +103,7 @@ fn embed_sentence_batch( let token_ids = encoded.token_ids(); for (tid, input_id) in token_ids .iter() - .zip(input_ids.slice_mut((i, ..token_ids.len())).iter_mut()) + .zip(input_ids.slice_mut_dyn((i, ..token_ids.len())).iter_mut()) { *input_id = *tid as i32; } @@ -114,7 +114,7 @@ fn embed_sentence_batch( let mut attention_mask = Tensor::zeros(&[batch, max_sequence_len]); for (i, encoded) in encoded.iter().enumerate() { attention_mask - .slice_mut((i, ..encoded.token_ids().len())) + .slice_mut::<1, _>((i, ..encoded.token_ids().len())) .fill(1i32); } @@ -147,7 +147,7 @@ fn embed_sentence_batch( // Take the mean of the non-padding elements along the sequence // dimension. let seq_len = input.token_ids().len(); - item.slice(..seq_len) + item.slice_dyn(..seq_len) .reduce_mean(Some(&[0]), false /* keep_dims */) .unwrap() }) @@ -157,7 +157,7 @@ fn embed_sentence_batch( .map(|mp| { // Re-add batch dim. let mut view = mp.view(); - view.insert_dim(0); + view.insert_axis(0); view }) .collect(); @@ -241,7 +241,7 @@ fn main() -> Result<(), Box> { // all be "high" values (close to 1.0). They should be used only for // comparison with other scores. let mut scores: Vec<(usize, f32)> = similarities - .slice(0) + .slice_dyn(0) .iter() .copied() .enumerate() diff --git a/rten-examples/src/wav2vec2.rs b/rten-examples/src/wav2vec2.rs index 9f666cc1..f7f4313d 100644 --- a/rten-examples/src/wav2vec2.rs +++ b/rten-examples/src/wav2vec2.rs @@ -117,7 +117,7 @@ fn main() -> Result<(), Box> { let samples = read_wav_file(&args.wav_file)?; let mut sample_batch = Tensor::from_vec(samples); - sample_batch.insert_dim(0); + sample_batch.insert_axis(0); let result: NdTensor = model .run_one(sample_batch.view().into(), None)? diff --git a/rten-examples/src/yolo.rs b/rten-examples/src/yolo.rs index efcfa3f6..27a768de 100644 --- a/rten-examples/src/yolo.rs +++ b/rten-examples/src/yolo.rs @@ -106,7 +106,7 @@ fn main() -> Result<(), Box> { let [_, image_height, image_width] = image.shape(); let mut image = image.as_dyn().to_tensor(); - image.insert_dim(0); // Add batch dim + image.insert_axis(0); // Add batch dim let input_shape = model .input_shape(0) From 856220ee20454544513185d06a0f9781eb73a6c3 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 21 Jan 2024 10:12:18 +0000 Subject: [PATCH 06/12] Remove the old rten-tensor implementations --- rten-tensor/src/ndtensor.rs | 1563 ----------------------- rten-tensor/src/tensor.rs | 2397 ----------------------------------- 2 files changed, 3960 deletions(-) delete mode 100644 rten-tensor/src/ndtensor.rs delete mode 100644 rten-tensor/src/tensor.rs diff --git a/rten-tensor/src/ndtensor.rs b/rten-tensor/src/ndtensor.rs deleted file mode 100644 index a67b7ab2..00000000 --- a/rten-tensor/src/ndtensor.rs +++ /dev/null @@ -1,1563 +0,0 @@ -use std::borrow::Cow; -use std::iter::zip; -use std::marker::PhantomData; -use std::ops::{Index, IndexMut}; - -use crate::errors::{DimensionError, FromDataError}; -use crate::index_iterator::NdIndices; -use crate::iterators::{Iter, IterMut, MutViewRef, ViewRef}; -use crate::layout::{Layout, MatrixLayout, NdLayout, OverlapPolicy}; -use crate::tensor::{TensorBase, TensorView, TensorViewMut, View}; -use crate::{IntoSliceItems, RandomSource}; - -/// Multi-dimensional array view with a static dimension count. This trait -/// includes operations that are available on tensors that own their data -/// ([NdTensor]), as well as views ([NdTensorView], [NdTensorViewMut]). -/// -/// `N` is the static rank of this tensor. -/// -/// [NdTensorView] implements specialized versions of these methods as inherent -/// methods, which preserve lifetiems on the result. -pub trait NdView: Layout { - /// The data type of elements in this tensor. - type Elem; - - /// Return a view of this tensor with a dynamic dimension count. - fn as_dyn(&self) -> TensorView { - self.view().as_dyn() - } - - /// Return the underlying data of the tensor as a slice, if it is contiguous. - fn data(&self) -> Option<&[Self::Elem]>; - - /// Return the element at a given index, or `None` if the index is out of - /// bounds in any dimension. - fn get(&self, index: [usize; N]) -> Option<&Self::Elem> { - self.view().get(index) - } - - /// Return an iterator over elements of this tensor, in their logical order. - fn iter(&self) -> Iter { - self.view().iter() - } - - /// Create a view of this tensor which is broadcasted to `shape`. - /// - /// See notes in [View::broadcast]. - /// - /// Panics if the shape is not broadcast-compatible with the current shape. - fn broadcast(&self, shape: [usize; M]) -> NdTensorView { - self.view().broadcast(shape) - } - - /// Return a copy of this tensor with each element replaced by `f(element)`. - /// - /// The order in which elements are visited is unspecified and may not - /// correspond to the logical order. - fn map(&self, f: F) -> NdTensor - where - F: Fn(&Self::Elem) -> U, - { - self.view().map(f) - } - - /// Return a new view with a given shape. - /// - /// The current view must be contiguous and the new shape must have the - /// same product as the current shape. - fn reshaped(&self, shape: [usize; M]) -> NdTensorView { - self.view().reshaped(shape) - } - - /// Return a new view with the dimensions re-ordered according to `dims`. - fn permuted(&self, dims: [usize; N]) -> NdTensorView { - self.view().permuted(dims) - } - - /// Return a new view with the order of dimensions reversed. - fn transposed(&self) -> NdTensorView { - self.view().transposed() - } - - /// Return an immutable view of part of this tensor. - /// - /// `M` specifies the number of dimensions that the layout must have after - /// slicing with `range`. Panics if the sliced layout has a different number - /// of dims. - /// - /// If the range has fewer dimensions than the tensor, they refer to the - /// leading dimensions. - /// - /// See [IntoSliceItems] for a description of how slices can be specified. - /// Slice ranges are currently restricted to use positive steps. In other - /// words, NumPy-style slicing with negative steps is not supported. - fn slice(&self, range: R) -> NdTensorView { - self.view().slice(range) - } - - /// Return a tensor with data laid out in contiguous order. This will - /// be a view if the data is already contiguous, or a copy otherwise. - fn to_contiguous(&self) -> NdTensorBase, N> - where - Self::Elem: Clone, - { - self.view().to_contiguous() - } - - /// Return a new contiguous tensor with the same shape and elements as this - /// view. - fn to_tensor(&self) -> NdTensor - where - Self::Elem: Clone, - { - self.view().to_tensor() - } - - /// Return an immutable view of this tensor. - fn view(&self) -> NdTensorView; -} - -/// N-dimensional array, where `N` is specified as generic argument. -/// -/// `T` is the element type, `S` is the element storage and `N` is the number -/// of dimensions. -/// -/// Most code will not use `NdTensorBase` directly but instead use the type -/// aliases [NdTensor], [NdTensorView] and [NdTensorViewMut]. [NdTensor] owns -/// its elements, and the other two types are views of slices. -/// -/// All [NdTensorBase] variants implement the [Layout] trait which provide -/// operations related to the shape and strides of the tensor, and the -/// [NdView] trait which provides common methods applicable to all variants. -#[derive(Clone, Copy, Debug)] -pub struct NdTensorBase, const N: usize> { - data: S, - layout: NdLayout, - - /// Avoids compiler complaining `T` is unused. - element_type: PhantomData, -} - -/// Return the offsets of `M` successive elements along the `dim` axis, starting -/// at index `base`. -/// -/// Panics if any of the M element indices are out of bounds. -fn array_offsets( - layout: &NdLayout, - base: [usize; N], - dim: usize, -) -> [usize; M] { - assert!( - base[dim] < usize::MAX - M && layout.size(dim) >= base[dim] + M, - "array indices invalid" - ); - - let offset = layout.offset(base); - let stride = layout.stride(dim); - let mut offsets = [0; M]; - for i in 0..M { - offsets[i] = offset + i * stride; - } - offsets -} - -impl, const N: usize> NdTensorBase { - pub fn from_data(shape: [usize; N], data: S) -> NdTensorBase { - Self::from_data_with_strides(shape, data, NdLayout::contiguous_strides(shape)) - .expect("data length too short for shape") - } - - /// Constructs a tensor from the associated storage type and optional - /// strides. - /// - /// If creating an immutable view with strides, prefer - /// [NdTensorBase::from_slice]. This method enforces that every index in the - /// tensor maps to a unique element in the data. This upholds Rust's rules - /// for mutable aliasing. [NdTensorBase::from_slice] does not have this - /// restriction. - pub fn from_data_with_strides( - shape: [usize; N], - data: S, - strides: [usize; N], - ) -> Result, FromDataError> { - NdLayout::try_from_shape_and_strides(shape, strides, OverlapPolicy::DisallowOverlap) - .and_then(|layout| { - if layout.min_data_len() > data.as_ref().len() { - Err(FromDataError::StorageTooShort) - } else { - Ok(layout) - } - }) - .map(|layout| NdTensorBase { - data, - layout, - element_type: PhantomData, - }) - } - - /// Consume self and return the underlying element storage. - pub fn into_data(self) -> S { - self.data - } - - /// Consume self and return a dynamic-rank tensor. - pub fn into_dyn(self) -> TensorBase { - let layout = self.layout.as_dyn(); - TensorBase::new(self.data, &layout) - } - - /// Return the layout which maps indices to offsets in the data. - pub fn layout(&self) -> &NdLayout { - &self.layout - } - - /// Return a new tensor by applying `f` to each element of this tensor. - pub fn map(&self, f: F) -> NdTensor - where - F: Fn(&T) -> U, - { - // Convert to dynamic and back to benefit from fast paths in - // `Tensor::map`. - self.as_dyn().map(f).try_into().unwrap() - } - - /// Change the layout to put dimensions in the order specified by `dims`. - /// - /// This does not modify the order of elements in the data buffer, it just - /// updates the strides used by indexing. - pub fn permute(&mut self, dims: [usize; N]) { - self.layout = self.layout.permuted(dims); - } - - /// Return a new contiguous tensor with the same shape and elements as this - /// view. - pub fn to_tensor(&self) -> NdTensor - where - T: Clone, - { - // Convert to dynamic and back to benefit from fast paths in - // `Tensor::to_tensor`. - self.as_dyn().to_tensor().try_into().unwrap() - } - - /// Return a copy of the elements of this tensor in their logical order - /// as a vector. - /// - /// This is equivalent to `self.iter().cloned().collect()` but faster - /// when the tensor is already contiguous or has a small number (<= 4) - /// dimensions. - pub fn to_vec(&self) -> Vec - where - T: Clone, - { - self.as_dyn().to_vec() - } - - /// Return an immutable view of this tensor. - pub fn view(&self) -> NdTensorView { - NdTensorView { - data: self.data.as_ref(), - layout: self.layout, - element_type: PhantomData, - } - } - - /// Load an array of `M` elements from successive entries of a tensor along - /// the `dim` axis. - /// - /// eg. If `base` is `[0, 1, 2]`, dim=0 and `M` = 4 this will return an - /// array with values from indices `[0, 1, 2]`, `[1, 1, 2]` ... `[3, 1, 2]`. - /// - /// Panics if any of the array indices are out of bounds. - #[inline] - pub fn get_array(&self, base: [usize; N], dim: usize) -> [T; M] - where - T: Copy + Default, - { - let offsets: [usize; M] = array_offsets(&self.layout, base, dim); - let data = self.data.as_ref(); - let mut result = [T::default(); M]; - for i in 0..M { - // Safety: `array_offsets` returns valid offsets - result[i] = unsafe { *data.get_unchecked(offsets[i]) }; - } - result - } -} - -impl> NdTensorBase { - /// Convert this vector to a static array of length `M`. - /// - /// Panics if the length of this vector is not M. - #[inline] - pub fn to_array(&self) -> [T; M] - where - T: Copy + Default, - { - self.get_array([0], 0) - } -} - -impl + AsMut<[T]>> NdTensorBase { - /// Fill this vector with values from a static array of length `M`. - /// - /// Panics if the length of this vector is not M. - #[inline] - pub fn assign_array(&mut self, values: [T; M]) - where - T: Copy + Default, - { - self.set_array([0], 0, values) - } -} - -impl, const N: usize> NdView for NdTensorBase { - type Elem = T; - - fn data(&self) -> Option<&[T]> { - self.is_contiguous().then_some(self.data.as_ref()) - } - - fn view(&self) -> NdTensorView { - NdTensorBase { - data: self.data.as_ref(), - layout: self.layout, - element_type: PhantomData, - } - } -} - -/// Convert a slice into a contiguous 1D tensor view. -impl<'a, T, S: AsRef<[T]>> From<&'a S> for NdTensorBase { - fn from(data: &'a S) -> Self { - Self::from_data([data.as_ref().len()], data.as_ref()) - } -} - -impl<'a, T, const N: usize> NdTensorView<'a, T, N> { - /// Constructs a view from a slice and optional strides. - /// - /// Unlike [NdTensorBase::from_data], combinations of strides which cause - /// multiple indices in the tensor to refer to the same data element are - /// allowed. Since the returned view is immutable, this will not enable - /// violation of Rust's aliasing rules. - pub fn from_slice_with_strides( - shape: [usize; N], - data: &'a [T], - strides: [usize; N], - ) -> Result { - NdLayout::try_from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap) - .and_then(|layout| { - if layout.min_data_len() > data.as_ref().len() { - Err(FromDataError::StorageTooShort) - } else { - Ok(layout) - } - }) - .map(|layout| NdTensorBase { - data, - layout, - element_type: PhantomData, - }) - } - - /// Return the element at a given index, without performing any bounds- - /// checking. - /// - /// # Safety - /// - /// The caller must ensure that the index is valid for the tensor's shape. - pub unsafe fn get_unchecked(&self, index: [usize; N]) -> &'a T { - self.data.get_unchecked(self.layout.offset_unchecked(index)) - } - - /// Return a view of this tensor where indexing checks the bounds of offsets - /// into the data buffer, but not individual dimensions. This is faster, but - /// can hide errors. - pub fn unchecked(&self) -> UncheckedNdTensor { - let base = NdTensorBase { - data: self.data, - layout: self.layout, - element_type: PhantomData, - }; - UncheckedNdTensor { base } - } -} - -/// Specialized versions of the [NdView] methods for immutable views. -/// These preserve the underlying lifetime of the view in results, allowing for -/// method calls to be chained. -impl<'a, T, const N: usize> NdTensorView<'a, T, N> { - pub fn as_dyn(&self) -> TensorView<'a, T> { - TensorView::new(self.data, &self.layout.as_dyn()) - } - - pub fn data(&self) -> Option<&'a [T]> { - self.is_contiguous().then_some(self.data) - } - - /// Return the view's underlying data as a slice, without checking whether - /// it is contiguous. - /// - /// # Safety - /// - /// It is the caller's responsibility not to access elements in the slice - /// which are not part of this view. - pub unsafe fn data_unchecked(&self) -> &'a [T] { - self.data - } - - pub fn get(&self, index: [usize; N]) -> Option<&'a T> { - self.layout - .try_offset(index) - .and_then(|offset| self.data.get(offset)) - } - - pub fn iter(&self) -> Iter<'a, T> { - Iter::new(self.view_ref()) - } - - fn view_ref(&self) -> ViewRef<'a, '_, T, NdLayout> { - ViewRef::new(self.data, &self.layout) - } - - fn broadcast(&self, shape: [usize; M]) -> NdTensorView<'a, T, M> { - NdTensorView { - layout: self.layout.broadcast(shape), - data: self.data, - element_type: PhantomData, - } - } - - pub fn permuted(&self, dims: [usize; N]) -> NdTensorView<'a, T, N> { - NdTensorBase { - data: self.data, - layout: self.layout.permuted(dims), - element_type: PhantomData, - } - } - - pub fn transposed(&self) -> NdTensorView<'a, T, N> { - NdTensorBase { - data: self.data, - layout: self.layout.transposed(), - element_type: PhantomData, - } - } - - pub fn reshaped(&self, shape: [usize; M]) -> NdTensorView<'a, T, M> { - NdTensorBase { - data: self.data, - layout: self.layout.reshaped(shape), - element_type: PhantomData, - } - } - - pub fn to_contiguous(&self) -> NdTensorBase, N> - where - T: Clone, - { - if self.is_contiguous() { - NdTensorBase { - data: Cow::Borrowed(self.data), - layout: self.layout, - element_type: PhantomData, - } - } else { - let data = self.to_vec(); - NdTensorBase { - data: Cow::Owned(data), - layout: NdLayout::from_shape(self.layout.shape()), - element_type: PhantomData, - } - } - } - - pub fn slice(&self, range: R) -> NdTensorView<'a, T, M> { - let range = range.into_slice_items(); - let (offset_range, sliced_layout) = self.layout.slice(range.as_ref()); - NdTensorView { - data: &self.data[offset_range], - layout: sliced_layout, - element_type: PhantomData, - } - } -} - -impl + AsMut<[T]>, const N: usize> NdTensorBase { - /// Return the underlying data of the tensor as a mutable slice, if it is - /// contiguous. - pub fn data_mut(&mut self) -> Option<&mut [T]> { - self.is_contiguous().then_some(self.data.as_mut()) - } - - /// Return a mutable reference to the element at a given index. - pub fn get_mut(&mut self, index: [usize; N]) -> Option<&mut T> { - self.layout - .try_offset(index) - .and_then(|offset| self.data.as_mut().get_mut(offset)) - } - - /// Return the element at a given index, without performing any bounds- - /// checking. - /// - /// # Safety - /// - /// The caller must ensure that the index is valid for the tensor's shape. - pub unsafe fn get_unchecked_mut(&mut self, index: [usize; N]) -> &mut T { - let offset = self.layout.offset_unchecked(index); - self.data.as_mut().get_unchecked_mut(offset) - } - - /// Return a mutable view of this tensor. - pub fn view_mut(&mut self) -> NdTensorViewMut { - NdTensorViewMut { - data: self.data.as_mut(), - layout: self.layout, - element_type: PhantomData, - } - } - - /// Return a mutable view of part of this tensor. - /// - /// `M` specifies the number of dimensions that the layout must have after - /// slicing with `range`. Panics if the sliced layout has a different number - /// of dims. - pub fn slice_mut( - &mut self, - range: R, - ) -> NdTensorViewMut { - let range = range.into_slice_items(); - let (offset_range, sliced_layout) = self.layout.slice(range.as_ref()); - NdTensorViewMut { - data: &mut self.data.as_mut()[offset_range], - layout: sliced_layout, - element_type: PhantomData, - } - } - - /// Return a mutable view of this tensor which uses unchecked indexing. - /// - /// See [NdTensorView::unchecked] for more details. - pub fn unchecked_mut(&mut self) -> UncheckedNdTensor { - let base = NdTensorBase { - data: self.data.as_mut(), - layout: self.layout, - element_type: PhantomData, - }; - UncheckedNdTensor { base } - } - - /// Return a view of this tensor with a dynamic dimension count. - pub fn as_dyn_mut(&mut self) -> TensorViewMut { - TensorViewMut::new(self.data.as_mut(), &self.layout.as_dyn()) - } - - /// Return a mutable iterator over elements of this tensor. - pub fn iter_mut(&mut self) -> IterMut { - IterMut::new(self.mut_view_ref()) - } - - fn mut_view_ref(&mut self) -> MutViewRef> { - MutViewRef::new(self.data.as_mut(), &self.layout) - } - - /// Replace elements of this tensor with `f(element)`. - /// - /// This is the in-place version of `map`. - /// - /// The order in which elements are visited is unspecified and may not - /// correspond to the logical order. - pub fn apply T>(&mut self, f: F) { - if self.is_contiguous() { - self.data.as_mut().iter_mut().for_each(|x| *x = f(x)); - } else { - self.iter_mut().for_each(|x| *x = f(x)); - } - } - - /// Replace all elements of this tensor with `value`. - pub fn fill(&mut self, value: T) - where - T: Clone, - { - self.apply(|_| value.clone()); - } - - /// Copy elements from another tensor into this tensor. - /// - /// This tensor and `other` must have the same shape. - pub fn copy_from(&mut self, other: &NdTensorView) - where - T: Clone, - { - assert!(self.shape() == other.shape()); - for (out, x) in zip(self.iter_mut(), other.iter()) { - *out = x.clone(); - } - } - - /// Store an array of `M` elements into successive entries of a tensor along - /// the `dim` axis. - /// - /// See [NdTensorBase::get_array] for more details. - #[inline] - pub fn set_array(&mut self, base: [usize; N], dim: usize, values: [T; M]) - where - T: Copy, - { - let offsets: [usize; M] = array_offsets(&self.layout, base, dim); - let data = self.data.as_mut(); - - for i in 0..M { - // Safety: `array_offsets` returns valid offsets. - unsafe { *data.get_unchecked_mut(offsets[i]) = values[i] }; - } - } -} - -impl NdTensorBase, N> { - /// Create a new tensor with a given shape, contigous layout and all - /// elements set to zero (or whatever `T::default()` returns). - pub fn zeros(shape: [usize; N]) -> Self { - Self::full(shape, T::default()) - } - - /// Create a new tensor with a given shape, contiguous layout and all - /// elements initialized to `element`. - pub fn full(shape: [usize; N], element: T) -> Self { - let layout = NdLayout::from_shape(shape); - NdTensorBase { - data: vec![element; layout.len()], - layout, - element_type: PhantomData, - } - } - - /// Create a new tensor filled with random numbers from a given source. - pub fn rand>(shape: [usize; N], rand_src: &mut R) -> NdTensor - where - T: Clone + Default, - { - let mut tensor = NdTensor::zeros(shape); - tensor.data.fill_with(|| rand_src.next()); - tensor - } -} - -impl, S2: AsRef<[T]>, const N: usize> TryFrom> - for NdTensorBase -where - S1: Into, -{ - type Error = DimensionError; - - /// Convert a dynamic-dimensional tensor or view into a static-dimensional one. - /// - /// Fails if `value` does not have `N` dimensions. - fn try_from(value: TensorBase) -> Result { - let layout: NdLayout = value.layout().try_into()?; - Ok(NdTensorBase { - data: value.into_data().into(), - layout, - element_type: PhantomData, - }) - } -} - -impl, const N: usize> Index<[usize; N]> for NdTensorBase { - type Output = T; - fn index(&self, index: [usize; N]) -> &Self::Output { - &self.data.as_ref()[self.layout.offset(index)] - } -} - -impl + AsMut<[T]>, const N: usize> IndexMut<[usize; N]> for NdTensorBase { - fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output { - let offset = self.layout.offset(index); - &mut self.data.as_mut()[offset] - } -} - -impl, const N: usize> Layout for NdTensorBase { - type Index<'a> = [usize; N]; - type Indices = NdIndices; - - fn ndim(&self) -> usize { - N - } - - fn len(&self) -> usize { - self.layout.len() - } - - fn try_offset(&self, index: [usize; N]) -> Option { - self.layout.try_offset(index) - } - - fn is_empty(&self) -> bool { - self.layout.is_empty() - } - - fn shape(&self) -> Self::Index<'_> { - self.layout.shape() - } - - fn size(&self, dim: usize) -> usize { - self.layout.size(dim) - } - - fn strides(&self) -> Self::Index<'_> { - self.layout.strides() - } - - fn stride(&self, dim: usize) -> usize { - self.layout.stride(dim) - } - - fn indices(&self) -> Self::Indices { - self.layout.indices() - } -} - -impl> MatrixLayout for NdTensorBase { - fn rows(&self) -> usize { - self.layout.rows() - } - - fn cols(&self) -> usize { - self.layout.cols() - } - - fn row_stride(&self) -> usize { - self.layout.row_stride() - } - - fn col_stride(&self) -> usize { - self.layout.col_stride() - } -} - -/// Variant of [NdTensorBase] which owns its elements, using a `Vec` as -/// the backing storage. -pub type NdTensor = NdTensorBase, N>; - -/// Variant of [NdTensorBase] which borrows its elements from an [NdTensor]. -/// -/// Conceptually the relationship between [NdTensorView] and [NdTensor] is -/// similar to that between `[T]` and `Vec`. They share the same element -/// buffer, but views can have distinct layouts, with some limitations. -pub type NdTensorView<'a, T, const N: usize> = NdTensorBase; - -/// Variant of [NdTensorBase] which mutably borrows its elements from an -/// [NdTensor]. -/// -/// This is similar to [NdTensorView], except elements in the underyling -/// Tensor can be modified through it. -pub type NdTensorViewMut<'a, T, const N: usize> = NdTensorBase; - -/// Alias for viewing a slice as a 2D matrix. -pub type Matrix<'a, T = f32> = NdTensorBase; - -/// Alias for viewing a mutable slice as a 2D matrix. -pub type MatrixMut<'a, T = f32> = NdTensorBase; - -/// A variant of NdTensor which does not bounds-check individual dimensions -/// when indexing, but does still bounds-check the offset into the underlying -/// storage, and hence is not unsafe. -/// -/// Indexing using `UncheckedNdTensor` is faster than normal indexing into -/// NdTensorBase, but not as fast as the unsafe [NdTensorBase::get_unchecked] -/// method, which doesn't bounds-check individual dimensions or the final -/// offset into the data. -pub struct UncheckedNdTensor, const N: usize> { - base: NdTensorBase, -} - -impl, const N: usize> Layout for UncheckedNdTensor { - type Index<'a> = [usize; N]; - type Indices = NdIndices; - - fn ndim(&self) -> usize { - N - } - - fn try_offset(&self, index: [usize; N]) -> Option { - self.base.try_offset(index) - } - - fn len(&self) -> usize { - self.base.len() - } - - fn shape(&self) -> Self::Index<'_> { - self.base.shape() - } - - fn strides(&self) -> Self::Index<'_> { - self.base.strides() - } - - fn indices(&self) -> Self::Indices { - self.base.indices() - } -} - -impl, const N: usize> Index<[usize; N]> for UncheckedNdTensor { - type Output = T; - fn index(&self, index: [usize; N]) -> &Self::Output { - &self.base.data.as_ref()[self.base.layout.offset_unchecked(index)] - } -} - -impl + AsMut<[T]>, const N: usize> IndexMut<[usize; N]> - for UncheckedNdTensor -{ - fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output { - let offset = self.base.layout.offset_unchecked(index); - &mut self.base.data.as_mut()[offset] - } -} - -impl FromIterator for NdTensor { - fn from_iter(iter: I) -> Self - where - I: IntoIterator, - { - let data: Vec<_> = FromIterator::from_iter(iter); - let len = data.len(); - NdTensor::from_data([len], data) - } -} - -impl, S2: AsRef<[T]>, const N: usize> PartialEq> - for NdTensorBase -{ - fn eq(&self, other: &NdTensorBase) -> bool { - self.shape() == other.shape() && self.iter().eq(other.iter()) - } -} - -#[cfg(test)] -mod tests { - use crate::errors::{DimensionError, FromDataError}; - use crate::{ - ndtensor, Layout, MatrixLayout, NdTensor, NdTensorView, NdTensorViewMut, NdView, - RandomSource, SliceItem, Tensor, View, - }; - - /// Return elements of `tensor` in their logical order. - fn tensor_elements(tensor: NdTensorView) -> Vec { - tensor.iter().cloned().collect() - } - - /// Create a tensor where the value of each element is its logical index - /// plus one. - fn steps(shape: [usize; N]) -> NdTensor { - let mut x = NdTensor::zeros(shape); - for (index, elt) in x.iter_mut().enumerate() { - *elt = (index + 1) as i32; - } - x - } - - #[test] - fn test_ndtensor_apply() { - let mut tensor = ndtensor!((2, 2); [1, 2, 3, 4]); - - // Whole tensor - tensor.apply(|x| x * 2); - assert_eq!(tensor.to_vec(), &[2, 4, 6, 8]); - - // Non-contiguous slice - tensor.slice_mut::<1, _>((.., 0)).apply(|_| 0); - assert_eq!(tensor.to_vec(), &[0, 4, 0, 8]); - } - - #[test] - fn test_ndtensor_fill() { - let mut x = NdTensor::zeros([2, 2]); - x.fill(1i32); - assert_eq!(x.to_vec(), &[1, 1, 1, 1]); - - x.slice_mut::<1, _>(0).fill(2); - x.slice_mut::<1, _>(1).fill(3); - - assert_eq!(x.to_vec(), &[2, 2, 3, 3]); - } - - // Test conversion of a static-dim tensor with default strides, to a - // dynamic dim tensor. - #[test] - fn test_ndtensor_as_dyn() { - let tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); - let dyn_tensor = tensor.as_dyn(); - assert_eq!(tensor.shape(), dyn_tensor.shape()); - assert_eq!(tensor.data(), dyn_tensor.data()); - } - - #[test] - fn test_ndtensor_as_dyn_mut() { - let mut tensor = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); - let mut dyn_tensor = tensor.as_dyn_mut(); - assert_eq!(dyn_tensor.shape(), [2, 2]); - assert_eq!(dyn_tensor.data_mut().unwrap(), &[1, 2, 3, 4]); - } - - // Test conversion of a static-dim tensor with broadcasting strides (ie. - // some strides are 0), to a dynamic dim tensor. - #[test] - fn test_ndtensor_as_dyn_broadcast() { - let data = [1, 2, 3, 4]; - let view = NdTensorView::from_slice_with_strides([4, 4], &data, [0, 1]).unwrap(); - let dyn_view = view.as_dyn(); - let elements: Vec<_> = dyn_view.iter().copied().collect(); - assert_eq!(elements, &[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]); - } - - #[test] - fn test_ndtensor_broadcast() { - let x = NdTensor::from_data([2], vec![1, 2]); - let bx = x.broadcast([3, 2]); - assert_eq!(bx.shape(), [3, 2]); - assert_eq!(bx.strides(), [0, 1]); - assert_eq!(bx.as_dyn().to_vec(), &[1, 2, 1, 2, 1, 2]); - - let x = NdTensor::from_data([], vec![3]); - let bx = x.broadcast([2, 4]); - assert_eq!(bx.shape(), [2, 4]); - assert_eq!(bx.strides(), [0, 0]); - assert_eq!(bx.as_dyn().to_vec(), &[3, 3, 3, 3, 3, 3, 3, 3]); - } - - #[test] - #[should_panic(expected = "Cannot broadcast to specified shape")] - fn test_ndtensor_broadcast_invalid() { - let x = NdTensor::from_data([2], vec![1, 2]); - x.broadcast([1, 4]); - } - - #[test] - fn test_ndtensor_copy_from() { - let x = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); - let mut y = NdTensor::zeros(x.shape()); - - y.copy_from(&x.view()); - - assert_eq!(y, x); - } - - #[test] - fn test_ndtensor_from_data() { - let data = vec![1., 2., 3., 4.]; - let view = NdTensorView::::from_data([2, 2], &data); - assert_eq!(view.data(), Some(data.as_slice())); - assert_eq!(view.shape(), [2, 2]); - assert_eq!(view.strides(), [2, 1]); - } - - #[test] - fn test_ndtensor_from_data_custom_strides() { - struct Case { - data: Vec, - shape: [usize; 2], - strides: [usize; 2], - } - - let cases = [ - // Contiguous view (no gaps, shortest stride last) - Case { - data: vec![1., 2., 3., 4.], - shape: [2, 2], - strides: [2, 1], - }, - // Transposed view (reversed strides) - Case { - data: vec![1., 2., 3., 4.], - shape: [2, 2], - strides: [1, 2], - }, - // Sliced view (gaps between elements) - Case { - data: vec![1.; 10], - shape: [2, 2], - strides: [4, 2], - }, - // Sliced view (gaps between rows) - Case { - data: vec![1.; 10], - shape: [2, 2], - strides: [4, 1], - }, - ]; - - for case in cases { - let result = NdTensorView::::from_data_with_strides( - case.shape, - &case.data, - case.strides, - ) - .unwrap(); - assert_eq!(result.shape(), case.shape); - assert_eq!(result.strides(), case.strides); - assert_eq!( - result.data(), - result.is_contiguous().then_some(case.data.as_slice()) - ); - } - } - - #[test] - fn test_ndtensor_from_iterator() { - let tensor: NdTensor = [1., 2., 3., 4.].into_iter().collect(); - assert_eq!(tensor_elements(tensor.view()), [1., 2., 3., 4.]); - } - - #[test] - fn test_slice_into_1d_ndtensor() { - let data = &[1., 2., 3., 4.]; - let view: NdTensorView = data.into(); - assert_eq!(view.data(), Some(data.as_slice())); - assert_eq!(view.shape(), [4]); - assert_eq!(view.strides(), [1]); - } - - #[test] - fn test_ndtensor_from_slice_with_strides() { - let data = vec![1., 2., 3., 4.]; - let view = NdTensorView::::from_slice_with_strides([2, 2], &data, [2, 1]).unwrap(); - assert_eq!(view.data(), Some(data.as_slice())); - assert_eq!(view.shape(), [2, 2]); - assert_eq!(view.strides(), [2, 1]); - } - - #[test] - fn test_ndtensor_from_slice_with_strides_too_short() { - let data = vec![1., 2., 3., 4.]; - let result = NdTensorView::::from_slice_with_strides([3, 3], &data, [2, 1]); - assert_eq!(result.err(), Some(FromDataError::StorageTooShort)); - } - - #[test] - fn test_ndtensor_from_data_fails_if_overlap() { - struct Case { - data: Vec, - shape: [usize; 3], - strides: [usize; 3], - } - - let cases = [ - // Broadcasting view (zero strides) - Case { - data: vec![1., 2., 3., 4.], - shape: [10, 2, 2], - strides: [0, 2, 1], - }, - // Case where there is actually no overlap, but `from_data` fails - // with a `MayOverlap` error due to the conservative logic it uses. - Case { - data: vec![1.; (3 * 3) + (3 * 4) + 1], - shape: [1, 4, 4], - strides: [20, 3, 4], - }, - ]; - - for case in cases { - let result = NdTensorView::::from_data_with_strides( - case.shape, - &case.data, - case.strides, - ); - assert_eq!(result.err(), Some(FromDataError::MayOverlap)); - } - } - - #[test] - fn test_ndtensor_from_slice_allows_overlap() { - let data = vec![1., 2., 3., 4.]; - let result = NdTensorView::::from_slice_with_strides([10, 2, 2], &data, [0, 2, 1]); - assert!(result.is_ok()); - } - - #[test] - fn test_ndtensor_try_from_tensor() { - // Tensor -> NdTensor - let tensor = Tensor::zeros(&[1, 10, 20]); - let ndtensor: NdTensor = tensor.clone().try_into().unwrap(); - assert_eq!(ndtensor.data(), tensor.data()); - assert_eq!(ndtensor.shape(), tensor.shape()); - assert_eq!(ndtensor.strides(), tensor.strides()); - - // Failed Tensor -> NdTensor - let matrix: Result, _> = tensor.clone().try_into(); - assert_eq!(matrix, Err(DimensionError {})); - - // TensorView -> NdTensorView - let ndview: NdTensorView = tensor.view().try_into().unwrap(); - assert_eq!(ndview.data(), tensor.data()); - assert_eq!(ndview.shape(), tensor.shape()); - assert_eq!(ndview.strides(), tensor.strides()); - - // TensorViewMut -> NdTensorViewMut - let mut tensor = Tensor::zeros(&[1, 10, 20]); - let mut ndview: NdTensorViewMut = tensor.view_mut().try_into().unwrap(); - ndview[[0, 0, 0]] = 1; - assert_eq!(tensor[[0, 0, 0]], 1); - } - - #[test] - fn test_ndtensor_get() { - let tensor = NdTensor::::zeros([5, 10, 15]); - - assert_eq!(tensor.get([0, 0, 0]), Some(&0)); - assert_eq!(tensor.get([4, 9, 14]), Some(&0)); - assert_eq!(tensor.get([5, 9, 14]), None); - assert_eq!(tensor.get([4, 10, 14]), None); - assert_eq!(tensor.get([4, 9, 15]), None); - } - - #[test] - fn test_ndtensor_get_array() { - let tensor = steps([4, 2, 2]); - - // First dim, zero base. - let values: [i32; 4] = tensor.get_array([0, 0, 0], 0); - assert_eq!(values, [1, 5, 9, 13]); - - // First dim, different base. - let values: [i32; 4] = tensor.get_array([0, 1, 1], 0); - assert_eq!(values, [4, 8, 12, 16]); - - // Last dim, zero base. - let values: [i32; 2] = tensor.get_array([0, 0, 0], 2); - assert_eq!(values, [1, 2]); - } - - #[test] - fn test_ndtensor_set_array() { - let mut tensor = steps([4, 2, 2]); - tensor.set_array([0, 0, 0], 0, [-1, -2, -3, -4]); - assert_eq!( - tensor.iter().copied().collect::>(), - &[-1, 2, 3, 4, -2, 6, 7, 8, -3, 10, 11, 12, -4, 14, 15, 16] - ); - } - - #[test] - fn test_ndtensor_assign_array() { - let mut tensor = NdTensor::zeros([2, 2]); - let mut transposed = tensor.view_mut(); - - transposed.permute([1, 0]); - transposed.slice_mut(0).assign_array([1, 2]); - transposed.slice_mut(1).assign_array([3, 4]); - - assert_eq!(tensor.iter().copied().collect::>(), [1, 3, 2, 4]); - } - - #[test] - #[should_panic(expected = "array indices invalid")] - fn test_ndtensor_get_array_invalid_index() { - let tensor = steps([4, 2, 2]); - tensor.get_array::<5>([0, 0, 0], 0); - } - - #[test] - #[should_panic(expected = "array indices invalid")] - fn test_ndtensor_get_array_invalid_index_2() { - let tensor = steps([4, 2, 2]); - tensor.get_array::<4>([1, 0, 0], 0); - } - - #[test] - fn test_ndtensor_get_mut() { - let mut tensor = NdTensor::::zeros([5, 10, 15]); - - assert_eq!(tensor.get_mut([0, 0, 0]), Some(&mut 0)); - assert_eq!(tensor.get_mut([4, 9, 14]), Some(&mut 0)); - assert_eq!(tensor.get_mut([5, 9, 14]), None); - assert_eq!(tensor.get_mut([4, 10, 14]), None); - assert_eq!(tensor.get_mut([4, 9, 15]), None); - } - - #[test] - fn test_ndtensor_get_unchecked() { - let tensor = NdTensor::::zeros([5, 10, 15]); - let tensor = tensor.view(); - unsafe { - assert_eq!(tensor.get_unchecked([0, 0, 0]), &0); - assert_eq!(tensor.get_unchecked([4, 9, 14]), &0); - } - } - - #[test] - fn test_ndtensor_get_unchecked_mut() { - let mut tensor = NdTensor::::zeros([5, 10, 15]); - unsafe { - assert_eq!(tensor.get_unchecked_mut([0, 0, 0]), &0); - assert_eq!(tensor.get_unchecked_mut([4, 9, 14]), &0); - } - } - - #[test] - fn test_ndtensor_into_dyn() { - let nd_tensor = ndtensor!((2, 3); [0., 1., 2., 3., 4., 5., 6.]); - let tensor = nd_tensor.into_dyn(); - assert_eq!(tensor.shape(), [2, 3]); - assert_eq!( - tensor.iter().copied().collect::>(), - [0., 1., 2., 3., 4., 5., 6.] - ); - } - - #[test] - fn test_ndtensor_iter() { - let tensor = NdTensor::::from_data([2, 2], vec![1, 2, 3, 4]); - let elements: Vec<_> = tensor.iter().copied().collect(); - assert_eq!(elements, &[1, 2, 3, 4]); - } - - #[test] - fn test_ndtensor_iter_mut() { - let mut tensor = NdTensor::::zeros([2, 2]); - tensor - .iter_mut() - .enumerate() - .for_each(|(i, el)| *el = i as i32); - let elements: Vec<_> = tensor.iter().copied().collect(); - assert_eq!(elements, &[0, 1, 2, 3]); - } - - #[test] - fn test_ndtensor_map() { - let tensor = NdTensor::::from_data([2, 2], vec![1, 2, 3, 4]); - let doubled = tensor.map(|x| x * 2); - assert_eq!(tensor_elements(doubled.view()), &[2, 4, 6, 8]); - } - - #[test] - fn test_ndtensor_to_array() { - let tensor = ndtensor!((2, 2); [1., 2., 3., 4.]); - let col0: [f32; 2] = tensor.view().transposed().slice::<1, _>(0).to_array(); - let col1: [f32; 2] = tensor.view().transposed().slice::<1, _>(1).to_array(); - assert_eq!(col0, [1., 3.]); - assert_eq!(col1, [2., 4.]); - } - - #[test] - fn test_ndtensor_to_tensor() { - let data = vec![1., 2., 3., 4.]; - let view = NdTensorView::::from_data([2, 2], &data).permuted([1, 0]); - let owned = view.to_tensor(); - assert_eq!(owned.shape(), view.shape()); - assert!(owned.is_contiguous()); - } - - #[test] - fn test_ndtensor_to_vec() { - let tensor = ndtensor!((2, 2); [1, 2, 3, 4]); - let tensor = tensor.view().transposed(); - assert_eq!(tensor.to_vec(), &[1, 3, 2, 4]); - } - - #[test] - fn test_ndtensor_partial_eq() { - let a = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); - let b = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); - let c = NdTensor::from_data([1, 4], vec![1, 2, 3, 4]); - let d = NdTensor::from_data([2, 2], vec![1, 2, 3, 5]); - - assert_eq!(a, b); - assert_ne!(a, c); - assert_ne!(a, d); - } - - #[test] - fn test_ndtensor_permuted() { - let data = vec![1, 2, 3, 4]; - let view = NdTensorView::from(&data).reshaped([2, 2]); - let transposed = view.permuted([1, 0]); - assert_eq!(tensor_elements(transposed), &[1, 3, 2, 4]); - - let transposed = transposed.permuted([1, 0]); - assert_eq!(tensor_elements(transposed), &[1, 2, 3, 4]); - } - - #[test] - fn test_ndtensor_permute() { - let data = vec![1, 2, 3, 4]; - let mut view = NdTensorView::from(&data).reshaped([2, 2]); - view.permute([1, 0]); - assert_eq!(tensor_elements(view), &[1, 3, 2, 4]); - view.permute([1, 0]); - assert_eq!(tensor_elements(view), &[1, 2, 3, 4]); - } - - #[test] - fn test_ndtensor_rand() { - struct NotRandom { - next: f32, - } - - impl RandomSource for NotRandom { - fn next(&mut self) -> f32 { - let curr = self.next; - self.next += 1.0; - curr - } - } - - let mut rng = NotRandom { next: 0. }; - - let tensor = NdTensor::rand([2, 2], &mut rng); - assert_eq!(tensor.shape(), [2, 2]); - assert_eq!(tensor.to_vec(), &[0., 1., 2., 3.]); - } - - #[test] - #[should_panic(expected = "permutation is invalid")] - fn test_ndtensor_permuted_panics_if_dims_invalid() { - let data = vec![1, 2, 3, 4]; - let view = NdTensorView::from(&data).reshaped([2, 2]); - view.permuted([2, 0]); - } - - #[test] - fn test_ndtensor_reshaped() { - let data = vec![1, 2, 3, 4]; - let view = NdTensorView::from(&data); - let matrix = view.reshaped([2, 2]); - assert_eq!(matrix.shape(), [2, 2]); - assert_eq!(tensor_elements(matrix), &[1, 2, 3, 4]); - } - - #[test] - #[should_panic(expected = "new shape must have same number of elements as current shape")] - fn test_ndtensor_reshaped_panics_if_product_not_equal() { - let data = vec![1, 2, 3, 4]; - let view = NdTensorView::from(&data); - view.reshaped([2, 3]); - } - - #[test] - #[should_panic(expected = "can only reshape a contiguous tensor")] - fn test_ndtensor_reshaped_panics_if_not_contiguous() { - let data = vec![1, 2, 3, 4]; - let view = NdTensorView::from(&data).reshaped([2, 2]); - let transposed = view.transposed(); - transposed.reshaped([4]); - } - - #[test] - fn test_ndtensor_to_contiguous() { - let x = NdTensor::from_data([3, 3], vec![1, 2, 3, 4, 5, 6, 7, 8, 9]); - let y = x.to_contiguous(); - assert!(y.is_contiguous()); - assert_eq!(y.data().unwrap().as_ptr(), x.data().unwrap().as_ptr()); - - let x = x.permuted([1, 0]); - assert!(!x.is_contiguous()); - - let y = x.to_contiguous(); - assert!(y.is_contiguous()); - assert_eq!( - y.data(), - Some(x.iter().copied().collect::>().as_slice()) - ); - } - - #[test] - fn test_ndtensor_transposed() { - let data = vec![1, 2, 3, 4]; - let view = NdTensorView::from(&data).reshaped([2, 2]); - assert_eq!(tensor_elements(view), &[1, 2, 3, 4]); - let view = view.transposed(); - assert_eq!(tensor_elements(view), &[1, 3, 2, 4]); - - let view = NdTensorView::from(&data).reshaped([1, 1, 4]); - let transposed = view.transposed(); - assert_eq!(transposed.shape(), [4, 1, 1]); - } - - #[test] - fn test_ndtensor_slice() { - let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - let view = NdTensorView::from(&data).reshaped([4, 4]); - let slice: NdTensorView<_, 2> = view.slice([1..3, 1..3]); - assert_eq!(tensor_elements(slice), &[6, 7, 10, 11]); - } - - #[test] - fn test_ndtensor_slice_step() { - let data: Vec = (0..25).collect(); - let view = NdTensorView::from(&data).reshaped([5, 5]); - let slice: NdTensorView<_, 2> = - view.slice((SliceItem::range(0, None, 2), SliceItem::range(0, None, 2))); - assert_eq!(slice.shape(), [3, 3]); - assert_eq!( - slice.iter().copied().collect::>(), - [0, 2, 4, 10, 12, 14, 20, 22, 24] - ); - } - - #[test] - #[should_panic(expected = "sliced dims != 3")] - fn test_ndtensor_slice_wrong_dims() { - let data = vec![1, 2, 3, 4]; - let view = NdTensorView::from(&data).reshaped([2, 2]); - view.slice::<3, _>([0..2, 0..2]); - } - - #[test] - fn test_ndtensor_slice_mut() { - let mut data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - let mut view = NdTensorViewMut::::from_data([4, 4], &mut data); - let mut slice = view.slice_mut([1..3, 1..3]); - slice[[0, 0]] = -1; - slice[[0, 1]] = -2; - slice[[1, 0]] = -3; - slice[[1, 1]] = -4; - assert_eq!( - tensor_elements(view.view()), - &[1, 2, 3, 4, 5, -1, -2, 8, 9, -3, -4, 12, 13, 14, 15, 16] - ); - } - - #[test] - #[should_panic(expected = "sliced dims != 3")] - fn test_ndtensor_slice_mut_wrong_dims() { - let mut data = vec![1, 2, 3, 4]; - let mut view = NdTensorViewMut::::from_data([2, 2], &mut data); - view.slice_mut::<3, _>([0..2, 0..2]); - } - - #[test] - fn test_matrix_layout() { - let data = vec![1., 2., 3., 4.]; - let mat = NdTensorView::from(&data).reshaped([2, 2]); - assert_eq!(mat.data(), Some(data.as_slice())); - assert_eq!(mat.rows(), 2); - assert_eq!(mat.cols(), 2); - assert_eq!(mat.row_stride(), 2); - assert_eq!(mat.col_stride(), 1); - } - - #[test] - #[ignore] - fn bench_iter() { - use crate::rng::XorShiftRng; - use crate::test_util::bench_loop; - - let mut rng = XorShiftRng::new(1234); - - // Create 4D NCHW tensor, such as is common in vision models. - let tensor = NdTensor::rand([5, 64, 256, 256], &mut rng); - let n_iters = 1; - - // Iteration via `for` loop; - let for_stats = bench_loop(n_iters, || { - let mut sum = 0.; - let data = tensor.data().unwrap(); - for i in 0..data.len() { - sum += data[i]; - } - assert!(sum > 0.); - }); - println!( - "NCHW iteration via for loop: {:.3}ms", - for_stats.duration_ms() - ); - - // Iteration via slice traversal. - let slice_iter_stats = bench_loop(n_iters, || { - let sum = tensor.data().unwrap().iter().sum::(); - assert!(sum > 0.); - }); - println!( - "NCHW iteration via slice iter: {:.3}ms", - slice_iter_stats.duration_ms() - ); - - // Iteration via tensor iterator (contiguous). - let iter_contiguous_stats = bench_loop(n_iters, || { - let sum = tensor.iter().sum::(); - assert!(sum > 0.); - }); - println!( - "NCHW iteration via contiguous iter: {:.3}ms", - iter_contiguous_stats.duration_ms() - ); - - // Iteration via tensor iterator (non-contiguous). - let iter_non_contiguous_stats = bench_loop(n_iters, || { - let sum = tensor.permuted([1, 0, 2, 3]).iter().sum::(); - assert!(sum > 0.); - }); - println!( - "NCHW iteration via non-contiguous iter: {:.3}ms", - iter_non_contiguous_stats.duration_ms() - ); - - // Iteration via indexing. - let indexing_stats = bench_loop(n_iters, || { - let mut sum = 0.; - let [batch, chans, height, width] = tensor.shape(); - for n in 0..batch { - for c in 0..chans { - for h in 0..height { - for w in 0..width { - sum += tensor[[n, c, h, w]]; - } - } - } - } - assert!(sum > 0.); - }); - println!( - "NCHW iteration via indexing: {:.3}ms", - indexing_stats.duration_ms() - ); - - // Iteration via unchecked indexing. - let unchecked_indexing_stats = bench_loop(n_iters, || { - let unchecked = tensor.view().unchecked(); - let mut sum = 0.; - let [batch, chans, height, width] = tensor.shape(); - for n in 0..batch { - for c in 0..chans { - for h in 0..height { - for w in 0..width { - sum += unchecked[[n, c, h, w]]; - } - } - } - } - assert!(sum > 0.); - }); - println!( - "NCHW iteration via unchecked indexing: {:.3}ms", - unchecked_indexing_stats.duration_ms() - ); - - // Iteration via dynamic rank indexing. - let dyn_indexing_stats = bench_loop(n_iters, || { - let dyn_view = tensor.as_dyn(); - let mut sum = 0.; - let [batch, chans, height, width] = tensor.shape(); - for n in 0..batch { - for c in 0..chans { - for h in 0..height { - for w in 0..width { - sum += dyn_view[[n, c, h, w]]; - } - } - } - } - assert!(sum > 0.); - }); - println!( - "NCHW iteration via dyn indexing: {:.3}ms", - dyn_indexing_stats.duration_ms() - ); - } -} diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs deleted file mode 100644 index 7f7ad038..00000000 --- a/rten-tensor/src/tensor.rs +++ /dev/null @@ -1,2397 +0,0 @@ -use std::borrow::Cow; -use std::fmt::Debug; -use std::io; -use std::io::Write; -use std::iter::zip; -use std::marker::PhantomData; -use std::ops::{Index, IndexMut, Range}; - -use crate::errors::SliceError; -use crate::iterators::{ - AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterMut, Iter, - IterMut, Lanes, LanesMut, MutViewRef, ViewRef, -}; -use crate::layout::{DynLayout, Layout}; -use crate::ndtensor::{NdTensorBase, NdTensorView, NdTensorViewMut}; -use crate::range::{IntoSliceItems, SliceItem}; -use crate::RandomSource; - -/// Multi-dimensional array view with a dynamic dimension count. This trait -/// includes operations that are available on tensors that own their data -/// ([Tensor]) as well as views ([TensorView], [TensorViewMut]). -/// -/// [TensorView] implements specialized versions of these methods as -/// inherent methods, which preserve lifetimes on the result. -pub trait View: Layout { - /// The data type of elements in this tensor. - type Elem; - - /// Return an iterator over slices of this tensor along a given axis. - fn axis_iter(&self, dim: usize) -> AxisIter { - self.view().axis_iter(dim) - } - - /// Return an iterator over slices of this tensor along a given axis. - fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks { - self.view().axis_chunks(dim, chunk_size) - } - - /// Create a view of this tensor which is broadcasted to `shape`. - /// - /// A broadcasted view behaves as if the underlying tensor had the broadcasted - /// shape, repeating elements as necessary to fill the given dimensions. - /// Broadcasting is only possible if the actual and broadcast shapes are - /// compatible according to ONNX's rules. See - /// . - /// - /// See also - /// for worked examples of how broadcasting works. - /// - /// Panics if `shape` is not broadcast-compatible with the current shape. - fn broadcast(&self, shape: &[usize]) -> TensorView { - self.view().broadcast(shape) - } - - /// Return an iterator over the elements of this view broadcasted to - /// `shape`. - /// - /// This is functionally the same as `self.broadcast(shape).iter()` but - /// has some additional optimizations for common broadcasting uses - eg. - /// when the sequence of elements in the broadcasted view is equivalent to - /// cycling the original view using `tensor.iter().cycle()`. - fn broadcast_iter(&self, shape: &[usize]) -> BroadcastIter { - self.view().broadcast_iter(shape) - } - - /// Return the underlying data of the tensor as a slice, if it is contiguous. - fn data(&self) -> Option<&[Self::Elem]>; - - /// Return an iterator over views of the innermost N dimensions of this - /// tensor. - fn inner_iter(&self) -> InnerIter { - self.view().inner_iter() - } - - /// Returns the single item if this tensor is a 0-dimensional tensor - /// (ie. a scalar) - fn item(&self) -> Option<&Self::Elem> { - self.view().item() - } - - /// Return the element at a given index, or `None` if the index is out of - /// bounds in any dimension. - #[inline] - fn get>(&self, index: I) -> Option<&Self::Elem> { - self.view().get(index) - } - - /// Return an iterator over elements of this tensor, in their logical order. - fn iter(&self) -> Iter { - self.view().iter() - } - - /// Return an iterator over all 1D slices ("lanes") along a given axis. - /// - /// Each slice is an iterator over the elements in that lane. - fn lanes(&self, dim: usize) -> Lanes { - self.view().lanes(dim) - } - - /// Return a copy of this tensor with each element replaced by `f(element)`. - /// - /// The order in which elements are visited is unspecified and may not - /// correspond to the logical order. - fn map(&self, f: F) -> Tensor - where - F: Fn(&Self::Elem) -> U, - { - let data = if let Some(data) = self.data() { - data.iter().map(f).collect() - } else { - self.iter().map(f).collect() - }; - Tensor { - data, - layout: DynLayout::from_shape(self.shape().as_ref()), - element_type: PhantomData, - } - } - - /// Return a view with a static rank. - /// - /// Panics if the rank of this tensor is not `N`. - fn nd_view(&self) -> NdTensorView { - self.view().nd_view() - } - - /// Return a new view with the given shape. - /// - /// The current view must be contiguous and the new shape must have the - /// same product as the current shape. - fn reshaped(&self, shape: &[usize]) -> TensorView { - self.view().reshaped(shape) - } - - /// Return a new view with the dimensions re-ordered according to `dims`. - fn permuted(&self, dims: &[usize]) -> TensorView { - self.view().permuted(dims) - } - - /// Return a view of part of this tensor. - /// - /// `range` specifies the indices or ranges of this tensor to include in the - /// returned view. If the range has fewer dimensions than the tensor, they - /// refer to the leading dimensions. - /// - /// See [IntoSliceItems] for a description of how slices can be specified. - /// Slice ranges are currently restricted to use positive steps. In other - /// words, NumPy-style slicing with negative steps is not supported. - fn slice(&self, range: R) -> TensorView { - self.view().slice(range) - } - - /// Variant of [View::slice] which returns an error if the slice spec - /// is invalid, instead of panicking. - fn try_slice(&self, range: R) -> Result, SliceError> { - self.view().try_slice(range) - } - - /// Return an iterator over a slice of this tensor. - /// - /// This is similar to `self.slice(range).iter()` except that it - /// returns an iterator directly instead of creating an intermediate view. - /// Also slicing with this method is more flexible as negative steps are - /// supported for items in `range`. - fn slice_iter(&self, range: &[SliceItem]) -> Iter { - self.view().slice_iter(range) - } - - /// Return a view of this tensor with all dimensions of size 1 removed. - fn squeezed(&self) -> TensorView { - self.view().squeezed() - } - - /// Return a tensor with data laid out in contiguous order. This will - /// be a view if the data is already contiguous, or a copy otherwise. - fn to_contiguous(&self) -> TensorBase> - where - Self::Elem: Clone, - { - self.view().to_contiguous() - } - - /// Return a new contiguous tensor with the same shape and elements as this - /// view. - fn to_tensor(&self) -> Tensor - where - Self::Elem: Clone, - { - self.view().to_tensor() - } - - /// Return a copy of the elements of this tensor in their logical order - /// as a vector. - /// - /// This is equivalent to `self.iter().cloned().collect()` but faster - /// when the tensor is already contiguous or has a small number (<= 4) - /// dimensions. - fn to_vec(&self) -> Vec - where - Self::Elem: Clone, - { - self.view().to_vec() - } - - /// Return a new view with the order of dimensions reversed. - fn transposed(&self) -> TensorView { - self.view().transposed() - } - - /// Return an immutable view of this tensor. - fn view(&self) -> TensorView; -} - -/// N-dimensional array, where `N` is determined at runtime based on the shape -/// that is specified when the tensor is constructed. -/// -/// `T` is the element type and `S` is the element storage. -/// -/// Most code will not use `TensorBase` directly but instead use the type -/// aliases [Tensor], [TensorView] and [TensorViewMut]. [Tensor] owns -/// its elements, and the other two types are views of slices. -/// -/// All [TensorBase] variants implement the [Layout] trait which provide -/// operations related to the shape and strides of the tensor, and the -/// [View] trait which provides common methods applicable to all variants. -#[derive(Debug)] -pub struct TensorBase> { - data: S, - layout: DynLayout, - element_type: PhantomData, -} - -/// Variant of [TensorBase] which borrows its elements from a [Tensor]. -/// -/// Conceptually the relationship between [TensorView] and [Tensor] is similar -/// to that between `[T]` and `Vec`. They share the same element buffer, but -/// views can have distinct layouts, with some limitations. -pub type TensorView<'a, T = f32> = TensorBase; - -/// Variant of [TensorBase] which mutably borrows its elements from a [Tensor]. -/// -/// This is similar to [TensorView], except elements in the underyling -/// Tensor can be modified through it. -pub type TensorViewMut<'a, T = f32> = TensorBase; - -impl> TensorBase { - /// Create a new tensor with a given layout and storage. - pub(crate) fn new(data: S, layout: &DynLayout) -> Self { - TensorBase { - data, - layout: layout.clone(), - element_type: PhantomData, - } - } - - /// Create a new tensor from a given shape and set of elements. No copying - /// is required. - pub fn from_data>(shape: &[usize], data: D) -> Self { - let data = data.into(); - assert!( - shape[..].iter().product::() == data.as_ref().len(), - "Number of elements given by shape {:?} does not match data length {}", - shape, - data.as_ref().len() - ); - TensorBase { - data, - layout: DynLayout::from_shape(shape), - element_type: PhantomData, - } - } - - /// Consume self and return the underlying element buffer. - /// - /// As with [TensorBase::data], there is no guarantee about the ordering of - /// elements. - pub fn into_data(self) -> S { - self.data - } - - /// Return an immutable view of this tensor. - /// - /// Views share the same element array, but can have an independent layout, - /// with some limitations. - pub fn view(&self) -> TensorView { - TensorView::new(self.data.as_ref(), &self.layout) - } - - /// Change the layout to put dimensions in the order specified by `dims`. - /// - /// This does not modify the order of elements in the data buffer, it just - /// updates the strides used by indexing. - pub fn permute(&mut self, dims: &[usize]) { - self.layout.permute(dims); - } - - /// Move the index at axis `from` to `to`, keeping the relative order of - /// other dimensions the same. This is like NumPy's `moveaxis` function. - /// - /// Panics if the `from` or `to` axes are >= `self.ndim()`. - pub fn move_axis(&mut self, from: usize, to: usize) { - self.layout.move_axis(from, to); - } - - /// Reverse the order of dimensions. - /// - /// This does not modify the order of elements in the data buffer, it just - /// changes the strides used by indexing. - pub fn transpose(&mut self) { - self.layout.transpose(); - } - - /// Insert a dimension of size one at index `dim`. - pub fn insert_dim(&mut self, dim: usize) { - self.layout.insert_dim(dim); - } - - /// Return the layout which maps indices to offsets in the data. - pub fn layout(&self) -> &DynLayout { - &self.layout - } -} - -/// Specialized versions of the [View] methods for immutable views. -/// These preserve the underlying lifetime of the view in results, allowing for -/// method calls to be chained. -impl<'a, T> TensorView<'a, T> { - pub fn axis_iter(&self, dim: usize) -> AxisIter<'a, T> { - AxisIter::new(self, dim) - } - - pub fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'a, T> { - AxisChunks::new(self, dim, chunk_size) - } - - pub fn broadcast(&self, shape: &[usize]) -> TensorView<'a, T> { - Self { - data: self.data, - layout: self.layout.broadcast(shape), - element_type: PhantomData, - } - } - - pub fn broadcast_iter(&self, shape: &[usize]) -> BroadcastIter<'a, T> { - assert!( - self.can_broadcast_to(shape), - "Cannot broadcast to specified shape" - ); - BroadcastIter::new(self.view_ref(), shape) - } - - pub fn data(&self) -> Option<&'a [T]> { - self.is_contiguous().then_some(self.data) - } - - /// Return the view's underlying data as a slice, without checking whether - /// it is contiguous. - /// - /// # Safety - /// - /// It is the caller's responsibility not to access elements in the slice - /// which are not part of this view. - pub unsafe fn data_unchecked(&self) -> &'a [T] { - self.data - } - - #[inline] - fn get>(&self, index: I) -> Option<&'a T> { - let offset = self.layout.try_offset(index.as_ref())?; - Some(&self.data[offset]) - } - - pub fn inner_iter(&self) -> InnerIter<'a, T, N> { - InnerIter::new(self.clone()) - } - - pub fn iter(&self) -> Iter<'a, T> { - Iter::new(self.view_ref()) - } - - pub(crate) fn view_ref(&self) -> ViewRef<'a, '_, T, DynLayout> { - ViewRef::new(self.data, &self.layout) - } - - pub fn item(&self) -> Option<&'a T> { - match self.ndim() { - 0 => Some(&self.data[0]), - _ if self.len() == 1 => self.iter().next(), - _ => None, - } - } - - pub fn lanes(&self, dim: usize) -> Lanes<'a, T> { - Lanes::new(self.view_ref(), dim) - } - - pub fn nd_view(&self) -> NdTensorView<'a, T, N> { - assert!(self.ndim() == N); - let shape: [usize; N] = self.shape().try_into().unwrap(); - let strides: [usize; N] = self.strides().try_into().unwrap(); - NdTensorView::from_slice_with_strides(shape, self.data, strides).unwrap() - } - - pub fn permuted(&self, dims: &[usize]) -> TensorView<'a, T> { - Self { - data: self.data, - layout: self.layout.permuted(dims), - element_type: PhantomData, - } - } - - pub fn reshaped(&self, shape: &[usize]) -> TensorView<'a, T> { - Self { - data: self.data, - layout: self.layout.reshaped(shape), - element_type: PhantomData, - } - } - - /// Change the layout of this view to have the given shape. - /// - /// The current view must be contiguous and the new shape must have the - /// same product as the current shape. - pub fn reshape(&mut self, shape: &[usize]) { - self.layout.reshape(shape); - } - - pub fn slice(&self, range: R) -> TensorView<'a, T> { - self.try_slice(range.into_slice_items().as_ref()).unwrap() - } - - pub fn try_slice(&self, range: R) -> Result, SliceError> { - let (offset_range, layout) = self.layout.try_slice(range.into_slice_items().as_ref())?; - Ok(TensorBase { - data: &self.data[offset_range], - layout, - element_type: PhantomData, - }) - } - - pub fn slice_iter(&self, range: &[SliceItem]) -> Iter<'a, T> { - Iter::slice(self.view_ref(), range) - } - - pub fn squeezed(&self) -> TensorView<'a, T> { - TensorBase { - data: self.data, - layout: self.layout.squeezed(), - element_type: PhantomData, - } - } - - pub fn to_contiguous(&self) -> TensorBase> - where - T: Clone, - { - if self.is_contiguous() { - TensorBase { - data: Cow::Borrowed(self.data), - layout: self.layout.clone(), - element_type: PhantomData, - } - } else { - let data = self.to_vec(); - TensorBase { - data: Cow::Owned(data), - layout: DynLayout::from_shape(self.layout().shape()), - element_type: PhantomData, - } - } - } - - pub fn transposed(&self) -> TensorView<'a, T> { - Self { - data: self.data, - layout: self.layout.transposed(), - element_type: PhantomData, - } - } -} - -impl> Layout for TensorBase { - type Index<'a> = ::Index<'a>; - type Indices = ::Indices; - - /// Return the number of dimensions. - fn ndim(&self) -> usize { - self.layout.ndim() - } - - fn try_offset(&self, index: &[usize]) -> Option { - self.layout.try_offset(index) - } - - /// Returns the number of elements in the array. - fn len(&self) -> usize { - self.layout.len() - } - - /// Returns true if the array has no elements. - fn is_empty(&self) -> bool { - self.layout.is_empty() - } - - /// Returns an array of the sizes of each dimension. - fn shape(&self) -> Self::Index<'_> { - self.layout.shape() - } - - /// Returns the size of the dimension `dim`. - fn size(&self, dim: usize) -> usize { - self.layout.size(dim) - } - - /// Returns an array of the strides of each dimension. - fn strides(&self) -> Self::Index<'_> { - self.layout.strides() - } - - /// Returns the offset between adjacent indices along dimension `dim`. - fn stride(&self, dim: usize) -> usize { - self.layout.stride(dim) - } - - /// Return an iterator over all valid indices in this tensor. - fn indices(&self) -> Self::Indices { - self.layout.indices() - } -} - -impl> View for TensorBase { - type Elem = T; - - fn data(&self) -> Option<&[T]> { - self.is_contiguous().then_some(self.data.as_ref()) - } - - fn to_tensor(&self) -> Tensor - where - T: Clone, - { - Tensor::from_data(self.shape(), self.to_vec()) - } - - fn to_vec(&self) -> Vec - where - T: Clone, - { - if let Some(data) = self.data() { - data.to_vec() - } else { - // This branch is equivalent to - // `x.iter().cloned().collect::>()` but uses a faster - // iteration method that is optimized for tensors with few (<= 4) - // dimensions. - let mut data = Vec::with_capacity(self.len()); - let ptr: *mut T = data.as_mut_ptr(); - - let mut offset = 0; - fast_for_each_element(self.view(), |elt| { - // Safety: `fast_for_each_element` calls fn `self.len()` times, - // matching the buffer capacity. - unsafe { *ptr.add(offset) = elt.clone() }; - offset += 1; - }); - - // Safety: Length here matches capacity passed to `Vec::with_capacity`. - unsafe { data.set_len(self.len()) } - - data - } - } - - fn view(&self) -> TensorView { - TensorView::new(self.data.as_ref(), &self.layout) - } -} - -impl, T, S: AsRef<[T]>> Index for TensorBase { - type Output = T; - - fn index(&self, index: I) -> &Self::Output { - let offset = self.layout.offset(index.as_ref()); - &self.data.as_ref()[offset] - } -} - -impl + AsMut<[T]>> TensorBase { - /// Copy elements from another tensor into this tensor. - /// - /// This tensor and `other` must have the same shape. - pub fn copy_from(&mut self, other: &TensorView) - where - T: Clone, - { - assert!(self.shape() == other.shape()); - for (out, x) in zip(self.iter_mut(), other.iter()) { - *out = x.clone(); - } - } - - /// Return the underlying data of the tensor as a mutable slice, if it is - /// contiguous. - pub fn data_mut(&mut self) -> Option<&mut [T]> { - self.is_contiguous().then_some(self.data.as_mut()) - } - - /// Return a mutable iterator over elements of this view. - pub fn iter_mut(&mut self) -> IterMut { - IterMut::new(self.mut_view_ref()) - } - - pub(crate) fn mut_view_ref(&mut self) -> MutViewRef { - MutViewRef::new(self.data.as_mut(), &self.layout) - } - - /// Return an iterator over mutable slices of this tensor along a given - /// axis. - pub fn axis_iter_mut(&mut self, dim: usize) -> AxisIterMut { - AxisIterMut::new(self.view_mut(), dim) - } - - pub fn axis_chunks_mut(&mut self, dim: usize, chunk_size: usize) -> AxisChunksMut { - AxisChunksMut::new(self.view_mut(), dim, chunk_size) - } - - /// Return the element at a given index, or `None` if the index is out of - /// bounds in any dimension. - #[inline] - pub fn get_mut>(&mut self, index: I) -> Option<&mut T> { - let offset = self.layout.try_offset(index.as_ref())?; - Some(&mut self.data.as_mut()[offset]) - } - - /// Return an iterator over views of the innermost N dimensions of this - /// tensor. - pub fn inner_iter_mut(&mut self) -> InnerIterMut { - InnerIterMut::new(self.view_mut()) - } - - /// Return a mutable iterator over all 1D slices of this tensor along a - /// given axis. - pub fn lanes_mut(&mut self, dim: usize) -> LanesMut { - LanesMut::new(self.mut_view_ref(), dim) - } - - /// Replace elements of this tensor with `f(element)`. - /// - /// This is the in-place version of `map`. - /// - /// The order in which elements are visited is unspecified and may not - /// correspond to the logical order. - pub fn apply T>(&mut self, f: F) { - if self.is_contiguous() { - self.data.as_mut().iter_mut().for_each(|x| *x = f(x)); - } else { - self.iter_mut().for_each(|x| *x = f(x)); - } - } - - /// Replace all elements of this tensor with `value`. - pub fn fill(&mut self, value: T) - where - T: Clone, - { - self.apply(|_| value.clone()); - } - - /// Return a new view with the dimensions re-ordered according to `dims`. - pub fn permuted_mut(&mut self, dims: &[usize]) -> TensorViewMut { - TensorBase { - data: self.data.as_mut(), - layout: self.layout.permuted(dims), - element_type: PhantomData, - } - } - - /// Return a new view with a given shape. This has the same requirements - /// as `reshape`. - pub fn reshaped_mut(&mut self, shape: &[usize]) -> TensorViewMut { - TensorBase { - data: self.data.as_mut(), - layout: self.layout.reshaped(shape), - element_type: PhantomData, - } - } - - /// Return a new mutable slice of this tensor. - /// - /// Slices are specified in the same way as for [TensorBase::slice]. - pub fn slice_mut(&mut self, range: R) -> TensorViewMut { - self.try_slice_mut(range.into_slice_items().as_ref()) - .unwrap() - } - - /// Variant of [TensorViewMut::slice_mut] which returns an error instead of - /// panicking if the slice range is invalid. - pub fn try_slice_mut( - &mut self, - range: R, - ) -> Result, SliceError> { - let (offset_range, layout) = self.layout.try_slice(range.into_slice_items().as_ref())?; - let data = &mut self.data.as_mut()[offset_range]; - Ok(TensorViewMut { - data, - layout, - element_type: PhantomData, - }) - } - - /// Return a new view with the order of dimensions reversed. - pub fn transposed_mut(&mut self) -> TensorViewMut { - TensorBase { - data: self.data.as_mut(), - layout: self.layout.transposed(), - element_type: PhantomData, - } - } - - /// Return a mutable view of this tensor. - /// - /// Views share the same element array, but can have an independent layout, - /// with some limitations. - pub fn view_mut(&mut self) -> TensorViewMut { - TensorViewMut::new(self.data.as_mut(), &self.layout) - } - - /// Return a mutable view with a static rank. - /// - /// Panics if the rank of this tensor is not `N`. - pub fn nd_view_mut(&mut self) -> NdTensorViewMut { - assert!(self.ndim() == N); - let shape: [usize; N] = self.shape().try_into().unwrap(); - let strides: [usize; N] = self.strides().try_into().unwrap(); - NdTensorViewMut::from_data_with_strides(shape, self.data.as_mut(), strides).unwrap() - } -} - -impl<'a, T> TensorViewMut<'a, T> { - /// Consume this view and return the underlying data slice. - /// - /// This differs from [Self::data_mut] as the lifetime of the returned slice - /// is tied to the underlying tensor, rather than the view. - pub fn into_data_mut(self) -> &'a mut [T] { - self.data - } -} - -impl, T, S: AsRef<[T]> + AsMut<[T]>> IndexMut for TensorBase { - fn index_mut(&mut self, index: I) -> &mut Self::Output { - let offset = self.layout.offset(index.as_ref()); - &mut self.data.as_mut()[offset] - } -} - -/// Variant of [TensorBase] which owns its elements, using a `Vec` as -/// the backing storage. -pub type Tensor = TensorBase>; - -impl Tensor { - /// Create a new zero-filled tensor with a given shape. - pub fn zeros(shape: &[usize]) -> Tensor - where - T: Clone + Default, - { - let n_elts = shape.iter().product(); - let data = vec![T::default(); n_elts]; - Tensor { - data, - layout: DynLayout::from_shape(shape), - element_type: PhantomData, - } - } - - /// Create a new tensor filled with a given value. - pub fn full(shape: &[usize], value: T) -> Tensor - where - T: Clone, - { - let n_elts = shape.iter().product(); - let data = vec![value; n_elts]; - Tensor { - data, - layout: DynLayout::from_shape(shape), - element_type: PhantomData, - } - } - - /// Create a new tensor filled with random numbers from a given source. - pub fn rand>(shape: &[usize], rand_src: &mut R) -> Tensor - where - T: Clone + Default, - { - let mut tensor = Tensor::zeros(shape); - tensor.data.fill_with(|| rand_src.next()); - tensor - } - - /// Create a new 1D tensor filled with an arithmetic sequence of values - /// in the range `[start, end)` separated by `step`. If `step` is omitted, - /// it defaults to 1. - pub fn arange(start: T, end: T, step: Option) -> Tensor - where - T: Copy + PartialOrd + From + std::ops::Add, - { - let step = step.unwrap_or((true).into()); - let mut data = Vec::new(); - let mut curr = start; - while curr < end { - data.push(curr); - curr = curr + step; - } - Tensor::from_vec(data) - } - - /// Create a new 0-dimensional (scalar) tensor from a single value. - pub fn from_scalar(value: T) -> Tensor { - Self::from_data(&[], vec![value]) - } - - /// Create a new 1-dimensional tensor from a vector. No copying is required. - pub fn from_vec(data: Vec) -> Tensor { - Self::from_data(&[data.len()], data) - } - - /// Clone this tensor with a new shape. The new shape must have the same - /// total number of elements as the existing shape. See `reshape`. - pub fn to_shape(&self, shape: &[usize]) -> Tensor - where - T: Clone, - { - Self::from_data(shape, self.to_vec()) - } - - /// Clip dimension `dim` to `[range.start, range.end)`. The new size for - /// the dimension must be <= the old size. - /// - /// This currently requires `T: Copy` to support efficiently moving data - /// from the new start offset to the beginning of the element buffer. - pub fn clip_dim(&mut self, dim: usize, range: Range) - where - T: Copy, - { - let (start, end) = (range.start, range.end); - - assert!(start <= end, "start must be <= end"); - assert!(end <= self.size(dim), "end must be <= dim size"); - - let start_offset = self.layout.stride(dim) * start; - self.layout.resize_dim(dim, end - start); - - let range = start_offset..start_offset + self.layout.min_data_len(); - self.data.copy_within(range.clone(), 0); - self.data.truncate(range.end - range.start); - } - - /// Convert the internal layout of elements to be contiguous, as reported - /// by `is_contiguous`. - /// - /// This is a no-op if the tensor is already contiguous. - pub fn make_contiguous(&mut self) - where - T: Clone, - { - if self.is_contiguous() { - return; - } - self.data = self.to_vec(); - self.layout.make_contiguous(); - } - - /// Update the shape of the tensor. - /// - /// The total number of elements for the new shape must be the same as the - /// existing shape. - /// - /// This is a cheap operation if the tensor is contiguous, but requires - /// copying data if the tensor has a non-contiguous layout. - pub fn reshape(&mut self, shape: &[usize]) - where - T: Clone, - { - let len: usize = shape.iter().product(); - let current_len = self.len(); - assert!( - len == current_len, - "New shape must have same total elements as current shape" - ); - - // We currently always copy data whenever the input is non-contiguous. - // However there are cases of custom strides where copies could be - // avoided. See https://pytorch.org/docs/stable/generated/torch.Tensor.view.html. - self.make_contiguous(); - self.layout = DynLayout::from_shape(shape); - } - - /// Like [Tensor::reshape] but consumes self. - pub fn into_shape(mut self, new_shape: &[usize]) -> Tensor - where - T: Clone, - { - self.reshape(new_shape); - self - } -} - -impl> TensorBase { - /// Serialize the tensor to a simple binary format. - /// - /// The serialized data is in little-endian order and has the structure: - /// - /// ```text - /// [ndim: u32][dims: u32 * rank][elements: T * product(dims)] - /// ``` - /// - /// Where `T` is the tensor's element type. - pub fn write(&self, writer: &mut W) -> io::Result<()> { - let mut buf_writer = io::BufWriter::new(writer); - let ndim: u32 = self.ndim() as u32; - buf_writer.write_all(&ndim.to_le_bytes())?; - for &dim in self.shape() { - buf_writer.write_all(&(dim as u32).to_le_bytes())?; - } - for el in self.iter() { - buf_writer.write_all(&el.to_le_bytes())?; - } - buf_writer.flush()?; - Ok(()) - } -} - -impl, V: View> PartialEq for TensorBase { - fn eq(&self, other: &V) -> bool { - self.shape() == other.shape().as_ref() && self.iter().eq(other.iter()) - } -} - -impl + Clone> Clone for TensorBase { - fn clone(&self) -> TensorBase { - let data = self.data.clone(); - TensorBase { - data, - layout: self.layout.clone(), - element_type: PhantomData, - } - } -} - -impl, S2: AsRef<[T]>, const N: usize> From> - for TensorBase -where - S1: Into, -{ - fn from(value: NdTensorBase) -> TensorBase { - let layout: DynLayout = value.layout().into(); - TensorBase { - data: value.into_data().into(), - layout, - element_type: PhantomData, - } - } -} - -impl FromIterator for Tensor { - fn from_iter(iter: I) -> Self - where - I: IntoIterator, - { - let data: Vec<_> = FromIterator::from_iter(iter); - Tensor::from_vec(data) - } -} - -// Trait for scalar (ie. non-array) values. -// -// This is used as a bound in contexts where we don't want a generic type -// `T` to be inferred as an array type. -pub trait Scalar {} - -impl Scalar for i32 {} -impl Scalar for f32 {} - -// The `T: Scalar` bound avoids ambiguity when choosing a `Tensor::from` -// impl for a nested array literal, as it prevents `T` from matching an array -// type. - -impl From<[T; N]> for Tensor { - /// Construct a 1D tensor from a 1D array. - fn from(value: [T; N]) -> Tensor { - Tensor::from_vec(value.iter().cloned().collect()) - } -} - -impl From<[[T; N]; M]> for Tensor { - /// Construct a 2D tensor from a nested array. - fn from(value: [[T; N]; M]) -> Tensor { - let data: Vec<_> = value.iter().flat_map(|y| y.iter()).cloned().collect(); - Tensor::from_data(&[M, N], data) - } -} - -impl From<[[[T; K]; N]; M]> - for Tensor -{ - /// Construct a 3D tensor from a nested array. - fn from(value: [[[T; K]; N]; M]) -> Tensor { - let data: Vec<_> = value - .iter() - .flat_map(|y| y.iter().flat_map(|z| z.iter())) - .cloned() - .collect(); - Tensor::from_data(&[M, N, K], data) - } -} - -/// Call `f` with every element in `x` in logical order. -/// -/// This is equivalent to `x.iter().for_each(f)` but is faster that Rust's -/// standard iteration protocol when `x` is non-contiguous and has <= 4 -/// dimensions. -fn fast_for_each_element(mut x: TensorView, mut f: F) { - if x.ndim() > 4 { - x.iter().for_each(f) - } else { - while x.ndim() < 4 { - x.insert_dim(0); - } - - // Safety: We only access valid offsets according to the shape and - // strides in the loop below. - let x_data = unsafe { x.data_unchecked() }; - let x: NdTensorView = x.nd_view(); - let shape = x.shape(); - let strides = x.strides(); - - assert!(x_data.len() >= x.layout().min_data_len()); - - for i0 in 0..shape[0] { - for i1 in 0..shape[1] { - for i2 in 0..shape[2] { - for i3 in 0..shape[3] { - let offset = - i0 * strides[0] + i1 * strides[1] + i2 * strides[2] + i3 * strides[3]; - - // Safety: We checked data length > max offset produced - // by layout. - let elt = unsafe { x_data.get_unchecked(offset) }; - f(elt) - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use std::ops::IndexMut; - - use crate::iterators::Offsets; - use crate::rng::XorShiftRng; - use crate::tensor; - use crate::{ - Lanes, LanesMut, Layout, NdTensor, NdView, SliceItem, SliceRange, Tensor, TensorView, - TensorViewMut, View, - }; - - /// Create a tensor where the value of each element is its logical index - /// plus one. - fn steps(shape: &[usize]) -> Tensor { - let steps: usize = shape.iter().product(); - Tensor::arange(1, steps as i32 + 1, None).into_shape(shape) - } - - #[test] - fn test_apply() { - let mut x = steps(&[3, 3]); - x.apply(|el| el * el); - let expected = Tensor::from_data(&[3, 3], vec![1, 4, 9, 16, 25, 36, 49, 64, 81]); - assert_eq!(x, expected); - } - - #[test] - fn test_fill() { - let mut x = Tensor::zeros(&[2, 2]); - x.fill(1i32); - assert_eq!(x.to_vec(), &[1, 1, 1, 1]); - - x.slice_mut(0).fill(2); - x.slice_mut(1).fill(3); - - assert_eq!(x.to_vec(), &[2, 2, 3, 3]); - } - - #[test] - fn test_arange() { - let x = Tensor::arange(1, 5, None); - assert_eq!(x.to_vec(), [1, 2, 3, 4]); - - let x = Tensor::arange(1, 10, Some(2)); - assert_eq!(x.to_vec(), [1, 3, 5, 7, 9]); - - let x = Tensor::arange(1., 5., None); - assert_eq!(x.to_vec(), [1., 2., 3., 4.]); - - let x = Tensor::arange(1., 5., Some(0.5)); - assert_eq!(x.to_vec(), [1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5]); - } - - #[test] - fn test_axis_iter() { - let x = steps(&[2, 3, 4]); - - // First dimension. - let views: Vec<_> = x.axis_iter(0).collect(); - assert_eq!(views.len(), 2); - assert_eq!(views[0], x.slice([0])); - assert_eq!(views[1], x.slice([1])); - - // Second dimension. - let views: Vec<_> = x.axis_iter(1).collect(); - assert_eq!(views.len(), 3); - assert_eq!(views[0], x.slice((.., 0))); - assert_eq!(views[1], x.slice((.., 1))); - } - - #[test] - fn test_axis_iter_mut() { - let mut x = steps(&[2, 3]); - let y0 = x.slice([0]).to_tensor(); - let y1 = x.slice([1]).to_tensor(); - - // First dimension. - let mut views: Vec<_> = x.axis_iter_mut(0).collect(); - assert_eq!(views.len(), 2); - assert_eq!(views[0], y0); - assert_eq!(views[1], y1); - views[0].iter_mut().for_each(|x| *x += 1); - views[1].iter_mut().for_each(|x| *x += 2); - assert_eq!(x.to_vec(), &[2, 3, 4, 6, 7, 8]); - - let z0 = x.slice((.., 0)).to_tensor(); - let z1 = x.slice((.., 1)).to_tensor(); - - // Second dimension. - let views: Vec<_> = x.axis_iter_mut(1).collect(); - assert_eq!(views.len(), 3); - assert_eq!(views[0], z0); - assert_eq!(views[1], z1); - } - - #[test] - fn test_axis_chunks() { - let x = steps(&[4, 2, 2]); - - let mut chunks = x.axis_chunks(0, 2); - - let chunk = chunks.next().expect("expected chunk"); - assert_eq!(chunk.shape(), &[2, 2, 2]); - assert_eq!( - chunk.iter().copied().collect::>(), - &[1, 2, 3, 4, 5, 6, 7, 8] - ); - - let chunk = chunks.next().expect("expected chunk"); - assert_eq!(chunk.shape(), &[2, 2, 2]); - assert_eq!( - chunk.iter().copied().collect::>(), - &[9, 10, 11, 12, 13, 14, 15, 16] - ); - - assert!(chunks.next().is_none()); - } - - #[test] - fn test_axis_chunks_mut() { - let mut x = steps(&[4, 2, 2]); - - let mut chunks = x.axis_chunks_mut(0, 2); - - let mut chunk = chunks.next().expect("expected chunk"); - assert_eq!(chunk.shape(), &[2, 2, 2]); - chunk.iter_mut().for_each(|x| *x *= 10); - assert_eq!( - chunk.iter().copied().collect::>(), - &[10, 20, 30, 40, 50, 60, 70, 80] - ); - - let mut chunk = chunks.next().expect("expected chunk"); - assert_eq!(chunk.shape(), &[2, 2, 2]); - chunk.iter_mut().for_each(|x| *x *= 10); - assert_eq!( - chunk.iter().copied().collect::>(), - &[90, 100, 110, 120, 130, 140, 150, 160] - ); - - assert!(chunks.next().is_none()); - } - - #[test] - fn test_clip_dim() { - let mut x = steps(&[3, 3]); - x.clip_dim(0, 1..2); - x.clip_dim(1, 1..2); - assert_eq!(x.to_vec(), vec![5]); - } - - #[test] - fn test_clip_dim_start() { - let mut x = steps(&[3, 3]); - - // Clip the start of the tensor, adjusting the `base` offset. - x.clip_dim(0, 1..3); - - // Indexing should reflect the slice. - assert_eq!(x.to_vec(), &[4, 5, 6, 7, 8, 9]); - assert_eq!(x[[0, 0]], 4); - assert_eq!(*x.index_mut([0, 0]), 4); - - // Slices returned by `data` should reflect the slice. - assert_eq!(x.data(), Some([4, 5, 6, 7, 8, 9].as_slice())); - assert_eq!(x.data_mut().as_deref(), Some([4, 5, 6, 7, 8, 9].as_slice())); - - // Offsets should be relative to the sliced returned by `data`, - // `data_mut`. - assert_eq!( - Offsets::new(&x).collect::>(), - &[0, 1, 2, 3, 4, 5] - ); - assert_eq!(x.layout().offset(&[0, 0]), 0); - } - - #[test] - fn test_copy_from() { - let x = steps(&[3, 3]); - let mut y = Tensor::zeros(x.shape()); - - y.copy_from(&x.view()); - - assert_eq!(y, x); - } - - #[test] - fn test_from_arrays() { - // 2D - let x = Tensor::from([[2, 3], [4, 5], [6, 7]]); - assert_eq!(x.shape(), &[3, 2]); - assert_eq!(x.data(), Some([2, 3, 4, 5, 6, 7].as_slice())); - - // 3D - let x = Tensor::from([[[2, 3], [4, 5], [6, 7]]]); - assert_eq!(x.shape(), &[1, 3, 2]); - assert_eq!(x.data(), Some([2, 3, 4, 5, 6, 7].as_slice())); - } - - #[test] - fn test_from_scalar() { - let x = Tensor::from_scalar(5); - assert_eq!(x.shape().len(), 0); - assert_eq!(x.data(), Some([5].as_slice())); - } - - #[test] - fn test_from_vec() { - let x = tensor!([1, 2, 3]); - assert_eq!(x.shape(), &[3]); - assert_eq!(x.data(), Some([1, 2, 3].as_slice())); - } - - #[test] - fn test_full() { - let x = Tensor::full(&[2, 2], 1.0); - assert_eq!(x.shape(), &[2, 2]); - assert_eq!(x.data(), Some([1., 1., 1., 1.].as_slice())); - } - - #[test] - fn test_from_iterator() { - let x: Tensor = FromIterator::from_iter(0..10); - assert_eq!(x.shape(), &[10]); - assert_eq!(x.data(), Some([0, 1, 2, 3, 4, 5, 6, 7, 8, 9].as_slice())); - } - - #[test] - fn test_stride() { - let x = Tensor::::zeros(&[2, 5, 7, 3]); - assert_eq!(x.stride(3), 1); - assert_eq!(x.stride(2), 3); - assert_eq!(x.stride(1), 7 * 3); - assert_eq!(x.stride(0), 5 * 7 * 3); - } - - #[test] - fn test_strides() { - let x = Tensor::::zeros(&[2, 5, 7, 3]); - assert_eq!(x.strides(), [5 * 7 * 3, 7 * 3, 3, 1]); - } - - #[test] - fn test_get() { - let mut x = Tensor::::zeros(&[2, 2]); - - x.data[0] = 1.0; - x.data[1] = 2.0; - x.data[2] = 3.0; - x.data[3] = 4.0; - - // Index with fixed-sized array. - assert_eq!(x.get([0, 0]), Some(&1.0)); - assert_eq!(x.get([0, 1]), Some(&2.0)); - assert_eq!(x.get([1, 0]), Some(&3.0)); - assert_eq!(x.get([1, 1]), Some(&4.0)); - - // Invalid indices - assert_eq!(x.get([0, 2]), None); - assert_eq!(x.get([2, 0]), None); - assert_eq!(x.get([1, 0, 0]), None); - - // Index with slice. - assert_eq!(x.get([0, 0].as_slice()), Some(&1.0)); - assert_eq!(x.get([0, 1].as_slice()), Some(&2.0)); - assert_eq!(x.get([1, 0].as_slice()), Some(&3.0)); - assert_eq!(x.get([1, 1].as_slice()), Some(&4.0)); - } - - #[test] - fn test_index() { - let mut x = Tensor::::zeros(&[2, 2]); - - x.data[0] = 1.0; - x.data[1] = 2.0; - x.data[2] = 3.0; - x.data[3] = 4.0; - - // Index with fixed-sized array. - assert_eq!(x[[0, 0]], 1.0); - assert_eq!(x[[0, 1]], 2.0); - assert_eq!(x[[1, 0]], 3.0); - assert_eq!(x[[1, 1]], 4.0); - - // Index with slice. - assert_eq!(x[[0, 0].as_slice()], 1.0); - assert_eq!(x[[0, 1].as_slice()], 2.0); - assert_eq!(x[[1, 0].as_slice()], 3.0); - assert_eq!(x[[1, 1].as_slice()], 4.0); - } - - #[test] - fn test_index_scalar() { - let x = Tensor::from_scalar(5.0); - assert_eq!(x[[]], 5.0); - } - - #[test] - fn test_index_mut() { - let mut x = Tensor::::zeros(&[2, 2]); - - x[[0, 0]] = 1.0; - x[[0, 1]] = 2.0; - x[[1, 0]] = 3.0; - x[[1, 1]] = 4.0; - - assert_eq!(x.data[0], 1.0); - assert_eq!(x.data[1], 2.0); - assert_eq!(x.data[2], 3.0); - assert_eq!(x.data[3], 4.0); - } - - #[test] - fn test_get_mut() { - let mut x = Tensor::::zeros(&[2, 2]); - - *x.get_mut([0, 0]).unwrap() = 1.0; - *x.get_mut([0, 1]).unwrap() = 2.0; - *x.get_mut([1, 0]).unwrap() = 3.0; - *x.get_mut([1, 1]).unwrap() = 4.0; - - assert_eq!(x.data[0], 1.0); - assert_eq!(x.data[1], 2.0); - assert_eq!(x.data[2], 3.0); - assert_eq!(x.data[3], 4.0); - - assert_eq!(x.get_mut([1, 2]), None); - } - - #[test] - #[should_panic] - fn test_index_panics_if_invalid() { - let x = Tensor::::zeros(&[2, 2]); - x[[2, 0]]; - } - - #[test] - #[should_panic] - fn test_index_panics_if_wrong_dim_count() { - let x = Tensor::::zeros(&[2, 2]); - x[[0, 0, 0]]; - } - - #[test] - fn test_indices() { - let x = Tensor::::zeros(&[2, 2]); - let x_indices = { - let mut indices = Vec::new(); - let mut iter = x.indices(); - while let Some(index) = iter.next() { - indices.push(index.to_vec()); - } - indices - }; - assert_eq!( - x_indices, - &[vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1],] - ); - } - - #[test] - fn test_item() { - let scalar = Tensor::from_scalar(5.0); - assert_eq!(scalar.item(), Some(&5.0)); - - let vec_one_item = tensor!([5.0]); - assert_eq!(vec_one_item.item(), Some(&5.0)); - - let vec_many_items = tensor!([1.0, 2.0]); - assert_eq!(vec_many_items.item(), None); - - let matrix_one_item = Tensor::from_data(&[1, 1], vec![5.0]); - assert_eq!(matrix_one_item.item(), Some(&5.0)); - } - - #[test] - fn test_map() { - // Contiguous tensor. - let x = steps(&[2, 3]).map(|val| val * 2); - assert_eq!(x.to_vec(), &[2, 4, 6, 8, 10, 12]); - - // Non-contiguous view. - let x = steps(&[2, 3]); - let x = x.transposed(); - assert!(!x.is_contiguous()); - assert_eq!(x.to_vec(), &[1, 4, 2, 5, 3, 6]); - let x = x.map(|val| val * 2); - assert_eq!(x.to_vec(), &[2, 8, 4, 10, 6, 12]); - } - - #[test] - fn test_move_axis() { - let mut x = steps(&[2, 3]); - x.move_axis(1, 0); - assert_eq!(x.shape(), [3, 2]); - } - - #[test] - fn test_ndim() { - let scalar = Tensor::from_scalar(5.0); - let vec = tensor!([5.0]); - let matrix = Tensor::from_data(&[1, 1], vec![5.0]); - - assert_eq!(scalar.ndim(), 0); - assert_eq!(vec.ndim(), 1); - assert_eq!(matrix.ndim(), 2); - } - - #[test] - fn test_partial_eq() { - let x = tensor!([1, 2, 3, 4, 5]); - let y = x.clone(); - let z = x.to_shape(&[1, 5]); - - // Int tensors are equal if they have the same shape and elements. - assert_eq!(&x, &y); - assert_ne!(&x, &z); - } - - #[test] - fn test_len() { - let scalar = Tensor::from_scalar(5); - let vec = tensor!([1, 2, 3]); - let matrix = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]); - - assert_eq!(scalar.len(), 1); - assert_eq!(vec.len(), 3); - assert_eq!(matrix.len(), 4); - } - - #[test] - fn test_is_empty() { - assert!(Tensor::::from_vec(vec![]).is_empty()); - assert!(!tensor!([1]).is_empty()); - assert!(!Tensor::from_scalar(5.0).is_empty()); - } - - #[test] - fn test_reshape() { - let mut rng = XorShiftRng::new(1234); - let mut x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - let x_data: Vec = x.data().unwrap().to_vec(); - - assert_eq!(x.shape(), &[10, 5, 3, 7]); - - x.reshape(&[10, 5, 3 * 7]); - - assert_eq!(x.shape(), &[10, 5, 3 * 7]); - assert_eq!(x.data(), Some(x_data.as_slice())); - } - - #[test] - fn test_reshape_non_contiguous() { - let mut rng = XorShiftRng::new(1234); - let mut x = Tensor::rand(&[10, 10], &mut rng); - - // Set the input up so that it is non-contiguous and has a non-zero - // `base` offset. - x.permute(&[1, 0]); - x.clip_dim(0, 2..8); - - // Reshape the tensor. This should copy the data and reset the `base` - // offset. - x.reshape(&[x.shape().iter().product()]); - - // After reshaping, we should be able to successfully read all the elements. - // Note this test doesn't check that the correct elements were read. - let elts: Vec<_> = x.iter().collect(); - assert_eq!(elts.len(), 60); - - // Set up another input so it is non-contiguous and has a non-zero `base` offset. - let mut x = steps(&[3, 3]); - x.clip_dim(0, 1..3); - x.clip_dim(1, 1..3); - - // Flatten the input with reshape. - x.reshape(&[4]); - - // Check that the correct elements were read. - assert_eq!(x.to_vec(), &[5, 6, 8, 9]); - } - - #[test] - fn test_reshape_copies_with_custom_strides() { - let mut rng = XorShiftRng::new(1234); - let mut x = Tensor::rand(&[10, 10], &mut rng); - - // Give the tensor a non-default stride - x.clip_dim(1, 0..8); - assert!(!x.is_contiguous()); - let x_elements = x.to_vec(); - - x.reshape(&[80]); - - // Since the tensor had a non-default stride, `reshape` will have copied - // data. - assert_eq!(x.shape(), &[80]); - assert!(x.is_contiguous()); - assert_eq!(x.data(), Some(x_elements.as_slice())); - } - - #[test] - #[should_panic(expected = "New shape must have same total elements as current shape")] - fn test_reshape_with_wrong_size() { - let mut rng = XorShiftRng::new(1234); - let mut x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - x.reshape(&[10, 5]); - } - - #[test] - fn test_permute() { - // Test with a vector (this is a no-op) - let mut input = steps(&[5]); - assert!(input.iter().eq([1, 2, 3, 4, 5].iter())); - input.permute(&[0]); - assert!(input.iter().eq([1, 2, 3, 4, 5].iter())); - - // Test with a matrix (ie. transpose the matrix) - let mut input = steps(&[2, 3]); - assert!(input.iter().eq([1, 2, 3, 4, 5, 6].iter())); - input.permute(&[1, 0]); - assert_eq!(input.shape(), &[3, 2]); - assert!(input.iter().eq([1, 4, 2, 5, 3, 6].iter())); - - // Test with a higher-rank tensor. For this test we don't list out the - // full permuted element sequence, but just check the shape and strides - // were updated. - let mut input = steps(&[3, 4, 5]); - let (stride_0, stride_1, stride_2) = (input.stride(0), input.stride(1), input.stride(2)); - input.permute(&[2, 0, 1]); - assert_eq!(input.shape(), &[5, 3, 4]); - assert_eq!( - (input.stride(0), input.stride(1), input.stride(2)), - (stride_2, stride_0, stride_1) - ); - } - - #[test] - #[should_panic(expected = "permutation is invalid")] - fn test_permute_wrong_dim_count() { - let mut input = steps(&[2, 3]); - input.permute(&[1, 2, 3]); - } - - #[test] - fn test_transpose() { - // Test with a vector (this is a no-op) - let mut input = steps(&[5]); - input.transpose(); - assert_eq!(input.shape(), &[5]); - - // Test with a matrix - let mut input = steps(&[2, 3]); - assert!(input.iter().eq([1, 2, 3, 4, 5, 6].iter())); - input.transpose(); - assert_eq!(input.shape(), &[3, 2]); - assert!(input.iter().eq([1, 4, 2, 5, 3, 6].iter())); - - // Test with a higher-rank tensor - let mut input = steps(&[1, 3, 7]); - input.transpose(); - assert_eq!(input.shape(), [7, 3, 1]); - } - - #[test] - fn test_insert_dim() { - // Insert dims in contiguous tensor. - let mut input = steps(&[2, 3]); - input.insert_dim(1); - assert_eq!(input.shape(), &[2, 1, 3]); - assert_eq!(input.strides(), &[3, 6, 1]); - - input.insert_dim(1); - assert_eq!(input.shape(), &[2, 1, 1, 3]); - assert_eq!(input.strides(), &[3, 6, 6, 1]); - - input.insert_dim(0); - assert_eq!(input.shape(), &[1, 2, 1, 1, 3]); - assert_eq!(input.strides(), &[6, 3, 6, 6, 1]); - - // Insert dims in non-contiguous tensor. - let mut input = steps(&[2, 3]); - input.transpose(); - input.insert_dim(0); - assert_eq!(input.shape(), &[1, 3, 2]); - assert_eq!(input.strides(), &[6, 1, 3]); - - input.insert_dim(3); - assert_eq!(input.shape(), &[1, 3, 2, 1]); - assert_eq!(input.strides(), &[6, 1, 3, 6]); - - // Insert dims in a tensor where the smallest stride is > 1. - let input = steps(&[2, 4]); - let mut view = input.slice((.., SliceRange::new(0, None, 2))); - assert_eq!(view.shape(), &[2, 2]); - assert_eq!(view.strides(), &[4, 2]); - - view.insert_dim(0); - - assert_eq!(view.shape(), &[1, 2, 2]); - assert_eq!(view.strides(), &[8, 4, 2]); - } - - #[test] - fn test_to_shape() { - let mut rng = XorShiftRng::new(1234); - let x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - let y = x.to_shape(&[10, 5, 3 * 7]); - - assert_eq!(y.shape(), &[10, 5, 3 * 7]); - assert_eq!(y.data(), x.data()); - } - - #[test] - fn test_nd_view() { - let mut rng = XorShiftRng::new(1234); - let x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - let x_view = x.nd_view::<4>(); - assert_eq!(x_view.shape(), x.shape()); - assert_eq!(x_view.strides(), x.strides()); - assert_eq!(x_view.data(), x.data()); - } - - #[test] - fn test_nd_view_mut() { - let mut rng = XorShiftRng::new(1234); - let mut x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - let layout = x.layout().clone(); - let x_view = x.nd_view_mut::<4>(); - assert_eq!(x_view.shape(), layout.shape()); - assert_eq!(x_view.strides(), layout.strides()); - } - - #[test] - fn test_iter_for_contiguous_array() { - for dims in 1..7 { - let mut shape = Vec::new(); - for d in 0..dims { - shape.push(d + 1); - } - let mut rng = XorShiftRng::new(1234); - let x = Tensor::rand(&shape, &mut rng); - - let elts: Vec = x.iter().copied().collect(); - - assert_eq!(x.data(), Some(elts.as_slice())); - } - } - - #[test] - fn test_iter_for_empty_array() { - let empty = Tensor::::zeros(&[3, 0, 5]); - assert!(empty.iter().next().is_none()); - } - - #[test] - fn test_iter_for_non_contiguous_array() { - let mut x = Tensor::zeros(&[3, 3]); - for (index, elt) in x.iter_mut().enumerate() { - *elt = index + 1; - } - - // Initially tensor is contiguous, so data buffer and element sequence - // match. - assert_eq!( - x.data(), - Some(x.iter().copied().collect::>().as_slice()) - ); - - // Slice the tensor along an outer dimension. This will leave the tensor - // contiguous, and hence `data` and `elements` should return the same - // elements. - x.clip_dim(0, 0..2); - assert_eq!(x.data(), Some([1, 2, 3, 4, 5, 6].as_slice())); - assert_eq!(x.iter().copied().collect::>(), &[1, 2, 3, 4, 5, 6]); - // Test with step > 1 to exercise `Elements::nth`. - assert_eq!(x.iter().step_by(2).copied().collect::>(), &[1, 3, 5]); - - // Slice the tensor along an inner dimension. The tensor will no longer - // be contiguous and hence `elements` will return different results than - // `data`. - x.clip_dim(1, 0..2); - assert_eq!(x.data(), None); - assert_eq!(x.iter().copied().collect::>(), &[1, 2, 4, 5]); - // Test with step > 1 to exercise `Elements::nth`. - assert_eq!(x.iter().step_by(2).copied().collect::>(), &[1, 4]); - } - - // PyTorch and numpy do not allow iteration over a scalar, but it seems - // consistent for `Tensor::iter` to always yield `Tensor::len` elements, - // and `len` returns 1 for a scalar. - #[test] - fn test_iter_for_scalar() { - let x = Tensor::from_scalar(5.0); - let elements = x.iter().copied().collect::>(); - assert_eq!(&elements, &[5.0]); - } - - #[test] - fn test_iter_mut_for_contiguous_array() { - for dims in 1..7 { - let mut shape = Vec::new(); - for d in 0..dims { - shape.push(d + 1); - } - let mut rng = XorShiftRng::new(1234); - let mut x = Tensor::rand(&shape, &mut rng); - - let elts: Vec = x.iter().map(|x| x * 2.).collect(); - - for elt in x.iter_mut() { - *elt *= 2.; - } - - assert_eq!(x.data(), Some(elts.as_slice())); - } - } - - #[test] - fn test_iter_mut_for_non_contiguous_array() { - let mut x = Tensor::zeros(&[3, 3]); - for (index, elt) in x.iter_mut().enumerate() { - *elt = index + 1; - } - x.permute(&[1, 0]); - - let x_doubled: Vec = x.iter().map(|x| x * 2).collect(); - for elt in x.iter_mut() { - *elt *= 2; - } - assert_eq!(x.to_vec(), x_doubled); - } - - #[test] - fn test_lanes() { - let x = steps(&[3, 3]); - - let collect_lane = - |lanes: &mut Lanes<'_, i32>| lanes.next().map(|lane| lane.copied().collect::>()); - - let mut rows = x.lanes(1); - assert_eq!(collect_lane(&mut rows), Some([1, 2, 3].to_vec())); - assert_eq!(collect_lane(&mut rows), Some([4, 5, 6].to_vec())); - assert_eq!(collect_lane(&mut rows), Some([7, 8, 9].to_vec())); - - let mut cols = x.lanes(0); - assert_eq!(collect_lane(&mut cols), Some([1, 4, 7].to_vec())); - assert_eq!(collect_lane(&mut cols), Some([2, 5, 8].to_vec())); - assert_eq!(collect_lane(&mut cols), Some([3, 6, 9].to_vec())); - } - - #[test] - fn test_lanes_mut() { - let update_lanes = |lanes: LanesMut<'_, i32>| { - let mut lane_idx = 0; - for lane in lanes { - for el in lane { - *el = lane_idx; - } - lane_idx += 1; - } - }; - - let mut x = Tensor::zeros(&[3, 3]); - let rows = x.lanes_mut(1); - update_lanes(rows); - assert_eq!(x.to_vec(), &[0, 0, 0, 1, 1, 1, 2, 2, 2]); - - let mut x = Tensor::zeros(&[3, 3]); - let cols = x.lanes_mut(0); - update_lanes(cols); - assert_eq!(x.to_vec(), &[0, 1, 2, 0, 1, 2, 0, 1, 2]); - } - - #[test] - fn test_to_vec() { - let mut x = steps(&[3, 3]); - - // Contiguous case. This should use the fast-path. - assert_eq!(x.to_vec(), x.iter().copied().collect::>()); - - // Non-contiguous case. - x.clip_dim(1, 0..2); - assert!(!x.is_contiguous()); - assert_eq!(x.to_vec(), x.iter().copied().collect::>()); - } - - #[test] - fn test_offsets() { - let mut rng = XorShiftRng::new(1234); - let mut x = Tensor::rand(&[10, 10], &mut rng); - - let x_elts: Vec<_> = x.to_vec(); - - let x_offsets = Offsets::new(&x); - let x_data = x.data_mut().unwrap(); - let x_elts_from_offset: Vec<_> = x_offsets.map(|off| x_data[off]).collect(); - - assert_eq!(x_elts, x_elts_from_offset); - } - - #[test] - fn test_offsets_nth() { - let x = steps(&[3]); - let mut iter = Offsets::new(&x); - assert_eq!(iter.nth(0), Some(0)); - assert_eq!(iter.nth(0), Some(1)); - assert_eq!(iter.nth(0), Some(2)); - assert_eq!(iter.nth(0), None); - - let x = steps(&[10]); - let mut iter = Offsets::new(&x); - assert_eq!(iter.nth(1), Some(1)); - assert_eq!(iter.nth(5), Some(7)); - assert_eq!(iter.nth(1), Some(9)); - assert_eq!(iter.nth(0), None); - } - - #[test] - fn test_from_data() { - let scalar = Tensor::from_data(&[], vec![1.0]); - assert_eq!(scalar.len(), 1); - - let matrix = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]); - assert_eq!(matrix.shape(), &[2, 2]); - assert_eq!(matrix.data(), Some([1, 2, 3, 4].as_slice())); - } - - #[test] - fn test_from_data_with_slice() { - let matrix = TensorView::from_data(&[2, 2], [1, 2, 3, 4].as_slice()); - assert_eq!(matrix.shape(), &[2, 2]); - assert_eq!(matrix.data(), Some([1, 2, 3, 4].as_slice())); - } - - #[test] - fn test_from_data_with_mut_slice() { - let mut data = vec![1, 2, 3, 4]; - let mut matrix = TensorViewMut::from_data(&[2, 2], &mut data[..]); - matrix[[0, 1]] = 5; - matrix[[1, 0]] = 6; - assert_eq!(data, &[1, 5, 6, 4]); - } - - #[test] - #[should_panic] - fn test_from_data_panics_with_wrong_len() { - Tensor::from_data(&[1], vec![1, 2, 3]); - } - - #[test] - #[should_panic] - fn test_from_data_panics_if_scalar_data_empty() { - Tensor::::from_data(&[], vec![]); - } - - #[test] - #[should_panic] - fn test_from_data_panics_if_scalar_data_has_many_elements() { - Tensor::from_data(&[], vec![1, 2, 3]); - } - - #[test] - fn test_from_ndtensor() { - // NdTensor -> Tensor - let ndtensor = NdTensor::zeros([1, 10, 20]); - let tensor: Tensor = ndtensor.clone().into(); - assert_eq!(tensor.data(), ndtensor.data()); - assert_eq!(tensor.shape(), ndtensor.shape()); - assert_eq!(tensor.strides(), ndtensor.strides()); - - // NdTensorView -> TensorView - let view: TensorView = ndtensor.view().into(); - assert_eq!(view.shape(), ndtensor.shape()); - - // NdTensorViewMut -> TensorViewMut - let mut ndtensor = NdTensor::zeros([1, 10, 20]); - let mut view: TensorViewMut = ndtensor.view_mut().into(); - view[[0, 0, 0]] = 1; - assert_eq!(ndtensor[[0, 0, 0]], 1); - } - - #[test] - fn test_is_contiguous() { - let mut x = Tensor::zeros(&[3, 3]); - for (index, elt) in x.iter_mut().enumerate() { - *elt = index + 1; - } - - // Freshly-allocated tensor - assert!(x.is_contiguous()); - - // Tensor where outermost dimension has been clipped at the end. - let mut y = x.clone(); - y.clip_dim(0, 0..2); - assert!(y.is_contiguous()); - assert_eq!(y.data(), Some([1, 2, 3, 4, 5, 6].as_slice())); - - // Tensor where outermost dimension has been clipped at the start. - let mut y = x.clone(); - y.clip_dim(0, 1..3); - assert!(y.is_contiguous()); - assert_eq!(y.data(), Some([4, 5, 6, 7, 8, 9].as_slice())); - - // Tensor where inner dimension has been clipped at the start. - let mut y = x.clone(); - y.clip_dim(1, 1..3); - assert!(!y.is_contiguous()); - - // Tensor where inner dimension has been clipped at the end. - let mut y = x.clone(); - y.clip_dim(1, 0..2); - assert!(!y.is_contiguous()); - } - - #[test] - fn test_is_contiguous_1d() { - let mut x = Tensor::zeros(&[10]); - for (index, elt) in x.iter_mut().enumerate() { - *elt = index + 1; - } - - assert!(x.is_contiguous()); - x.clip_dim(0, 0..5); - assert!(x.is_contiguous()); - } - - #[test] - fn test_make_contiguous() { - let mut x = steps(&[3, 3]); - assert!(x.is_contiguous()); - - // Clip outer dimension at start. This will modify the base offset. - x.clip_dim(0, 1..3); - - // Clip inner dimension at start. This will modify the strides. - x.clip_dim(1, 1..3); - assert!(!x.is_contiguous()); - - x.make_contiguous(); - assert!(x.is_contiguous()); - assert_eq!(x.to_vec(), &[5, 6, 8, 9]); - } - - #[test] - fn test_to_contiguous() { - let x = steps(&[3, 3]); - let y = x.to_contiguous(); - assert!(y.is_contiguous()); - assert_eq!(y.data().unwrap().as_ptr(), x.data().unwrap().as_ptr()); - - let x = x.permuted(&[1, 0]); - let y = x.to_contiguous(); - assert!(y.is_contiguous()); - assert_eq!(x.data(), None); - assert_eq!(y.data().unwrap(), x.to_vec()); - } - - #[test] - fn test_broadcast_iter() { - let x = steps(&[1, 2, 1, 2]); - assert_eq!(x.to_vec(), &[1, 2, 3, 4]); - - // Broadcast a 1-size dimension to size 2 - let bx = x.broadcast_iter(&[2, 2, 1, 2]); - assert_eq!(bx.copied().collect::>(), &[1, 2, 3, 4, 1, 2, 3, 4]); - - // Broadcast a different 1-size dimension to size 2 - let bx = x.broadcast_iter(&[1, 2, 2, 2]); - assert_eq!(bx.copied().collect::>(), &[1, 2, 1, 2, 3, 4, 3, 4]); - - // Broadcast to a larger number of dimensions - let x = steps(&[5]); - let bx = x.broadcast_iter(&[1, 5]); - assert_eq!(bx.copied().collect::>(), &[1, 2, 3, 4, 5]); - } - - #[test] - fn test_broadcast_iter_with_scalar() { - let scalar = Tensor::from_scalar(7); - let bx = scalar.broadcast_iter(&[3, 3]); - assert_eq!( - bx.copied().collect::>(), - &[7, 7, 7, 7, 7, 7, 7, 7, 7] - ); - } - - #[test] - #[should_panic(expected = "Cannot broadcast to specified shape")] - fn test_broadcast_iter_with_invalid_shape() { - let x = steps(&[2, 2]); - x.broadcast_iter(&[3, 2]); - } - - #[test] - #[should_panic(expected = "Cannot broadcast to specified shape")] - fn test_broadcast_iter_with_shorter_shape() { - let x = steps(&[2, 2]); - x.broadcast_iter(&[4]); - } - - #[test] - fn test_broadcast() { - let x = steps(&[1, 2, 1, 2]); - assert_eq!(x.to_vec(), &[1, 2, 3, 4]); - - // Broadcast a 1-size dimension to size 2 - let bx = x.broadcast(&[2, 2, 1, 2]); - assert_eq!(bx.shape(), &[2, 2, 1, 2]); - assert_eq!(bx.strides(), &[0, x.stride(1), x.stride(2), x.stride(3)]); - assert_eq!(bx.to_vec(), &[1, 2, 3, 4, 1, 2, 3, 4]); - - // Broadcast a different 1-size dimension to size 2 - let bx = x.broadcast(&[1, 2, 2, 2]); - assert_eq!(bx.shape(), &[1, 2, 2, 2]); - assert_eq!(bx.strides(), &[x.stride(0), x.stride(1), 0, x.stride(3)]); - assert_eq!(bx.to_vec(), &[1, 2, 1, 2, 3, 4, 3, 4]); - - // Broadcast to a larger number of dimensions - let x = steps(&[5]); - let bx = x.broadcast(&[1, 5]); - assert_eq!(bx.shape(), &[1, 5]); - assert_eq!(bx.strides(), &[0, x.stride(0)]); - assert_eq!(bx.to_vec(), &[1, 2, 3, 4, 5]); - - // Broadcast a scalar - let scalar = Tensor::from_scalar(7); - let bx = scalar.broadcast(&[3, 3]); - assert_eq!(bx.shape(), &[3, 3]); - assert_eq!(bx.strides(), &[0, 0]); - assert_eq!(bx.to_vec(), &[7, 7, 7, 7, 7, 7, 7, 7, 7]); - } - - #[test] - #[should_panic(expected = "Cannot broadcast to specified shape")] - fn test_broadcast_invalid() { - let x = steps(&[2, 2]); - x.broadcast(&[4]); - } - - #[test] - fn test_can_broadcast_to() { - let x = steps(&[1, 5, 10]); - assert!(x.can_broadcast_to(&[2, 5, 10])); - assert!(x.can_broadcast_to(&[1, 5, 10])); - assert!(!x.can_broadcast_to(&[1, 1, 10])); - } - - #[test] - fn test_can_broadcast_with() { - let x = steps(&[1, 5, 10]); - assert!(x.can_broadcast_with(&[2, 5, 10])); - assert!(x.can_broadcast_with(&[1, 5, 10])); - assert!(x.can_broadcast_with(&[1, 1, 10])); - } - - #[test] - fn test_inner_iter() { - let x = steps(&[2, 2, 2]); - let mut iter = x.inner_iter::<2>(); - - let mat = iter.next().unwrap(); - assert_eq!(mat.shape(), [2, 2]); - assert_eq!(mat.iter().copied().collect::>(), [1, 2, 3, 4]); - - let mat = iter.next().unwrap(); - assert_eq!(mat.shape(), [2, 2]); - assert_eq!(mat.iter().copied().collect::>(), [5, 6, 7, 8]); - - assert_eq!(iter.next(), None); - } - - #[test] - fn test_inner_iter_mut() { - let mut x = steps(&[2, 2, 2]); - let mut iter = x.inner_iter_mut::<2>(); - - let mat = iter.next().unwrap(); - assert_eq!(mat.shape(), [2, 2]); - assert_eq!(mat.iter().copied().collect::>(), [1, 2, 3, 4]); - - let mat = iter.next().unwrap(); - assert_eq!(mat.shape(), [2, 2]); - assert_eq!(mat.iter().copied().collect::>(), [5, 6, 7, 8]); - - assert_eq!(iter.next(), None); - } - - // Common slice tests for all slicing functions. - macro_rules! slice_tests { - ($x:expr, $method:ident) => { - assert_eq!($x.shape(), &[2, 3, 4]); - - // 1D index - let y = $x.$method([0]); - assert_eq!(y.shape(), [3, 4]); - assert_eq!(y.to_vec(), (1..=(3 * 4)).into_iter().collect::>()); - - // Negative 1D index - let y = $x.$method([-1]); - assert_eq!(y.shape(), [3, 4]); - assert_eq!( - y.to_vec(), - ((3 * 4 + 1)..=(2 * 3 * 4)) - .into_iter() - .collect::>() - ); - - // 2D index - let y = $x.$method([0, 1]); - assert_eq!(y.shape(), [4]); - assert_eq!(y.to_vec(), (5..=8).into_iter().collect::>()); - - // 3D index - let y = $x.$method([0, 1, 2]); - assert_eq!(y.shape(), []); - assert_eq!(y.item(), Some(&7)); - - // Full range - let y = $x.$method([..]); - assert_eq!(y.shape(), [2, 3, 4]); - assert_eq!(y.to_vec(), $x.to_vec()); - - // Partial ranges - let y = $x.$method((.., ..2, 1..)); - assert_eq!(y.shape(), [2, 2, 3]); - - // Stepped range - let y = $x.$method((.., .., SliceItem::range(0, None, 2))); - assert_eq!( - y.to_vec(), - $x.iter() - .copied() - .enumerate() - .filter_map(|(i, x)| (i % 2 == 0).then_some(x)) - .collect::>() - ); - - // Mixed indices and ranges - let y = $x.$method((.., 0, ..)); - assert_eq!(y.shape(), [2, 4]); - - let y = $x.$method((.., .., 0)); - assert_eq!(y.shape(), [2, 3]); - assert_eq!(y.to_vec(), &[1, 5, 9, 13, 17, 21]); - }; - } - - #[test] - fn test_slice() { - let x = steps(&[2, 3, 4]); - slice_tests!(x.view(), slice); - } - - #[test] - fn test_slice_mut() { - let mut x = steps(&[2, 3, 4]); - slice_tests!(x, slice_mut); - } - - #[test] - fn test_slice_iter() { - let sr = |start, end| SliceItem::range(start, Some(end), 1); - let x = steps(&[3, 3]); - - // Slice that extracts a specific index - let slice: Vec<_> = x - .slice_iter(&[SliceItem::Index(0), SliceItem::full_range()]) - .copied() - .collect(); - assert_eq!(slice, &[1, 2, 3]); - - // Slice that removes start of each dimension - let slice: Vec<_> = x.slice_iter(&[sr(1, 3), sr(1, 3)]).copied().collect(); - assert_eq!(slice, &[5, 6, 8, 9]); - - // Slice that removes end of each dimension - let slice: Vec<_> = x.slice_iter(&[sr(0, 2), sr(0, 2)]).copied().collect(); - assert_eq!(slice, &[1, 2, 4, 5]); - - // Slice that removes start and end of first dimension - let slice: Vec<_> = x.slice_iter(&[sr(1, 2), sr(0, 3)]).copied().collect(); - assert_eq!(slice, &[4, 5, 6]); - - // Slice that removes start and end of second dimension - let slice: Vec<_> = x.slice_iter(&[sr(0, 3), sr(1, 2)]).copied().collect(); - assert_eq!(slice, &[2, 5, 8]); - } - - #[test] - fn test_slice_iter_with_step() { - let sr = |start, end, step| SliceItem::range(start, Some(end), step); - let x = steps(&[10]); - - // Positive steps > 1. - let slice: Vec<_> = x.slice_iter(&[sr(0, 10, 2)]).copied().collect(); - assert_eq!(slice, &[1, 3, 5, 7, 9]); - - let slice: Vec<_> = x.slice_iter(&[sr(0, 10, 3)]).copied().collect(); - assert_eq!(slice, &[1, 4, 7, 10]); - - let slice: Vec<_> = x.slice_iter(&[sr(0, 10, 10)]).copied().collect(); - assert_eq!(slice, &[1]); - - // Negative steps. - let slice: Vec<_> = x.slice_iter(&[sr(10, -11, -1)]).copied().collect(); - assert_eq!(slice, &[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]); - - let slice: Vec<_> = x.slice_iter(&[sr(8, 0, -1)]).copied().collect(); - assert_eq!(slice, &[9, 8, 7, 6, 5, 4, 3, 2]); - - let slice: Vec<_> = x.slice_iter(&[sr(10, 0, -2)]).copied().collect(); - assert_eq!(slice, &[10, 8, 6, 4, 2]); - - let slice: Vec<_> = x.slice_iter(&[sr(10, 0, -10)]).copied().collect(); - assert_eq!(slice, &[10]); - } - - #[test] - fn test_slice_iter_negative_indices() { - let sr = |start, end| SliceItem::range(start, Some(end), 1); - let x = steps(&[10]); - - // Negative start - let slice: Vec<_> = x.slice_iter(&[sr(-2, 10)]).copied().collect(); - assert_eq!(slice, &[9, 10]); - - // Negative end - let slice: Vec<_> = x.slice_iter(&[sr(7, -1)]).copied().collect(); - assert_eq!(slice, &[8, 9]); - - // Negative start and end - let slice: Vec<_> = x.slice_iter(&[sr(-3, -1)]).copied().collect(); - assert_eq!(slice, &[8, 9]); - } - - #[test] - fn test_slice_iter_clamps_indices() { - let sr = |start, end, step| SliceItem::range(start, Some(end), step); - let x = steps(&[5]); - - // Test cases for positive steps (ie. traversing forwards). - - // Positive start out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(10, 11, 1)]).collect(); - assert_eq!(slice.len(), 0); - - // Positive end out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(0, 10, 1)]).copied().collect(); - assert_eq!(slice, &[1, 2, 3, 4, 5]); - - // Negative start out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(-10, 5, 1)]).copied().collect(); - assert_eq!(slice, &[1, 2, 3, 4, 5]); - - // Negative end out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(-10, -5, 1)]).collect(); - assert_eq!(slice.len(), 0); - - // Test cases for negative steps (ie. traversing backwards). - - // Positive start out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(10, -6, -1)]).copied().collect(); - assert_eq!(slice, &[5, 4, 3, 2, 1]); - - // Positive end out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(0, 10, -1)]).collect(); - assert_eq!(slice.len(), 0); - - // Negative start out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(-10, 5, -1)]).collect(); - assert_eq!(slice.len(), 0); - - // Negative end out of bounds - let slice: Vec<_> = x.slice_iter(&[sr(-1, -10, -1)]).copied().collect(); - assert_eq!(slice, &[5, 4, 3, 2, 1]); - } - - #[test] - fn test_slice_iter_start_end_step_combinations() { - let sr = |start, end, step| SliceItem::range(start, Some(end), step); - let x = steps(&[3]); - - // Test various combinations of slice starts, ends and steps that are - // positive and negative, in-bounds and out-of-bounds, and ensure they - // don't cause a panic. - for start in -5..5 { - for end in -5..5 { - for step in -5..5 { - if step == 0 { - continue; - } - x.slice_iter(&[sr(start, end, step)]).for_each(drop); - } - } - } - } - - #[test] - fn test_squeezed() { - let mut rng = XorShiftRng::new(1234); - let x = Tensor::rand(&[1, 1, 10, 20], &mut rng); - let y = x.squeezed(); - assert_eq!(y.data(), x.data()); - assert_eq!(y.shape(), &[10, 20]); - assert_eq!(y.stride(0), 20); - assert_eq!(y.stride(1), 1); - } - - #[test] - fn test_write() -> std::io::Result<()> { - use std::io::{Cursor, Read}; - let x = Tensor::from_data(&[2, 3], vec![1., 2., 3., 4., 5., 6.]); - let mut buf: Vec = Vec::new(); - - x.write(&mut buf)?; - - assert_eq!(buf.len(), 4 + x.ndim() * 4 + x.len() * 4); - - let mut cursor = Cursor::new(buf); - let mut tmp = [0u8; 4]; - - cursor.read(&mut tmp)?; - let ndim = u32::from_le_bytes(tmp); - assert_eq!(ndim, x.ndim() as u32); - - for &size in x.shape().iter() { - cursor.read(&mut tmp)?; - let written_size = u32::from_le_bytes(tmp); - assert_eq!(written_size, size as u32); - } - - for el in x.iter().copied() { - cursor.read(&mut tmp)?; - let written_el = f32::from_le_bytes(tmp); - assert_eq!(written_el, el); - } - - Ok(()) - } -} From 19dc6ee6fa2e47af961432170b6dcadf084f142f Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 21 Jan 2024 20:22:25 +0000 Subject: [PATCH 07/12] Rename unified_tensor module back to `tensor` Now that the legacy implementation has been removed, we can drop the `unified_` part of the name, and also combine iterators back into a single module. --- rten-tensor/src/iterators.rs | 249 +++++++++++++++++- rten-tensor/src/lib.rs | 16 +- .../src/{unified_tensor.rs => tensor.rs} | 8 +- rten-tensor/src/unified_tensor/iterators.rs | 247 ----------------- 4 files changed, 257 insertions(+), 263 deletions(-) rename rten-tensor/src/{unified_tensor.rs => tensor.rs} (99%) delete mode 100644 rten-tensor/src/unified_tensor/iterators.rs diff --git a/rten-tensor/src/iterators.rs b/rten-tensor/src/iterators.rs index 96288fac..ad5753fc 100644 --- a/rten-tensor/src/iterators.rs +++ b/rten-tensor/src/iterators.rs @@ -1,9 +1,14 @@ use std::iter::{repeat, zip, Cycle, FusedIterator, StepBy, Take}; -use std::ops::Range; +use std::ops::{Add, Range}; use std::slice; -use super::range::{SliceItem, SliceRange}; -use crate::Layout; +use crate::index_iterator::DynIndices; +use crate::layout::Layout; +use crate::range::{to_slice_items, SliceItem, SliceRange}; + +use super::{ + AsView, MutLayout, NdTensorView, NdTensorViewMut, TensorBase, TensorView, TensorViewMut, +}; /// Borrowed reference to a tensor's data and layout. This differs from /// [TensorView] in that it borrows the layout rather than having its own. @@ -804,6 +809,244 @@ impl<'a, T> Iterator for LanesMut<'a, T> { } } +/// Iterator over views of the N innermost dimensions of a tensor with element +/// type `T` and layout `L`. +pub struct InnerIter<'a, T, L: MutLayout, const N: usize> { + outer_indices: DynIndices, + view: TensorBase, +} + +impl<'a, T, L: MutLayout, const N: usize> InnerIter<'a, T, L, N> { + pub fn new(view: TensorBase) -> Self { + assert!(view.ndim() >= N); + let outer_dims = view.ndim() - N; + let outer_indices = DynIndices::from_shape(&view.shape().as_ref()[..outer_dims]); + InnerIter { + outer_indices, + view, + } + } +} + +impl<'a, T, L: MutLayout, const N: usize> Iterator for InnerIter<'a, T, L, N> { + type Item = NdTensorView<'a, T, N>; + + fn next(&mut self) -> Option { + self.outer_indices.next().map(|idx| { + let slice_items = to_slice_items(&idx); + self.view.slice(slice_items.as_slice()) + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.outer_indices.size_hint() + } +} + +impl<'a, T, L: MutLayout, const N: usize> ExactSizeIterator for InnerIter<'a, T, L, N> {} + +/// Iterator over mutable views of the N innermost dimensions of a tensor. +pub struct InnerIterMut<'a, T, L: MutLayout, const N: usize> { + outer_indices: DynIndices, + view: TensorBase, +} + +impl<'a, T, L: MutLayout, const N: usize> InnerIterMut<'a, T, L, N> { + pub fn new(view: TensorBase) -> Self { + assert!(view.ndim() >= N); + let outer_dims = view.ndim() - N; + let outer_indices = DynIndices::from_shape(&view.shape().as_ref()[..outer_dims]); + InnerIterMut { + outer_indices, + view, + } + } +} + +impl<'a, T, L: MutLayout, const N: usize> Iterator for InnerIterMut<'a, T, L, N> { + type Item = NdTensorViewMut<'a, T, N>; + + fn next(&mut self) -> Option { + self.outer_indices.next().map(|idx| { + let slice_items = to_slice_items(&idx); + let view: NdTensorViewMut<'_, T, N> = self.view.slice_mut(slice_items.as_slice()); + unsafe { + // Safety: Outer view is non-broadcasting, and we increment the + // outer index each time, so returned views will not overlap. + std::mem::transmute::, NdTensorViewMut<'a, T, N>>(view) + } + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.outer_indices.size_hint() + } +} + +impl<'a, T, L: MutLayout, const N: usize> ExactSizeIterator for InnerIterMut<'a, T, L, N> {} + +/// Iterator over slices of a tensor along an axis. See [TensorView::axis_iter]. +pub struct AxisIter<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, +} + +impl<'a, T, L: MutLayout> AxisIter<'a, T, L> { + pub fn new(view: &TensorBase, dim: usize) -> AxisIter<'a, T, L> { + let mut permuted = view.clone(); + permuted.move_axis(dim, 0); + AxisIter { + view: permuted, + index: 0, + } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisIter<'a, T, L> { + type Item = TensorView<'a, T>; + + fn next(&mut self) -> Option { + if self.index >= self.view.size(0) { + None + } else { + let view = self.view.slice_dyn([self.index]); + self.index += 1; + Some(view) + } + } +} + +/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_iter_mut]. +pub struct AxisIterMut<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, +} + +impl<'a, T, L: MutLayout> AxisIterMut<'a, T, L> { + pub fn new(mut view: TensorBase, dim: usize) -> AxisIterMut<'a, T, L> { + // See notes in `Layout` about internal overlap. + assert!( + !view.layout().is_broadcast(), + "Cannot mutably iterate over broadcasting view" + ); + view.move_axis(dim, 0); + AxisIterMut { view, index: 0 } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisIterMut<'a, T, L> { + type Item = TensorViewMut<'a, T>; + + fn next(&mut self) -> Option { + if self.index >= self.view.size(0) { + None + } else { + let index = self.index; + self.index += 1; + + // Safety: This is non-broadcasting view, and we increment the index + // each time, so returned views will not overlap. + let view = unsafe { + let view = self.view.slice_mut_dyn([index]); + std::mem::transmute::, TensorViewMut<'a, T>>(view) + }; + Some(view) + } + } +} + +/// Iterator over slices of a tensor along an axis. See [TensorView::axis_chunks]. +pub struct AxisChunks<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, + chunk_size: usize, +} + +impl<'a, T, L: MutLayout> AxisChunks<'a, T, L> { + pub fn new( + view: &TensorBase, + dim: usize, + chunk_size: usize, + ) -> AxisChunks<'a, T, L> { + let mut permuted = view.clone(); + permuted.move_axis(dim, 0); + AxisChunks { + view: permuted, + index: 0, + chunk_size, + } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> { + type Item = TensorView<'a, T>; + + fn next(&mut self) -> Option { + let size = self.view.size(0); + if self.index >= self.view.size(0) { + None + } else { + let view = self + .view + .slice_dyn(self.index..self.index.add(self.chunk_size).min(size)); + self.index += self.chunk_size; + Some(view) + } + } +} + +/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_chunks_mut]. +pub struct AxisChunksMut<'a, T, L: MutLayout> { + view: TensorBase, + index: usize, + chunk_size: usize, +} + +impl<'a, T, L: MutLayout> AxisChunksMut<'a, T, L> { + pub fn new( + mut view: TensorBase, + dim: usize, + chunk_size: usize, + ) -> AxisChunksMut<'a, T, L> { + // See notes in `Layout` about internal overlap. + assert!( + !view.layout().is_broadcast(), + "Cannot mutably iterate over broadcasting view" + ); + view.move_axis(dim, 0); + AxisChunksMut { + view, + chunk_size, + index: 0, + } + } +} + +impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> { + type Item = TensorViewMut<'a, T>; + + fn next(&mut self) -> Option { + let size = self.view.size(0); + + if self.index >= size { + None + } else { + let index = self.index; + self.index += self.chunk_size; + + // Safety: This is non-broadcasting view, and we increment the index + // each time, so returned views will not overlap. + let view = unsafe { + let view = self + .view + .slice_mut_dyn(index..index.add(self.chunk_size).min(size)); + std::mem::transmute::, TensorViewMut<'a, T>>(view) + }; + Some(view) + } + } +} + // Tests for iterator internals. Most tests of iterators are currently done via // tests on tensor methods. #[cfg(test)] diff --git a/rten-tensor/src/lib.rs b/rten-tensor/src/lib.rs index 5817e7d8..c2756981 100644 --- a/rten-tensor/src/lib.rs +++ b/rten-tensor/src/lib.rs @@ -41,8 +41,7 @@ mod layout; mod macros; mod overlap; mod range; - -mod unified_tensor; +mod tensor; /// Trait for sources of random data for tensors, for use with [Tensor::rand]. pub trait RandomSource { @@ -51,11 +50,14 @@ pub trait RandomSource { } pub use index_iterator::{DynIndices, Indices, NdIndices}; -pub use iterators::{BroadcastIter, Iter, IterMut, Lanes, LanesMut, Offsets}; +pub use iterators::{ + AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterMut, Iter, + IterMut, Lanes, LanesMut, Offsets, +}; pub use layout::{is_valid_permutation, DynLayout, Layout, MatrixLayout, NdLayout}; pub use range::{to_slice_items, DynSliceItems, IntoSliceItems, SliceItem, SliceRange}; -pub use unified_tensor::{ +pub use tensor::{ AsView, Matrix, MatrixMut, MutLayout, NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorBase, TensorView, TensorViewMut, }; @@ -63,12 +65,8 @@ pub use unified_tensor::{ // For backwards compatibility. pub type NdTensorBase = TensorBase>; -pub use unified_tensor::iterators::{ - AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut, -}; - // For backwards compatibility. -pub use unified_tensor::{AsView as View, AsView as NdView}; +pub use tensor::{AsView as View, AsView as NdView}; /// This module provides a convenient way to import the most common traits /// from this library via a glob import. diff --git a/rten-tensor/src/unified_tensor.rs b/rten-tensor/src/tensor.rs similarity index 99% rename from rten-tensor/src/unified_tensor.rs rename to rten-tensor/src/tensor.rs index 6ef41f18..272a5e40 100644 --- a/rten-tensor/src/unified_tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -3,13 +3,13 @@ use std::marker::PhantomData; use std::ops::{Index, IndexMut, Range}; use crate::errors::{DimensionError, FromDataError, SliceError}; -use crate::iterators::{BroadcastIter, Iter, IterMut, Lanes, LanesMut, MutViewRef, ViewRef}; +use crate::iterators::{ + AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterMut, Iter, + IterMut, Lanes, LanesMut, MutViewRef, ViewRef, +}; use crate::layout::{DynLayout, Layout, MatrixLayout, NdLayout, OverlapPolicy}; use crate::{IntoSliceItems, RandomSource, SliceItem}; -pub mod iterators; -use iterators::{AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut}; - /// The base type for multi-dimensional arrays. This consists of storage for /// elements, plus a _layout_ which maps from a multi-dimensional array index /// to a storage offset. This base type is not normally used directly but diff --git a/rten-tensor/src/unified_tensor/iterators.rs b/rten-tensor/src/unified_tensor/iterators.rs deleted file mode 100644 index f0e23e57..00000000 --- a/rten-tensor/src/unified_tensor/iterators.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::ops::Add; - -use crate::index_iterator::DynIndices; -use crate::layout::Layout; -use crate::range::to_slice_items; - -use super::{ - AsView, MutLayout, NdTensorView, NdTensorViewMut, TensorBase, TensorView, TensorViewMut, -}; - -/// Iterator over views of the N innermost dimensions of a tensor with element -/// type `T` and layout `L`. -pub struct InnerIter<'a, T, L: MutLayout, const N: usize> { - outer_indices: DynIndices, - view: TensorBase, -} - -impl<'a, T, L: MutLayout, const N: usize> InnerIter<'a, T, L, N> { - pub fn new(view: TensorBase) -> Self { - assert!(view.ndim() >= N); - let outer_dims = view.ndim() - N; - let outer_indices = DynIndices::from_shape(&view.shape().as_ref()[..outer_dims]); - InnerIter { - outer_indices, - view, - } - } -} - -impl<'a, T, L: MutLayout, const N: usize> Iterator for InnerIter<'a, T, L, N> { - type Item = NdTensorView<'a, T, N>; - - fn next(&mut self) -> Option { - self.outer_indices.next().map(|idx| { - let slice_items = to_slice_items(&idx); - self.view.slice(slice_items.as_slice()) - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.outer_indices.size_hint() - } -} - -impl<'a, T, L: MutLayout, const N: usize> ExactSizeIterator for InnerIter<'a, T, L, N> {} - -/// Iterator over mutable views of the N innermost dimensions of a tensor. -pub struct InnerIterMut<'a, T, L: MutLayout, const N: usize> { - outer_indices: DynIndices, - view: TensorBase, -} - -impl<'a, T, L: MutLayout, const N: usize> InnerIterMut<'a, T, L, N> { - pub fn new(view: TensorBase) -> Self { - assert!(view.ndim() >= N); - let outer_dims = view.ndim() - N; - let outer_indices = DynIndices::from_shape(&view.shape().as_ref()[..outer_dims]); - InnerIterMut { - outer_indices, - view, - } - } -} - -impl<'a, T, L: MutLayout, const N: usize> Iterator for InnerIterMut<'a, T, L, N> { - type Item = NdTensorViewMut<'a, T, N>; - - fn next(&mut self) -> Option { - self.outer_indices.next().map(|idx| { - let slice_items = to_slice_items(&idx); - let view: NdTensorViewMut<'_, T, N> = self.view.slice_mut(slice_items.as_slice()); - unsafe { - // Safety: Outer view is non-broadcasting, and we increment the - // outer index each time, so returned views will not overlap. - std::mem::transmute::, NdTensorViewMut<'a, T, N>>(view) - } - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.outer_indices.size_hint() - } -} - -impl<'a, T, L: MutLayout, const N: usize> ExactSizeIterator for InnerIterMut<'a, T, L, N> {} - -/// Iterator over slices of a tensor along an axis. See [TensorView::axis_iter]. -pub struct AxisIter<'a, T, L: MutLayout> { - view: TensorBase, - index: usize, -} - -impl<'a, T, L: MutLayout> AxisIter<'a, T, L> { - pub fn new(view: &TensorBase, dim: usize) -> AxisIter<'a, T, L> { - let mut permuted = view.clone(); - permuted.move_axis(dim, 0); - AxisIter { - view: permuted, - index: 0, - } - } -} - -impl<'a, T, L: MutLayout> Iterator for AxisIter<'a, T, L> { - type Item = TensorView<'a, T>; - - fn next(&mut self) -> Option { - if self.index >= self.view.size(0) { - None - } else { - let view = self.view.slice_dyn([self.index]); - self.index += 1; - Some(view) - } - } -} - -/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_iter_mut]. -pub struct AxisIterMut<'a, T, L: MutLayout> { - view: TensorBase, - index: usize, -} - -impl<'a, T, L: MutLayout> AxisIterMut<'a, T, L> { - pub fn new(mut view: TensorBase, dim: usize) -> AxisIterMut<'a, T, L> { - // See notes in `Layout` about internal overlap. - assert!( - !view.layout().is_broadcast(), - "Cannot mutably iterate over broadcasting view" - ); - view.move_axis(dim, 0); - AxisIterMut { view, index: 0 } - } -} - -impl<'a, T, L: MutLayout> Iterator for AxisIterMut<'a, T, L> { - type Item = TensorViewMut<'a, T>; - - fn next(&mut self) -> Option { - if self.index >= self.view.size(0) { - None - } else { - let index = self.index; - self.index += 1; - - // Safety: This is non-broadcasting view, and we increment the index - // each time, so returned views will not overlap. - let view = unsafe { - let view = self.view.slice_mut_dyn([index]); - std::mem::transmute::, TensorViewMut<'a, T>>(view) - }; - Some(view) - } - } -} - -/// Iterator over slices of a tensor along an axis. See [TensorView::axis_chunks]. -pub struct AxisChunks<'a, T, L: MutLayout> { - view: TensorBase, - index: usize, - chunk_size: usize, -} - -impl<'a, T, L: MutLayout> AxisChunks<'a, T, L> { - pub fn new( - view: &TensorBase, - dim: usize, - chunk_size: usize, - ) -> AxisChunks<'a, T, L> { - let mut permuted = view.clone(); - permuted.move_axis(dim, 0); - AxisChunks { - view: permuted, - index: 0, - chunk_size, - } - } -} - -impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> { - type Item = TensorView<'a, T>; - - fn next(&mut self) -> Option { - let size = self.view.size(0); - if self.index >= self.view.size(0) { - None - } else { - let view = self - .view - .slice_dyn(self.index..self.index.add(self.chunk_size).min(size)); - self.index += self.chunk_size; - Some(view) - } - } -} - -/// Iterator over mutable slices of a tensor along an axis. See [TensorViewMut::axis_chunks_mut]. -pub struct AxisChunksMut<'a, T, L: MutLayout> { - view: TensorBase, - index: usize, - chunk_size: usize, -} - -impl<'a, T, L: MutLayout> AxisChunksMut<'a, T, L> { - pub fn new( - mut view: TensorBase, - dim: usize, - chunk_size: usize, - ) -> AxisChunksMut<'a, T, L> { - // See notes in `Layout` about internal overlap. - assert!( - !view.layout().is_broadcast(), - "Cannot mutably iterate over broadcasting view" - ); - view.move_axis(dim, 0); - AxisChunksMut { - view, - chunk_size, - index: 0, - } - } -} - -impl<'a, T, L: MutLayout> Iterator for AxisChunksMut<'a, T, L> { - type Item = TensorViewMut<'a, T>; - - fn next(&mut self) -> Option { - let size = self.view.size(0); - - if self.index >= size { - None - } else { - let index = self.index; - self.index += self.chunk_size; - - // Safety: This is non-broadcasting view, and we increment the index - // each time, so returned views will not overlap. - let view = unsafe { - let view = self - .view - .slice_mut_dyn(index..index.add(self.chunk_size).min(size)); - std::mem::transmute::, TensorViewMut<'a, T>>(view) - }; - Some(view) - } - } -} From baa1c54fc7f29e64aad0ea87242b8832d01d3988 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 22 Jan 2024 06:43:56 +0000 Subject: [PATCH 08/12] Remove unnecessary `size` method call --- rten-tensor/src/iterators.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rten-tensor/src/iterators.rs b/rten-tensor/src/iterators.rs index ad5753fc..0a23cbe7 100644 --- a/rten-tensor/src/iterators.rs +++ b/rten-tensor/src/iterators.rs @@ -983,7 +983,7 @@ impl<'a, T, L: MutLayout> Iterator for AxisChunks<'a, T, L> { fn next(&mut self) -> Option { let size = self.view.size(0); - if self.index >= self.view.size(0) { + if self.index >= size { None } else { let view = self From 3d611335694424406d6f49bd5362721bd4cad15b Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 22 Jan 2024 06:45:03 +0000 Subject: [PATCH 09/12] Use static-rank views more in RNN ops Static-rank views are preferable when the dimensionality is known, as they are more efficient and the compiler can better catch errors. --- src/ops/rnn.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index fa783077..b1ff4385 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -139,9 +139,9 @@ fn compute_rnn_gate( gemm: &GemmExecutor, mut output: TensorViewMut, act: Activation, - input: &TensorView, + input: &NdTensorView, input_weight: GemmInputB, - hidden: &TensorView, + hidden: &NdTensorView, hidden_weight: GemmInputB, bias: Option<(NdTensorView, NdTensorView)>, ) { @@ -307,8 +307,8 @@ pub fn gru( extract_gru_weights_and_bias(dir, HIDDEN_GATE); for seq in sequence_for_dir(direction, dir, seq_len) { - let in_item = input.slice_dyn([seq]); - let hidden_item = hidden.slice_dyn([dir]); + let in_item = input.slice::<2, _>([seq]); + let hidden_item = hidden.slice::<2, _>([dir]); // From the ONNX spec, the intermediate values are computed as: // @@ -402,7 +402,7 @@ pub fn gru( } // Compute next hidden state - let mut hidden_item = hidden.slice_mut_dyn([dir]); + let mut hidden_item = hidden.slice_mut::<2, _>([dir]); for (hidden, update, hidden_gate) in zip3( hidden_item.iter_mut(), update_gate.iter(), @@ -412,8 +412,8 @@ pub fn gru( } hidden_seq - .slice_mut_dyn([seq, dir]) - .copy_from(&hidden_item.as_dyn()); + .slice_mut::<2, _>([seq, dir]) + .copy_from(&hidden_item); } } @@ -573,8 +573,8 @@ pub fn lstm( // supported. // - `f`, `g` and `h` are activations. `f`=sigmoid, `g` and `h` // are tanh. - let in_item = input.slice_dyn([seq]); - let hidden_item = hidden.slice_dyn([dir]); + let in_item = input.slice::<2, _>([seq]); + let hidden_item = hidden.slice::<2, _>([dir]); // Compute outputs for input, forget, cell and output gates. compute_rnn_gate( @@ -622,7 +622,7 @@ pub fn lstm( ); // Compute new values of cell and hidden state - let mut cell_item = cell.slice_mut_dyn([dir]); + let mut cell_item = cell.slice_mut::<2, _>([dir]); for (cell, forget_gate, input_gate, cell_gate) in zip4( cell_item.iter_mut(), @@ -633,7 +633,7 @@ pub fn lstm( *cell = forget_gate * *cell + input_gate * cell_gate; } - let mut hidden_item = hidden.slice_mut_dyn([dir]); + let mut hidden_item = hidden.slice_mut::<2, _>([dir]); for (hidden, out_gate, cell) in zip3(hidden_item.iter_mut(), out_gate.iter(), cell_item.iter()) { @@ -641,8 +641,8 @@ pub fn lstm( } hidden_seq - .slice_mut_dyn([seq, dir]) - .copy_from(&hidden_item.view()); + .slice_mut::<2, _>([seq, dir]) + .copy_from(&hidden_item); } } From 0b02610b0b3fdc16d616ee7339967e08a647dea5 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 22 Jan 2024 06:46:31 +0000 Subject: [PATCH 10/12] Simplify unnecessary `from_data` call This was added at an intermediate point during the unified tensor implementation, before `From` had been added. --- src/ops/resize.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops/resize.rs b/src/ops/resize.rs index c29cc565..c0284366 100644 --- a/src/ops/resize.rs +++ b/src/ops/resize.rs @@ -181,7 +181,7 @@ pub fn resize_image(input: TensorView, size: [usize; 2]) -> Result Date: Mon, 22 Jan 2024 07:45:09 +0000 Subject: [PATCH 11/12] Rename `try_slice` => `try_slice_dyn` This aligns with `slice_dyn`. --- rten-tensor/src/tensor.rs | 20 +++++++++++++------- src/ops/gather.rs | 2 +- src/ops/slice.rs | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index 272a5e40..a0cd19cc 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -178,8 +178,11 @@ pub trait AsView: Layout { /// /// Fails if the range has more dimensions than the view or is out of bounds /// for any dimension. - fn try_slice(&self, range: R) -> Result, SliceError> { - self.view().try_slice(range) + fn try_slice_dyn( + &self, + range: R, + ) -> Result, SliceError> { + self.view().try_slice_dyn(range) } /// Slice this tensor and return a static-rank view with `M` dimensions. @@ -1152,7 +1155,10 @@ impl<'a, T, L: Clone + MutLayout> TensorBase { } } - pub fn try_slice(&self, range: R) -> Result, SliceError> { + pub fn try_slice_dyn( + &self, + range: R, + ) -> Result, SliceError> { let (offset_range, layout) = self.layout.try_slice(range)?; Ok(TensorBase { data: &self.data[offset_range], @@ -2632,14 +2638,14 @@ mod tests { let data = vec![1., 2., 3., 4.]; let tensor = Tensor::from_data(&[2, 2], data); - let row = tensor.try_slice(0); + let row = tensor.try_slice_dyn(0); assert!(row.is_ok()); assert_eq!(row.unwrap().data(), Some([1., 2.].as_slice())); - let row = tensor.try_slice(1); + let row = tensor.try_slice_dyn(1); assert!(row.is_ok()); - let row = tensor.try_slice(2); + let row = tensor.try_slice_dyn(2); assert!(row.is_err()); } @@ -2656,7 +2662,7 @@ mod tests { let row = tensor.try_slice_mut(1); assert!(row.is_ok()); - let row = tensor.try_slice(2); + let row = tensor.try_slice_dyn(2); assert!(row.is_err()); } diff --git a/src/ops/gather.rs b/src/ops/gather.rs index 4b730c23..e4c449ab 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -400,7 +400,7 @@ pub fn scatter_nd< let update_slice = updates.slice_dyn(update_idx.as_slice()); let output_idx: DynSliceItems = indices - .try_slice(update_idx.as_slice()) + .try_slice_dyn(update_idx.as_slice()) .map_err(|_| OpError::InvalidValue("invalid scatter index"))? .iter() .map(|x| SliceItem::Index(*x as isize)) diff --git a/src/ops/slice.rs b/src/ops/slice.rs index fcaceb36..da0c09f7 100644 --- a/src/ops/slice.rs +++ b/src/ops/slice.rs @@ -60,7 +60,7 @@ pub fn slice( // all ranges except those with a negative step. This benefits from // optimizations that `Tensor::to_tensor` has for slices that are already // contiguous or have a small number of dims. - if let Ok(slice_view) = input.try_slice(items.as_slice()) { + if let Ok(slice_view) = input.try_slice_dyn(items.as_slice()) { return Ok(slice_view.to_tensor()); } From e3964f63e6cc5bcfbadb38ac71ee4b009bf65918 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 22 Jan 2024 07:47:00 +0000 Subject: [PATCH 12/12] Add a few additional tests and doc comments in tensor.rs --- rten-tensor/src/tensor.rs | 90 ++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 20 deletions(-) diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index a0cd19cc..30520c09 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -154,11 +154,19 @@ pub trait AsView: Layout { /// Return a view with dimensions permuted in the order given by `dims`. fn permuted( &self, - dims: Self::Index<'_>, + order: Self::Index<'_>, ) -> TensorBase { - self.view().permuted(dims) + self.view().permuted(order) } + /// Return a view with a given shape, without copying any data. This + /// requires that the tensor is contiguous. + /// + /// The new shape must have the same number of elments as the current + /// shape. The result will have a static rank if `shape` is an array or + /// a dynamic rank if it is a slice. + /// + /// Panics if the tensor is not contiguous. fn reshaped( &self, shape: S, @@ -187,10 +195,10 @@ pub trait AsView: Layout { /// Slice this tensor and return a static-rank view with `M` dimensions. /// - /// Use [AsView::slice_dyn] instead the number of dimensions in the returned - /// view is unknown at compile time. + /// Use [AsView::slice_dyn] instead if the number of dimensions in the + /// returned view is unknown at compile time. /// - /// Panics if the dimension count is not `M`. + /// Panics if the dimension count of the result is not `M`. fn slice(&self, range: R) -> NdTensorView { self.view().slice(range) } @@ -221,7 +229,12 @@ pub trait AsView: Layout { where Self::Elem: Clone; - /// Return a vector with the same shape but with strides in contiguous order. + /// Return a tensor with the same shape as this tensor/view but with the + /// data contiguous in memory and arranged in the same order as the + /// logical/iteration order (used by `iter`). + /// + /// This will return a view if the data is already contiguous or copy + /// data into a new buffer otherwise. /// /// Certain operations require or are faster with contiguous tensors. fn to_contiguous(&self) -> TensorBase, Self::Layout> @@ -239,7 +252,7 @@ pub trait AsView: Layout { where Self::Elem: Clone; - /// Return clone of this tensor which uniquely owns its elements. + /// Return a copy of this tensor/view which uniquely owns its elements. fn to_tensor(&self) -> TensorBase, Self::Layout> where Self::Elem: Clone, @@ -488,6 +501,9 @@ impl<'a> IntoLayout for &'a [usize] { /// Trait which extends [MutLayout] with support for changing the number of /// dimensions in-place. +/// +/// This is only implemented for [DynLayout], since layouts that have a static +/// rank cannot change their dimension count at runtime. pub trait ResizeLayout: MutLayout { /// Insert a size-one axis at the given index in the shape. This will have /// the same stride as the dimension that follows it. @@ -503,7 +519,7 @@ impl ResizeLayout for DynLayout { /// Trait for converting types into indices for use with a given layout. /// /// Static-rank tensors can be indexed with `[usize; N]` arrays. Dynamic-rank -/// tensors can be indexed with any type that can be converted to a `&[usize]` +/// tensors can be indexed with any type that can be converted to an `&[usize]` /// slice. pub trait AsIndex { /// Convert `self` into an index for use the layout `L`. @@ -528,7 +544,9 @@ impl, L: MutLayout> TensorBase { let layout = L::from_shape(shape); assert!( data.as_ref().len() == layout.len(), - "data length does not match shape" + "data length {} does not match shape {:?}", + data.as_ref().len(), + layout.shape().as_ref(), ); TensorBase { data, @@ -543,7 +561,7 @@ impl, L: MutLayout> TensorBase { /// This will fail if the data length is incorrect for the shape and stride /// combination, or if the strides lead to overlap (see [OverlapPolicy]). /// See also [TensorBase::from_slice_with_strides] which is a similar method - /// for immutable views, that does allow overlapping strides. + /// for immutable views that does allow overlapping strides. pub fn from_data_with_strides( shape: L::Index<'_>, data: S, @@ -730,8 +748,8 @@ impl + AsMut<[T]>, L: MutLayout> TensorBase { /// Slice this tensor and return a static-rank view with `M` dimensions. /// - /// Use [AsView::slice_dyn] instead the number of dimensions in the returned - /// view is unknown at compile time. + /// Use [AsView::slice_dyn] instead if the number of dimensions in the + /// returned view is unknown at compile time. /// /// Panics if the dimension count is not `M`. pub fn slice_mut( @@ -891,6 +909,11 @@ impl TensorBase, L> { } /// Make the underlying data in this tensor contiguous. + /// + /// This means that after calling `make_contiguous`, the elements are + /// guaranteed to be stored in the same order as the logical order in + /// which `iter` yields elements. This method is cheap if the storage is + /// already contiguous. pub fn make_contiguous(&mut self) where T: Clone, @@ -1027,7 +1050,7 @@ impl<'a, T, L: Clone + MutLayout> TensorBase { InnerIter::new(self.view()) } - /// Return the scalar value in this tensor if it has 0 dimensions. + /// Return the scalar value in this tensor if it has one element. pub fn item(&self) -> Option<&'a T> { match self.ndim() { 0 => Some(&self.data[0]), @@ -1408,6 +1431,7 @@ impl From> for TensorBase, L> where [usize; 1]: AsIndex, { + /// Create a 1D tensor from a vector. fn from(vec: Vec) -> Self { Self::from_data([vec.len()].as_index(), vec) } @@ -1417,6 +1441,7 @@ impl<'a, T, L: Clone + MutLayout> From<&'a [T]> for TensorBase where [usize; 1]: AsIndex, { + /// Create a 1D view from a slice. fn from(slice: &'a [T]) -> Self { Self::from_data([slice.len()].as_index(), slice) } @@ -1426,6 +1451,7 @@ impl<'a, T, L: Clone + MutLayout, const N: usize> From<&'a [T; N]> for TensorBas where [usize; 1]: AsIndex, { + /// Create a 1D view from a slice of known length. fn from(slice: &'a [T; N]) -> Self { Self::from_data([slice.len()].as_index(), slice.as_slice()) } @@ -1458,7 +1484,7 @@ impl + AsMut<[T]>, const N: usize> TensorBase /// Store an array of `M` elements into successive entries of a tensor along /// the `dim` axis. /// - /// See [NdTensorBase::get_array] for more details. + /// See [TensorBase::get_array] for more details. #[inline] pub fn set_array(&mut self, base: [usize; N], dim: usize, values: [T; M]) where @@ -1588,7 +1614,7 @@ where { type Error = DimensionError; - /// Convert a dynamic-dimensional tensor or view into a static-dimensional one. + /// Convert a tensor or view with dynamic rank into a static rank one. /// /// Fails if `value` does not have `N` dimensions. fn try_from(value: TensorBase) -> Result { @@ -1658,7 +1684,7 @@ where /// /// This offers a middle-ground between regular indexing, which bounds-checks /// each index element, and unchecked indexing, which does no bounds-checking -/// at all. +/// at all and is thus unsafe. pub struct WeaklyCheckedView, L: MutLayout> { base: TensorBase, } @@ -1987,7 +2013,7 @@ mod tests { } #[test] - #[should_panic(expected = "data length does not match shape")] + #[should_panic(expected = "data length 4 does not match shape [2, 2, 2]")] fn test_from_data_shape_mismatch() { NdTensor::from_data([2, 2, 2], vec![1, 2, 3, 4]); } @@ -2219,11 +2245,15 @@ mod tests { fn test_item() { let tensor = NdTensor::from_data([], vec![5.]); assert_eq!(tensor.item(), Some(&5.)); + let tensor = NdTensor::from_data([1], vec![6.]); + assert_eq!(tensor.item(), Some(&6.)); let tensor = NdTensor::from_data([2], vec![2., 3.]); assert_eq!(tensor.item(), None); let tensor = Tensor::from_data(&[], vec![5.]); assert_eq!(tensor.item(), Some(&5.)); + let tensor = Tensor::from_data(&[1], vec![6.]); + assert_eq!(tensor.item(), Some(&6.)); let tensor = Tensor::from_data(&[2], vec![2., 3.]); assert_eq!(tensor.item(), None); } @@ -2236,6 +2266,11 @@ mod tests { tensor.iter().copied().collect::>(), &[1., 2., 3., 4.] ); + let transposed = tensor.transposed(); + assert_eq!( + transposed.iter().copied().collect::>(), + &[1., 3., 2., 4.] + ); let data = vec![1., 2., 3., 4.]; let tensor = Tensor::from_data(&[2, 2], data); @@ -2243,6 +2278,11 @@ mod tests { tensor.iter().copied().collect::>(), &[1., 2., 3., 4.] ); + let transposed = tensor.transposed(); + assert_eq!( + transposed.iter().copied().collect::>(), + &[1., 3., 2., 4.] + ); } #[test] @@ -2576,10 +2616,20 @@ mod tests { fn test_to_contiguous() { let data = vec![1., 2., 3., 4.]; let tensor = NdTensor::from_data([2, 2], data); - // TODO - Try both contiguous and non-contiguous tensors. - let tensor = tensor.to_contiguous(); - // TODO - Check the actual storage from to_contiguous + + // Tensor is already contiguous, so this is a no-op. + let mut tensor = tensor.to_contiguous(); assert_eq!(tensor.to_vec(), &[1., 2., 3., 4.]); + + // Swap strides to make tensor non-contiguous. + tensor.transpose(); + assert!(!tensor.is_contiguous()); + assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]); + + // Create a new contiguous copy. + let tensor = tensor.to_contiguous(); + assert!(tensor.is_contiguous()); + assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]); } #[test]