Skip to content

Commit

Permalink
Add fast path for TensorBase::map for contiguous tensors
Browse files Browse the repository at this point in the history
This fast path already existed for `TensorBase::apply`.

In a simple transformer model with ~25M params, this reduced time over 9 `Pow`
operations used to square inputs with 128*512 elements from ~1.2ms to ~0.2ms.
  • Loading branch information
robertknight committed Feb 4, 2024
1 parent 5e654ea commit 5562fd2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
25 changes: 22 additions & 3 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,9 @@ impl<T, S: AsRef<[T]> + AsMut<[T]>, L: MutLayout> TensorBase<T, S, L> {
/// Replace each element in this tensor with the result of applying `f` to
/// the element.
pub fn apply<F: Fn(&T) -> T>(&mut self, f: F) {
if self.is_contiguous() {
self.data.as_mut().iter_mut().for_each(|x| *x = f(x));
if let Some(data) = self.data_mut() {
// Fast path for contiguous tensors.
data.iter_mut().for_each(|x| *x = f(x));
} else {
self.iter_mut().for_each(|x| *x = f(x));
}
Expand Down Expand Up @@ -1086,7 +1087,12 @@ impl<T, S: AsRef<[T]>, L: MutLayout + Clone> AsView for TensorBase<T, S, L> {
where
F: Fn(&Self::Elem) -> U,
{
let data: Vec<_> = self.iter().map(f).collect();
let data: Vec<U> = if let Some(data) = self.data() {
// Fast path for contiguous tensors.
data.iter().map(f).collect()
} else {
self.iter().map(f).collect()
};
TensorBase::from_data(self.shape(), data)
}

Expand Down Expand Up @@ -1545,9 +1551,16 @@ mod tests {
#[test]
fn test_apply() {
let data = vec![1., 2., 3., 4.];

// Contiguous tensor.
let mut tensor = NdTensor::from_data([2, 2], data);
tensor.apply(|x| *x * 2.);
assert_eq!(tensor.to_vec(), &[2., 4., 6., 8.]);

// Non-contiguous tensor
tensor.transpose();
tensor.apply(|x| *x / 2.);
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);
}

#[test]
Expand Down Expand Up @@ -2198,8 +2211,14 @@ mod tests {
fn test_map() {
let data = vec![1., 2., 3., 4.];
let tensor = NdTensor::from_data([2, 2], data);

// Contiguous tensor
let doubled = tensor.map(|x| x * 2.);
assert_eq!(doubled.to_vec(), &[2., 4., 6., 8.]);

// Non-contiguous tensor
let halved = doubled.transposed().map(|x| x / 2.);
assert_eq!(halved.to_vec(), &[1., 3., 2., 4.]);
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/ops/binary_elementwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,8 @@ fn powf(x: f32, y: f32) -> f32 {

/// Raise elements of `a` to powers of corresponding elements in `b`.
pub fn pow(a: TensorView, b: TensorView) -> Result<Tensor, OpError> {
if let Some(exp) = b.item() {
Ok(a.map(|x| powf(*x, *exp)))
if let Some(&exp) = b.item() {
Ok(a.map(|x| powf(*x, exp)))
} else {
binary_op(a, b, |x, y| x.powf(y))
}
Expand Down

0 comments on commit 5562fd2

Please sign in to comment.