Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
486 changes: 486 additions & 0 deletions mage/args_test.go

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions mage/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ Options:
_fmt.Println({{printf "%q" .Comment}})
_fmt.Println()
{{end}}
_fmt.Print("Usage:\n\n\t{{$.BinaryName}} {{lower .TargetName}}{{range .Args}} <{{.Name}}>{{end}}\n\n")
_fmt.Print("Usage:\n\n\t{{$.BinaryName}} {{lower .TargetName}}{{range .RequiredArgs}} <{{.Name}}>{{end}}{{range .OptionalArgs}} [-{{.Name}}=<{{.Type}}>]{{end}}\n\n")
var aliases []string
{{- $name := .Name -}}
{{- $recv := .Receiver -}}
Expand All @@ -391,7 +391,7 @@ Options:
_fmt.Println({{printf "%q" .Comment}})
_fmt.Println()
{{end}}
_fmt.Print("Usage:\n\n\t{{$.BinaryName}} {{lower .TargetName}}{{range .Args}} <{{.Name}}>{{end}}\n\n")
_fmt.Print("Usage:\n\n\t{{$.BinaryName}} {{lower .TargetName}}{{range .RequiredArgs}} <{{.Name}}>{{end}}{{range .OptionalArgs}} [-{{.Name}}=<{{.Type}}>]{{end}}\n\n")
var aliases []string
{{- $name := .Name -}}
{{- $recv := .Receiver -}}
Expand Down Expand Up @@ -445,7 +445,7 @@ Options:
switch _strings.ToLower(target) {
{{range .Funcs }}
case "{{lower .TargetName}}":
expected := x + {{len .Args}}
expected := x + {{.NumRequiredArgs}}
if expected > len(args.Args) {
// note that expected and args at this point include the arg for the target itself
// so we subtract 1 here to show the number of args without the target.
Expand All @@ -462,7 +462,7 @@ Options:
{{$imp := .}}
{{range .Info.Funcs }}
case "{{lower .TargetName}}":
expected := x + {{len .Args}}
expected := x + {{.NumRequiredArgs}}
if expected > len(args.Args) {
// note that expected and args at this point include the arg for the target itself
// so we subtract 1 here to show the number of args without the target.
Expand Down
101 changes: 101 additions & 0 deletions mage/testdata/optargs/magefile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//go:build mage
// +build mage

package main

import (
"context"
"fmt"
"strings"
"time"
)

// Greet greets someone with an optional greeting.
func Greet(name string, greeting *string) {
if greeting != nil {
fmt.Printf("%s, %s!\n", *greeting, name)
} else {
fmt.Printf("Hello, %s!\n", name)
}
}

// Add adds two numbers. The second number is optional and defaults to 0.
func Add(a int, b *int) {
if b != nil {
fmt.Println(a + *b)
} else {
fmt.Println(a)
}
}

// Scale scales a value by an optional factor.
func Scale(value float64, factor *float64) {
if factor != nil {
fmt.Printf("%.1f\n", value*(*factor))
} else {
fmt.Printf("%.1f\n", value)
}
}

// Run runs with an optional verbose flag.
func Run(verbose *bool) {
if verbose != nil && *verbose {
fmt.Println("running verbose")
} else {
fmt.Println("running quiet")
}
}

// Delay delays with an optional extra duration.
func Delay(base time.Duration, extra *time.Duration) {
if extra != nil {
fmt.Printf("delay %s + %s\n", base, *extra)
} else {
fmt.Printf("delay %s\n", base)
}
}

// AllOptional takes only optional args.
func AllOptional(a *string, b *int) {
if a != nil {
fmt.Printf("a=%s\n", *a)
} else {
fmt.Println("a=<nil>")
}
if b != nil {
fmt.Printf("b=%d\n", *b)
} else {
fmt.Println("b=<nil>")
}
}

// Say says the message with optional capitalization and repeat count.
func Say(msg string, cap *bool, count *int) {
if cap != nil && *cap {
msg = strings.ToUpper(msg)
}
repeat := 1
if count != nil {
repeat = *count
}
for i := 0; i < repeat; i++ {
fmt.Println(msg)
}
}

// Announce prints an announcement.
func Announce(msg string) {
fmt.Printf("Announcement: %s\n", msg)
}

// Mixed tests interleaved required and optional args with context.
func Mixed(ctx context.Context, name string, greeting *string, count int) error {
g := "Hello"
if greeting != nil {
g = *greeting
}
for i := 0; i < count; i++ {
fmt.Printf("%s, %s!\n", g, name)
}
return nil
}
10 changes: 0 additions & 10 deletions magefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@ import (
"github.com/magefile/mage/sh"
)

var Aliases = map[string]interface{}{
"Speak": Say,
}

// Say says something.
func Say(msg string, i int, b bool, d time.Duration) error {
_, err := fmt.Printf("%v(%T) %v(%T) %v(%T) %v(%T)\n", msg, msg, i, i, b, b, d, d)
return err
}

// Runs "go install" for mage. This generates the version info the binary.
func Install() error {
name := "mage"
Expand Down
164 changes: 162 additions & 2 deletions parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (s Functions) Swap(i, j int) {
// Arg is an argument to a Function.
type Arg struct {
Name, Type string
Optional bool
}

// ID returns user-readable information about where this function is defined.
Expand Down Expand Up @@ -101,6 +102,49 @@ func (f Function) TargetName() string {
return strings.Join(names, ":")
}

// NumRequiredArgs returns the number of non-optional arguments.
func (f Function) NumRequiredArgs() int {
n := 0
for _, a := range f.Args {
if !a.Optional {
n++
}
}
return n
}

// RequiredArgs returns only the non-optional arguments.
func (f Function) RequiredArgs() []Arg {
var out []Arg
for _, a := range f.Args {
if !a.Optional {
out = append(out, a)
}
}
return out
}

// OptionalArgs returns only the optional arguments.
func (f Function) OptionalArgs() []Arg {
var out []Arg
for _, a := range f.Args {
if a.Optional {
out = append(out, a)
}
}
return out
}

// HasOptionalArgs reports whether the function has any optional arguments.
func (f Function) HasOptionalArgs() bool {
for _, a := range f.Args {
if a.Optional {
return true
}
}
return false
}

// ExecCode returns code for the template switch to run the target.
// It wraps each target call to match the func(context.Context) error that
// runTarget requires.
Expand All @@ -114,7 +158,12 @@ func (f Function) ExecCode() string {
}

var parseargs string

// Phase 1: Parse positional (required) arguments
for x, arg := range f.Args {
if arg.Optional {
continue
}
switch arg.Type {
case "string":
parseargs += fmt.Sprintf(`
Expand Down Expand Up @@ -155,6 +204,107 @@ func (f Function) ExecCode() string {
}
}

// Phase 2: Declare optional argument variables (nil by default)
for x, arg := range f.Args {
if !arg.Optional {
continue
}
parseargs += fmt.Sprintf(`
var arg%d *%s`, x, arg.Type)
}

// Phase 3: Parse optional arguments from -name=value flags
if f.HasOptionalArgs() {
// Collect lowercase names of bool optional args for bare-flag support
var boolOptNames []string
for _, arg := range f.Args {
if arg.Optional && arg.Type == "bool" {
boolOptNames = append(boolOptNames, strings.ToLower(arg.Name))
}
}

parseargs += fmt.Sprintf(`
for x < len(args.Args) && _strings.HasPrefix(args.Args[x], "-") {
_optArg := args.Args[x]
_eqIdx := _strings.Index(_optArg, "=")
var _optName, _optVal string
if _eqIdx < 0 {
_optName = _strings.ToLower(_optArg[1:])
switch _optName {`)
// Generate cases for each bool optional arg
for _, bname := range boolOptNames {
parseargs += fmt.Sprintf(`
case %q:
_optVal = "true"`, bname)
}
parseargs += fmt.Sprintf(`
default:
logger.Printf("invalid option %%q for target \"%s\", expected -name=value format\n", _optArg)
os.Exit(2)
}
} else {
_optName = _strings.ToLower(_optArg[1:_eqIdx])
_optVal = _optArg[_eqIdx+1:]
}
switch _optName {`, f.TargetName())
for x, arg := range f.Args {
if !arg.Optional {
continue
}
lowerName := strings.ToLower(arg.Name)
switch arg.Type {
case "string":
parseargs += fmt.Sprintf(`
case %q:
_tmp%d := _optVal
arg%d = &_tmp%d`, lowerName, x, x, x)
case "int":
parseargs += fmt.Sprintf(`
case %q:
_tmp%d, err := strconv.Atoi(_optVal)
if err != nil {
logger.Printf("can't convert option %%q value %%q to int\n", _optName, _optVal)
os.Exit(2)
}
arg%d = &_tmp%d`, lowerName, x, x, x)
case "float64":
parseargs += fmt.Sprintf(`
case %q:
_tmp%d, err := strconv.ParseFloat(_optVal, 64)
if err != nil {
logger.Printf("can't convert option %%q value %%q to float64\n", _optName, _optVal)
os.Exit(2)
}
arg%d = &_tmp%d`, lowerName, x, x, x)
case "bool":
parseargs += fmt.Sprintf(`
case %q:
_tmp%d, err := strconv.ParseBool(_optVal)
if err != nil {
logger.Printf("can't convert option %%q value %%q to bool\n", _optName, _optVal)
os.Exit(2)
}
arg%d = &_tmp%d`, lowerName, x, x, x)
case "time.Duration":
parseargs += fmt.Sprintf(`
case %q:
_tmp%d, err := time.ParseDuration(_optVal)
if err != nil {
logger.Printf("can't convert option %%q value %%q to time.Duration\n", _optName, _optVal)
os.Exit(2)
}
arg%d = &_tmp%d`, lowerName, x, x, x)
}
}
parseargs += fmt.Sprintf(`
default:
logger.Printf("unknown option %%q for target \"%s\"\n", _optName)
os.Exit(2)
}
x++
}`, f.TargetName())
}

out := parseargs + `
wrapFn := func(ctx context.Context) error {
`
Expand Down Expand Up @@ -833,14 +983,24 @@ func funcType(ft *ast.FuncType) (*Function, error) {
}
for ; x < len(ft.Params.List); x++ {
param := ft.Params.List[x]
t := fmt.Sprint(param.Type)
optional := false
paramType := param.Type
// Check for pointer types (optional arguments)
if star, ok := param.Type.(*ast.StarExpr); ok {
optional = true
paramType = star.X
}
t := fmt.Sprint(paramType)
typ, ok := argTypes[t]
if !ok {
if optional {
return nil, fmt.Errorf("unsupported argument type: *%s", t)
}
return nil, fmt.Errorf("unsupported argument type: %s", t)
}
// support for foo, bar string
for _, name := range param.Names {
f.Args = append(f.Args, Arg{Name: name.Name, Type: typ})
f.Args = append(f.Args, Arg{Name: name.Name, Type: typ, Optional: optional})
}
}
return f, nil
Expand Down
Loading