diff --git a/cmdfactory/builder.go b/cmdfactory/builder.go index 10ddd3be7..6d4570ac4 100644 --- a/cmdfactory/builder.go +++ b/cmdfactory/builder.go @@ -20,8 +20,6 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - - "kraftkit.sh/internal/set" ) var ( @@ -93,8 +91,12 @@ func expandRegisteredFlags(cmd *cobra.Command) { continue } - for _, flag := range flags { - subCmd.Flags().AddFlag(flag) + if subCmd != nil && subCmd.Flags() != nil { + for _, flag := range flags { + if subCmd.Flags().Lookup(flag.Name) == nil { + subCmd.Flags().AddFlag(flag) + } + } } } } @@ -106,9 +108,9 @@ func filterOutRegisteredFlags(cmd *cobra.Command, args []string) (filteredArgs [ continue } - registeredFlagsNames := set.NewStringSet() + registeredFlagsNames := map[string]*pflag.Flag{} for _, flag := range flags { - registeredFlagsNames.Add(flag.Name) + registeredFlagsNames[flag.Name] = flag } for len(args) > 0 { @@ -125,8 +127,8 @@ func filterOutRegisteredFlags(cmd *cobra.Command, args []string) (filteredArgs [ subs := strings.SplitN(arg, "=", 2) flagName := strings.TrimPrefix(subs[0], "--") - if registeredFlagsNames.ContainsExactly(flagName) { - if len(subs) == 1 { + if flag, ok := registeredFlagsNames[flagName]; ok { + if flag.Value.Type() != "bool" && len(subs) == 1 { args = args[1:] } continue @@ -139,7 +141,7 @@ func filterOutRegisteredFlags(cmd *cobra.Command, args []string) (filteredArgs [ subs := strings.SplitN(arg, "=", 2) flagName := strings.TrimPrefix(subs[0], "-") - if registeredFlagsNames.ContainsExactly(flagName) { + if _, ok := registeredFlagsNames[flagName]; ok { if len(subs) == 1 { args = args[1:] } @@ -461,9 +463,9 @@ func AttributeFlags(c *cobra.Command, obj any, args ...string) error { // If any arguments are passed, parse them immediately if len(args) > 0 { - // Some kraft commands accept flags which registration is delayed using - // RegisterFlag. Parsing these here would result in a failure. - args = filterOutRegisteredFlags(c, args) + // Expand all registered flags pre-emptively such that they can be correctly + // parsed. + expandRegisteredFlags(c) if err := c.ParseFlags(args); err != nil && !errors.Is(err, pflag.ErrHelp) { return err