better help tests and functionality

Signed-off-by: Jess Frazelle <acidburn@microsoft.com>
diff --git a/cli/cli.go b/cli/cli.go
index 25cf03d..f17233b 100644
--- a/cli/cli.go
+++ b/cli/cli.go
@@ -117,7 +117,7 @@
 	if args == nil ||
 		len(args) < 1 ||
 		(len(args) > 1 && contains([]string{"-h", "--help", "help"}, args[1])) {
-		return true, nil
+		return true, flag.ErrHelp
 	}
 
 	// If we do not have an action set and we have no commands, print the usage
@@ -168,6 +168,14 @@
 				return false, err
 			}
 
+			// Check that they didn't add a -h or --help flag after the subcommand's
+			// commands, like `cmd sub other thing -h`.
+			if contains([]string{"-h", "--help"}, args...) {
+				// Print the flag usage and exit.
+				p.FlagSet.Usage()
+				return false, flag.ErrHelp
+			}
+
 			if p.Before != nil {
 				if err := p.Before(ctx); err != nil {
 					return false, err
@@ -371,11 +379,14 @@
 	return false
 }
 
-func contains(match []string, s string) bool {
-	// Iterate over the items to match.
-	for _, m := range match {
-		if s == m {
-			return true
+func contains(match []string, a ...string) bool {
+	// Iterate over the items in the slice.
+	for _, s := range a {
+		// Iterate over the items to match.
+		for _, m := range match {
+			if s == m {
+				return true
+			}
 		}
 	}
 	return false
diff --git a/cli/cli_test.go b/cli/cli_test.go
index e101b62..8a360a1 100644
--- a/cli/cli_test.go
+++ b/cli/cli_test.go
@@ -13,7 +13,21 @@
 )
 
 const (
-	testHelp = `Show the test information.`
+	testCommandExpectedHelp = `Usage: yo test` + " " + `
+
+Show the test information.
+
+`
+	errorCommandExpectedHelp = `Usage: yo error` + " " + `
+
+Show the error information.
+
+`
+	versionCommandExpectedHelp = `Usage: yo version` + " " + `
+
+Show the version information.
+
+`
 )
 
 var (
@@ -39,6 +53,8 @@
 	shouldPrintUsage   bool
 	shouldPrintVersion bool
 	expectedErr        error
+	expectedStderr     string
+	expectedStdout     string
 }
 
 // Define the testCommand.
@@ -46,8 +62,8 @@
 
 func (cmd *testCommand) Name() string                                 { return "test" }
 func (cmd *testCommand) Args() string                                 { return "" }
-func (cmd *testCommand) ShortHelp() string                            { return testHelp }
-func (cmd *testCommand) LongHelp() string                             { return testHelp }
+func (cmd *testCommand) ShortHelp() string                            { return "Show the test information." }
+func (cmd *testCommand) LongHelp() string                             { return "Show the test information." }
 func (cmd *testCommand) Hidden() bool                                 { return false }
 func (cmd *testCommand) Register(fs *flag.FlagSet)                    {}
 func (cmd *testCommand) Run(ctx context.Context, args []string) error { return nil }
@@ -57,8 +73,8 @@
 
 func (cmd *errorCommand) Name() string                                 { return "error" }
 func (cmd *errorCommand) Args() string                                 { return "" }
-func (cmd *errorCommand) ShortHelp() string                            { return testHelp }
-func (cmd *errorCommand) LongHelp() string                             { return testHelp }
+func (cmd *errorCommand) ShortHelp() string                            { return "Show the error information." }
+func (cmd *errorCommand) LongHelp() string                             { return "Show the error information." }
 func (cmd *errorCommand) Hidden() bool                                 { return false }
 func (cmd *errorCommand) Register(fs *flag.FlagSet)                    {}
 func (cmd *errorCommand) Run(ctx context.Context, args []string) error { return errExpectedFromCommand }
@@ -68,11 +84,13 @@
 		{
 			description:      "nil",
 			shouldPrintUsage: true,
+			expectedErr:      flag.ErrHelp,
 		},
 		{
 			description:      "empty",
 			args:             []string{},
 			shouldPrintUsage: true,
+			expectedErr:      flag.ErrHelp,
 		},
 	}
 }
@@ -143,48 +161,88 @@
 func testCasesHelp() []testCase {
 	return []testCase{
 		{
-			description: "args: foo --help",
-			args:        []string{"foo", "--help"},
+			description:      "args: foo --help",
+			args:             []string{"foo", "--help"},
+			expectedErr:      flag.ErrHelp,
+			shouldPrintUsage: true,
 		},
 		{
-			description: "args: foo help",
-			args:        []string{"foo", "help"},
+			description:      "args: foo help",
+			args:             []string{"foo", "help"},
+			expectedErr:      flag.ErrHelp,
+			shouldPrintUsage: true,
 		},
 		{
-			description: "args: foo help bar --thing",
-			args:        []string{"foo", "help", "bar", "--thing"},
+			description:      "args: foo -h",
+			args:             []string{"foo", "-h"},
+			expectedErr:      flag.ErrHelp,
+			shouldPrintUsage: true,
 		},
 		{
-			description: "args: foo bar --help",
-			args:        []string{"foo", "bar", "--help"},
+			description:      "args: foo -h test foo",
+			args:             []string{"foo", "-h", "test", "foo"},
+			expectedErr:      flag.ErrHelp,
+			shouldPrintUsage: true,
 		},
 		{
-			description: "args: foo test --help",
-			args:        []string{"foo", "test", "--help"},
+			description:      "args: foo help bar --thing",
+			args:             []string{"foo", "help", "bar", "--thing"},
+			expectedErr:      flag.ErrHelp,
+			shouldPrintUsage: true,
 		},
 		{
-			description: "args: foo -h test foo",
-			args:        []string{"foo", "-h", "test", "foo", "--help"},
+			description:      "args: foo bar --help",
+			args:             []string{"foo", "bar", "--help"},
+			expectedErr:      errors.New("bar: no such command"),
+			shouldPrintUsage: true,
 		},
 		{
-			description: "args: foo error -h",
-			args:        []string{"foo", "error", "-h"},
+			description:    "args: foo test --help",
+			args:           []string{"foo", "test", "--help"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: testCommandExpectedHelp,
 		},
 		{
-			description: "args: foo error foo --help",
-			args:        []string{"foo", "error", "foo", "--help"},
+			description:    "args: foo error -h",
+			args:           []string{"foo", "error", "-h"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: errorCommandExpectedHelp,
 		},
 		{
-			description: "args: foo error foo bar --help",
-			args:        []string{"foo", "error", "foo", "bar", "--help"},
+			description:    "args: foo error foo --help",
+			args:           []string{"foo", "error", "foo", "--help"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: errorCommandExpectedHelp,
 		},
 		{
-			description: "args: foo version --help",
-			args:        []string{"foo", "version", "--help"},
+			description:    "args: foo error foo bar --help",
+			args:           []string{"foo", "error", "foo", "bar", "--help"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: errorCommandExpectedHelp,
 		},
 		{
-			description: "args: foo version -h",
-			args:        []string{"foo", "version", "-h"},
+			description:    "args: foo version --help",
+			args:           []string{"foo", "version", "--help"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
+		},
+		{
+			description:    "args: foo version -h",
+			args:           []string{"foo", "version", "-h"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
+		},
+		{
+			description:    "args: foo version --help another",
+			args:           []string{"foo", "version", "--help", "another"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
+		},
+		{
+			description:    "args: foo version -h another",
+			args:           []string{"foo", "version", "-h", "another"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
 		},
 	}
 }
@@ -221,7 +279,7 @@
 
 Commands:
 
-  error    Show the test information.
+  error    Show the error information.
   test     Show the test information.
   version  Show the version information.
 
@@ -315,6 +373,10 @@
 func TestProgramHelpFlag(t *testing.T) {
 	p := NewProgram()
 	p.FlagSet = flag.NewFlagSet("global", flag.ContinueOnError)
+	p.Commands = []Command{
+		&testCommand{},
+		&errorCommand{},
+	}
 	testCases := testCasesHelp()
 
 	for _, tc := range testCases {
@@ -322,17 +384,15 @@
 			c := startCapture(t)
 			printUsage, err := p.run(p.defaultContext(), tc.args)
 			stdout, stderr := c.finish()
-			if strings.Contains(stdout, versionExpected) {
-				t.Fatalf("did not expect version information to print, got %s", stdout)
+			compareErrors(t, err, tc.expectedErr)
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t, got: %t", tc.shouldPrintUsage, printUsage)
 			}
-			if err != nil {
-				t.Fatalf("expected no error from run, got %v", err)
+			if tc.expectedStderr != stderr {
+				t.Fatalf("expected stderr: %q\ngot: %q", tc.expectedStderr, stderr)
 			}
-			if !printUsage {
-				t.Fatal("expected printUsage to be true")
-			}
-			if len(stderr) > 0 {
-				t.Fatalf("expected no stderr, got: %s", stderr)
+			if tc.expectedStdout != stdout {
+				t.Fatalf("expected stdout: %q\ngot: %q", tc.expectedStdout, stdout)
 			}
 		})
 	}