Skip to content

Commit

Permalink
feat(mockgen): respect custom function implementations
Browse files Browse the repository at this point in the history
Prior to this change, a user of mockgen was free to include any function
declarations that they wished, but it was possible that they collided
with generated functions, leading to a broken build.

This change performs a scan of the existing packages for any colliding
function names or functions that already provide the call to
mock.Expect.  We naively assume that the result of the mock.Expect call
is returned from the user defined function.  If there is more than one
call to a mock.Expect function within a single custom function then we
ignore the custom function and generate functions as per normal.  If
there is a function that collides with a generated function name but
does not make a call to a mock.Expect function then the generated
function falls back to a more verbose name.
  • Loading branch information
au-phiware committed Dec 21, 2023
1 parent 0e5cd36 commit b75a1e6
Show file tree
Hide file tree
Showing 8 changed files with 746 additions and 19 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.20

require (
github.com/google/subcommands v1.2.0
golang.org/x/text v0.14.0
golang.org/x/tools v0.16.0
rsc.io/script v0.0.2-0.20231205190631-334f6c18cff3
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.16.0 h1:GO788SKMRunPIBCXiQyo2AaexLstOrVhuAL5YwsckQM=
golang.org/x/tools v0.16.0/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0=
rsc.io/script v0.0.2-0.20231205190631-334f6c18cff3 h1:2vM6uMBq2/Dou/Wzu2p+yUFkuI3lgMbX0UYfVnzh0ck=
Expand Down
167 changes: 148 additions & 19 deletions internal/mock/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"strconv"
"strings"

"golang.org/x/text/cases"
"golang.org/x/text/language"
"golang.org/x/tools/go/packages"
)

