Skip to content

Commit

Permalink
handleReuse: add safe flag to skip expensive call to BorrowInt (#107)
Browse files Browse the repository at this point in the history
* handleReuse: add unsafe flag to skip expensive call to BorrowInt

* handleReuse: add safe flag to skip expensive call to BorrowInt
  • Loading branch information
khezen authored Mar 11, 2021
1 parent d5ff158 commit e3b127e
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions dense_linalg.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err
// check whether retVal has the same size as the resulting matrix would be: mx1
fo := ParseFuncOpts(opts...)
defer returnOpOpt(fo)
if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil {
if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
err = errors.Wrapf(err, opFail, "MatVecMul")
return
}
Expand Down Expand Up @@ -131,7 +131,7 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error)

fo := ParseFuncOpts(opts...)
defer returnOpOpt(fo)
if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil {
if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
err = errors.Wrapf(err, opFail, "MatMul")
return
}
Expand Down Expand Up @@ -170,7 +170,7 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error)

fo := ParseFuncOpts(opts...)
defer returnOpOpt(fo)
if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil {
if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
err = errors.Wrapf(err, opFail, "Outer")
return
}
Expand Down Expand Up @@ -380,13 +380,15 @@ func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) {
/* UTILITY FUNCTIONS */

// handleReuse extracts a *Dense from Tensor, and checks the shape of the reuse Tensor
func handleReuse(reuse Tensor, expectedShape Shape) (retVal *Dense, err error) {
func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, err error) {
if reuse != nil {
if retVal, err = assertDense(reuse); err != nil {
err = errors.Wrapf(err, opFail, "handling reuse")
return
}

if !safe {
return
}
if err = reuseCheckShape(retVal, expectedShape); err != nil {
err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.")
return
Expand Down

0 comments on commit e3b127e

Please sign in to comment.