pass command through

Signed-off-by: Jess Frazelle <acidburn@microsoft.com>
diff --git a/cli/cli.go b/cli/cli.go
index 82e022d..ae218c1 100644
--- a/cli/cli.go
+++ b/cli/cli.go
@@ -41,11 +41,11 @@
 	// Before defines a function to execute before any subcommands are run,
 	// but after the context is ready.
 	// If a non-nil error is returned, no subcommands are run.
-	Before func(context.Context) error
+	Before func(context.Context, Command) error
 	// After defines a function to execute after any commands or action is run
 	// and has finished.
 	// It is run _only_ if the subcommand exits without an error.
-	After func(context.Context) error
+	After func(context.Context, Command) error
 
 	// Action is the function to execute when no subcommands are specified.
 	// It gives the user back the arguments after the flags have been parsed.
@@ -168,7 +168,7 @@
 		// Run the main action _if_ we are not in the loop for the version command
 		// that is added by default.
 		if p.Before != nil {
-			if err := p.Before(ctx); err != nil {
+			if err := p.Before(ctx, nil); err != nil {
 				return err
 			}
 		}
@@ -201,7 +201,7 @@
 		// Only execute the Before function for user-supplied commands.
 		// This excludes the version command we supply.
 		if p.Before != nil && command.Name() != "version" {
-			if err := p.Before(ctx); err != nil {
+			if err := p.Before(ctx, command); err != nil {
 				return err
 			}
 		}
@@ -214,7 +214,7 @@
 
 	// Run the after function.
 	if p.After != nil {
-		if err := p.After(ctx); err != nil {
+		if err := p.After(ctx, command); err != nil {
 			return err
 		}
 	}
diff --git a/cli/cli_test.go b/cli/cli_test.go
index 9f58868..f5dab56 100644
--- a/cli/cli_test.go
+++ b/cli/cli_test.go
@@ -32,7 +32,7 @@
 )
 
 var (
-	nilFunction = func(ctx context.Context) error {
+	nilFunction = func(ctx context.Context, cmd Command) error {
 		return nil
 	}
 	nilActionFunction = func(ctx context.Context, args []string) error {
@@ -41,7 +41,7 @@
 
 	errExpected            = errors.New("expected error")
 	errExpectedFromCommand = errors.New("expected error command error")
-	errFunction            = func(ctx context.Context) error {
+	errFunction            = func(ctx context.Context, cmd Command) error {
 		return errExpected
 	}
 
@@ -396,11 +396,11 @@
 }
 
 func (p *Program) isErrorOnBefore() bool {
-	return p.Before != nil && p.Before(context.Background()) != nil
+	return p.Before != nil && p.Before(context.Background(), nil) != nil
 }
 
 func (p *Program) isErrorOnAfter() bool {
-	return p.After != nil && p.After(context.Background()) != nil
+	return p.After != nil && p.After(context.Background(), nil) != nil
 }
 
 func (tc *testCase) expectUsageToBePrintedBeforeBefore(p *Program) bool {
diff --git a/cli/example_action_test.go b/cli/example_action_test.go
index b800b5a..2112a03 100644
--- a/cli/example_action_test.go
+++ b/cli/example_action_test.go
@@ -49,7 +49,7 @@
 	p.FlagSet.BoolVar(&debug, "d", false, "enable debug logging")
 
 	// Set the before function.
-	p.Before = func(ctx context.Context) error {
+	p.Before = func(ctx context.Context, cmd cli.Command) error {
 		// Set the log level.
 		if debug {
 			// Setup your logger here...
diff --git a/cli/example_command_test.go b/cli/example_command_test.go
index 5dfccdb..898cf62 100644
--- a/cli/example_command_test.go
+++ b/cli/example_command_test.go
@@ -44,7 +44,7 @@
 	p.FlagSet.BoolVar(&debug, "d", false, "enable debug logging")
 
 	// Set the before function.
-	p.Before = func(ctx context.Context) error {
+	p.Before = func(ctx context.Context, cmd cli.Command) error {
 		// Set the log level.
 		if debug {
 			// Setup your logger here...