Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mock.ElementsMatch argument matcher #1347

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 18 additions & 186 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (
"unicode"
"unicode/utf8"

"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
"github.com/stretchr/testify/internal/check"
yaml "gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -52,49 +52,6 @@ type Comparison func() (success bool)
Helper functions
*/

// ObjectsAreEqual determines if two objects are considered equal.
//
// This function does no assertion of any kind.
func ObjectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
// values are equal.
func ObjectsAreEqualValues(expected, actual interface{}) bool {
if ObjectsAreEqual(expected, actual) {
return true
}

actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
}

/* CallerInfo is necessary because the assert functions use the testing object
internally, causing it to print the file:line of the assert method, rather than where
the problem actually occurred in calling code.*/
Expand Down Expand Up @@ -317,7 +274,7 @@ func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs
h.Helper()
}

if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) {
if !check.ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) {
return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...)
}

Expand All @@ -340,7 +297,7 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
expected, actual, err), msgAndArgs...)
}

if !ObjectsAreEqual(expected, actual) {
if !check.ObjectsAreEqual(expected, actual) {
diff := diff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
Expand Down Expand Up @@ -461,7 +418,7 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
h.Helper()
}

if !ObjectsAreEqualValues(expected, actual) {
if !check.ObjectsAreEqualValues(expected, actual) {
diff := diff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
Expand Down Expand Up @@ -551,41 +508,12 @@ func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return Fail(t, fmt.Sprintf("Expected nil, but got: %#v", object), msgAndArgs...)
}

// isEmpty gets whether the specified object is considered empty or not.
func isEmpty(object interface{}) bool {

// get nil case out of the way
if object == nil {
return true
}

objValue := reflect.ValueOf(object)

switch objValue.Kind() {
// collection types are empty when they have no element
case reflect.Chan, reflect.Map, reflect.Slice:
return objValue.Len() == 0
// pointers are empty if nil or if the value they point to is empty
case reflect.Ptr:
if objValue.IsNil() {
return true
}
deref := objValue.Elem().Interface()
return isEmpty(deref)
// for all other types, compare against the zero value
// array types are empty when they match their zero-initialized state
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
}
}

// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// assert.Empty(t, obj)
func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
pass := isEmpty(object)
pass := check.IsEmpty(object)
if !pass {
if h, ok := t.(tHelper); ok {
h.Helper()
Expand All @@ -604,7 +532,7 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
// assert.Equal(t, "two", obj[1])
// }
func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
pass := !isEmpty(object)
pass := !check.IsEmpty(object)
if !pass {
if h, ok := t.(tHelper); ok {
h.Helper()
Expand Down Expand Up @@ -692,7 +620,7 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{
expected, actual, err), msgAndArgs...)
}

if ObjectsAreEqual(expected, actual) {
if check.ObjectsAreEqual(expected, actual) {
return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...)
}

Expand All @@ -708,7 +636,7 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
h.Helper()
}

if ObjectsAreEqualValues(expected, actual) {
if check.ObjectsAreEqualValues(expected, actual) {
return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...)
}

Expand Down Expand Up @@ -742,15 +670,15 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) {
if listKind == reflect.Map {
mapKeys := listValue.MapKeys()
for i := 0; i < len(mapKeys); i++ {
if ObjectsAreEqual(mapKeys[i].Interface(), element) {
if check.ObjectsAreEqual(mapKeys[i].Interface(), element) {
return true, true
}
}
return true, false
}

for i := 0; i < listValue.Len(); i++ {
if ObjectsAreEqual(listValue.Index(i).Interface(), element) {
if check.ObjectsAreEqual(listValue.Index(i).Interface(), element) {
return true, true
}
}
Expand Down Expand Up @@ -843,7 +771,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()

if !ObjectsAreEqual(subsetElement, listElement) {
if !check.ObjectsAreEqual(subsetElement, listElement) {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...)
}
}
Expand Down Expand Up @@ -904,7 +832,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()

if !ObjectsAreEqual(subsetElement, listElement) {
if !check.ObjectsAreEqual(subsetElement, listElement) {
return true
}
}
Expand Down Expand Up @@ -935,93 +863,14 @@ func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface
if h, ok := t.(tHelper); ok {
h.Helper()
}
if isEmpty(listA) && isEmpty(listB) {
return true
}

if !isList(t, listA, msgAndArgs...) || !isList(t, listB, msgAndArgs...) {
return false
}

extraA, extraB := diffLists(listA, listB)

if len(extraA) == 0 && len(extraB) == 0 {
return true
if err := check.ElementsMatch(listA, listB); err != nil {
return Fail(t, err.Error(), msgAndArgs...)
}

return Fail(t, formatListDiff(listA, listB, extraA, extraB), msgAndArgs...)
}

// isList checks that the provided value is array or slice.
func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) {
kind := reflect.TypeOf(list).Kind()
if kind != reflect.Array && kind != reflect.Slice {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind),
msgAndArgs...)
}
return true
}

// diffLists diffs two arrays/slices and returns slices of elements that are only in A and only in B.
// If some element is present multiple times, each instance is counted separately (e.g. if something is 2x in A and
// 5x in B, it will be 0x in extraA and 3x in extraB). The order of items in both lists is ignored.
func diffLists(listA, listB interface{}) (extraA, extraB []interface{}) {
aValue := reflect.ValueOf(listA)
bValue := reflect.ValueOf(listB)

aLen := aValue.Len()
bLen := bValue.Len()

// Mark indexes in bValue that we already used
visited := make([]bool, bLen)
for i := 0; i < aLen; i++ {
element := aValue.Index(i).Interface()
found := false
for j := 0; j < bLen; j++ {
if visited[j] {
continue
}
if ObjectsAreEqual(bValue.Index(j).Interface(), element) {
visited[j] = true
found = true
break
}
}
if !found {
extraA = append(extraA, element)
}
}

for j := 0; j < bLen; j++ {
if visited[j] {
continue
}
extraB = append(extraB, bValue.Index(j).Interface())
}

return
}

func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) string {
var msg bytes.Buffer

msg.WriteString("elements differ")
if len(extraA) > 0 {
msg.WriteString("\n\nextra elements in list A:\n")
msg.WriteString(spewConfig.Sdump(extraA))
}
if len(extraB) > 0 {
msg.WriteString("\n\nextra elements in list B:\n")
msg.WriteString(spewConfig.Sdump(extraB))
}
msg.WriteString("\n\nlistA:\n")
msg.WriteString(spewConfig.Sdump(listA))
msg.WriteString("\n\nlistB:\n")
msg.WriteString(spewConfig.Sdump(listB))

return msg.String()
}

// Condition uses a Comparison to assert a complex condition.
func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
Expand Down Expand Up @@ -1673,11 +1522,11 @@ func diff(expected interface{}, actual interface{}) string {
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
case reflect.TypeOf(time.Time{}):
e = spewConfigStringerEnabled.Sdump(expected)
a = spewConfigStringerEnabled.Sdump(actual)
e = check.SpewConfigStringerEnabled.Sdump(expected)
a = check.SpewConfigStringerEnabled.Sdump(actual)
default:
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
e = check.SpewConfig.Sdump(expected)
a = check.SpewConfig.Sdump(actual)
}

diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
Expand All @@ -1700,23 +1549,6 @@ func isFunction(arg interface{}) bool {
return reflect.TypeOf(arg).Kind() == reflect.Func
}

var spewConfig = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
DisableMethods: true,
MaxDepth: 10,
}

var spewConfigStringerEnabled = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
MaxDepth: 10,
}

type tHelper interface {
Helper()
}
Expand Down
Loading