Skip to content

Commit

Permalink
fix: flag attribution to struct fields of composite type (#545)
Browse files Browse the repository at this point in the history
Reviewed-by: Cezar Craciunoiu <[email protected]>
Approved-by: Cezar Craciunoiu <[email protected]>
  • Loading branch information
craciunoiuc authored Jul 18, 2023
2 parents 0bc6d7b + fa50a22 commit 8b08d25
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 34 deletions.
185 changes: 168 additions & 17 deletions cmdfactory/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package cmdfactory
import (
"context"
"errors"
"flag"
"fmt"
"os"
"reflect"
Expand Down Expand Up @@ -172,12 +173,149 @@ func isSameCommand(cmd *cobra.Command, cmdline string) bool {
return par.Name() == cmdFields[len(cmdFields)-2] && cmd.Name() == cmdFields[len(cmdFields)-1]
}

func execute(c *cobra.Command, a []string) (err error) {
if len(c.Deprecated) > 0 {
c.Printf("command %q is deprecated, %s\n", c.Name(), c.Deprecated)
}

// If help is called, regardless of other flags, return we want help.
// Also say we need help if the command isn't runnable.
if helpVal, err := c.Flags().GetBool("help"); err == nil && helpVal {
return flag.ErrHelp
}

if !c.Runnable() {
return flag.ErrHelp
}

argWoFlags := c.Flags().Args()
if c.DisableFlagParsing {
argWoFlags = a
}

if err := c.ValidateArgs(argWoFlags); err != nil {
return err
}

for p := c; p != nil; p = p.Parent() {
if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
return err
}
break
} else if p.PersistentPreRun != nil {
p.PersistentPreRun(c, argWoFlags)
break
}
}
if c.PreRunE != nil {
if err := c.PreRunE(c, argWoFlags); err != nil {
return err
}
} else if c.PreRun != nil {
c.PreRun(c, argWoFlags)
}

if err := c.ValidateRequiredFlags(); err != nil {
return err
}
if err := c.ValidateFlagGroups(); err != nil {
return err
}

if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
return err
}
} else {
c.Run(c, argWoFlags)
}
if c.PostRunE != nil {
if err := c.PostRunE(c, argWoFlags); err != nil {
return err
}
} else if c.PostRun != nil {
c.PostRun(c, argWoFlags)
}
for p := c; p != nil; p = p.Parent() {
if p.PersistentPostRunE != nil {
if err := p.PersistentPostRunE(c, argWoFlags); err != nil {
return err
}
break
} else if p.PersistentPostRun != nil {
p.PersistentPostRun(c, argWoFlags)
break
}
}

return nil
}

func executeC(c *cobra.Command) (cmd *cobra.Command, err error) {
// Regardless of what command execute is called on, run on Root only
if c.HasParent() {
return executeC(c.Root())
}

args := os.Args[1:]

var flags []string
if c.TraverseChildren {
cmd, flags, err = c.Traverse(args)
} else {
cmd, flags, err = c.Find(args)
}
if err != nil {
// If found parse to a subcommand and then failed, talk about the subcommand
if cmd != nil {
c = cmd
}
if !c.SilenceErrors {
c.PrintErrln("Error:", err.Error())
c.PrintErrf("Run '%v --help' for usage.\n", c.CommandPath())
}
return c, err
}

// We have to pass global context to children command
// if context is present on the parent command.
if cmd.Context() == nil {
cmd.SetContext(c.Context())
}

if err = execute(cmd, flags); err != nil {
// Always show help if requested, even if SilenceErrors is in
// effect
if errors.Is(err, flag.ErrHelp) {
cmd.HelpFunc()(cmd, args)
return cmd, nil
}

// If root command has SilenceErrors flagged,
// all subcommands should respect it
if !cmd.SilenceErrors && !c.SilenceErrors {
c.PrintErrln("Error:", err.Error())
}

// If root command has SilenceUsage flagged,
// all subcommands should respect it
if !cmd.SilenceUsage && !c.SilenceUsage {
c.Println(cmd.UsageString())
}
}

return cmd, err
}

