Skip to content

Commit

Permalink
Bezineb5 fix interface scalar (#73)
Browse files Browse the repository at this point in the history
* Fixed an issue with the leftTensor parameter leading to a bug for scalars.

* OK this works

* Fixes #70 and #72.
Though this patch is quite at the surface. I haven't really got the time to dig in why the behaviour is as such, given that I'm feeling quite ill atm. I will come back and fix it if need be in the future

Co-authored-by: Benjamin <>
Co-authored-by: wzzhu <>
  • Loading branch information
chewxy authored Jun 30, 2020
1 parent 123d3a8 commit 2452c8b
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 32 deletions.
124 changes: 124 additions & 0 deletions api_arith_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,24 @@ func TestMulScalarScalar(t *testing.T) {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// Interface - tensor
ai := 2.0
b = NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3}))
correct = []float64{6.0}

res, err = Mul(ai, b)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// Commutativity
res, err = Mul(b, ai)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())
}

func TestDivScalarScalar(t *testing.T) {
Expand Down Expand Up @@ -253,6 +271,28 @@ func TestDivScalarScalar(t *testing.T) {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// interface-scalar
ai := 6.0
b = New(WithBacking([]float64{2}))
correct = 3.0

res, err = Div(ai, b)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// scalar-interface
a = New(WithBacking([]float64{6}))
bi := 2.0
correct = 3.0

res, err = Div(a, bi)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())
}

func TestAddScalarScalar(t *testing.T) {
Expand Down Expand Up @@ -309,6 +349,24 @@ func TestAddScalarScalar(t *testing.T) {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// interface-scalar
ai := 2.0
b = New(WithBacking([]float64{3}))
correct = 5.0

res, err = Add(ai, b)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// Test commutativity
res, err = Add(b, ai)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())
}

func TestSubScalarScalar(t *testing.T) {
Expand Down Expand Up @@ -355,6 +413,28 @@ func TestSubScalarScalar(t *testing.T) {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// interface-scalar
ai := 6.0
b = New(WithBacking([]float64{2}))
correct = 4.0

res, err = Sub(ai, b)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// scalar-interface
a = New(WithBacking([]float64{6}))
bi := 2.0
correct = 4.0

res, err = Sub(a, bi)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())
}

func TestModScalarScalar(t *testing.T) {
Expand Down Expand Up @@ -401,6 +481,28 @@ func TestModScalarScalar(t *testing.T) {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// interface-scalar
ai := 5.0
b = New(WithBacking([]float64{2}))
correct = 1.0

res, err = Mod(ai, b)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// scalar-interface
a = New(WithBacking([]float64{5}))
bi := 2.0
correct = 1.0

res, err = Mod(a, bi)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())
}

func TestPowScalarScalar(t *testing.T) {
Expand Down Expand Up @@ -447,4 +549,26 @@ func TestPowScalarScalar(t *testing.T) {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// interface-scalar
ai := 6.0
b = New(WithBacking([]float64{2}))
correct = 36.0

res, err = Pow(ai, b)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())

// scalar-interface
a = New(WithBacking([]float64{6}))
bi := 2.0
correct = 36.0

res, err = Pow(a, bi)
if err != nil {
t.Fatalf("Error: %v", err)
}
assert.Equal(t, correct, res.Data())
}
1 change: 1 addition & 0 deletions array_getset.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

78 changes: 54 additions & 24 deletions defaultengine_arith.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2452c8b

Please sign in to comment.