cleanup printing usage

Signed-off-by: Jess Frazelle <acidburn@microsoft.com>
diff --git a/cli/cli.go b/cli/cli.go
index b29c3f4..fc85e06 100644
--- a/cli/cli.go
+++ b/cli/cli.go
@@ -84,46 +84,56 @@
 	ctx := p.defaultContext()
 
 	// Pass the os.Args through so we can more easily unit test.
-	printUsage, err := p.run(ctx, os.Args)
-	if err == nil && !printUsage {
+	err := p.run(ctx, os.Args)
+	if err == nil {
+		// Return early if there was no error.
 		return
 	}
 
-	if err != nil {
+	if err != flag.ErrHelp {
+		// We did not return the error to print the usage, so let's print the
+		// error and exit.
 		fmt.Fprintln(os.Stderr, err.Error())
+		os.Exit(1)
 	}
-	if printUsage {
-		if err != nil {
-			// Print an extra new line to seperate from the usage output.
-			fmt.Fprintln(os.Stderr)
-		}
-		p.usage(ctx)
-	}
+
+	// Print the usage.
+	p.FlagSet.Usage()
 	os.Exit(1)
 }
 
-func (p *Program) run(ctx context.Context, args []string) (bool, error) {
+func (p *Program) run(ctx context.Context, args []string) error {
 	// Append the version command to the list of commands by default.
 	p.Commands = append(p.Commands, &versionCommand{})
 
+	// Set the default flagset if our flagset is undefined.
+	if p.FlagSet == nil {
+		p.FlagSet = defaultFlagSet(p.Name)
+	}
+
+	// Override the usage text to something nicer.
+	p.FlagSet.Usage = func() {
+		p.usage(ctx)
+	}
+
 	// IF
 	// args is <nil>
 	// OR
-	// args is less than zero
+	// args is less than 1
 	// OR
 	// we have more than one arg and it equals help OR is a help flag
 	// THEN
-	// printUsage
+	// print the usage
 	if args == nil ||
 		len(args) < 1 ||
 		(len(args) > 1 && contains([]string{"-h", "--help", "help"}, args[1])) {
-		return true, flag.ErrHelp
+		return flag.ErrHelp
 	}
 
 	// If we do not have an action set and we have no commands, print the usage
 	// and exit.
 	if p.Action == nil && len(p.Commands) < 2 {
-		return true, nil
+		return flag.ErrHelp
 	}
 
 	// Check if the command exists.
@@ -143,20 +153,15 @@
 	// Return early if we didn't enter the single action logic and
 	// the command does not exist or we were passed no commands.
 	if len(args) < 2 {
-		return true, nil
+		return flag.ErrHelp
 	}
 	if !commandExists {
-		return true, fmt.Errorf("%s: no such command", args[1])
+		return fmt.Errorf("%s: no such command", args[1])
 	}
 
 	// Iterate over the commands in the program.
 	for _, command := range p.Commands {
 		if args[1] == command.Name() {
-			// Set the default flagset if our flagset is undefined.
-			if p.FlagSet == nil {
-				p.FlagSet = defaultFlagSet(p.Name)
-			}
-
 			// Register the subcommand flags in with the common/global flags.
 			command.Register(p.FlagSet)
 
@@ -165,22 +170,21 @@
 
 			// Parse the flags the user gave us.
 			if err := p.FlagSet.Parse(args[2:]); err != nil {
-				return false, err
+				return err
 			}
 
 			// Check that they didn't add a -h or --help flag after the subcommand's
 			// commands, like `cmd sub other thing -h`.
 			if contains([]string{"-h", "--help"}, args...) {
 				// Print the flag usage and exit.
-				p.FlagSet.Usage()
-				return false, flag.ErrHelp
+				return flag.ErrHelp
 			}
 
 			// Only execute the Before function for user-supplied commands.
 			// This excludes the version command we supply.
 			if p.Before != nil && command.Name() != "version" {
 				if err := p.Before(ctx); err != nil {
-					return false, err
+					return err
 				}
 			}
 
@@ -190,43 +194,33 @@
 					p.After(ctx)
 				}
 
-				return false, err
+				return err
 			}
 
 			// Run the after function.
 			if p.After != nil {
 				if err := p.After(ctx); err != nil {
-					return false, err
+					return err
 				}
 			}
 		}
 	}
 
 	// Done.
-	return false, nil
+	return nil
 }
 
