Skip to content

Commit

Permalink
Add reduce_max, reduce_min to FloatOperators trait
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Jan 23, 2024
1 parent 6600a76 commit eea65f4
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/ops/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use rten_tensor::{DynLayout, NdLayout, NdTensorView, Tensor, TensorBase, TensorV
use crate::number::{Identities, IsInt};
use crate::ops::OpError;
use crate::ops::{
arg_max, div, matmul, mul, pad, reduce_l2, reduce_mean, resize_image, softmax, topk,
arg_max, div, matmul, mul, pad, reduce_l2, reduce_max, reduce_mean, reduce_min, resize_image,
softmax, topk,
};

/// Trait which exposes ONNX operators as methods of tensors.
Expand Down Expand Up @@ -61,7 +62,9 @@ pub trait FloatOperators {
fn matmul(&self, other: TensorView) -> Result<Tensor, OpError>;

fn reduce_l2(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError>;
fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError>;
fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError>;
fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError>;

/// Resize an NCHW image tensor to a given `[height, width]` using bilinear
/// interpolation.
Expand Down Expand Up @@ -180,6 +183,14 @@ impl<S: AsRef<[f32]>> FloatOperators for TensorBase<f32, S, DynLayout> {
reduce_l2(self.view(), axes, keep_dims)
}

fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
reduce_max(self.view(), axes, keep_dims)
}

fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
reduce_min(self.view(), axes, keep_dims)
}

fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
reduce_mean(self.view(), axes, keep_dims)
}
Expand All @@ -202,6 +213,14 @@ impl<S: AsRef<[f32]>, const N: usize> FloatOperators for TensorBase<f32, S, NdLa
reduce_l2(self.as_dyn(), axes, keep_dims)
}

fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
reduce_max(self.as_dyn(), axes, keep_dims)
}

fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
reduce_min(self.as_dyn(), axes, keep_dims)
}

fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
reduce_mean(self.as_dyn(), axes, keep_dims)
}
Expand Down

0 comments on commit eea65f4

Please sign in to comment.