add 50% test coverage

Signed-off-by: Jess Frazelle <acidburn@microsoft.com>
diff --git a/Makefile b/Makefile
index 2b81050..83ad90c 100644
--- a/Makefile
+++ b/Makefile
@@ -58,7 +58,7 @@
 .PHONY: test
 test: ## Runs the go tests
 	@echo "+ $@"
-	@$(GO) test -v -tags "$(BUILDTAGS) cgo" $(shell $(GO) list ./... | grep -v vendor)
+	@$(GO) test -tags "$(BUILDTAGS) cgo" $(shell $(GO) list ./... | grep -v vendor)
 
 .PHONY: vet
 vet: ## Verifies `go vet` passes
diff --git a/cli/cli.go b/cli/cli.go
index 3e1802f..a9a5ce7 100644
--- a/cli/cli.go
+++ b/cli/cli.go
@@ -47,7 +47,8 @@
 	After func(context.Context) error
 
 	// Action is the function to execute when no subcommands are specified.
-	Action func(context.Context) error
+	// It gives the user back the arguments after the flags have been parsed.
+	Action func(context.Context, []string) error
 }
 
 // Command defines the interface for each command in a program.
@@ -83,63 +84,69 @@
 	ctx := context.WithValue(context.Background(), GitCommitKey, p.GitCommit)
 	ctx = context.WithValue(ctx, VersionKey, p.Version)
 
