Skip to content

Commit

Permalink
Replace collapse_axis with prefix_with
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Jul 19, 2024
1 parent 3377de2 commit 5de3f7f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 64 deletions.
40 changes: 18 additions & 22 deletions core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,10 @@ impl EvalOp for BinOpByScalar {

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (a, b) = args_2!(inputs);
let mut a = a.into_tensor();
// Not a requirement as TensorView doesn't require a owned tensor but in reality
// "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
let a = a.into_tensor();
let b_shape = b.shape();
let mut view = a.view_mut();
let b_view = b.view();

let first_unary_axis = b_shape
.iter()
Expand All @@ -426,21 +426,18 @@ impl EvalOp for BinOpByScalar {
.last()
.context("Cannot use by_scalar when no trailing dimensions are unary")?;

let iterating_shape = view.shape()[..first_unary_axis].to_vec();
let iterating_shape = a.shape()[..first_unary_axis].to_vec();
if !iterating_shape.is_empty() {
for it_coords in tract_data::internal::iter_indices(&iterating_shape) {
let mut view = view.clone();
let mut tmp_b_view = b_view.clone();

// Prepare array view to perform computation
for (axis, idx) in it_coords.iter().enumerate() {
view.collapse_axis(axis, *idx as isize);
tmp_b_view.collapse_axis(axis, *idx as isize);
}

self.0.eval_by_scalar(&mut view, &tmp_b_view)?;
let mut view = TensorView::at_prefix(&a, &it_coords)?;
let b_view = TensorView::at_prefix(&b, &it_coords)?;
debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
self.0.eval_by_scalar(&mut view, &b_view)?;
}
} else {
let mut view = a.view();
let b_view = b.view();
debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
self.0.eval_by_scalar(&mut view, &b_view)?;
}
Ok(tvec!(a.into_tvalue()))
Expand Down Expand Up @@ -518,25 +515,24 @@ impl EvalOp for BinOpUnicast {

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (a, b) = args_2!(inputs);
let mut a = a.into_tensor();
// Not a requirement as TensorView doesn't require a owned tensor but in reality
// "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
let a = a.into_tensor();
let b_shape = b.shape();
let mut view = a.view_mut();
let b_view = b.view();

let first_non_unary_axis =
b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();

if let Some(first_non_unary_axis) = first_non_unary_axis {
// Iterate on outter dimensions and evaluate with unicast subviews
let iterating_shape = view.shape()[..first_non_unary_axis].to_vec();
let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
for it_coords in tract_data::internal::iter_indices(&iterating_shape) {
let mut view = view.clone();
it_coords.iter().enumerate().for_each(|(axis, idx)| {
view.collapse_axis(axis, *idx as isize);
});
let mut view = TensorView::at_prefix(&a, &it_coords)?;
debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.len()..]);
self.0.eval_unicast(&mut view, &b_view)?;
}
} else {
let mut view = a.view();
debug_assert_eq!(view.shape(), b_view.shape());
self.0.eval_unicast(&mut view, &b_view)?;
}
Expand Down
45 changes: 12 additions & 33 deletions data/src/tensor/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ use super::*;
use crate::internal::*;

#[derive(Clone, Debug)]
enum Indexing {
enum Indexing<'a> {
Prefix(usize),
Custom { shape: Vec<usize>, strides: Vec<isize> },
Custom { shape: &'a [usize], strides: &'a [isize] },
}

#[derive(Clone, Debug)]
pub struct TensorView<'a> {
pub tensor: &'a Tensor,
offset_bytes: isize,
indexing: Indexing,
indexing: Indexing<'a>,
}

impl<'a> TensorView<'a> {
Expand All @@ -24,7 +24,7 @@ impl<'a> TensorView<'a> {
TensorView {
tensor,
offset_bytes,
indexing: Indexing::Custom { shape: shape.to_vec(), strides: strides.to_vec() },
indexing: Indexing::Custom { shape, strides },
}
}

Expand All @@ -46,8 +46,8 @@ impl<'a> TensorView<'a> {
tensor,
offset_bytes,
indexing: Indexing::Custom {
shape: tensor.shape.to_vec(),
strides: tensor.strides.to_vec(),
shape: &tensor.shape,
strides: &tensor.strides,
},
}
}
Expand Down Expand Up @@ -236,29 +236,6 @@ impl<'a> TensorView<'a> {
unsafe { Ok(self.at_unchecked(coords)) }
}

