Skip to content

Commit

Permalink
Use a macro to de-duplicate proxy layout impls
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Oct 15, 2024
1 parent ba13b33 commit 0fdc55d
Showing 1 changed file with 48 additions and 114 deletions.
162 changes: 48 additions & 114 deletions src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,51 @@ pub enum DataType {
UInt8,
}

/// Generate the body of a [`Layout`] impl for a type which wraps an
/// underlying layout.
macro_rules! impl_proxy_layout {
() => {
type Index<'b> = <DynLayout as Layout>::Index<'b>;
type Indices = <DynLayout as Layout>::Indices;

fn ndim(&self) -> usize {
self.layout().ndim()
}

fn try_offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.layout().try_offset(index)
}

fn len(&self) -> usize {
self.layout().len()
}

fn is_empty(&self) -> bool {
self.layout().is_empty()
}

fn shape(&self) -> Self::Index<'_> {
self.layout().shape()
}

fn size(&self, dim: usize) -> usize {
self.layout().size(dim)
}

fn strides(&self) -> Self::Index<'_> {
self.layout().strides()
}

fn stride(&self, dim: usize) -> usize {
self.layout().stride(dim)
}

fn indices(&self) -> Self::Indices {
self.layout().indices()
}
};
}

/// Enum of the different types of tensor view that can be used as a model or
/// operator input.
#[derive(Clone)]
Expand Down Expand Up @@ -209,44 +254,7 @@ impl<'a> Input<'a> {
}

impl<'a> Layout for Input<'a> {
type Index<'b> = <DynLayout as Layout>::Index<'b>;
type Indices = <DynLayout as Layout>::Indices;

fn ndim(&self) -> usize {
self.layout().ndim()
}

fn try_offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.layout().try_offset(index)
}

fn len(&self) -> usize {
self.layout().len()
}

fn is_empty(&self) -> bool {
self.layout().is_empty()
}

fn shape(&self) -> Self::Index<'_> {
self.layout().shape()
}

fn size(&self, dim: usize) -> usize {
self.layout().size(dim)
}

fn strides(&self) -> Self::Index<'_> {
self.layout().strides()
}

fn stride(&self, dim: usize) -> usize {
self.layout().stride(dim)
}

fn indices(&self) -> Self::Indices {
self.layout().indices()
}
impl_proxy_layout!();
}

macro_rules! impl_input_conversions {
Expand Down Expand Up @@ -372,44 +380,7 @@ impl Output {
}

impl Layout for Output {
type Index<'a> = <DynLayout as Layout>::Index<'a>;
type Indices = <DynLayout as Layout>::Indices;

fn ndim(&self) -> usize {
self.layout().ndim()
}

fn try_offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.layout().try_offset(index)
}

fn len(&self) -> usize {
self.layout().len()
}

fn is_empty(&self) -> bool {
self.layout().is_empty()
}

fn shape(&self) -> Self::Index<'_> {
self.layout().shape()
}

fn size(&self, dim: usize) -> usize {
self.layout().size(dim)
}

fn strides(&self) -> Self::Index<'_> {
self.layout().strides()
}

fn stride(&self, dim: usize) -> usize {
self.layout().stride(dim)
}

fn indices(&self) -> Self::Indices {
self.layout().indices()
}
impl_proxy_layout!();
}

/// Declare conversions between `Output` and `Tensor<T>` / `NdTensor<T, N>`.
Expand Down Expand Up @@ -558,44 +529,7 @@ impl<'a> From<&'a Output> for InputOrOutput<'a> {
}

impl<'a> Layout for InputOrOutput<'a> {
type Index<'b> = <DynLayout as Layout>::Index<'b>;
type Indices = <DynLayout as Layout>::Indices;

fn ndim(&self) -> usize {
self.layout().ndim()
}

fn try_offset(&self, index: Self::Index<'_>) -> Option<usize> {
self.layout().try_offset(index)
}

fn len(&self) -> usize {
self.layout().len()
}

fn is_empty(&self) -> bool {
self.layout().is_empty()
}

fn shape(&self) -> Self::Index<'_> {
self.layout().shape()
}

fn size(&self, dim: usize) -> usize {
self.layout().size(dim)
}

fn strides(&self) -> Self::Index<'_> {
self.layout().strides()
}

fn stride(&self, dim: usize) -> usize {
self.layout().stride(dim)
}

fn indices(&self) -> Self::Indices {
self.layout().indices()
}
impl_proxy_layout!();
}

/// Trait for values that can be converted into the result type used by
Expand Down

0 comments on commit 0fdc55d

Please sign in to comment.