Expand Down Expand Up @@ -186,6 +188,7 @@ func Generate(ctx context.Context, patterns []string, opts GenerateOptions) ([]G
if len(errs) > 0 {
return nil, errs
}

generated := make([]GenerateResult, len(pkgs))
for i, pkg := range pkgs {
generated[i].PkgPath = pkg.PkgPath
Expand All @@ -194,18 +197,22 @@ func Generate(ctx context.Context, patterns []string, opts GenerateOptions) ([]G
generated[i].Errs = append(generated[i].Errs, err)
continue
}

outputFile := opts.PrefixOutputFile + "mock_gen"
if strings.HasSuffix(pkg.Name, "_test") {
outputFile += "_test"
}
outputFile += ".go"
generated[i].OutputPath = filepath.Join(outDir, outputFile)

g := newGen(pkg)
findFunctions(g, pkg)
errs := generateMocks(g, pkg)
if len(errs) > 0 {
generated[i].Errs = errs
continue
}

goSrc := g.frame(opts.Tags)
if len(opts.Header) > 0 {
goSrc = append(opts.Header, goSrc...)
Expand Down Expand Up @@ -251,13 +258,80 @@ func isMockStub(syntax *ast.File) bool {
return false
}

func findFunctions(g *gen, pkg *packages.Package) {
pkgName, _ := g.resolvePackageName("github.com/Versent/go-mock")
for _, syntax := range pkg.Syntax {
for _, decl := range syntax.Decls {
funcDecl, ok := g.addFunc(decl)
if !ok {
continue
}
if pkgName == "" {
// mock is not imported,
// so there cannot not be any custom functions
continue
}
if funcDecl.Recv != nil || funcDecl.Body == nil {
continue
}
// search for calls to mock.Expect or mock.ExpectMany
var funcName, structName, methodName string
ast.Inspect(funcDecl.Body, func(node ast.Node) (next bool) {
next = true
switch stmt := node.(type) {
case *ast.CallExpr:
var (
ok bool
index *ast.IndexExpr
sel *ast.SelectorExpr
ident *ast.Ident
lit *ast.BasicLit
)
if index, ok = stmt.Fun.(*ast.IndexExpr); !ok {
return
}
if sel, ok = index.X.(*ast.SelectorExpr); !ok {
return
}
if sel.Sel.Name != "Expect" && sel.Sel.Name != "ExpectMany" {
return
}
if ident, ok := sel.X.(*ast.Ident); !ok || ident.Name != pkgName {
return
}
funcName = sel.Sel.Name
if ident, ok = index.Index.(*ast.Ident); !ok {
return
}
structName = ident.Name
if len(stmt.Args) == 0 {
return
}
if lit, ok = stmt.Args[0].(*ast.BasicLit); !ok || lit.Kind != token.STRING {
return
}
if methodName != "" {
methodName = ""
return false
}
methodName = lit.Value
}
return
})
if methodName == "" {
continue
}
specName := fmt.Sprintf("%s[%s](%s)", funcName, structName, methodName)
g.funcs[specName] = struct{}{}
}
}
}

func generateMocks(g *gen, pkg *packages.Package) (errs []error) {
for _, syntax := range pkg.Syntax {
if !isMockStub(syntax) {
continue
}
// filename := pkg.Fset.File(syntax.Pos()).Name()
// ast.Print(pkg.Fset, syntax)

// Iterate over all declarations in the file
for _, decl := range syntax.Decls {
Expand Down Expand Up @@ -377,28 +451,23 @@ func generateMockMethods(g *gen, iface *types.Interface, structName string) erro
methodName := method.Name()
sig := method.Type().(*types.Signature)

methDecl := makeMockMethod(g, structName, methodName, sig)
expDecl := makeExpectFunc(g, "Expect", structName, methodName, sig)
manyDecl := makeExpectFunc(g, "ExpectMany", structName, methodName, sig)

// Generate the source code for the function
if err := g.addDecl(expDecl.Name, expDecl); err != nil {
if err := addExpectFunc(g, "Expect", structName, methodName, sig); err != nil {
return err
}
if err := g.addDecl(manyDecl.Name, manyDecl); err != nil {
if err := addExpectFunc(g, "ExpectMany", structName, methodName, sig); err != nil {
return err
}
if err := g.addDecl(methDecl.Name, methDecl); err != nil {
if err := addMockMethod(g, structName, methodName, sig); err != nil {
return err
}
}

return nil
}

func makeMockMethod(g *gen, structName, methodName string, sig *types.Signature) (methDecl *ast.FuncDecl) {
func addMockMethod(g *gen, structName, methodName string, sig *types.Signature) (err error) {
// Start building the function declaration
methDecl = &ast.FuncDecl{
methDecl := &ast.FuncDecl{
Recv: &ast.FieldList{
List: []*ast.Field{
{
Expand All @@ -413,6 +482,11 @@ func makeMockMethod(g *gen, structName, methodName string, sig *types.Signature)
Type: &ast.FuncType{},
}

if _, ok := g.funcs[g.keyForFunc(methDecl)]; ok {
// Method already exists
return
}

methDecl.Type.Params = fieldList("v", sig.Variadic(), sig.Params())
methDecl.Type.Results = fieldList("", false, sig.Results())

Expand Down Expand Up @@ -449,10 +523,33 @@ func makeMockMethod(g *gen, structName, methodName string, sig *types.Signature)
})
}

return
// Generate the source code for the function
return g.addDecl(methDecl.Name, methDecl)
}

func makeExpectFunc(g *gen, funcName, structName, methodName string, sig *types.Signature) (funcDecl *ast.FuncDecl) {
func addExpectFunc(g *gen, funcName, structName, methodName string, sig *types.Signature) error {
specName := fmt.Sprintf("%s[%s](%q)", funcName, structName, methodName)
if _, ok := g.funcs[specName]; ok {
// Custom implementation already exists
return nil
}

// Disambiguate the function name
name := ast.NewIdent(funcName + methodName)
if _, ok := g.funcs[name.Name]; ok {
if token.IsExported(structName) {
name = ast.NewIdent(funcName + structName + methodName)
} else {
name = ast.NewIdent(funcName + cases.Title(language.AmericanEnglish, cases.NoLower).String(structName) + methodName)
}
}
if _, ok := g.funcs[name.Name]; ok {
name = ast.NewIdent(name.Name + "T")
}
if _, ok := g.funcs[name.Name]; ok {
return fmt.Errorf("unable to disambiguate function name %q", name.Name)
}

delegateType := &ast.FuncType{
Params: &ast.FieldList{
List: []*ast.Field{{
Expand All @@ -473,8 +570,8 @@ func makeExpectFunc(g *gen, funcName, structName, methodName string, sig *types.
},
})
}
funcDecl = &ast.FuncDecl{
Name: ast.NewIdent(funcName + methodName),
funcDecl := &ast.FuncDecl{
Name: name,
Type: &ast.FuncType{
Results: &ast.FieldList{
List: []*ast.Field{{
Expand Down Expand Up @@ -532,7 +629,11 @@ func makeExpectFunc(g *gen, funcName, structName, methodName string, sig *types.
}
delegateType.Results.List = append(delegateType.Results.List, field)
})
return

g.funcs[specName] = struct{}{}

// Generate the source code for the function
return g.addDecl(funcDecl.Name, funcDecl)
}

func forTuple(prefix string, tuple *types.Tuple, f func(int, string, *types.Var)) {
Expand Down Expand Up @@ -590,6 +691,7 @@ type gen struct {
imports map[string]importInfo
anonImports map[string]bool
values map[ast.Expr]string
funcs map[string]struct{}
}

func newGen(pkg *packages.Package) *gen {
Expand All @@ -598,12 +700,12 @@ func newGen(pkg *packages.Package) *gen {
anonImports: make(map[string]bool),
imports: make(map[string]importInfo),
values: make(map[ast.Expr]string),
funcs: make(map[string]struct{}),
}
}

func (g *gen) addDecl(name fmt.Stringer, decl ast.Decl) error {
genDecl, ok := decl.(*ast.GenDecl)
if ok && genDecl.Tok == token.IMPORT {
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
importSpec := spec.(*ast.ImportSpec)
var name string
Expand Down Expand Up @@ -631,6 +733,7 @@ func (g *gen) addDecl(name fmt.Stringer, decl ast.Decl) error {
}
}
}
g.addFunc(decl)
var buf bytes.Buffer
if err := format.Node(&buf, g.pkg.Fset, decl); err != nil {
if name == nil {
Expand All @@ -643,6 +746,32 @@ func (g *gen) addDecl(name fmt.Stringer, decl ast.Decl) error {
return nil
}

func (g *gen) keyForFunc(funcDecl *ast.FuncDecl) (key string) {
if funcDecl.Recv == nil {
return funcDecl.Name.String()
} else if len(funcDecl.Recv.List) == 1 {
recv := bytes.Buffer{}
err := format.Node(&recv, g.pkg.Fset, funcDecl.Recv.List[0].Type)
if err != nil {
return
}
return recv.String() + "." + funcDecl.Name.String()
}
return
}

func (g *gen) addFunc(decl ast.Decl) (funcDecl *ast.FuncDecl, ok bool) {
if funcDecl, ok = decl.(*ast.FuncDecl); ok {
key := g.keyForFunc(funcDecl)
if key == "" {
ok = false
return
}
g.funcs[key] = struct{}{}
}
return
}

func (g *gen) resolvePackageName(path string) (string, bool) {
for _, pkg := range g.pkg.Imports {
if pkg.PkgPath == path {
Expand Down
Loading

0 comments on commit b75a1e6

Please sign in to comment.