diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 07f620d3..8706f114 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -393,6 +393,13 @@ macro_rules! impl_output_conversions { } } + // NdTensor => Output + impl From> for Output { + fn from(t: NdTensor<$element_type, N>) -> Output { + Output::$variant(t.into_dyn()) + } + } + // Output => Tensor impl TryFrom for Tensor<$element_type> { type Error = OpError; @@ -533,7 +540,7 @@ impl<'a> Layout for InputOrOutput<'a> { } /// Trait for values that can be converted into the result type used by -/// `Operator::run`. +/// [`Operator::run`]. pub trait IntoOpResult { fn into_op_result(self) -> Result; } @@ -550,9 +557,9 @@ impl IntoOpResult for Output { } } -impl IntoOpResult for Tensor +impl IntoOpResult for TensorBase where - Output: From>, + Output: From>, { fn into_op_result(self) -> Result { let output: Output = self.into(); @@ -560,28 +567,18 @@ where } } -impl IntoOpResult for NdTensor -where - Output: From>, -{ - fn into_op_result(self) -> Result { - let output: Output = self.into_dyn().into(); - Ok([output].into()) - } -} - -impl IntoOpResult for Result, OpError> +impl IntoOpResult for Result, OpError> where - Output: From>, + Output: From>, { fn into_op_result(self) -> Result { self.map(|tensor| [tensor.into()].into()) } } -impl IntoOpResult for Result>, OpError> +impl IntoOpResult for Result, OpError> where - Output: From>, + Output: From, { fn into_op_result(self) -> Result { self.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())