Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify static and dynamic-rank tensor implementations #43

Merged
merged 12 commits into from
Jan 22, 2024
Merged
4 changes: 2 additions & 2 deletions rten-examples/src/bert_qa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fn extract_nbest_answers<'a>(
.iter()
.map(|tid| *tid as i32)
.collect::<Tensor<_>>()
.into_shape(&[1, query_context.token_ids().len()]);
.into_shape([1, query_context.token_ids().len()].as_slice());
let attention_mask = Tensor::full(&[batch, input_ids.len()], 1i32);

let input_ids_id = model.node_id("input_ids")?;
Expand All @@ -112,7 +112,7 @@ fn extract_nbest_answers<'a>(
.token_type_ids()
.map(|tid| tid as i32)
.collect::<Tensor<_>>()
.into_shape(&[1, query_context.token_ids().len()]);
.into_shape([1, query_context.token_ids().len()].as_slice());
inputs.push((type_ids_id, type_ids.view().into()));
}

Expand Down
4 changes: 2 additions & 2 deletions rten-examples/src/deeplab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn main() -> Result<(), Box<dyn Error>> {

let mut image: Tensor = read_image(&args.image)?.into();
normalize_image(image.nd_view_mut());
image.insert_dim(0); // Add batch dim
image.insert_axis(0); // Add batch dim

// Resize image according to metadata in the model.
let input_shape = model
Expand All @@ -132,7 +132,7 @@ fn main() -> Result<(), Box<dyn Error>> {
output.permute(&[0, 2, 3, 1]); // (N,class,H,W) => (N,H,W,class)

let seg_classes: NdTensor<i32, 2> = output
.slice(0)
.slice_dyn(0)
.arg_max(-1, false /* keep_dims */)?
.try_into()?;
let [out_height, out_width] = seg_classes.shape();
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/detr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let [_, image_height, image_width] = image.shape();

let mut image = image.as_dyn().to_tensor();
image.insert_dim(0); // Add batch dim
image.insert_axis(0); // Add batch dim

// Resize input image according to min/max side length constraints.
//
Expand Down
10 changes: 5 additions & 5 deletions rten-examples/src/jina_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ fn embed_sentence_batch(
let token_ids = encoded.token_ids();
for (tid, input_id) in token_ids
.iter()
.zip(input_ids.slice_mut((i, ..token_ids.len())).iter_mut())
.zip(input_ids.slice_mut_dyn((i, ..token_ids.len())).iter_mut())
{
*input_id = *tid as i32;
}
Expand All @@ -114,7 +114,7 @@ fn embed_sentence_batch(
let mut attention_mask = Tensor::zeros(&[batch, max_sequence_len]);
for (i, encoded) in encoded.iter().enumerate() {
attention_mask
.slice_mut((i, ..encoded.token_ids().len()))
.slice_mut::<1, _>((i, ..encoded.token_ids().len()))
.fill(1i32);
}

Expand Down Expand Up @@ -147,7 +147,7 @@ fn embed_sentence_batch(
// Take the mean of the non-padding elements along the sequence
// dimension.
let seq_len = input.token_ids().len();
item.slice(..seq_len)
item.slice_dyn(..seq_len)
.reduce_mean(Some(&[0]), false /* keep_dims */)
.unwrap()
})
Expand All @@ -157,7 +157,7 @@ fn embed_sentence_batch(
.map(|mp| {
// Re-add batch dim.
let mut view = mp.view();
view.insert_dim(0);
view.insert_axis(0);
view
})
.collect();
Expand Down Expand Up @@ -241,7 +241,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// all be "high" values (close to 1.0). They should be used only for
// comparison with other scores.
let mut scores: Vec<(usize, f32)> = similarities
.slice(0)
.slice_dyn(0)
.iter()
.copied()
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/wav2vec2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let samples = read_wav_file(&args.wav_file)?;

let mut sample_batch = Tensor::from_vec(samples);
sample_batch.insert_dim(0);
sample_batch.insert_axis(0);

let result: NdTensor<f32, 3> = model
.run_one(sample_batch.view().into(), None)?
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/src/yolo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let [_, image_height, image_width] = image.shape();

let mut image = image.as_dyn().to_tensor();
image.insert_dim(0); // Add batch dim
image.insert_axis(0); // Add batch dim

let input_shape = model
.input_shape(0)
Expand Down
1 change: 1 addition & 0 deletions rten-imageproc/src/contours.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ pub fn find_contours(mask: NdTensorView<i32, 2>, mode: RetrievalMode) -> Polygon
mod tests {
use std::iter::zip;

use rten_tensor::prelude::*;
use rten_tensor::NdTensor;

use crate::tests::border_points;
Expand Down
Loading