-func (p *Program) runAction(ctx context.Context, args []string) (bool, error) {
-	// Set the default flagset if our flagset is undefined.
-	if p.FlagSet == nil {
-		p.FlagSet = defaultFlagSet(p.Name)
-	}
-
-	// Override the usage text to something nicer.
-	p.FlagSet.Usage = func() {
-		p.usage(ctx)
-	}
-
+func (p *Program) runAction(ctx context.Context, args []string) error {
 	// Parse the flags the user gave us.
 	if err := p.FlagSet.Parse(args[1:]); err != nil {
-		return true, nil
+		return flag.ErrHelp
 	}
 
 	// Run the main action _if_ we are not in the loop for the version command
 	// that is added by default.
 	if p.Before != nil {
 		if err := p.Before(ctx); err != nil {
-			return false, err
+			return err
 		}
 	}
 
@@ -237,18 +231,18 @@
 			p.After(ctx)
 		}
 
-		return false, err
+		return err
 	}
 
 	// Run the after function.
 	if p.After != nil {
 		if err := p.After(ctx); err != nil {
-			return false, err
+			return err
 		}
 	}
 
 	// Done.
-	return false, nil
+	return nil
 }
 
 func (p *Program) usage(ctx context.Context) error {
diff --git a/cli/cli_test.go b/cli/cli_test.go
index cbd9752..9f58868 100644
--- a/cli/cli_test.go
+++ b/cli/cli_test.go
@@ -85,187 +85,6 @@
 func (cmd *errorCommand) Register(fs *flag.FlagSet)                    {}
 func (cmd *errorCommand) Run(ctx context.Context, args []string) error { return errExpectedFromCommand }
 
-func testCasesEmpty() []testCase {
-	return []testCase{
-		{
-			description:      "nil",
-			shouldPrintUsage: true,
-			expectedErr:      flag.ErrHelp,
-		},
-		{
-			description:      "empty",
-			args:             []string{},
-			shouldPrintUsage: true,
-			expectedErr:      flag.ErrHelp,
-		},
-	}
-}
-
-func testCasesUndefinedCommand() []testCase {
-	return []testCase{
-		{
-			description:      "args: foo",
-			args:             []string{"foo"},
-			shouldPrintUsage: true,
-		},
-		{
-			description:      "args: foo bar",
-			args:             []string{"foo", "bar"},
-			shouldPrintUsage: true,
-			expectedErr:      errors.New("bar: no such command"),
-		},
-	}
-}
-
-func testCasesWithCommands() []testCase {
-	return []testCase{
-		{
-			description: "args: foo test",
-			args:        []string{"foo", "test"},
-		},
-		{
-			description: "args: foo test foo",
-			args:        []string{"foo", "test", "foo"},
-		},
-		{
-			description: "args: foo test foo bar",
-			args:        []string{"foo", "test", "foo", "bar"},
-		},
-		{
-			description: "args: foo error",
-			args:        []string{"foo", "error"},
-			expectedErr: errExpectedFromCommand,
-		},
-		{
-			description: "args: foo error foo",
-			args:        []string{"foo", "error", "foo"},
-			expectedErr: errExpectedFromCommand,
-		},
-		{
-			description: "args: foo error foo bar",
-			args:        []string{"foo", "error", "foo", "bar"},
-			expectedErr: errExpectedFromCommand,
-		},
-		{
-			description:    "args: foo version",
-			args:           []string{"foo", "version"},
-			expectedStdout: versionCommandExpectedStdout,
-		},
-		{
-			description:    "args: foo version foo",
-			args:           []string{"foo", "version", "foo"},
-			expectedStdout: versionCommandExpectedStdout,
-		},
-		{
-			description:    "args: foo version foo bar",
-			args:           []string{"foo", "version", "foo", "bar"},
-			expectedStdout: versionCommandExpectedStdout,
-		},
-	}
-}
-
-func testCasesHelp() []testCase {
-	return []testCase{
-		{
-			description:      "args: foo --help",
-			args:             []string{"foo", "--help"},
-			expectedErr:      flag.ErrHelp,
-			shouldPrintUsage: true,
-		},
-		{
-			description:      "args: foo help",
-			args:             []string{"foo", "help"},
-			expectedErr:      flag.ErrHelp,
-			shouldPrintUsage: true,
-		},
-		{
-			description:      "args: foo -h",
-			args:             []string{"foo", "-h"},
-			expectedErr:      flag.ErrHelp,
-			shouldPrintUsage: true,
-		},
-		{
-			description:      "args: foo -h test foo",
-			args:             []string{"foo", "-h", "test", "foo"},
-			expectedErr:      flag.ErrHelp,
-			shouldPrintUsage: true,
-		},
-		{
-			description:      "args: foo help bar --thing",
-			args:             []string{"foo", "help", "bar", "--thing"},
-			expectedErr:      flag.ErrHelp,
-			shouldPrintUsage: true,
-		},
-		{
-			description:      "args: foo bar --help",
-			args:             []string{"foo", "bar", "--help"},
-			expectedErr:      errors.New("bar: no such command"),
-			shouldPrintUsage: true,
-		},
-		{
-			description:    "args: foo test --help",
-			args:           []string{"foo", "test", "--help"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: testCommandExpectedHelp,
-		},
-		{
-			description:    "args: foo error -h",
-			args:           []string{"foo", "error", "-h"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: errorCommandExpectedHelp,
-		},
-		{
-			description:    "args: foo error foo --help",
-			args:           []string{"foo", "error", "foo", "--help"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: errorCommandExpectedHelp,
-		},
-		{
-			description:    "args: foo error foo bar --help",
-			args:           []string{"foo", "error", "foo", "bar", "--help"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: errorCommandExpectedHelp,
-		},
-		{
-			description:    "args: foo version --help",
-			args:           []string{"foo", "version", "--help"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: versionCommandExpectedHelp,
-		},
-		{
-			description:    "args: foo version -h",
-			args:           []string{"foo", "version", "-h"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: versionCommandExpectedHelp,
-		},
-		{
-			description:    "args: foo version --help another",
-			args:           []string{"foo", "version", "--help", "another"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: versionCommandExpectedHelp,
-		},
-		{
-			description:    "args: foo version -h another",
-			args:           []string{"foo", "version", "-h", "another"},
-			expectedErr:    flag.ErrHelp,
-			expectedStderr: versionCommandExpectedHelp,
-		},
-	}
-}
-
-func testCasesWithAction() []testCase {
-	return []testCase{
-		{
-			description: "args: foo",
-			args:        []string{"foo"},
-		},
-		{
-			description: "args: foo bar",
-			args:        []string{"foo", "bar"},
-		},
-	}
-}
-
 func TestProgramUsage(t *testing.T) {
 	var (
 		debug  bool
@@ -388,12 +207,9 @@
 	for _, tc := range testCases {
 		t.Run(tc.description, func(t *testing.T) {
 			c := startCapture(t)
-			printUsage, err := p.run(p.defaultContext(), tc.args)
+			err := p.run(p.defaultContext(), tc.args)
 			stdout, stderr := c.finish()
 			compareErrors(t, err, tc.expectedErr)
-			if printUsage != tc.shouldPrintUsage {
-				t.Fatalf("expected printUsage to be %t, got: %t", tc.shouldPrintUsage, printUsage)
-			}
 			if tc.expectedStderr != stderr {
 				t.Fatalf("expected stderr: %q\ngot: %q", tc.expectedStderr, stderr)
 			}
@@ -587,9 +403,15 @@
 	return p.After != nil && p.After(context.Background()) != nil
 }
 
+func (tc *testCase) expectUsageToBePrintedBeforeBefore(p *Program) bool {
+	return tc.args == nil || len(tc.args) < 1 ||
+		(p.Action == nil && len(p.Commands) > 1 && len(tc.args) < 2) ||
+		(tc.expectedErr != nil && strings.Contains(tc.expectedErr.Error(), "no such command"))
+}
+
 func (p *Program) doTestRun(t *testing.T, tc testCase) {
 	c := startCapture(t)
-	printUsage, err := p.run(p.defaultContext(), tc.args)
+	err := p.run(p.defaultContext(), tc.args)
 	stdout, stderr := c.finish()
 	if len(stderr) > 0 {
 		t.Fatalf("expected no stderr, got: %s", stderr)
@@ -609,7 +431,7 @@
 	// THEN
 	// check we got the expected error defined in the testcase.
 	if (!p.isErrorOnAfter() && !p.isErrorOnBefore()) ||
-		tc.shouldPrintUsage ||
+		tc.expectUsageToBePrintedBeforeBefore(p) ||
 		(!p.isErrorOnBefore() && p.isErrorOnAfter() && tc.expectedErr == errExpectedFromCommand) {
 		compareErrors(t, err, tc.expectedErr)
 	}
@@ -624,11 +446,7 @@
 	// check we got the expected error from Before/After.
 	if (p.isErrorOnBefore() ||
 		(p.isErrorOnAfter() && tc.expectedErr != errExpectedFromCommand)) &&
-		!tc.shouldPrintUsage {
+		!tc.expectUsageToBePrintedBeforeBefore(p) {
 		compareErrors(t, err, errExpected)
 	}
-
-	if printUsage != tc.shouldPrintUsage {
-		t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
-	}
 }
diff --git a/cli/testcases_test.go b/cli/testcases_test.go
new file mode 100644
index 0000000..6939674
--- /dev/null
+++ b/cli/testcases_test.go
@@ -0,0 +1,176 @@
+package cli
+
+import (
+	"errors"
+	"flag"
+)
+
+func testCasesEmpty() []testCase {
+	return []testCase{
+		{
+			description: "nil",
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "empty",
+			args:        []string{},
+			expectedErr: flag.ErrHelp,
+		},
+	}
+}
+
+func testCasesUndefinedCommand() []testCase {
+	return []testCase{
+		{
+			description: "args: foo",
+			args:        []string{"foo"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "args: foo bar",
+			args:        []string{"foo", "bar"},
+			expectedErr: errors.New("bar: no such command"),
+		},
+	}
+}
+
+func testCasesWithCommands() []testCase {
+	return []testCase{
+		{
+			description: "args: foo test",
+			args:        []string{"foo", "test"},
+		},
+		{
+			description: "args: foo test foo",
+			args:        []string{"foo", "test", "foo"},
+		},
+		{
+			description: "args: foo test foo bar",
+			args:        []string{"foo", "test", "foo", "bar"},
+		},
+		{
+			description: "args: foo error",
+			args:        []string{"foo", "error"},
+			expectedErr: errExpectedFromCommand,
+		},
+		{
+			description: "args: foo error foo",
+			args:        []string{"foo", "error", "foo"},
+			expectedErr: errExpectedFromCommand,
+		},
+		{
+			description: "args: foo error foo bar",
+			args:        []string{"foo", "error", "foo", "bar"},
+			expectedErr: errExpectedFromCommand,
+		},
+		{
+			description:    "args: foo version",
+			args:           []string{"foo", "version"},
+			expectedStdout: versionCommandExpectedStdout,
+		},
+		{
+			description:    "args: foo version foo",
+			args:           []string{"foo", "version", "foo"},
+			expectedStdout: versionCommandExpectedStdout,
+		},
+		{
+			description:    "args: foo version foo bar",
+			args:           []string{"foo", "version", "foo", "bar"},
+			expectedStdout: versionCommandExpectedStdout,
+		},
+	}
+}
+
+func testCasesHelp() []testCase {
+	return []testCase{
+		{
+			description: "args: foo --help",
+			args:        []string{"foo", "--help"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "args: foo help",
+			args:        []string{"foo", "help"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "args: foo -h",
+			args:        []string{"foo", "-h"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "args: foo -h test foo",
+			args:        []string{"foo", "-h", "test", "foo"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "args: foo help bar --thing",
+			args:        []string{"foo", "help", "bar", "--thing"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "args: foo bar --help",
+			args:        []string{"foo", "bar", "--help"},
+			expectedErr: errors.New("bar: no such command"),
+		},
+		{
+			description:    "args: foo test --help",
+			args:           []string{"foo", "test", "--help"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: testCommandExpectedHelp,
+		},
+		{
+			description:    "args: foo error -h",
+			args:           []string{"foo", "error", "-h"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: errorCommandExpectedHelp,
+		},
+		{
+			description: "args: foo error foo --help",
+			args:        []string{"foo", "error", "foo", "--help"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description: "args: foo error foo bar --help",
+			args:        []string{"foo", "error", "foo", "bar", "--help"},
+			expectedErr: flag.ErrHelp,
+		},
+		{
+			description:    "args: foo version --help",
+			args:           []string{"foo", "version", "--help"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
+		},
+		{
+			description:    "args: foo version -h",
+			args:           []string{"foo", "version", "-h"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
+		},
+		{
+			description:    "args: foo version --help another",
+			args:           []string{"foo", "version", "--help", "another"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
+		},
+		{
+			description:    "args: foo version -h another",
+			args:           []string{"foo", "version", "-h", "another"},
+			expectedErr:    flag.ErrHelp,
+			expectedStderr: versionCommandExpectedHelp,
+		},
+	}
+}
+
+func testCasesWithAction() []testCase {
+	return []testCase{
+		{
+			description: "args: foo",
+			args:        []string{"foo"},
+		},
+		{
+			description: "args: foo bar",
+			args:        []string{"foo", "bar"},
+		},
+	}
+}