+	// Pass the os.Args through so we can more easily unit test.
+	printUsage, err := p.run(ctx, os.Args)
+	if err == nil && !printUsage {
+		return
+	}
+
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err.Error())
+	}
+	if printUsage {
+		if err != nil {
+			// Print an extra new line to seperate from the usage output.
+			fmt.Fprintln(os.Stderr)
+		}
+		p.usage(ctx)
+	}
+	os.Exit(1)
+}
+
+func (p *Program) run(ctx context.Context, args []string) (bool, error) {
+	// TODO(jessfraz): Find a better way to tell that they passed -h through as a flag.
+	if len(args) > 1 &&
+		(strings.Contains(strings.ToLower(args[1]), "help") ||
+			strings.ToLower(args[1]) == "-h") ||
+		args == nil || len(args) < 1 {
+		return true, nil
+	}
+
+	// If we do not have an action set and we have no commands, print the usage
+	// and exit.
+	if p.Action == nil && len(p.Commands) < 1 {
+		return true, nil
+	}
+
 	// Append the version command to the list of commands by default.
 	p.Commands = append(p.Commands, &versionCommand{})
 
-	// TODO(jessfraz): Find a better way to tell that they passed -h through as a flag.
-	if len(os.Args) > 1 &&
-		(strings.Contains(strings.ToLower(os.Args[1]), "help") ||
-			strings.ToLower(os.Args[1]) == "-h") {
-		p.usage(ctx)
-		os.Exit(1)
-	}
-
-	// Set the default action to print the usage if it is undefined.
-	if p.Action == nil {
-		p.Action = p.usage
+	// Check if the command exists.
+	var commandExists bool
+	if len(args) > 1 && in(args[1], p.Commands) {
+		commandExists = true
 	}
 
 	// If we are not running a commands we know, then automatically
 	// run the main action of the program instead.
 	// Also enter this loop if we weren't passed any arguments.
-	if len(os.Args) < 2 || !in(os.Args[1], p.Commands) {
-		// Set the default flagset if our flagset is undefined.
-		if p.FlagSet == nil {
-			p.FlagSet = defaultFlagSet(p.Name)
-		}
+	if p.Action != nil &&
+		(len(args) < 2 || !commandExists) {
+		return p.runAction(ctx, args)
+	}
 
-		// Override the usage text to something nicer.
-		p.FlagSet.Usage = func() {
-			p.usage(ctx)
-		}
-
-		// Parse the flags the user gave us.
-		if err := p.FlagSet.Parse(os.Args[1:]); err != nil {
-			p.usage(ctx)
-			os.Exit(1)
-		}
-
-		// Run the main action _if_ we are not in the loop for the version command
-		// that is added by default.
-		if p.Before != nil {
-			if err := p.Before(ctx); err != nil {
-				fmt.Fprintf(os.Stderr, "%v\n", err)
-				os.Exit(1)
-			}
-		}
-
-		if err := p.Action(ctx); err != nil {
-			fmt.Fprintf(os.Stderr, "%v\n", err)
-			os.Exit(1)
-		}
-
-		// Done.
-		return
+	// Return early if we didn't enter the single action logic and
+	// the command does not exist or we were passed no commands.
+	if len(args) < 2 {
+		return true, nil
+	}
+	if !commandExists {
+		return true, fmt.Errorf("%s: no such command", args[1])
 	}
 
 	// Iterate over the commands in the program.
 	for _, command := range p.Commands {
-		if os.Args[1] == command.Name() {
+		if args[1] == command.Name() {
 			// Set the default flagset if our flagset is undefined.
 			if p.FlagSet == nil {
 				p.FlagSet = defaultFlagSet(p.Name)
@@ -152,47 +159,81 @@
 			p.resetCommandUsage(command)
 
 			// Parse the flags the user gave us.
-			if err := p.FlagSet.Parse(os.Args[2:]); err != nil {
-				p.FlagSet.Usage()
-				os.Exit(1)
+			if err := p.FlagSet.Parse(args[2:]); err != nil {
+				return false, err
 			}
 
 			if p.Before != nil {
 				if err := p.Before(ctx); err != nil {
-					fmt.Fprintf(os.Stderr, "%v\n", err)
-					os.Exit(1)
+					return false, err
 				}
 			}
 
 			// Run the command with the context and post-flag-processing args.
 			if err := command.Run(ctx, p.FlagSet.Args()); err != nil {
-				fmt.Fprintf(os.Stderr, "%v\n", err)
-
 				if p.After != nil {
-					if err := p.After(ctx); err != nil {
-						fmt.Fprintf(os.Stderr, "%v\n", err)
-					}
+					p.After(ctx)
 				}
 
-				os.Exit(1)
+				return false, err
 			}
 
 			// Run the after function.
 			if p.After != nil {
 				if err := p.After(ctx); err != nil {
-					fmt.Fprintf(os.Stderr, "%v\n", err)
-					os.Exit(1)
+					return false, err
 				}
 			}
-
-			// Done.
-			return
 		}
 	}
 
-	fmt.Fprintf(os.Stderr, "%s: no such command\n\n", os.Args[1])
-	p.usage(ctx)
-	os.Exit(1)
+	// Done.
+	return false, nil
+}
+
+func (p *Program) runAction(ctx context.Context, args []string) (bool, error) {
+	// Set the default flagset if our flagset is undefined.
+	if p.FlagSet == nil {
+		p.FlagSet = defaultFlagSet(p.Name)
+	}
+
+	// Override the usage text to something nicer.
+	p.FlagSet.Usage = func() {
+		p.usage(ctx)
+	}
+
+	// Parse the flags the user gave us.
+	if err := p.FlagSet.Parse(args[1:]); err != nil {
+		return true, nil
+	}
+
+	// Run the main action _if_ we are not in the loop for the version command
+	// that is added by default.
+	if p.Before != nil {
+		if err := p.Before(ctx); err != nil {
+			return false, err
+		}
+	}
+
+	// Run the action with the context and post-flag-processing args.
+	if err := p.Action(ctx, p.FlagSet.Args()); err != nil {
+		// Run the after function.
+		if p.After != nil {
+			p.After(ctx)
+		}
+
+		return false, err
+	}
+
+	// Run the after function.
+	if p.After != nil {
+		if err := p.After(ctx); err != nil {
+			return false, err
+		}
+	}
+
+	// Done.
+	return false, nil
 }
 
 func (p *Program) usage(ctx context.Context) error {
diff --git a/cli/cli_test.go b/cli/cli_test.go
new file mode 100644
index 0000000..ef7f040
--- /dev/null
+++ b/cli/cli_test.go
@@ -0,0 +1,493 @@
+package cli
+
+import (
+	"context"
+	"errors"
+	"flag"
+	"fmt"
+	"testing"
+)
+
+const (
+	testHelp = `Show the test information.`
+)
+
+var (
+	nilFunction = func(ctx context.Context) error {
+		return nil
+	}
+	nilActionFunction = func(ctx context.Context, args []string) error {
+		return nil
+	}
+
+	errExpected            = errors.New("expected error")
+	errExpectedFromCommand = errors.New("expected error command error")
+	errFunction            = func(ctx context.Context) error {
+		return errExpected
+	}
+)
+
+func (cmd *testCommand) Name() string      { return "test" }
+func (cmd *testCommand) Args() string      { return "" }
+func (cmd *testCommand) ShortHelp() string { return testHelp }
+func (cmd *testCommand) LongHelp() string  { return testHelp }
+func (cmd *testCommand) Hidden() bool      { return false }
+
+func (cmd *testCommand) Register(fs *flag.FlagSet) {}
+
+type testCommand struct{}
+
+func (cmd *testCommand) Run(ctx context.Context, args []string) error {
+	return nil
+}
+
+func (cmd *errorCommand) Name() string      { return "error" }
+func (cmd *errorCommand) Args() string      { return "" }
+func (cmd *errorCommand) ShortHelp() string { return testHelp }
+func (cmd *errorCommand) LongHelp() string  { return testHelp }
+func (cmd *errorCommand) Hidden() bool      { return false }
+
+func (cmd *errorCommand) Register(fs *flag.FlagSet) {}
+
+type errorCommand struct{}
+
+func (cmd *errorCommand) Run(ctx context.Context, args []string) error {
+	return errExpectedFromCommand
+}
+
+func TestProgramWithNoCommandsOrFlagsOrAction(t *testing.T) {
+	p := NewProgram()
+	testCases := []struct {
+		description string
+		args        []string
+	}{
+		{
+			description: "nil",
+		},
+		{
+			description: "empty",
+			args:        []string{},
+		},
+		{
+			description: "args: foo",
+			args:        []string{"foo"},
+		},
+		{
+			description: "args: foo bar",
+			args:        []string{"foo", "bar"},
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.description, func(t *testing.T) {
+			printUsage, err := p.run(context.Background(), tc.args)
+			if err != nil {
+				t.Fatalf("expected no error, got: %v", err)
+			}
+
+			if !printUsage {
+				t.Fatal("expected behavior was to print the usage")
+			}
+		})
+	}
+}
+
+func TestProgramWithNoCommandsOrFlags(t *testing.T) {
+	p := NewProgram()
+	p.Action = nilActionFunction
+	testCases := []struct {
+		description      string
+		args             []string
+		shouldPrintUsage bool
+	}{
+		{
+			description:      "nil",
+			shouldPrintUsage: true,
+		},
+		{
+			description:      "empty",
+			args:             []string{},
+			shouldPrintUsage: true,
+		},
+		{
+			description:      "args: foo",
+			args:             []string{"foo"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo bar",
+			args:             []string{"foo", "bar"},
+			shouldPrintUsage: false,
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.description, func(t *testing.T) {
+			printUsage, err := p.run(context.Background(), tc.args)
+			if err != nil {
+				t.Fatalf("expected no error, got: %v", err)
+			}
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+}
+
+func TestProgramWithCommandsAndAction(t *testing.T) {
+	p := NewProgram()
+	p.Commands = []Command{
+		&errorCommand{},
+		&testCommand{},
+	}
+	p.Action = nilActionFunction
+	testCases := []struct {
+		description      string
+		args             []string
+		shouldPrintUsage bool
+		expectedErr      error
+	}{
+		{
+			description:      "nil",
+			shouldPrintUsage: true,
+		},
+		{
+			description:      "empty",
+			args:             []string{},
+			shouldPrintUsage: true,
+		},
+		{
+			description:      "args: foo",
+			args:             []string{"foo"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo bar",
+			args:             []string{"foo", "bar"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo test",
+			args:             []string{"foo", "test"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo test foo",
+			args:             []string{"foo", "test", "foo"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo test foo bar",
+			args:             []string{"foo", "test", "foo", "bar"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo error",
+			args:             []string{"foo", "error"},
+			shouldPrintUsage: false,
+			expectedErr:      errExpectedFromCommand,
+		},
+		{
+			description:      "args: foo error foo",
+			args:             []string{"foo", "error", "foo"},
+			shouldPrintUsage: false,
+			expectedErr:      errExpectedFromCommand,
+		},
+		{
+			description:      "args: foo error foo bar",
+			args:             []string{"foo", "error", "foo", "bar"},
+			shouldPrintUsage: false,
+			expectedErr:      errExpectedFromCommand,
+		},
+		{
+			description:      "args: foo version",
+			args:             []string{"foo", "version"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo version foo",
+			args:             []string{"foo", "version", "foo"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo version foo bar",
+			args:             []string{"foo", "version", "foo", "bar"},
+			shouldPrintUsage: false,
+		},
+		/*{
+			description:      "args: foo version --help",
+			args:             []string{"foo", "version", "--help"},
+			shouldPrintUsage: true,
+		},*/
+	}
+
+	// Create the context with the values we need to pass to the version command.
+	ctx := context.WithValue(context.Background(), GitCommitKey, p.GitCommit)
+	ctx = context.WithValue(ctx, VersionKey, p.Version)
+
+	for _, tc := range testCases {
+		t.Run(tc.description, func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			compareErrors(t, err, tc.expectedErr)
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Add a Before.
+	p.Before = nilFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with Successful Before -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			compareErrors(t, err, tc.expectedErr)
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Add an After.
+	p.After = nilFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with successful After -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			compareErrors(t, err, tc.expectedErr)
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Test program with an error on After.
+	p.After = errFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with error on After -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			// When we print the usage for nil and empty, we never hit
+			// the After function.
+			if !tc.shouldPrintUsage {
+				// If we are at the point where the command should fail, we should
+				// expect that error.
+				if tc.expectedErr == errExpectedFromCommand {
+					compareErrors(t, err, errExpectedFromCommand)
+				} else {
+					compareErrors(t, err, errExpected)
+				}
+			} else {
+				compareErrors(t, err, tc.expectedErr)
+			}
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Test program with an error on Before.
+	p.Before = errFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with error on Before -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			// When we print the usage for nil and empty, we never hit
+			// the After function.
+			if !tc.shouldPrintUsage {
+				compareErrors(t, err, errExpected)
+			} else {
+				compareErrors(t, err, tc.expectedErr)
+			}
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+}
+
+func TestProgramWithCommands(t *testing.T) {
+	p := NewProgram()
+	p.Commands = []Command{
+		&errorCommand{},
+		&testCommand{},
+	}
+	testCases := []struct {
+		description      string
+		args             []string
+		shouldPrintUsage bool
+		expectedErr      error
+	}{
+		{
+			description:      "nil",
+			shouldPrintUsage: true,
+		},
+		{
+			description:      "empty",
+			args:             []string{},
+			shouldPrintUsage: true,
+		},
+		{
+			description:      "args: foo",
+			args:             []string{"foo"},
+			shouldPrintUsage: true,
+		},
+		{
+			description:      "args: foo bar",
+			args:             []string{"foo", "bar"},
+			shouldPrintUsage: true,
+			expectedErr:      errors.New("bar: no such command"),
+		},
+		{
+			description:      "args: foo test",
+			args:             []string{"foo", "test"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo test foo",
+			args:             []string{"foo", "test", "foo"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo test foo bar",
+			args:             []string{"foo", "test", "foo", "bar"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo error",
+			args:             []string{"foo", "error"},
+			shouldPrintUsage: false,
+			expectedErr:      errExpectedFromCommand,
+		},
+		{
+			description:      "args: foo error foo",
+			args:             []string{"foo", "error", "foo"},
+			shouldPrintUsage: false,
+			expectedErr:      errExpectedFromCommand,
+		},
+		{
+			description:      "args: foo error foo bar",
+			args:             []string{"foo", "error", "foo", "bar"},
+			shouldPrintUsage: false,
+			expectedErr:      errExpectedFromCommand,
+		},
+		{
+			description:      "args: foo version",
+			args:             []string{"foo", "version"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo version foo",
+			args:             []string{"foo", "version", "foo"},
+			shouldPrintUsage: false,
+		},
+		{
+			description:      "args: foo version foo bar",
+			args:             []string{"foo", "version", "foo", "bar"},
+			shouldPrintUsage: false,
+		},
+	}
+
+	// Create the context with the values we need to pass to the version command.
+	ctx := context.WithValue(context.Background(), GitCommitKey, p.GitCommit)
+	ctx = context.WithValue(ctx, VersionKey, p.Version)
+
+	for _, tc := range testCases {
+		t.Run(tc.description, func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			compareErrors(t, err, tc.expectedErr)
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Add a Before.
+	p.Before = nilFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with Successful Before -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			compareErrors(t, err, tc.expectedErr)
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Add an After.
+	p.After = nilFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with successful After -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			compareErrors(t, err, tc.expectedErr)
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Test program with an error on After.
+	p.After = errFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with error on After -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(ctx, tc.args)
+			// When we print the usage for nil and empty, we never hit
+			// the After function.
+			if !tc.shouldPrintUsage {
+				// If we are at the point where the command should fail, we should
+				// expect that error.
+				if tc.expectedErr == errExpectedFromCommand {
+					compareErrors(t, err, errExpectedFromCommand)
+				} else {
+					compareErrors(t, err, errExpected)
+				}
+			} else {
+				compareErrors(t, err, tc.expectedErr)
+			}
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+
+	// Test program with an error on Before.
+	p.Before = errFunction
+	for _, tc := range testCases {
+		t.Run(fmt.Sprintf("with error on Before -> %s", tc.description), func(t *testing.T) {
+			printUsage, err := p.run(context.Background(), tc.args)
+			// When we print the usage for nil and empty, we never hit
+			// the After function.
+			if !tc.shouldPrintUsage {
+				compareErrors(t, err, errExpected)
+			} else {
+				compareErrors(t, err, tc.expectedErr)
+			}
+
+			if printUsage != tc.shouldPrintUsage {
+				t.Fatalf("expected printUsage to be %t got %t", tc.shouldPrintUsage, printUsage)
+			}
+		})
+	}
+}
+
+func compareErrors(t *testing.T, err, expectedErr error) {
+	if expectedErr != nil {
+		if err == nil || err.Error() != expectedErr.Error() {
+			t.Fatalf("expected error %#v got: %#v", expectedErr, err)
+		}
+
+		return
+	}
+
+	if err != expectedErr {
+		t.Fatalf("expected error %#v got: %#v", expectedErr, err)
+	}
+
+	return
+}