Skip to content

Commit

Permalink
Implement Index/Mut for Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
barakugav committed Aug 27, 2024
1 parent 1ead874 commit e3b7d08
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
5 changes: 5 additions & 0 deletions executorch-sys/cpp/executorch_rs_ext/api_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <vector>

#include "executorch_rs_ext/api_utils.hpp"
#include "executorch/runtime/core/exec_aten/util/tensor_util.h"

namespace executorch_rs
{
Expand Down Expand Up @@ -145,6 +146,10 @@ namespace executorch_rs
{
return self.mutable_data_ptr();
}
size_t Tensor_coordinate_to_index(const exec_aten::Tensor &self, const size_t *coordinate)
{
return torch::executor::coordinateToIndex(self, coordinate);
}
void Tensor_destructor(exec_aten::Tensor &self)
{
self.~Tensor();
Expand Down
1 change: 1 addition & 0 deletions executorch-sys/cpp/executorch_rs_ext/api_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ namespace executorch_rs
exec_aten::ArrayRef<exec_aten::StridesType> Tensor_strides(const exec_aten::Tensor &self);
const void *Tensor_const_data_ptr(const exec_aten::Tensor &self);
void *Tensor_mutable_data_ptr(const exec_aten::Tensor &self);
size_t Tensor_coordinate_to_index(const exec_aten::Tensor &self, const size_t *coordinate);
void Tensor_destructor(exec_aten::Tensor &self);

// torch::executor::EValue EValue_shallow_clone(torch::executor::EValue *evalue);
Expand Down
93 changes: 93 additions & 0 deletions src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Tensor struct is an input or output tensor to a executorch program.

use core::ops::{Index, IndexMut};
use core::ptr::NonNull;
use std::fmt::Debug;
use std::marker::PhantomData;
Expand Down Expand Up @@ -545,6 +546,36 @@ impl<'a, D: DataTyped + DataMut> TensorBase<'a, D> {
}
}

impl<'a, D: DataTyped> Index<&[usize]> for TensorBase<'a, D> {
type Output = D::Scalar;

fn index(&self, index: &[usize]) -> &Self::Output {
assert_eq!(
index.len(),
self.dim() as usize,
"Invalid number of dimensions"
);
let index =
unsafe { et_rs_c::Tensor_coordinate_to_index(self.as_cpp_tensor(), index.as_ptr()) };
let base_ptr = self.as_ptr();
debug_assert!(!base_ptr.is_null());
unsafe { &*base_ptr.offset(index as isize) }
}
}
impl<'a, D: DataTyped + DataMut> IndexMut<&[usize]> for TensorBase<'a, D> {
fn index_mut(&mut self, index: &[usize]) -> &mut Self::Output {
assert_eq!(
index.len(),
self.dim() as usize,
"Invalid number of dimensions"
);
let index =
unsafe { et_rs_c::Tensor_coordinate_to_index(self.as_cpp_tensor(), index.as_ptr()) };
let base_ptr = self.as_mut_ptr().as_ptr();
unsafe { &mut *base_ptr.offset(index as isize) }
}
}

/// A typed immutable tensor that does not own the underlying data.
pub type Tensor<'a, S> = TensorBase<'a, View<S>>;
impl<'a, S: Scalar> Tensor<'a, S> {
Expand Down Expand Up @@ -835,6 +866,11 @@ impl<A: Scalar, S: ndarray::RawData<Elem = A>, D: Dimension> Array<A, S, D> {
};
TensorImplBase(impl_, PhantomData)
}

/// Get a reference to the underlying ndarray.
pub fn as_ndarray(&self) -> &ArrayBase<S, D> {
&self.array
}
}
impl<A: Scalar, S: ndarray::RawDataMut<Elem = A>, D: Dimension> Array<A, S, D> {
/// Create a [`TensorImplMut`] pointing to this struct's data.
Expand Down Expand Up @@ -1057,4 +1093,61 @@ mod tests {
test_scalar_type::<f64>(|size| vec![0.0; size]);
test_scalar_type::<bool>(|size| vec![false; size]);
}

#[test]
fn test_tensor_index() {
let arr = Array::new(Array3::<i32>::from_shape_fn((4, 5, 3), |(x, y, z)| {
x as i32 * 1337 - y as i32 * 87 + z as i32 * 13
}));
let tensor_impl = arr.as_tensor_impl();
let tensor = Tensor::new(&tensor_impl);

let arr = arr.as_ndarray();
for (ix, &expected) in arr.indexed_iter() {
let ix: [usize; 3] = ix.into();
let actual = tensor[&ix];
assert_eq!(actual, expected);
}
}

#[test]
fn test_tensor_index_mut() {
let mut arr = Array::new(Array3::<i32>::zeros((4, 5, 3)));
let mut tensor_impl = arr.as_tensor_impl_mut();
let mut tensor = TensorMut::new(&mut tensor_impl);

for ix in indexed_iter(&tensor) {
assert_eq!(tensor[&ix], 0);
}
for ix in indexed_iter(&tensor) {
let (x, y, z) = (ix[0], ix[1], ix[2]);
tensor[&ix] = x as i32 * 1337 - y as i32 * 87 + z as i32 * 13;
}
}

fn indexed_iter<D: Data>(tensor: &TensorBase<D>) -> impl Iterator<Item = Vec<usize>> {
let dim = tensor.dim() as usize;
let sizes = tensor
.sizes()
.iter()
.map(|&s| s as usize)
.collect::<Vec<_>>();
let mut coordinate = vec![0_usize; dim];
let mut remaining_elms = tensor.numel();
std::iter::from_fn(move || {
if remaining_elms <= 0 {
return None;
}
for j in (0..dim).rev() {
if coordinate[j] + 1 < sizes[j] {
coordinate[j] += 1;
break;
} else {
coordinate[j] = 0;
}
}
remaining_elms -= 1;
return Some(coordinate.clone());
})
}
}

0 comments on commit e3b7d08

Please sign in to comment.