diff --git a/README.md b/README.md index ce17281..414a9af 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,31 @@ func applyCmd() *cli.Command { } ``` +Another `pflag.FlagSet` can be accessed using `*Command.PersistentFlags()`. Contrary to the +basic flags, flags set via the persistent flag set will be passed down to the children of the +command. + +```go +func applyCmd() *cli.Command { + cmd := &cli.Command{ + Use: "apply", + Short: "apply the changes" + } + force := cmd.PersistentFlags().BoolP("force", "f", false, "skip checks") + childCmd := &cli.Command{ + Use: "now" + Short: "do it now" + } + cmd.AddCommand(childCmd) + childCmd.Run = func(cmd *cli.Command, args []string) error { + fmt.Println("applied now", args[0]) + if *force { + fmt.Println("The force was with us.") + } + } +} +``` + ## Aliases To make the `apply` subcommand also available as `make` and `do`: diff --git a/children.go b/children.go index e5faf21..f985c72 100644 --- a/children.go +++ b/children.go @@ -1,12 +1,21 @@ package cli -import "fmt" +import ( + "fmt" + + "github.com/spf13/pflag" +) // AddCommand adds the supplied commands as subcommands. +// Persistent flags are passed down to the child. // This command is set as the parent of the new children. func (c *Command) AddCommand(children ...*Command) { for _, child := range children { child.parentPtr = c + if c.persistentFlags != nil { + child.persistentFlags = pflag.NewFlagSet(child.Name(), pflag.ContinueOnError) + child.PersistentFlags().AddFlagSet(c.persistentFlags) + } c.children = append(c.children, child) } } diff --git a/command.go b/command.go index a045673..f4843d0 100644 --- a/command.go +++ b/command.go @@ -43,9 +43,10 @@ type Command struct { Args Arguments // internal fields - children []*Command - flags *pflag.FlagSet - parentPtr *Command + children []*Command + flags *pflag.FlagSet + persistentFlags *pflag.FlagSet + parentPtr *Command } // Execute runs the application. It should be run on the most outer level @@ -90,6 +91,7 @@ func (c *Command) execute(args []string) error { } // parse flags + c.Flags().AddFlagSet(c.PersistentFlags()) if err := c.Flags().Parse(args); err != nil { return c.help(err) } diff --git a/flags.go b/flags.go index f14f162..93bfbe4 100644 --- a/flags.go +++ b/flags.go @@ -14,6 +14,15 @@ func (c *Command) Flags() *pflag.FlagSet { return c.flags } +// PersistentFlags returns the `*pflag.FlagSet` with the persistent flags of this command. +// Persistent flags are passed to subcommands. +func (c *Command) PersistentFlags() *pflag.FlagSet { + if c.persistentFlags == nil { + c.persistentFlags = pflag.NewFlagSet(c.Name(), pflag.ContinueOnError) + } + return c.persistentFlags +} + // stripFlags removes flags from the argument line, leaving only subcommands and // positional arguments. func stripFlags(args []string, c *Command) []string { diff --git a/flags_test.go b/flags_test.go index 2377c0c..c16e196 100644 --- a/flags_test.go +++ b/flags_test.go @@ -77,3 +77,17 @@ func TestStripFlags(t *testing.T) { } } } + +func TestPersistentFlags(t *testing.T) { + parent := &Command{} + parent.PersistentFlags().String("persistent", "", "") + parent.Flags().String("non-persistent", "", "") + child := &Command{} + parent.AddCommand(child) + if child.PersistentFlags().Lookup("persistent") == nil { + t.Error("expected persistent flag to be passed to child") + } + if child.Flags().Lookup("non-persistent") != nil { + t.Error("expected non-persistent flag to not be passed to child") + } +}