Skip to content

Commit

Permalink
Reduce some duplication in bilinear resize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Jan 23, 2024
1 parent 898519e commit 6600a76
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/ops/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ mod tests {

use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal;
use rten_tensor::Tensor;
use rten_tensor::{NdTensor, NdTensorView, Tensor};

use crate::ops::tests::expect_eq_1e4;
use crate::ops::{
Expand Down Expand Up @@ -527,24 +527,27 @@ mod tests {

#[test]
fn test_resize_bilinear() -> Result<(), Box<dyn Error>> {
struct Case {
image: Tensor,
struct Case<'a> {
image: NdTensorView<'a, f32, 4>,
scales: Vec<f32>,
expected: Tensor,
coord_transform_mode: Option<CoordTransformMode>,
}

let image = NdTensor::from([0.2, 0.7, 0.3, 0.8]).into_shape([1, 1, 2, 2]);
let image = image.view();

let cases = [
// Scale width and height by 0x
Case {
image: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
image,
scales: vec![1., 1., 0., 0.],
coord_transform_mode: None,
expected: Tensor::from_data(&[1, 1, 0, 0], vec![]),
},
// Scale width and height by 0.5x
Case {
image: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
image,
scales: vec![1., 1., 0.5, 0.5],
coord_transform_mode: None,

Expand All @@ -556,14 +559,14 @@ mod tests {
},
// Scale width and height by 1x
Case {
image: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
image,
scales: vec![1., 1., 1., 1.],
coord_transform_mode: None,
expected: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
},
// Scale width and height by 1.5x
Case {
image: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
image,
scales: vec![1., 1., 1.5, 1.5],
coord_transform_mode: None,
expected: Tensor::from_data(
Expand All @@ -577,7 +580,7 @@ mod tests {
},
// Scale width and height by 2x
Case {
image: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
image,
scales: vec![1., 1., 2., 2.],
coord_transform_mode: None,
expected: Tensor::from_data(
Expand All @@ -592,7 +595,7 @@ mod tests {
},
// Scale width and height by 2x, align corners.
Case {
image: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
image,
scales: vec![1., 1., 2., 2.],
coord_transform_mode: Some(CoordTransformMode::AlignCorners),

Expand All @@ -608,7 +611,7 @@ mod tests {
},
// Scale width and height by 3x
Case {
image: Tensor::from_data(&[1, 1, 2, 2], vec![0.2, 0.7, 0.3, 0.8]),
image,
scales: vec![1., 1., 3., 3.],
coord_transform_mode: None,
expected: Tensor::from_data(
Expand All @@ -627,7 +630,7 @@ mod tests {

for case in cases {
let result = resize(
case.image.view(),
case.image.as_dyn(),
ResizeTarget::Scales(case.scales.as_slice().into()),
ResizeMode::Linear,
case.coord_transform_mode
Expand Down

0 comments on commit 6600a76

Please sign in to comment.