Skip to content

Commit

Permalink
Merge pull request #386 from robertknight/generalize-into-op-result
Browse files Browse the repository at this point in the history
Generalize `IntoOpResult` impls
  • Loading branch information
robertknight authored Oct 15, 2024
2 parents e30751d + 11bb482 commit b557d7a
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ macro_rules! impl_output_conversions {
}
}

// NdTensor<T> => Output
impl<const N: usize> From<NdTensor<$element_type, N>> for Output {
fn from(t: NdTensor<$element_type, N>) -> Output {
Output::$variant(t.into_dyn())
}
}

// Output => Tensor<T>
impl TryFrom<Output> for Tensor<$element_type> {
type Error = OpError;
Expand Down Expand Up @@ -533,7 +540,7 @@ impl<'a> Layout for InputOrOutput<'a> {
}

/// Trait for values that can be converted into the result type used by
/// `Operator::run`.
/// [`Operator::run`].
pub trait IntoOpResult {
fn into_op_result(self) -> Result<OutputList, OpError>;
}
Expand All @@ -550,38 +557,28 @@ impl IntoOpResult for Output {
}
}

impl<T> IntoOpResult for Tensor<T>
impl<S: Storage, L: MutLayout> IntoOpResult for TensorBase<S, L>
where
Output: From<Tensor<T>>,
Output: From<TensorBase<S, L>>,
{
fn into_op_result(self) -> Result<OutputList, OpError> {
let output: Output = self.into();
Ok([output].into())
}
}

impl<T, const N: usize> IntoOpResult for NdTensor<T, N>
where
Output: From<Tensor<T>>,
{
fn into_op_result(self) -> Result<OutputList, OpError> {
let output: Output = self.into_dyn().into();
Ok([output].into())
}
}

impl<T> IntoOpResult for Result<Tensor<T>, OpError>
impl<S: Storage, L: MutLayout> IntoOpResult for Result<TensorBase<S, L>, OpError>
where
Output: From<Tensor<T>>,
Output: From<TensorBase<S, L>>,
{
fn into_op_result(self) -> Result<OutputList, OpError> {
self.map(|tensor| [tensor.into()].into())
}
}

impl<T> IntoOpResult for Result<Vec<Tensor<T>>, OpError>
impl<T> IntoOpResult for Result<Vec<T>, OpError>
where
Output: From<Tensor<T>>,
Output: From<T>,
{
fn into_op_result(self) -> Result<OutputList, OpError> {
self.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
Expand Down

0 comments on commit b557d7a

Please sign in to comment.