diff --git a/src/ops/operators.rs b/src/ops/operators.rs index 12e24ca0..ab4237f3 100644 --- a/src/ops/operators.rs +++ b/src/ops/operators.rs @@ -6,7 +6,8 @@ use rten_tensor::{DynLayout, NdLayout, NdTensorView, Tensor, TensorBase, TensorV use crate::number::{Identities, IsInt}; use crate::ops::OpError; use crate::ops::{ - arg_max, div, matmul, mul, pad, reduce_l2, reduce_mean, resize_image, softmax, topk, + arg_max, div, matmul, mul, pad, reduce_l2, reduce_max, reduce_mean, reduce_min, resize_image, + softmax, topk, }; /// Trait which exposes ONNX operators as methods of tensors. @@ -61,7 +62,9 @@ pub trait FloatOperators { fn matmul(&self, other: TensorView) -> Result; fn reduce_l2(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result; + fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result; fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result; + fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result; /// Resize an NCHW image tensor to a given `[height, width]` using bilinear /// interpolation. @@ -180,6 +183,14 @@ impl> FloatOperators for TensorBase { reduce_l2(self.view(), axes, keep_dims) } + fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result { + reduce_max(self.view(), axes, keep_dims) + } + + fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result { + reduce_min(self.view(), axes, keep_dims) + } + fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result { reduce_mean(self.view(), axes, keep_dims) } @@ -202,6 +213,14 @@ impl, const N: usize> FloatOperators for TensorBase, keep_dims: bool) -> Result { + reduce_max(self.as_dyn(), axes, keep_dims) + } + + fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result { + reduce_min(self.as_dyn(), axes, keep_dims) + } + fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result { reduce_mean(self.as_dyn(), axes, keep_dims) }