diff --git a/assert/assertions.go b/assert/assertions.go index c540bb02d..97f44529c 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -17,8 +17,8 @@ import ( "unicode" "unicode/utf8" - "github.com/davecgh/go-spew/spew" "github.com/pmezard/go-difflib/difflib" + "github.com/stretchr/testify/internal/check" yaml "gopkg.in/yaml.v3" ) @@ -52,49 +52,6 @@ type Comparison func() (success bool) Helper functions */ -// ObjectsAreEqual determines if two objects are considered equal. -// -// This function does no assertion of any kind. -func ObjectsAreEqual(expected, actual interface{}) bool { - if expected == nil || actual == nil { - return expected == actual - } - - exp, ok := expected.([]byte) - if !ok { - return reflect.DeepEqual(expected, actual) - } - - act, ok := actual.([]byte) - if !ok { - return false - } - if exp == nil || act == nil { - return exp == nil && act == nil - } - return bytes.Equal(exp, act) -} - -// ObjectsAreEqualValues gets whether two objects are equal, or if their -// values are equal. -func ObjectsAreEqualValues(expected, actual interface{}) bool { - if ObjectsAreEqual(expected, actual) { - return true - } - - actualType := reflect.TypeOf(actual) - if actualType == nil { - return false - } - expectedValue := reflect.ValueOf(expected) - if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { - // Attempt comparison after type conversion - return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) - } - - return false -} - /* CallerInfo is necessary because the assert functions use the testing object internally, causing it to print the file:line of the assert method, rather than where the problem actually occurred in calling code.*/ @@ -317,7 +274,7 @@ func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs h.Helper() } - if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { + if !check.ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) } @@ -340,7 +297,7 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) expected, actual, err), msgAndArgs...) } - if !ObjectsAreEqual(expected, actual) { + if !check.ObjectsAreEqual(expected, actual) { diff := diff(expected, actual) expected, actual = formatUnequalValues(expected, actual) return Fail(t, fmt.Sprintf("Not equal: \n"+ @@ -461,7 +418,7 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa h.Helper() } - if !ObjectsAreEqualValues(expected, actual) { + if !check.ObjectsAreEqualValues(expected, actual) { diff := diff(expected, actual) expected, actual = formatUnequalValues(expected, actual) return Fail(t, fmt.Sprintf("Not equal: \n"+ @@ -551,41 +508,12 @@ func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { return Fail(t, fmt.Sprintf("Expected nil, but got: %#v", object), msgAndArgs...) } -// isEmpty gets whether the specified object is considered empty or not. -func isEmpty(object interface{}) bool { - - // get nil case out of the way - if object == nil { - return true - } - - objValue := reflect.ValueOf(object) - - switch objValue.Kind() { - // collection types are empty when they have no element - case reflect.Chan, reflect.Map, reflect.Slice: - return objValue.Len() == 0 - // pointers are empty if nil or if the value they point to is empty - case reflect.Ptr: - if objValue.IsNil() { - return true - } - deref := objValue.Elem().Interface() - return isEmpty(deref) - // for all other types, compare against the zero value - // array types are empty when they match their zero-initialized state - default: - zero := reflect.Zero(objValue.Type()) - return reflect.DeepEqual(object, zero.Interface()) - } -} - // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // assert.Empty(t, obj) func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - pass := isEmpty(object) + pass := check.IsEmpty(object) if !pass { if h, ok := t.(tHelper); ok { h.Helper() @@ -604,7 +532,7 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { // assert.Equal(t, "two", obj[1]) // } func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - pass := !isEmpty(object) + pass := !check.IsEmpty(object) if !pass { if h, ok := t.(tHelper); ok { h.Helper() @@ -692,7 +620,7 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{ expected, actual, err), msgAndArgs...) } - if ObjectsAreEqual(expected, actual) { + if check.ObjectsAreEqual(expected, actual) { return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) } @@ -708,7 +636,7 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte h.Helper() } - if ObjectsAreEqualValues(expected, actual) { + if check.ObjectsAreEqualValues(expected, actual) { return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) } @@ -742,7 +670,7 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) { if listKind == reflect.Map { mapKeys := listValue.MapKeys() for i := 0; i < len(mapKeys); i++ { - if ObjectsAreEqual(mapKeys[i].Interface(), element) { + if check.ObjectsAreEqual(mapKeys[i].Interface(), element) { return true, true } } @@ -750,7 +678,7 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) { } for i := 0; i < listValue.Len(); i++ { - if ObjectsAreEqual(listValue.Index(i).Interface(), element) { + if check.ObjectsAreEqual(listValue.Index(i).Interface(), element) { return true, true } } @@ -843,7 +771,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok subsetElement := subsetValue.MapIndex(subsetKey).Interface() listElement := listValue.MapIndex(subsetKey).Interface() - if !ObjectsAreEqual(subsetElement, listElement) { + if !check.ObjectsAreEqual(subsetElement, listElement) { return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...) } } @@ -904,7 +832,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) subsetElement := subsetValue.MapIndex(subsetKey).Interface() listElement := listValue.MapIndex(subsetKey).Interface() - if !ObjectsAreEqual(subsetElement, listElement) { + if !check.ObjectsAreEqual(subsetElement, listElement) { return true } } @@ -935,93 +863,14 @@ func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface if h, ok := t.(tHelper); ok { h.Helper() } - if isEmpty(listA) && isEmpty(listB) { - return true - } - - if !isList(t, listA, msgAndArgs...) || !isList(t, listB, msgAndArgs...) { - return false - } - extraA, extraB := diffLists(listA, listB) - - if len(extraA) == 0 && len(extraB) == 0 { - return true + if err := check.ElementsMatch(listA, listB); err != nil { + return Fail(t, err.Error(), msgAndArgs...) } - return Fail(t, formatListDiff(listA, listB, extraA, extraB), msgAndArgs...) -} - -// isList checks that the provided value is array or slice. -func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) { - kind := reflect.TypeOf(list).Kind() - if kind != reflect.Array && kind != reflect.Slice { - return Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind), - msgAndArgs...) - } return true } -// diffLists diffs two arrays/slices and returns slices of elements that are only in A and only in B. -// If some element is present multiple times, each instance is counted separately (e.g. if something is 2x in A and -// 5x in B, it will be 0x in extraA and 3x in extraB). The order of items in both lists is ignored. -func diffLists(listA, listB interface{}) (extraA, extraB []interface{}) { - aValue := reflect.ValueOf(listA) - bValue := reflect.ValueOf(listB) - - aLen := aValue.Len() - bLen := bValue.Len() - - // Mark indexes in bValue that we already used - visited := make([]bool, bLen) - for i := 0; i < aLen; i++ { - element := aValue.Index(i).Interface() - found := false - for j := 0; j < bLen; j++ { - if visited[j] { - continue - } - if ObjectsAreEqual(bValue.Index(j).Interface(), element) { - visited[j] = true - found = true - break - } - } - if !found { - extraA = append(extraA, element) - } - } - - for j := 0; j < bLen; j++ { - if visited[j] { - continue - } - extraB = append(extraB, bValue.Index(j).Interface()) - } - - return -} - -func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) string { - var msg bytes.Buffer - - msg.WriteString("elements differ") - if len(extraA) > 0 { - msg.WriteString("\n\nextra elements in list A:\n") - msg.WriteString(spewConfig.Sdump(extraA)) - } - if len(extraB) > 0 { - msg.WriteString("\n\nextra elements in list B:\n") - msg.WriteString(spewConfig.Sdump(extraB)) - } - msg.WriteString("\n\nlistA:\n") - msg.WriteString(spewConfig.Sdump(listA)) - msg.WriteString("\n\nlistB:\n") - msg.WriteString(spewConfig.Sdump(listB)) - - return msg.String() -} - // Condition uses a Comparison to assert a complex condition. func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { @@ -1673,11 +1522,11 @@ func diff(expected interface{}, actual interface{}) string { e = reflect.ValueOf(expected).String() a = reflect.ValueOf(actual).String() case reflect.TypeOf(time.Time{}): - e = spewConfigStringerEnabled.Sdump(expected) - a = spewConfigStringerEnabled.Sdump(actual) + e = check.SpewConfigStringerEnabled.Sdump(expected) + a = check.SpewConfigStringerEnabled.Sdump(actual) default: - e = spewConfig.Sdump(expected) - a = spewConfig.Sdump(actual) + e = check.SpewConfig.Sdump(expected) + a = check.SpewConfig.Sdump(actual) } diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ @@ -1700,23 +1549,6 @@ func isFunction(arg interface{}) bool { return reflect.TypeOf(arg).Kind() == reflect.Func } -var spewConfig = spew.ConfigState{ - Indent: " ", - DisablePointerAddresses: true, - DisableCapacities: true, - SortKeys: true, - DisableMethods: true, - MaxDepth: 10, -} - -var spewConfigStringerEnabled = spew.ConfigState{ - Indent: " ", - DisablePointerAddresses: true, - DisableCapacities: true, - SortKeys: true, - MaxDepth: 10, -} - type tHelper interface { Helper() } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index b5193a582..c1e28accb 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -16,6 +16,8 @@ import ( "testing" "time" "unsafe" + + "github.com/stretchr/testify/internal/check" ) var ( @@ -101,54 +103,6 @@ func (a *AssertionTesterConformingObject) TestMethod() { type AssertionTesterNonConformingObject struct { } -func TestObjectsAreEqual(t *testing.T) { - cases := []struct { - expected interface{} - actual interface{} - result bool - }{ - // cases that are expected to be equal - {"Hello World", "Hello World", true}, - {123, 123, true}, - {123.5, 123.5, true}, - {[]byte("Hello World"), []byte("Hello World"), true}, - {nil, nil, true}, - - // cases that are expected not to be equal - {map[int]int{5: 10}, map[int]int{10: 20}, false}, - {'x', "x", false}, - {"x", 'x', false}, - {0, 0.1, false}, - {0.1, 0, false}, - {time.Now, time.Now, false}, - {func() {}, func() {}, false}, - {uint32(10), int32(10), false}, - } - - for _, c := range cases { - t.Run(fmt.Sprintf("ObjectsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { - res := ObjectsAreEqual(c.expected, c.actual) - - if res != c.result { - t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result) - } - - }) - } - - // Cases where type differ but values are equal - if !ObjectsAreEqualValues(uint32(10), int32(10)) { - t.Error("ObjectsAreEqualValues should return true") - } - if ObjectsAreEqualValues(0, nil) { - t.Fail() - } - if ObjectsAreEqualValues(nil, 0) { - t.Fail() - } - -} - func TestImplements(t *testing.T) { mockT := new(testing.T) @@ -841,90 +795,6 @@ func TestElementsMatch(t *testing.T) { } } -func TestDiffLists(t *testing.T) { - tests := []struct { - name string - listA interface{} - listB interface{} - extraA []interface{} - extraB []interface{} - }{ - { - name: "equal empty", - listA: []string{}, - listB: []string{}, - extraA: nil, - extraB: nil, - }, - { - name: "equal same order", - listA: []string{"hello", "world"}, - listB: []string{"hello", "world"}, - extraA: nil, - extraB: nil, - }, - { - name: "equal different order", - listA: []string{"hello", "world"}, - listB: []string{"world", "hello"}, - extraA: nil, - extraB: nil, - }, - { - name: "extra A", - listA: []string{"hello", "hello", "world"}, - listB: []string{"hello", "world"}, - extraA: []interface{}{"hello"}, - extraB: nil, - }, - { - name: "extra A twice", - listA: []string{"hello", "hello", "hello", "world"}, - listB: []string{"hello", "world"}, - extraA: []interface{}{"hello", "hello"}, - extraB: nil, - }, - { - name: "extra B", - listA: []string{"hello", "world"}, - listB: []string{"hello", "hello", "world"}, - extraA: nil, - extraB: []interface{}{"hello"}, - }, - { - name: "extra B twice", - listA: []string{"hello", "world"}, - listB: []string{"hello", "hello", "world", "hello"}, - extraA: nil, - extraB: []interface{}{"hello", "hello"}, - }, - { - name: "integers 1", - listA: []int{1, 2, 3, 4, 5}, - listB: []int{5, 4, 3, 2, 1}, - extraA: nil, - extraB: nil, - }, - { - name: "integers 2", - listA: []int{1, 2, 1, 2, 1}, - listB: []int{2, 1, 2, 1, 2}, - extraA: []interface{}{1}, - extraB: []interface{}{2}, - }, - } - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - actualExtraA, actualExtraB := diffLists(test.listA, test.listB) - Equal(t, test.extraA, actualExtraA, "extra A does not match for listA=%v listB=%v", - test.listA, test.listB) - Equal(t, test.extraB, actualExtraB, "extra B does not match for listA=%v listB=%v", - test.listA, test.listB) - }) - } -} - func TestCondition(t *testing.T) { mockT := new(testing.T) @@ -1146,33 +1016,6 @@ func TestErrorContains(t *testing.T) { "ErrorContains should return true") } -func Test_isEmpty(t *testing.T) { - - chWithValue := make(chan struct{}, 1) - chWithValue <- struct{}{} - - True(t, isEmpty("")) - True(t, isEmpty(nil)) - True(t, isEmpty([]string{})) - True(t, isEmpty(0)) - True(t, isEmpty(int32(0))) - True(t, isEmpty(int64(0))) - True(t, isEmpty(false)) - True(t, isEmpty(map[string]string{})) - True(t, isEmpty(new(time.Time))) - True(t, isEmpty(time.Time{})) - True(t, isEmpty(make(chan struct{}))) - True(t, isEmpty([1]int{})) - False(t, isEmpty("something")) - False(t, isEmpty(errors.New("something"))) - False(t, isEmpty([]string{"something"})) - False(t, isEmpty(1)) - False(t, isEmpty(true)) - False(t, isEmpty(map[string]string{"Hello": "World"})) - False(t, isEmpty(chWithValue)) - False(t, isEmpty([1]int{42})) -} - func TestEmpty(t *testing.T) { mockT := new(testing.T) @@ -2184,7 +2027,7 @@ func TestBytesEqual(t *testing.T) { {nil, make([]byte, 0)}, } for i, c := range cases { - Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1) + Equal(t, reflect.DeepEqual(c.a, c.b), check.ObjectsAreEqual(c.a, c.b), "case %d failed", i+1) } } diff --git a/internal/check/checks.go b/internal/check/checks.go new file mode 100644 index 000000000..a4131dea0 --- /dev/null +++ b/internal/check/checks.go @@ -0,0 +1,192 @@ +package check + +import ( + "bytes" + "errors" + "fmt" + "reflect" + + "github.com/davecgh/go-spew/spew" +) + +var SpewConfig = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, + DisableMethods: true, + MaxDepth: 10, +} + +var SpewConfigStringerEnabled = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, + MaxDepth: 10, +} + +// IsEmpty gets whether the specified object is considered empty or not. +func IsEmpty(object interface{}) bool { + + // get nil case out of the way + if object == nil { + return true + } + + objValue := reflect.ValueOf(object) + + switch objValue.Kind() { + // collection types are empty when they have no element + case reflect.Chan, reflect.Map, reflect.Slice: + return objValue.Len() == 0 + // pointers are empty if nil or if the value they point to is empty + case reflect.Ptr: + if objValue.IsNil() { + return true + } + deref := objValue.Elem().Interface() + return IsEmpty(deref) + // for all other types, compare against the zero value + // array types are empty when they match their zero-initialized state + default: + zero := reflect.Zero(objValue.Type()) + return reflect.DeepEqual(object, zero.Interface()) + } +} + +// DiffLists diffs two arrays/slices and returns slices of elements that are only in A and only in B. +// If some element is present multiple times, each instance is counted separately (e.g. if something is 2x in A and +// 5x in B, it will be 0x in extraA and 3x in extraB). The order of items in both lists is ignored. +func DiffLists(listA, listB interface{}) (extraA, extraB []interface{}) { + aValue := reflect.ValueOf(listA) + bValue := reflect.ValueOf(listB) + + aLen := aValue.Len() + bLen := bValue.Len() + + // Mark indexes in bValue that we already used + visited := make([]bool, bLen) + for i := 0; i < aLen; i++ { + element := aValue.Index(i).Interface() + found := false + for j := 0; j < bLen; j++ { + if visited[j] { + continue + } + if ObjectsAreEqual(bValue.Index(j).Interface(), element) { + visited[j] = true + found = true + break + } + } + if !found { + extraA = append(extraA, element) + } + } + + for j := 0; j < bLen; j++ { + if visited[j] { + continue + } + extraB = append(extraB, bValue.Index(j).Interface()) + } + + return +} + +// ObjectsAreEqual determines if two objects are considered equal. +// +// This function does no assertion of any kind. +func ObjectsAreEqual(expected, actual interface{}) bool { + if expected == nil || actual == nil { + return expected == actual + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + if exp == nil || act == nil { + return exp == nil && act == nil + } + return bytes.Equal(exp, act) +} + +// ObjectsAreEqualValues gets whether two objects are equal, or if their +// values are equal. +func ObjectsAreEqualValues(expected, actual interface{}) bool { + if ObjectsAreEqual(expected, actual) { + return true + } + + actualType := reflect.TypeOf(actual) + if actualType == nil { + return false + } + expectedValue := reflect.ValueOf(expected) + if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + } + + return false +} + +// ElementsMatch returns nil if listA has the same number of the same elements +// as listB. It returns an error with a printable message if it does not. +func ElementsMatch(listA, listB interface{}) error { + if IsEmpty(listA) && IsEmpty(listB) { + return nil + } + + if err := IsList(listA); err != nil { + return err + } + if err := IsList(listB); err != nil { + return err + } + + extraA, extraB := DiffLists(listA, listB) + + if len(extraA) == 0 && len(extraB) == 0 { + return nil + } + + return errors.New(formatListDiff(listA, listB, extraA, extraB)) +} + +func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) string { + var msg bytes.Buffer + + msg.WriteString("elements differ") + if len(extraA) > 0 { + msg.WriteString("\n\nextra elements in list A:\n") + msg.WriteString(SpewConfig.Sdump(extraA)) + } + if len(extraB) > 0 { + msg.WriteString("\n\nextra elements in list B:\n") + msg.WriteString(SpewConfig.Sdump(extraB)) + } + msg.WriteString("\n\nlistA:\n") + msg.WriteString(SpewConfig.Sdump(listA)) + msg.WriteString("\n\nlistB:\n") + msg.WriteString(SpewConfig.Sdump(listB)) + + return msg.String() +} + +// IsList checks that the provided value is array or slice and returns a +// printable error if it is not +func IsList(list interface{}) error { + kind := reflect.TypeOf(list).Kind() + if kind != reflect.Array && kind != reflect.Slice { + return fmt.Errorf("%q has an unsupported type %s, expecting array or slice", list, kind) + } + return nil +} diff --git a/internal/check/checks_test.go b/internal/check/checks_test.go new file mode 100644 index 000000000..af76f63a2 --- /dev/null +++ b/internal/check/checks_test.go @@ -0,0 +1,170 @@ +package check_test + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/internal/check" +) + +func Test_IsEmpty(t *testing.T) { + + chWithValue := make(chan struct{}, 1) + chWithValue <- struct{}{} + + assert.True(t, check.IsEmpty("")) + assert.True(t, check.IsEmpty(nil)) + assert.True(t, check.IsEmpty([]string{})) + assert.True(t, check.IsEmpty(0)) + assert.True(t, check.IsEmpty(int32(0))) + assert.True(t, check.IsEmpty(int64(0))) + assert.True(t, check.IsEmpty(false)) + assert.True(t, check.IsEmpty(map[string]string{})) + assert.True(t, check.IsEmpty(new(time.Time))) + assert.True(t, check.IsEmpty(time.Time{})) + assert.True(t, check.IsEmpty(make(chan struct{}))) + assert.True(t, check.IsEmpty([1]int{})) + assert.False(t, check.IsEmpty("something")) + assert.False(t, check.IsEmpty(errors.New("something"))) + assert.False(t, check.IsEmpty([]string{"something"})) + assert.False(t, check.IsEmpty(1)) + assert.False(t, check.IsEmpty(true)) + assert.False(t, check.IsEmpty(map[string]string{"Hello": "World"})) + assert.False(t, check.IsEmpty(chWithValue)) + assert.False(t, check.IsEmpty([1]int{42})) +} + +func TestObjectsAreEqual(t *testing.T) { + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + // cases that are expected to be equal + {"Hello World", "Hello World", true}, + {123, 123, true}, + {123.5, 123.5, true}, + {[]byte("Hello World"), []byte("Hello World"), true}, + {nil, nil, true}, + + // cases that are expected not to be equal + {map[int]int{5: 10}, map[int]int{10: 20}, false}, + {'x', "x", false}, + {"x", 'x', false}, + {0, 0.1, false}, + {0.1, 0, false}, + {time.Now, time.Now, false}, + {func() {}, func() {}, false}, + {uint32(10), int32(10), false}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("ObjectsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := check.ObjectsAreEqual(c.expected, c.actual) + + if res != c.result { + t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result) + } + + }) + } + + // Cases where type differ but values are equal + if !check.ObjectsAreEqualValues(uint32(10), int32(10)) { + t.Error("ObjectsAreEqualValues should return true") + } + if check.ObjectsAreEqualValues(0, nil) { + t.Fail() + } + if check.ObjectsAreEqualValues(nil, 0) { + t.Fail() + } + +} + +func TestDiffLists(t *testing.T) { + tests := []struct { + name string + listA interface{} + listB interface{} + extraA []interface{} + extraB []interface{} + }{ + { + name: "equal empty", + listA: []string{}, + listB: []string{}, + extraA: nil, + extraB: nil, + }, + { + name: "equal same order", + listA: []string{"hello", "world"}, + listB: []string{"hello", "world"}, + extraA: nil, + extraB: nil, + }, + { + name: "equal different order", + listA: []string{"hello", "world"}, + listB: []string{"world", "hello"}, + extraA: nil, + extraB: nil, + }, + { + name: "extra A", + listA: []string{"hello", "hello", "world"}, + listB: []string{"hello", "world"}, + extraA: []interface{}{"hello"}, + extraB: nil, + }, + { + name: "extra A twice", + listA: []string{"hello", "hello", "hello", "world"}, + listB: []string{"hello", "world"}, + extraA: []interface{}{"hello", "hello"}, + extraB: nil, + }, + { + name: "extra B", + listA: []string{"hello", "world"}, + listB: []string{"hello", "hello", "world"}, + extraA: nil, + extraB: []interface{}{"hello"}, + }, + { + name: "extra B twice", + listA: []string{"hello", "world"}, + listB: []string{"hello", "hello", "world", "hello"}, + extraA: nil, + extraB: []interface{}{"hello", "hello"}, + }, + { + name: "integers 1", + listA: []int{1, 2, 3, 4, 5}, + listB: []int{5, 4, 3, 2, 1}, + extraA: nil, + extraB: nil, + }, + { + name: "integers 2", + listA: []int{1, 2, 1, 2, 1}, + listB: []int{2, 1, 2, 1, 2}, + extraA: []interface{}{1}, + extraB: []interface{}{2}, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + actualExtraA, actualExtraB := check.DiffLists(test.listA, test.listB) + assert.Equal(t, test.extraA, actualExtraA, "extra A does not match for listA=%v listB=%v", + test.listA, test.listB) + assert.Equal(t, test.extraB, actualExtraB, "extra B does not match for listA=%v listB=%v", + test.listA, test.listB) + }) + } +} diff --git a/mock/mock.go b/mock/mock.go index e6ff8dfeb..f8b29f9bb 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -14,6 +14,7 @@ import ( "github.com/pmezard/go-difflib/difflib" "github.com/stretchr/objx" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/internal/check" ) // TestingT is an interface wrapper around *testing.T @@ -780,14 +781,19 @@ func IsType(t interface{}) *IsTypeArgument { return &IsTypeArgument{t: t} } -// argumentMatcher performs custom argument matching, returning whether or +type argumentMatcher interface { + Matches(interface{}) bool + String() string +} + +// matchedByFn performs custom argument matching, returning whether or // not the argument is matched by the expectation fixture function. -type argumentMatcher struct { +type matchedByFn struct { // fn is a function which accepts one argument, and returns a bool. fn reflect.Value } -func (f argumentMatcher) Matches(argument interface{}) bool { +func (f matchedByFn) Matches(argument interface{}) bool { expectType := f.fn.Type().In(0) expectTypeNilSupported := false switch expectType.Kind() { @@ -813,7 +819,7 @@ func (f argumentMatcher) Matches(argument interface{}) bool { return false } -func (f argumentMatcher) String() string { +func (f matchedByFn) String() string { return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String()) } @@ -841,7 +847,30 @@ func MatchedBy(fn interface{}) argumentMatcher { panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn)) } - return argumentMatcher{fn: reflect.ValueOf(fn)} + return matchedByFn{fn: reflect.ValueOf(fn)} +} + +type elementsMatch struct { + list interface{} +} + +// ElementsMatch matches if the passed array or slice's elements match the +// elements in list. +// +// Example: +// m.On("Do", ElementsMatch([]int{5, 6, 7, 8})).Return() +func ElementsMatch(list interface{}) argumentMatcher { + return elementsMatch{ + list: list, + } +} + +func (e elementsMatch) Matches(list interface{}) bool { + return check.ElementsMatch(e.list, list) == nil +} + +func (e elementsMatch) String() string { + return fmt.Sprintf("(elements of %[1]T=%[1]v)", e.list) } // Get Returns the argument at the specified index. @@ -929,7 +958,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { } else { // normal checking - if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { + if check.ObjectsAreEqual(expected, Anything) || check.ObjectsAreEqual(actual, Anything) || check.ObjectsAreEqual(actual, expected) { // match output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) } else { diff --git a/mock/mock_test.go b/mock/mock_test.go index 260bb9c4f..59244d49f 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -1705,6 +1705,26 @@ func Test_Arguments_Diff_WithIsTypeArgument_Failing(t *testing.T) { assert.Contains(t, diff, `string != type int - (int=123)`) } +func Test_Arguments_Diff_WithAnyOrder(t *testing.T) { + args := Arguments([]interface{}{ElementsMatch([]int{1, 2, 3})}) + _, count := args.Diff([]interface{}{[]int{2, 1, 3}}) + assert.Equal(t, 0, count) +} + +func Test_Arguments_Diff_WithAnyOrder_Failing(t *testing.T) { + args := Arguments([]interface{}{ElementsMatch([]int{1, 2, 3})}) + s, count := args.Diff([]interface{}{[]int{1, 2, 4}}) + assert.Equal(t, 1, count) + assert.Equal(t, "\n\t0: FAIL: ([]int=[1 2 4]) not matched by (elements of []int=[1 2 3])\n", s) +} + +func Test_Arguments_Diff_WithAnyOrder_WrongType(t *testing.T) { + args := Arguments([]interface{}{ElementsMatch([]string{"a", "b"})}) + s, count := args.Diff([]interface{}{1}) + assert.Equal(t, 1, count) + assert.Equal(t, "\n\t0: FAIL: (int=1) not matched by (elements of []string=[a b])\n", s) +} + func Test_Arguments_Diff_WithArgMatcher(t *testing.T) { matchFn := func(a int) bool { return a == 123 @@ -1916,7 +1936,7 @@ func TestArgumentMatcherToPrintMismatch(t *testing.T) { defer func() { if r := recover(); r != nil { matchingExp := regexp.MustCompile( - `\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(int=1\) not matched by func\(int\) bool`) + `\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.matchedByFn\)\s+0: mock.matchedByFn\{.*?\}\s+Diff:.*\(int=1\) not matched by func\(int\) bool`) assert.Regexp(t, matchingExp, r) } }() @@ -1933,7 +1953,7 @@ func TestArgumentMatcherToPrintMismatchWithReferenceType(t *testing.T) { defer func() { if r := recover(); r != nil { matchingExp := regexp.MustCompile( - `\s+mock: Unexpected Method Call\s+-*\s+GetTimes\(\[\]int\)\s+0: \[\]int\{1\}\s+The closest call I have is:\s+GetTimes\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(\[\]int=\[1\]\) not matched by func\(\[\]int\) bool`) + `\s+mock: Unexpected Method Call\s+-*\s+GetTimes\(\[\]int\)\s+0: \[\]int\{1\}\s+The closest call I have is:\s+GetTimes\(mock.matchedByFn\)\s+0: mock.matchedByFn\{.*?\}\s+Diff:.*\(\[\]int=\[1\]\) not matched by func\(\[\]int\) bool`) assert.Regexp(t, matchingExp, r) } }()