From 0fdc55dbdd845347c249ff0c68f0c62101007320 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Tue, 15 Oct 2024 07:09:05 +0100 Subject: [PATCH] Use a macro to de-duplicate proxy layout impls --- src/ops/mod.rs | 162 +++++++++++++++---------------------------------- 1 file changed, 48 insertions(+), 114 deletions(-) diff --git a/src/ops/mod.rs b/src/ops/mod.rs index a57c74dc..6a8f07af 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -178,6 +178,51 @@ pub enum DataType { UInt8, } +/// Generate the body of a [`Layout`] impl for a type which wraps an +/// underlying layout. +macro_rules! impl_proxy_layout { + () => { + type Index<'b> = ::Index<'b>; + type Indices = ::Indices; + + fn ndim(&self) -> usize { + self.layout().ndim() + } + + fn try_offset(&self, index: Self::Index<'_>) -> Option { + self.layout().try_offset(index) + } + + fn len(&self) -> usize { + self.layout().len() + } + + fn is_empty(&self) -> bool { + self.layout().is_empty() + } + + fn shape(&self) -> Self::Index<'_> { + self.layout().shape() + } + + fn size(&self, dim: usize) -> usize { + self.layout().size(dim) + } + + fn strides(&self) -> Self::Index<'_> { + self.layout().strides() + } + + fn stride(&self, dim: usize) -> usize { + self.layout().stride(dim) + } + + fn indices(&self) -> Self::Indices { + self.layout().indices() + } + }; +} + /// Enum of the different types of tensor view that can be used as a model or /// operator input. #[derive(Clone)] @@ -209,44 +254,7 @@ impl<'a> Input<'a> { } impl<'a> Layout for Input<'a> { - type Index<'b> = ::Index<'b>; - type Indices = ::Indices; - - fn ndim(&self) -> usize { - self.layout().ndim() - } - - fn try_offset(&self, index: Self::Index<'_>) -> Option { - self.layout().try_offset(index) - } - - fn len(&self) -> usize { - self.layout().len() - } - - fn is_empty(&self) -> bool { - self.layout().is_empty() - } - - fn shape(&self) -> Self::Index<'_> { - self.layout().shape() - } - - fn size(&self, dim: usize) -> usize { - self.layout().size(dim) - } - - fn strides(&self) -> Self::Index<'_> { - self.layout().strides() - } - - fn stride(&self, dim: usize) -> usize { - self.layout().stride(dim) - } - - fn indices(&self) -> Self::Indices { - self.layout().indices() - } + impl_proxy_layout!(); } macro_rules! impl_input_conversions { @@ -372,44 +380,7 @@ impl Output { } impl Layout for Output { - type Index<'a> = ::Index<'a>; - type Indices = ::Indices; - - fn ndim(&self) -> usize { - self.layout().ndim() - } - - fn try_offset(&self, index: Self::Index<'_>) -> Option { - self.layout().try_offset(index) - } - - fn len(&self) -> usize { - self.layout().len() - } - - fn is_empty(&self) -> bool { - self.layout().is_empty() - } - - fn shape(&self) -> Self::Index<'_> { - self.layout().shape() - } - - fn size(&self, dim: usize) -> usize { - self.layout().size(dim) - } - - fn strides(&self) -> Self::Index<'_> { - self.layout().strides() - } - - fn stride(&self, dim: usize) -> usize { - self.layout().stride(dim) - } - - fn indices(&self) -> Self::Indices { - self.layout().indices() - } + impl_proxy_layout!(); } /// Declare conversions between `Output` and `Tensor` / `NdTensor`. @@ -558,44 +529,7 @@ impl<'a> From<&'a Output> for InputOrOutput<'a> { } impl<'a> Layout for InputOrOutput<'a> { - type Index<'b> = ::Index<'b>; - type Indices = ::Indices; - - fn ndim(&self) -> usize { - self.layout().ndim() - } - - fn try_offset(&self, index: Self::Index<'_>) -> Option { - self.layout().try_offset(index) - } - - fn len(&self) -> usize { - self.layout().len() - } - - fn is_empty(&self) -> bool { - self.layout().is_empty() - } - - fn shape(&self) -> Self::Index<'_> { - self.layout().shape() - } - - fn size(&self, dim: usize) -> usize { - self.layout().size(dim) - } - - fn strides(&self) -> Self::Index<'_> { - self.layout().strides() - } - - fn stride(&self, dim: usize) -> usize { - self.layout().stride(dim) - } - - fn indices(&self) -> Self::Indices { - self.layout().indices() - } + impl_proxy_layout!(); } /// Trait for values that can be converted into the result type used by