Skip to content

Commit

Permalink
create fast paths for trivial tdims
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 16, 2024
1 parent 9b98017 commit 9e8cf59
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 9 deletions.
8 changes: 7 additions & 1 deletion data/src/dim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ impl DimLike for TDim {
}

fn broadcast(self, other: Self) -> TractResult<Self> {
Ok(TDim::Broadcast(vec![self, other]).simplify())
if self.is_one() {
Ok(other)
} else if other.is_one() {
Ok(self)
} else {
Ok(TDim::Broadcast(vec![self, other]).simplify())
}
}

fn compatible_with(&self, other: &Self) -> bool {
Expand Down
66 changes: 58 additions & 8 deletions data/src/dim/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl fmt::Display for TDim {
impl TDim {
#[inline]
pub fn is_one(&self) -> bool {
self == &Val(1)
matches!(self, Val(1))
}

#[inline]
Expand Down Expand Up @@ -881,13 +881,25 @@ impl<'a> From<&'a Symbol> for TDim {
impl ops::Neg for TDim {
type Output = Self;
fn neg(self) -> Self {
TDim::MulInt(-1, Box::new(self)).reduce()
if let Val(v) = self {
Val(-v)
} else {
TDim::MulInt(-1, Box::new(self)).reduce()
}
}
}

impl<'a> ops::AddAssign<&'a TDim> for TDim {
fn add_assign(&mut self, rhs: &'a TDim) {
*self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
if rhs.is_zero() {
()
} else if self.is_zero() {
*self = rhs.clone();
} else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
*s += o;
} else {
*self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
}
}
}

Expand All @@ -897,7 +909,15 @@ where
{
fn add_assign(&mut self, rhs: I) {
let rhs = rhs.into();
*self += &rhs
if rhs.is_zero() {
()
} else if self.is_zero() {
*self = rhs;
} else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
*s += o;
} else {
*self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
}
}
}

Expand All @@ -924,7 +944,15 @@ impl<'a> ops::Add<&'a TDim> for TDim {
impl<'a> ops::SubAssign<&'a TDim> for TDim {
fn sub_assign(&mut self, rhs: &'a TDim) {
use std::ops::Neg;
*self += rhs.clone().neg()
if rhs.is_zero() {
()
} else if self.is_zero() {
*self = rhs.clone().neg();
} else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
*s -= o;
} else {
*self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
}
}
}

Expand All @@ -933,7 +961,16 @@ where
I: Into<TDim>,
{
fn sub_assign(&mut self, rhs: I) {
*self -= &rhs.into()
let rhs = rhs.into();
if rhs.is_zero() {
()
} else if self.is_zero() {
*self = rhs.neg();
} else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
*s -= o;
} else {
*self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
}
}
}

Expand All @@ -958,13 +995,26 @@ impl<'a> ops::Sub<&'a TDim> for TDim {

impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
fn mul_assign(&mut self, rhs: I) {
*self = TDim::Mul(vec![rhs.into(), std::mem::take(self)]).reduce()
let rhs = rhs.into();
if self.is_one() {
*self = rhs
} else if rhs.is_one() {
()
} else {
*self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
}
}
}

impl<'a> ops::MulAssign<&'a TDim> for TDim {
fn mul_assign(&mut self, rhs: &'a TDim) {
*self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
if self.is_one() {
*self = rhs.clone()
} else if rhs.is_one() {
()
} else {
*self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
}
}
}

Expand Down

0 comments on commit 9e8cf59

Please sign in to comment.