diff --git a/data/src/dim/mod.rs b/data/src/dim/mod.rs index 7ef9e4e03e..34a89ac692 100644 --- a/data/src/dim/mod.rs +++ b/data/src/dim/mod.rs @@ -163,7 +163,13 @@ impl DimLike for TDim { } fn broadcast(self, other: Self) -> TractResult { - Ok(TDim::Broadcast(vec![self, other]).simplify()) + if self.is_one() { + Ok(other) + } else if other.is_one() { + Ok(self) + } else { + Ok(TDim::Broadcast(vec![self, other]).simplify()) + } } fn compatible_with(&self, other: &Self) -> bool { diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index eb193671b4..26af9f09c0 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -91,7 +91,7 @@ impl fmt::Display for TDim { impl TDim { #[inline] pub fn is_one(&self) -> bool { - self == &Val(1) + matches!(self, Val(1)) } #[inline] @@ -881,13 +881,25 @@ impl<'a> From<&'a Symbol> for TDim { impl ops::Neg for TDim { type Output = Self; fn neg(self) -> Self { - TDim::MulInt(-1, Box::new(self)).reduce() + if let Val(v) = self { + Val(-v) + } else { + TDim::MulInt(-1, Box::new(self)).reduce() + } } } impl<'a> ops::AddAssign<&'a TDim> for TDim { fn add_assign(&mut self, rhs: &'a TDim) { - *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce() + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs.clone(); + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s += o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce() + } } } @@ -897,7 +909,15 @@ where { fn add_assign(&mut self, rhs: I) { let rhs = rhs.into(); - *self += &rhs + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs; + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s += o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce() + } } } @@ -924,7 +944,15 @@ impl<'a> ops::Add<&'a TDim> for TDim { impl<'a> ops::SubAssign<&'a TDim> for TDim { fn sub_assign(&mut self, rhs: &'a TDim) { use std::ops::Neg; - *self += rhs.clone().neg() + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs.clone().neg(); + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s -= o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce() + } } } @@ -933,7 +961,16 @@ where I: Into, { fn sub_assign(&mut self, rhs: I) { - *self -= &rhs.into() + let rhs = rhs.into(); + if rhs.is_zero() { + () + } else if self.is_zero() { + *self = rhs.neg(); + } else if let (Val(s), Val(o)) = (&mut *self, &rhs) { + *s -= o; + } else { + *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce() + } } } @@ -958,13 +995,26 @@ impl<'a> ops::Sub<&'a TDim> for TDim { impl> ops::MulAssign for TDim { fn mul_assign(&mut self, rhs: I) { - *self = TDim::Mul(vec![rhs.into(), std::mem::take(self)]).reduce() + let rhs = rhs.into(); + if self.is_one() { + *self = rhs + } else if rhs.is_one() { + () + } else { + *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce() + } } } impl<'a> ops::MulAssign<&'a TDim> for TDim { fn mul_assign(&mut self, rhs: &'a TDim) { - *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce() + if self.is_one() { + *self = rhs.clone() + } else if rhs.is_one() { + () + } else { + *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce() + } } }