Skip to content

Commit

Permalink
V0.9.6 performance (#67)
Browse files Browse the repository at this point in the history
* Added a bit of optimizations for Repeat

* There was a bug in prepDataSV and prepDataVS that led to unnecessary use of iterators.

* Changed the notion of an array's equality - it doesn't need to check
for whether the caps are the same.

* Added a check in reuse to Reshape if not correct, but the storage is correct

* Updated the file that generates array_getset.go

* Incorporated #68 which fixes some possibly unsafe slice issue
  • Loading branch information
chewxy authored Jun 1, 2020
1 parent fdbd1e5 commit 123d3a8
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 17 deletions.
10 changes: 5 additions & 5 deletions array_getset.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

77 changes: 71 additions & 6 deletions defaultengine_matop_misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,87 @@ func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int,
return d, nil
}

func (e StdEng) fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error {
func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, stride, newStride int, repeats []int) error {
sarr := src.arr()
darr := dest.arr()

var destStart, srcStart int
for i := 0; i < outers; i++ {
// faster shortcut for common case.
//
// Consider a case where:
// a := ⎡ 1 ⎤
// ⎢ 2 ⎥
// ⎢ 3 ⎥
// ⎣ 4 ⎦
// a has a shape of (4, 1). it is a *Dense.
//
// Now assume we want to repeat it on axis 1, 3 times. We want to repeat it into `b`,
// which is already allocated and zeroed, as shown below
//
// b := ⎡ 0 0 0 ⎤
// ⎢ 0 0 0 ⎥
// ⎢ 0 0 0 ⎥
// ⎣ 0 0 0 ⎦
//
// Now, both `a` and `b` have a stride of 1.
//
// The desired result is:
// b := ⎡ 1 1 1 ⎤
// ⎢ 2 2 2 ⎥
// ⎢ 3 3 3 ⎥
// ⎣ 4 4 4 ⎦
///
// Observe that this is simply broadcasting (copying) a[0] (a scalar value) to the row b[0], and so on and so forth.
// This can be done without knowing the full type - we simply copy the bytes over.
if stride == 1 && newStride == 1 {
for sz := 0; sz < size; sz++ {
tmp := repeats[sz]

// first we get the bounds of the src and the dest
// the srcStart and destStart are the indices assuming a flat array of []T
// we need to get the byte slice equivalent.
bSrcStart := srcStart * int(sarr.t.Size())
bSrcEnd := (srcStart + stride) * int(sarr.t.Size())
bDestStart := destStart * int(darr.t.Size())
bDestEnd := (destStart + tmp) * int(darr.t.Size())

// then we get the data as a slice of raw bytes
sBS := storage.AsByteSlice(&sarr.Header, sarr.t.Type)
dBS := storage.AsByteSlice(&darr.Header, darr.t.Type)

// recall that len(src) < len(dest)
// it's easier to understand if we define the ranges.
// Less prone to errors.
sRange := sBS[bSrcStart:bSrcEnd]
dRange := dBS[bDestStart:bDestEnd]

// finally we copy things.
for i := 0; i < len(dRange); i += len(sRange) {
copy(dRange[i:], sRange)
}
srcStart += stride
destStart += tmp
}

// we can straightaway broadcast

continue
}

for j := 0; j < size; j++ {
var tmp int
tmp = repeats[j]
var tSlice array2
tarr := t.arr()
tSlice = tarr.slice(srcStart, t.len())

tSlice = sarr.slice(srcStart, src.len())

for k := 0; k < tmp; k++ {
if srcStart >= t.len() || destStart+stride > d.len() {
if srcStart >= src.len() || destStart+stride > dest.len() {
break
}
arr := d.arr()
dSlice := arr.slice(destStart, d.len())

dSlice := darr.slice(destStart, destStart+newStride)

// THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED.
storage.Copy(dSlice.t.Type, &dSlice.Header, &tSlice.Header)
Expand Down
14 changes: 11 additions & 3 deletions defaultengine_prep.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opt
err = errors.Wrapf(err, "Cannot use reuse: shape mismatch - reuse.len() %v, expShape.TotalSize() %v", reuse.len(), expShape.TotalSize())
return
}
if !reuse.Shape().Eq(expShape) {
cloned := expShape.Clone()
if err = reuse.Reshape(cloned...); err != nil {
return

}
ReturnInts([]int(cloned))
}

if !incr && reuse != nil {
reuse.setDataOrder(o)
Expand Down Expand Up @@ -119,7 +127,6 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea
iit = reuse.Iterator()
}
}
// log.Printf("Use Itrer %v ", useIter)

// swap
if _, ok := a.(*CS); ok {
Expand All @@ -146,7 +153,7 @@ func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse
}
useIter = a.RequiresIterator() ||
(reuse != nil && reuse.RequiresIterator()) ||
(reuse != nil && reuse.DataOrder().HasSameOrder(a.DataOrder()))
(reuse != nil && !reuse.DataOrder().HasSameOrder(a.DataOrder()))
if useIter {
ait = a.Iterator()
if reuse != nil {
Expand All @@ -170,7 +177,8 @@ func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse
}
useIter = b.RequiresIterator() ||
(reuse != nil && reuse.RequiresIterator()) ||
(reuse != nil && reuse.DataOrder().HasSameOrder(b.DataOrder()))
(reuse != nil && !reuse.DataOrder().HasSameOrder(b.DataOrder()))

if useIter {
bit = b.Iterator()
if reuse != nil {
Expand Down
4 changes: 2 additions & 2 deletions dense_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ type fmtState struct {

meta bool
flat bool
ext bool
comp bool
ext bool // extended (i.e no elision)
comp bool // compact
c rune // c is here mainly for struct packing reasons

w, p int // width and precision
Expand Down
1 change: 1 addition & 0 deletions flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func (f DataOrder) toggleColMajor() DataOrder { return f ^ (ColMajor) }

func (f DataOrder) clearTransposed() DataOrder { return f &^ (Transposed) }

// HasSameOrder returns true if both data orders are the same (either both are ColMajor or both are RowMajor)
func (f DataOrder) HasSameOrder(other DataOrder) bool {
return (f.IsColMajor() && other.IsColMajor()) || (f.IsRowMajor() && other.IsRowMajor())
}
Expand Down
3 changes: 2 additions & 1 deletion genlib2/array_getset.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ func (a array) Eq(other interface{}) bool {
if oa.L != a.L {
return false
}
/*
if oa.C != a.C {
return false
}
*/
// same exact thing
if uintptr(oa.Ptr) == uintptr(a.Ptr){
Expand Down

0 comments on commit 123d3a8

Please sign in to comment.