Skip to content

Commit

Permalink
Merge pull request #18 from gorgonia/v0.8.0-working
Browse files Browse the repository at this point in the history
Committing changes made by @siquus in gorgornia : siquus/gorgonia@403…
  • Loading branch information
chewxy authored Dec 17, 2017
2 parents 287edfb + fcc24d4 commit 863b768
Showing 1 changed file with 41 additions and 15 deletions.
56 changes: 41 additions & 15 deletions api_arith.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,45 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
oe = at.standardEngine()
switch bt := b.(type) {
case Tensor:
if oe != nil {
return oe.Mul(at, bt, opts...)
}
if oe = bt.standardEngine(); oe != nil {
return oe.Mul(at, bt, opts...)
}
if muler, ok = at.Engine().(Muler); ok {
return muler.Mul(at, bt, opts...)
}
if muler, ok = bt.Engine().(Muler); ok {
return muler.Mul(at, bt, opts...)
if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor multiplication
if oe != nil {
return oe.Mul(at, bt, opts...)
}
if oe = bt.standardEngine(); oe != nil {
return oe.Mul(at, bt, opts...)
}
if muler, ok = at.Engine().(Muler); ok {
return muler.Mul(at, bt, opts...)
}
if muler, ok = bt.Engine().(Muler); ok {
return muler.Mul(at, bt, opts...)
}
return nil, errors.New("Neither engines of either operand support Mul")

} else { // one of the operands is a scalar
var leftTensor bool
if at.Shape().IsScalar() {
leftTensor = false // a Scalar-Tensor * b Tensor
} else {
leftTensor = true // a Tensor * b Scalar-Tensor
}

if oe != nil {
return oe.MulScalar(at, bt, leftTensor, opts...)
}
if oe = bt.standardEngine(); oe != nil {
return oe.MulScalar(at, bt, leftTensor, opts...)
}
if muler, ok = at.Engine().(Muler); ok {
return muler.MulScalar(at, bt, leftTensor, opts...)
}
if muler, ok = bt.Engine().(Muler); ok {
return muler.MulScalar(at, bt, leftTensor, opts...)
}
return nil, errors.New("Neither engines of either operand support Mul")
}
return nil, errors.New("Neither engines of either operand support Mul")

default:
default: // a Tensor * b interface
if oe != nil {
return oe.MulScalar(at, bt, true, opts...)
}
Expand All @@ -157,17 +181,19 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
}
return nil, errors.New("Operand A's engine does not support Mul")
}

default:
switch bt := b.(type) {
case Tensor:
case Tensor: // b Tensor * a interface
if oe = bt.standardEngine(); oe != nil {
return oe.MulScalar(bt, at, false, opts...)
}
if muler, ok = bt.Engine().(Muler); ok {
return muler.MulScalar(bt, at, false, opts...)
}
return nil, errors.New("Operand B's engine does not support Mul")
default:

default: // b interface * a interface
return nil, errors.Errorf("Cannot perform Mul of %T and %T", a, b)
}
}
Expand Down

0 comments on commit 863b768

Please sign in to comment.