Skip to content

Commit

Permalink
Fixed Engine stuff. Also Go 1.14 fixed (#61)
Browse files Browse the repository at this point in the history
* Fixed Engine stuff. Also Go 1.14 fixed

* Fixed Travis (Gonum only supports the two most recent versions of Go)
  • Loading branch information
chewxy authored Mar 1, 2020
1 parent ac10895 commit 6848ca2
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 31 deletions.
4 changes: 1 addition & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ branches:
only:
- master
go:
- 1.10.x
- 1.11.x
- 1.12.x
- 1.13.x
- 1.14.x
- tip

env:
Expand Down
1 change: 1 addition & 0 deletions consopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt {
// Memory must be manually managed by the caller.
// Tensors called with this construction option will not be returned to any pool - rather, all references to the pointers will be null'd.
// Use with caution.
//go:nocheckptr
func FromMemory(ptr uintptr, memsize uintptr) ConsOpt {
f := func(t Tensor) {
switch tt := t.(type) {
Expand Down
2 changes: 1 addition & 1 deletion defaultengine_linalg.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) {

// extract values
var um, vm mat.Dense
s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])})
s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])}, WithEngine(e))
svd.Values(s.Data().([]float64))
if uv {
svd.UTo(&um)
Expand Down
4 changes: 2 additions & 2 deletions defaultengine_matop_misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (StdEng) denseRepeat(t DenseTensor, axis int, repeats []int) (retVal DenseT
axis = 0
}

d := recycledDense(t.Dtype(), newShape)
d := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{}))

var outers int
if t.IsScalar() {
Expand Down Expand Up @@ -155,7 +155,7 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen
return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation")
}

retVal := recycledDense(a.Dtype(), newShape)
retVal := recycledDense(a.Dtype(), newShape, WithEngine(e))
if isMasked {
retVal.makeMask()
}
Expand Down
6 changes: 3 additions & 3 deletions dense_linalg.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err
}

if retVal == nil {
retVal = recycledDense(t.t, expectedShape)
retVal = recycledDense(t.t, expectedShape, WithEngine(t.e))
if t.o.IsColMajor() {
AsFortran(nil)(retVal)
}
Expand Down Expand Up @@ -137,7 +137,7 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error)
}

if retVal == nil {
retVal = recycledDense(t.t, expectedShape)
retVal = recycledDense(t.t, expectedShape, WithEngine(t.e))
if t.o.IsColMajor() {
AsFortran(nil)(retVal)
}
Expand Down Expand Up @@ -176,7 +176,7 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error)
}

if retVal == nil {
retVal = recycledDense(t.t, expectedShape)
retVal = recycledDense(t.t, expectedShape, WithEngine(t.e))
if t.o.IsColMajor() {
AsFortran(nil)(retVal)
}
Expand Down
2 changes: 1 addition & 1 deletion dense_matop.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) {
}
}

retVal = recycledDense(t.t, Shape{t.len()})
retVal = recycledDense(t.t, Shape{t.len()}, WithEngine(t.e))
copyDense(retVal, t)

retVal.e = t.e
Expand Down
19 changes: 0 additions & 19 deletions dense_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"testing"
"testing/quick"
"time"
"unsafe"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -90,24 +89,6 @@ func TestFromScalar(t *testing.T) {
assert.Equal(t, []float64{3.14}, data)
}

func TestFromMemory(t *testing.T) {
// dummy memory - this could be an externally malloc'd memory, or a mmap'ed file.
// but here we're just gonna let Go manage memory.
s := make([]float64, 100)
ptr := uintptr(unsafe.Pointer(&s[0]))
size := uintptr(100 * 8)

T := New(Of(Float32), WithShape(50, 4), FromMemory(ptr, size))
if len(T.Float32s()) != 200 {
t.Error("expected 200 Float32s")
}
assert.Equal(t, make([]float32, 200), T.Data())
assert.True(t, T.IsManuallyManaged(), "Unamanged %v |%v | q: %v", ManuallyManaged, T.flag, (T.flag>>ManuallyManaged)&MemoryFlag(1))

fail := func() { New(FromMemory(ptr, size), Of(Float32)) }
assert.Panics(t, fail, "Expected bad New() call to panic")
}

func Test_recycledDense(t *testing.T) {
T := recycledDense(Float64, ScalarShape())
assert.Equal(t, float64(0), T.Data())
Expand Down
2 changes: 1 addition & 1 deletion dense_views.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ func (t *Dense) Materialize() Tensor {
return t
}

retVal := recycledDense(t.t, t.shape.Clone())
retVal := recycledDense(t.t, t.shape.Clone(), WithEngine(t.e))
copyDenseIter(retVal, t, nil, nil)
retVal.e = t.e
retVal.oe = t.oe
Expand Down
33 changes: 33 additions & 0 deletions known_race_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// +build !race

package tensor

import (
"testing"
"unsafe"

"github.com/stretchr/testify/assert"
)

// This test will fail the `go test -race`.
//
// This is because FromMemory() will use uintptr in a way that is incorrect according to the checkptr directive of Go 1.14+
//
// Though it's incorrect, it's the only way to use heterogenous, readable memory (i.e. CUDA).
func TestFromMemory(t *testing.T) {
// dummy memory - this could be an externally malloc'd memory, or a mmap'ed file.
// but here we're just gonna let Go manage memory.
s := make([]float64, 100)
ptr := uintptr(unsafe.Pointer(&s[0]))
size := uintptr(100 * 8)

T := New(Of(Float32), WithShape(50, 4), FromMemory(ptr, size))
if len(T.Float32s()) != 200 {
t.Error("expected 200 Float32s")
}
assert.Equal(t, make([]float32, 200), T.Data())
assert.True(t, T.IsManuallyManaged(), "Unamanged %v |%v | q: %v", ManuallyManaged, T.flag, (T.flag>>ManuallyManaged)&MemoryFlag(1))

fail := func() { New(FromMemory(ptr, size), Of(Float32)) }
assert.Panics(t, fail, "Expected bad New() call to panic")
}
2 changes: 1 addition & 1 deletion sparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (t *CS) Dense() *Dense {
// use
}

d := recycledDense(t.t, t.Shape().Clone())
d := recycledDense(t.t, t.Shape().Clone(), WithEngine(t.e))
if t.o.IsColMajor() {
for i := 0; i < len(t.indptr)-1; i++ {
for j := t.indptr[i]; j < t.indptr[i+1]; j++ {
Expand Down

0 comments on commit 6848ca2

Please sign in to comment.