#[inline]
pub fn collapse_axis(&mut self, axis: usize, index: isize) {
let stride = self.strides()[axis] * self.datum_type().size_of() as isize;
unsafe { self.offset_bytes(stride * index) };
match &mut self.indexing {
Indexing::Prefix(x) => {
if *x == 0 {
let mut new_shape = self.tensor.shape().to_owned();
new_shape[axis] = 1;
self.indexing = Indexing::Custom {
shape: new_shape,
strides: self.tensor.strides().to_owned(),
}
} else {
unimplemented!("TODO: understand how it is used")
}
}
Indexing::Custom { shape, .. } => {
shape[axis] = 1;
}
}
}

#[inline]
pub fn at_mut<T: Datum>(&mut self, coords: impl AsRef<[usize]>) -> TractResult<&mut T> {
self.check_dt::<T>()?;
Expand Down Expand Up @@ -288,13 +265,15 @@ impl<'a> TensorView<'a> {
#[cfg(test)]
mod test {
use crate::prelude::Tensor;
use super::TensorView;

#[test]
fn test_collapse_axis() {
fn test_at_prefix() {
let a = Tensor::from_shape(&[2, 2], &[1, 2, 3, 4]).unwrap();
let mut a_view = a.view();
a_view.collapse_axis(0, 1);
assert_eq!(a_view.shape(), &[1, 2]);
let a_view = TensorView::at_prefix(&a, &[1]).unwrap();
assert_eq!(a_view.shape(), &[2]);
assert_eq!(a_view.as_slice::<i32>().unwrap(), &[3, 4]);


}
}
47 changes: 38 additions & 9 deletions linalg/src/frame/unicast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ std::thread_local! {
static TMP: std::cell::RefCell<(TempBuffer, TempBuffer)> = std::cell::RefCell::new((TempBuffer::default(), TempBuffer::default()));
}

fn create_incomplete_tile<'a, T: LADatum>(a: &'a mut [T], b: &'a [T], a_prefix_len: usize, b_prefix_len: usize) -> (&'a mut [T], &'a [T], usize) {
let effective_prefix = if (a_prefix_len == 0) || (b_prefix_len == 0) {
// One of the two slice is aligned, the target size is the number of unaligned elements of
// the other slice, the max value between the two.
a_prefix_len.max(b_prefix_len)
} else {
// Both are unaligned, the minimal common subset is the one including elements from a and b
// so it's the min value between the two.
a_prefix_len.min(b_prefix_len)
};
(&mut a[..effective_prefix], &b[..effective_prefix], effective_prefix)
}


pub(crate) fn unicast_with_alignment<T>(
a: &mut [T],
b: &[T],
Expand All @@ -127,18 +141,33 @@ where
f(tmp_a, tmp_b);
a.copy_from_slice(&tmp_a[..a.len()])
};
let prefix_len = a.as_ptr().align_offset(alignment_bytes).min(a.len());
if prefix_len > 0 {
compute_via_temp_buffer(&mut a[..prefix_len], &b[..prefix_len]);

let mut num_element_processed = 0;
let a_prefix_len = a.as_ptr().align_offset(alignment_bytes).min(a.len());
let b_prefix_len = b.as_ptr().align_offset(alignment_bytes).min(b.len());
let mut applied_prefix_len = 0;
if (a_prefix_len > 0) || (b_prefix_len > 0) {
// Incomplete tile needs to be created to process unaligned data.
let (mut sub_a, sub_b, applied_prefix) = create_incomplete_tile(a, b, a_prefix_len, b_prefix_len);
applied_prefix_len = applied_prefix;
compute_via_temp_buffer(&mut sub_a, &sub_b);
num_element_processed += applied_prefix_len;
}
let aligned_len = (a.len() - prefix_len) / nr * nr;
if aligned_len > 0 {
f(&mut a[prefix_len..][..aligned_len], &b[prefix_len..][..aligned_len]);

let num_complete_tiles = (a.len() - applied_prefix_len) / nr;
if num_complete_tiles > 0 {
// Process all tiles that are complete.
let mut sub_a = &mut a[applied_prefix_len..][..(num_complete_tiles * nr)];
let sub_b = &b[applied_prefix_len..][..(num_complete_tiles * nr)];
f(&mut sub_a, &sub_b);
num_element_processed += num_complete_tiles * nr;
}
if prefix_len + aligned_len < a.len() {

if num_element_processed < a.len() {
// Incomplete tile needs to be created to process remaining elements.
compute_via_temp_buffer(
&mut a[prefix_len + aligned_len..],
&b[prefix_len + aligned_len..],
&mut a[num_element_processed..],
&b[num_element_processed..],
);
}
})
Expand Down

0 comments on commit 5de3f7f

Please sign in to comment.