// Main executes the given command
func Main(ctx context.Context, cmd *cobra.Command) {
// Expand flag all dynamically registered flag overrides.
expandRegisteredFlags(cmd)

if err := cmd.ExecuteContext(ctx); err != nil {
cmd.SetContext(ctx)

if _, err := executeC(cmd); err != nil {
fmt.Println(err)
os.Exit(1)
}
Expand Down Expand Up @@ -265,14 +403,26 @@ func AttributeFlags(c *cobra.Command, obj any, args ...string) error {
switch fieldType.Tag.Get("split") {
case "false":
arrays[name] = v
flags.StringArrayP(name, alias, nil, usage)
if ptr := (*[]string)(unsafe.Pointer(v.Addr().Pointer())); *ptr != nil {
flags.StringArrayVarP(ptr, name, alias, *ptr, usage)
} else {
flags.StringArrayP(name, alias, nil, usage)
}
default:
slices[name] = v
flags.StringSliceP(name, alias, nil, usage)
if ptr := (*[]string)(unsafe.Pointer(v.Addr().Pointer())); *ptr != nil {
flags.StringSliceVarP(ptr, name, alias, *ptr, usage)
} else {
flags.StringSliceP(name, alias, nil, usage)
}
}
case reflect.Map:
maps[name] = v
flags.StringSliceP(name, alias, nil, usage)
if ptr := (*[]string)(unsafe.Pointer(v.Addr().Pointer())); *ptr != nil {
flags.StringSliceVarP(ptr, name, alias, *ptr, usage)
} else {
flags.StringSliceP(name, alias, nil, usage)
}
case reflect.Pointer:
switch fieldType.Type.Elem().Kind() {
case reflect.Int, reflect.Int64:
Expand Down Expand Up @@ -310,16 +460,12 @@ func AttributeFlags(c *cobra.Command, obj any, args ...string) error {
}

// If any arguments are passed, parse them immediately
subC, args, err := c.Find(args)
if err != nil {
return err
}
if len(args) > 0 {
// Some kraft commands accept flags which registration is delayed using
// RegisterFlag. Parsing these here would result in a failure.
args = filterOutRegisteredFlags(subC, args)
args = filterOutRegisteredFlags(c, args)

if err := subC.ParseFlags(args); err != nil && !errors.Is(err, pflag.ErrHelp) {
if err := c.ParseFlags(args); err != nil && !errors.Is(err, pflag.ErrHelp) {
return err
}
}
Expand Down Expand Up @@ -351,11 +497,16 @@ func New(obj Runnable, cmd cobra.Command) (*cobra.Command, error) {
c.SilenceErrors = true
c.SilenceUsage = true
c.DisableFlagsInUseLine = true
c.RunE = obj.Run
c.InitDefaultHelpFlag()
c.InitDefaultCompletionCmd()

if obj != nil {
c.RunE = obj.Run

// Parse the attributes of this object into addressable flags for this command
if err := AttributeFlags(&c, obj); err != nil {
return nil, err
// Parse the attributes of this object into addressable flags for this command
if err := AttributeFlags(&c, obj); err != nil {
return nil, err
}
}

// Set help and usage methods
Expand Down Expand Up @@ -418,7 +569,7 @@ func assignMaps(app *cobra.Command, maps map[string]reflect.Value) error {
k = contextKey(k)
s, err := app.Flags().GetStringSlice(k)
if err != nil {
return err
continue
}
if s != nil {
values := map[string]string{}
Expand All @@ -441,7 +592,7 @@ func assignSlices(app *cobra.Command, slices map[string]reflect.Value) error {
k = contextKey(k)
s, err := app.Flags().GetStringSlice(k)
if err != nil {
return err
continue
}
a := app.Flags().Lookup(k)
if a.Changed && len(s) == 0 {
Expand All @@ -459,7 +610,7 @@ func assignArrays(app *cobra.Command, arrays map[string]reflect.Value) error {
k = contextKey(k)
s, err := app.Flags().GetStringArray(k)
if err != nil {
return err
continue
}
a := app.Flags().Lookup(k)
if a.Changed && len(s) == 0 {
Expand Down
Loading

0 comments on commit 8b08d25

Please sign in to comment.