Skip to content

Commit

Permalink
BinOpX are slower ..
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Jul 19, 2024
1 parent d4fceff commit 3377de2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 31 deletions.
11 changes: 10 additions & 1 deletion core/src/ops/binary.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::internal::*;
use downcast_rs::Downcast;
use tract_itertools::Itertools;
use std::fmt;
use tract_data::itertools::izip;
use tract_itertools::Itertools;

use super::cast::cast;

Expand Down Expand Up @@ -397,6 +397,11 @@ impl Op for BinOpByScalar {
format!("{}ByScalar", self.0.name()).into()
}

fn same_as(&self, other: &dyn Op) -> bool {
let Some(other) = other.downcast_ref::<BinOpByScalar>() else { return false };
self.0.same_as(&*other.0)
}

op_as_typed_op!();
}

Expand Down Expand Up @@ -499,6 +504,10 @@ impl Op for BinOpUnicast {
format!("{}Unicast", self.0.name()).into()
}

fn same_as(&self, other: &dyn Op) -> bool {
let Some(other) = other.downcast_ref::<BinOpUnicast>() else { return false };
self.0.same_as(&*other.0)
}
op_as_typed_op!();
}

Expand Down
46 changes: 16 additions & 30 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,45 +82,31 @@ bin_to_super_type!(mul, Mul,
},
linalg: Mul,
uniform_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult<bool> {
if b.datum_type() == f32::datum_type() {
let a = a.to_scalar::<f32>()?;
let slice = b.as_slice_mut::<f32>()?;
(tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, *a)?;
Ok(true)
} else if b.datum_type() == f16::datum_type() {
let a = a.to_scalar::<f16>()?;
let slice = b.as_slice_mut::<f16>()?;
(tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, *a)?;
Ok(true)
} else {
Ok(false)
}
let mut slice = b.view_mut();
let scalar = a.view();
let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul)
.and_then(move |func| (func)(&mut slice, &scalar).ok())
.is_some();
Ok(res)
},
unicast_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult<bool> {
if b.datum_type() == f32::datum_type() {
let a = a.as_slice::<f32>()?;
let slice = b.as_slice_mut::<f32>()?;
(tract_linalg::ops().unicast_mul_f32)().run(slice, a)?;
Ok(true)
} else if b.datum_type() == f16::datum_type() {
let a = a.as_slice::<f16>()?;
let slice = b.as_slice_mut::<f16>()?;
(tract_linalg::ops().unicast_mul_f16)().run(slice, a)?;
Ok(true)
} else {
Ok(false)
}
let mut slice = b.view_mut();
let other = a.view();
let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul)
.and_then(move |func| (func)(&mut slice, &other).ok())
.is_some();
Ok(res)
},
eval_by_scalar: |a: &mut TensorView, b: &TensorView | -> TractResult<bool> {
let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul)
.context("unimplemented mul by scalar")?(a, b)
.is_ok();
.and_then(move |func| (func)(a, b).ok())
.is_some();
Ok(res)
},
eval_unicast: |a: &mut TensorView, b: &TensorView | -> TractResult<bool> {
let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul)
.context("unimplemented mul unicast")?(a, b)
.is_ok();
.and_then(move |func| (func)(a, b).ok())
.is_some();
Ok(res)
},
neutral_element: 1,
Expand Down

0 comments on commit 3377de2

Please sign in to comment.