diff --git a/internal/cli/app.go b/internal/cli/app.go index 6e26370..d0cdd18 100644 --- a/internal/cli/app.go +++ b/internal/cli/app.go @@ -1,60 +1,43 @@ package cli import ( - "crypto/tls" "fmt" "io" - "net" "net/url" "os" - "strconv" "strings" - "time" "github.com/ryanfowler/fetch/internal/aws" + "github.com/ryanfowler/fetch/internal/config" "github.com/ryanfowler/fetch/internal/core" - "github.com/ryanfowler/fetch/internal/printer" ) // App represents the full configuration for a fetch invocation. type App struct { URL *url.URL - AWSSigv4 *aws.Config - Basic *core.KeyVal - Bearer string - BuildInfo bool - Color core.Color - Data io.Reader - DNSServer *url.URL - DryRun bool - Edit bool - Form []core.KeyVal - Format core.Format - Headers []core.KeyVal - Help bool - HTTP core.HTTPVersion - IgnoreStatus bool - Insecure bool - JSON bool - Method string - Multipart []core.KeyVal - NoEncode bool - NoPager bool - Output string - Proxy *url.URL - QueryParams []core.KeyVal - Redirects *int - Silent bool - Timeout time.Duration - TLS uint16 - Update bool - Verbose int - Version bool - XML bool + Cfg config.Config + + AWSSigv4 *aws.Config + Basic *core.KeyVal + Bearer string + BuildInfo bool + ConfigPath string + Data io.Reader + DryRun bool + Edit bool + Form []core.KeyVal + Help bool + JSON bool + Method string + Multipart []core.KeyVal + Output string + Update bool + Version bool + XML bool } -func (a *App) PrintHelp(p *printer.Printer) { +func (a *App) PrintHelp(p *core.Printer) { printHelp(a.CLI(), p) } @@ -87,15 +70,6 @@ func (a *App) CLI() *CLI { return fmt.Errorf("unsupported url scheme: %s", u.Scheme) } - if u.Scheme == "" { - host := u.Hostname() - if !strings.Contains(host, ".") || net.ParseIP(host) != nil { - u.Scheme = "http" - } else { - u.Scheme = "https" - } - } - a.URL = u return nil }, @@ -118,7 +92,7 @@ func (a *App) CLI() *CLI { region, service, ok := cut(value, "/") if !ok { const usage = "format must be " - return flagValueError("aws-sigv4", value, usage) + return core.NewValueError("aws-sigv4", value, usage, false) } accessKey := os.Getenv("AWS_ACCESS_KEY_ID") @@ -152,7 +126,7 @@ func (a *App) CLI() *CLI { user, pass, ok := cut(value, ":") if !ok { const usage = "format must be " - return flagValueError("basic", value, usage) + return core.NewValueError("basic", value, usage, false) } a.Basic = &core.KeyVal{Key: user, Val: pass} return nil @@ -195,20 +169,23 @@ func (a *App) CLI() *CLI { Default: "", Values: []string{"auto", "off", "on"}, IsSet: func() bool { - return a.Color != core.ColorUnknown + return a.Cfg.Color != core.ColorUnknown }, Fn: func(value string) error { - switch value { - case "auto": - a.Color = core.ColorAuto - case "off": - a.Color = core.ColorOff - case "on": - a.Color = core.ColorOn - default: - const usage = "must be one of [auto, off, on]" - return flagValueError("color", value, usage) - } + return a.Cfg.ParseColor(value) + }, + }, + { + Short: "c", + Long: "config", + Args: "PATH", + Description: "Path to config file", + Default: "", + IsSet: func() bool { + return a.ConfigPath != "" + }, + Fn: func(value string) error { + a.ConfigPath = value return nil }, }, @@ -254,35 +231,10 @@ func (a *App) CLI() *CLI { Description: "DNS server IP or DoH URL", Default: "", IsSet: func() bool { - return a.DNSServer != nil + return a.Cfg.DNSServer != nil }, Fn: func(value string) error { - if strings.HasPrefix(value, "https://") || strings.HasPrefix(value, "http://") { - u, err := url.Parse(value) - if err != nil { - return flagValueError("dns-server", value, "unable to parse DoH URL") - } - a.DNSServer = u - return nil - } - - port := "53" - host := value - const usage = "must be in the format " - if colons := strings.Count(value, ":"); colons == 1 || (colons > 1 && strings.HasPrefix(value, "[")) { - var err error - host, port, err = net.SplitHostPort(value) - if err != nil { - return flagValueError("dns-server", value, usage) - } - } - if net.ParseIP(host) == nil { - return flagValueError("dns-server", value, usage) - } - - u := url.URL{Host: host + ":" + port} - a.DNSServer = &u - return nil + return a.Cfg.ParseDNSServer(value) }, }, { @@ -336,21 +288,10 @@ func (a *App) CLI() *CLI { Default: "", Values: []string{"auto", "off", "on"}, IsSet: func() bool { - return a.Format != core.FormatUnknown + return a.Cfg.Format != core.FormatUnknown }, Fn: func(value string) error { - switch value { - case "auto": - a.Format = core.FormatAuto - case "off": - a.Format = core.FormatOff - case "on": - a.Format = core.FormatOn - default: - const usage = "must be one of [auto, off, on]" - return flagValueError("format", value, usage) - } - return nil + return a.Cfg.ParseFormat(value) }, }, { @@ -360,12 +301,10 @@ func (a *App) CLI() *CLI { Description: "Set headers for the request", Default: "", IsSet: func() bool { - return len(a.Headers) > 0 + return len(a.Cfg.Headers) > 0 }, Fn: func(value string) error { - key, val, _ := cut(value, ":") - a.Headers = append(a.Headers, core.KeyVal{Key: key, Val: val}) - return nil + return a.Cfg.ParseHeader(value) }, }, { @@ -390,19 +329,10 @@ func (a *App) CLI() *CLI { Default: "", Values: []string{"1", "2"}, IsSet: func() bool { - return a.HTTP != core.HTTPDefault + return a.Cfg.HTTP != core.HTTPDefault }, Fn: func(value string) error { - switch value { - case "1": - a.HTTP = core.HTTP1 - case "2": - a.HTTP = core.HTTP2 - default: - const usage = "must be one of [1, 2]" - return flagValueError("http", value, usage) - } - return nil + return a.Cfg.ParseHTTP(value) }, }, { @@ -412,10 +342,11 @@ func (a *App) CLI() *CLI { Description: "Exit code unaffected by HTTP status", Default: "", IsSet: func() bool { - return a.IgnoreStatus + return a.Cfg.IgnoreStatus != nil }, Fn: func(value string) error { - a.IgnoreStatus = true + v := true + a.Cfg.IgnoreStatus = &v return nil }, }, @@ -426,10 +357,11 @@ func (a *App) CLI() *CLI { Description: "Accept invalid TLS certificates - DANGER!", Default: "", IsSet: func() bool { - return a.Insecure + return a.Cfg.Insecure != nil }, Fn: func(value string) error { - a.Insecure = true + v := true + a.Cfg.Insecure = &v return nil }, }, @@ -496,10 +428,11 @@ func (a *App) CLI() *CLI { Description: "Avoid requesting gzip encoding", Default: "", IsSet: func() bool { - return a.NoEncode + return a.Cfg.NoEncode != nil }, Fn: func(value string) error { - a.NoEncode = true + v := true + a.Cfg.NoEncode = &v return nil }, }, @@ -510,10 +443,11 @@ func (a *App) CLI() *CLI { Description: "Avoid using a pager for the response body", Default: "", IsSet: func() bool { - return a.NoPager + return a.Cfg.NoPager != nil }, Fn: func(value string) error { - a.NoPager = true + v := true + a.Cfg.NoPager = &v return nil }, }, @@ -538,15 +472,10 @@ func (a *App) CLI() *CLI { Description: "Configure a proxy", Default: "", IsSet: func() bool { - return a.Proxy != nil + return a.Cfg.Proxy != nil }, Fn: func(value string) error { - proxy, err := url.Parse(value) - if err != nil { - return flagValueError("proxy", value, err.Error()) - } - a.Proxy = proxy - return nil + return a.Cfg.ParseProxy(value) }, }, { @@ -556,12 +485,10 @@ func (a *App) CLI() *CLI { Description: "Append query parameters to the url", Default: "", IsSet: func() bool { - return len(a.QueryParams) > 0 + return len(a.Cfg.QueryParams) > 0 }, Fn: func(value string) error { - key, val, _ := cut(value, "=") - a.QueryParams = append(a.QueryParams, core.KeyVal{Key: key, Val: val}) - return nil + return a.Cfg.ParseQuery(value) }, }, { @@ -571,16 +498,10 @@ func (a *App) CLI() *CLI { Description: "Maximum number of redirects", Default: "", IsSet: func() bool { - return a.Redirects != nil + return a.Cfg.Redirects != nil }, Fn: func(value string) error { - n, err := strconv.Atoi(value) - if err != nil || n < 0 { - const usage = "must be a positive integer" - return flagValueError("redirects", value, usage) - } - a.Redirects = &n - return nil + return a.Cfg.ParseRedirects(value) }, }, { @@ -590,10 +511,11 @@ func (a *App) CLI() *CLI { Description: "Print only errors to stderr", Default: "", IsSet: func() bool { - return a.Silent + return a.Cfg.Silent != nil }, Fn: func(value string) error { - a.Silent = true + v := true + a.Cfg.Silent = &v return nil }, }, @@ -604,16 +526,10 @@ func (a *App) CLI() *CLI { Description: "Timeout in seconds applied to the request", Default: "", IsSet: func() bool { - return a.Timeout != 0 + return a.Cfg.Timeout != nil }, Fn: func(value string) error { - secs, err := strconv.ParseFloat(value, 64) - if err != nil { - return flagValueError("timeout", value, "must be a valid number") - } - - a.Timeout = time.Duration(float64(time.Second) * secs) - return nil + return a.Cfg.ParseTimeout(value) }, }, { @@ -624,23 +540,10 @@ func (a *App) CLI() *CLI { Default: "", Values: []string{"1.0", "1.1", "1.2", "1.3"}, IsSet: func() bool { - return a.TLS != 0 + return a.Cfg.TLS != nil }, Fn: func(value string) error { - switch value { - case "1.0": - a.TLS = tls.VersionTLS10 - case "1.1": - a.TLS = tls.VersionTLS11 - case "1.2": - a.TLS = tls.VersionTLS12 - case "1.3": - a.TLS = tls.VersionTLS13 - default: - const usage = "must be one of [1.0, 1.1, 1.2, 1.3]" - return flagValueError("tls", value, usage) - } - return nil + return a.Cfg.ParseTLS(value) }, }, { @@ -664,10 +567,14 @@ func (a *App) CLI() *CLI { Description: "Verbosity of the output", Default: "", IsSet: func() bool { - return a.Verbose > 0 + return a.Cfg.Verbosity != nil }, Fn: func(value string) error { - a.Verbose += 1 + if a.Cfg.Verbosity == nil { + a.Cfg.Verbosity = core.PointerTo(1) + } else { + (*a.Cfg.Verbosity)++ + } return nil }, }, @@ -710,48 +617,6 @@ func cut(s, sep string) (string, string, bool) { return key, val, ok } -type FlagValueError struct { - Flag string - Value string - Usage string -} - -func flagValueError(flag, value, usage string) *FlagValueError { - return &FlagValueError{ - Flag: flag, - Value: value, - Usage: usage, - } -} - -func (err *FlagValueError) Error() string { - msg := fmt.Sprintf("invalid value '%s' for option '--%s'", err.Flag, err.Value) - if err.Usage == "" { - msg = fmt.Sprintf("%s: %s", msg, err.Usage) - } - return msg -} - -func (err *FlagValueError) PrintTo(p *printer.Printer) { - p.WriteString("invalid value '") - p.Set(printer.Yellow) - p.WriteString(err.Value) - p.Reset() - - p.WriteString("' for option '") - p.Set(printer.Bold) - p.WriteString("--") - p.WriteString(err.Flag) - p.Reset() - - p.WriteString("'") - - if err.Usage != "" { - p.WriteString(": ") - p.WriteString(err.Usage) - } -} - type MissingEnvVarError struct { EnvVar string Flag string @@ -768,14 +633,14 @@ func (err *MissingEnvVarError) Error() string { return fmt.Sprintf("missing environment variable '%s' required for option '--%s'", err.EnvVar, err.Flag) } -func (err *MissingEnvVarError) PrintTo(p *printer.Printer) { +func (err *MissingEnvVarError) PrintTo(p *core.Printer) { p.WriteString("missing environment variable '") - p.Set(printer.Yellow) + p.Set(core.Yellow) p.WriteString(err.EnvVar) p.Reset() p.WriteString("' required for option '") - p.Set(printer.Bold) + p.Set(core.Bold) p.WriteString("--") p.WriteString(err.Flag) p.Reset() diff --git a/internal/cli/cli.go b/internal/cli/cli.go index fc9fe52..40b1a12 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/ryanfowler/fetch/internal/printer" + "github.com/ryanfowler/fetch/internal/core" ) type CLI struct { @@ -219,17 +219,17 @@ func Parse(args []string) (*App, error) { return &app, nil } -func printHelp(cli *CLI, p *printer.Printer) { +func printHelp(cli *CLI, p *core.Printer) { p.WriteString(cli.Description) p.WriteString("\n\n") - p.Set(printer.Bold) - p.Set(printer.Underline) + p.Set(core.Bold) + p.Set(core.Underline) p.WriteString("Usage") p.Reset() p.WriteString(": ") - p.Set(printer.Bold) + p.Set(core.Bold) p.WriteString("fetch") p.Reset() @@ -247,8 +247,8 @@ func printHelp(cli *CLI, p *printer.Printer) { if len(cli.Args) > 0 { p.WriteString("\n") - p.Set(printer.Bold) - p.Set(printer.Underline) + p.Set(core.Bold) + p.Set(core.Underline) p.WriteString("Arguments") p.Reset() p.WriteString(":\n") @@ -265,8 +265,8 @@ func printHelp(cli *CLI, p *printer.Printer) { if len(cli.Flags) > 0 { p.WriteString("\n") - p.Set(printer.Bold) - p.Set(printer.Underline) + p.Set(core.Bold) + p.Set(core.Underline) p.WriteString("Options") p.Reset() p.WriteString(":\n") @@ -277,7 +277,7 @@ func printHelp(cli *CLI, p *printer.Printer) { continue } - p.Set(printer.Bold) + p.Set(core.Bold) p.WriteString(" ") if flag.Short == "" { diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 71bd708..ede3c78 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -6,7 +6,6 @@ import ( "unicode/utf8" "github.com/ryanfowler/fetch/internal/core" - "github.com/ryanfowler/fetch/internal/printer" ) func TestCLI(t *testing.T) { @@ -14,7 +13,7 @@ func TestCLI(t *testing.T) { if err != nil { t.Fatalf("unable to parse cli: %s", err.Error()) } - p := printer.NewHandle(core.ColorOff).Stdout() + p := core.NewHandle(core.ColorOff).Stdout() // Verify that no line of the help command is over 80 characters. app.PrintHelp(p) diff --git a/internal/client/client.go b/internal/client/client.go index f7cebb4..dec53ae 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "fmt" "io" + "net" "net/http" "net/url" "strings" @@ -135,10 +136,23 @@ func (c *Client) NewRequest(ctx context.Context, cfg RequestConfig) (*http.Reque cfg.Body = cfg.Multipart } - // Create the initial HTTP request. + // If no scheme was provided, use various heuristics to choose between + // http and https. + if cfg.URL.Scheme == "" { + host := cfg.URL.Hostname() + if !strings.Contains(host, ".") || net.ParseIP(host) != nil { + cfg.URL.Scheme = "http" + } else { + cfg.URL.Scheme = "https" + } + } + + // If no method was provided, default to GET. if cfg.Method == "" { cfg.Method = "GET" } + + // Create the initial HTTP request. req, err := http.NewRequestWithContext(ctx, cfg.Method, cfg.URL.String(), cfg.Body) if err != nil { return nil, err diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..9617f04 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,322 @@ +package config + +import ( + "crypto/tls" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/ryanfowler/fetch/internal/core" +) + +// Config represents the configuration options for fetch. +type Config struct { + isFile bool + + Color core.Color + DNSServer *url.URL + Format core.Format + Headers []core.KeyVal + HTTP core.HTTPVersion + IgnoreStatus *bool + Insecure *bool + NoEncode *bool + NoPager *bool + Proxy *url.URL + QueryParams []core.KeyVal + Redirects *int + Silent *bool + Timeout *time.Duration + TLS *uint16 + Verbosity *int +} + +// Merge merges the two Configs together, with "c" taking priority. +func (c *Config) Merge(c2 *Config) { + if c.Color == core.ColorUnknown { + c.Color = c2.Color + } + if c.DNSServer == nil { + c.DNSServer = c2.DNSServer + } + if c.Format == core.FormatUnknown { + c.Format = c2.Format + } + if len(c2.Headers) > 0 { + c.Headers = append(c2.Headers, c.Headers...) + } + if c.HTTP == core.HTTPDefault { + c.HTTP = c2.HTTP + } + if c.IgnoreStatus == nil { + c.IgnoreStatus = c2.IgnoreStatus + } + if c.Insecure == nil { + c.Insecure = c2.Insecure + } + if c.NoEncode == nil { + c.NoEncode = c2.NoEncode + } + if c.NoPager == nil { + c.NoPager = c2.NoPager + } + if c.Proxy == nil { + c.Proxy = c2.Proxy + } + if len(c2.QueryParams) > 0 { + c.QueryParams = append(c2.QueryParams, c.QueryParams...) + } + if c.Redirects == nil { + c.Redirects = c2.Redirects + } + if c.Silent == nil { + c.Silent = c2.Silent + } + if c.Timeout == nil { + c.Timeout = c2.Timeout + } + if c.TLS == nil { + c.TLS = c2.TLS + } + if c.Verbosity == nil { + c.Verbosity = c2.Verbosity + } +} + +// Set sets the provided key and value pair, returning any error encountered. +func (c *Config) Set(key, val string) error { + var err error + switch key { + case "color": + err = c.ParseColor(val) + case "dns-server": + err = c.ParseDNSServer(val) + case "format": + err = c.ParseFormat(val) + case "header": + err = c.ParseHeader(val) + case "http": + err = c.ParseHTTP(val) + case "ignore-status": + err = c.ParseIgnoreStatus(val) + case "insecure": + err = c.ParseInsecure(val) + case "no-encode": + err = c.ParseNoEncode(val) + case "no-pager": + err = c.ParseNoPager(val) + case "proxy": + err = c.ParseProxy(val) + case "query": + err = c.ParseQuery(val) + case "redirects": + err = c.ParseRedirects(val) + case "silent": + err = c.ParseSilent(val) + case "timeout": + err = c.ParseTimeout(val) + case "tls": + err = c.ParseTLS(val) + case "verbosity": + err = c.ParseVerbosity(val) + default: + err = fmt.Errorf("invalid option '%s'", key) + } + return err +} + +func (c *Config) ParseColor(value string) error { + switch value { + case "auto": + c.Color = core.ColorAuto + case "off": + c.Color = core.ColorOff + case "on": + c.Color = core.ColorOn + default: + const usage = "must be one of [auto, off, on]" + return core.NewValueError("color", value, usage, c.isFile) + } + return nil +} + +func (c *Config) ParseDNSServer(value string) error { + if strings.HasPrefix(value, "https://") || strings.HasPrefix(value, "http://") { + u, err := url.Parse(value) + if err != nil { + return core.NewValueError("dns-server", value, "unable to parse DoH URL", c.isFile) + } + c.DNSServer = u + return nil + } + + port := "53" + host := value + const usage = "must be in the format " + if colons := strings.Count(value, ":"); colons == 1 || (colons > 1 && strings.HasPrefix(value, "[")) { + var err error + host, port, err = net.SplitHostPort(value) + if err != nil { + return core.NewValueError("dns-server", value, usage, c.isFile) + } + } + if net.ParseIP(host) == nil { + return core.NewValueError("dns-server", value, usage, c.isFile) + } + + u := url.URL{Host: net.JoinHostPort(host, port)} + c.DNSServer = &u + return nil +} + +func (c *Config) ParseFormat(value string) error { + switch value { + case "auto": + c.Format = core.FormatAuto + case "off": + c.Format = core.FormatOff + case "on": + c.Format = core.FormatOn + default: + const usage = "must be one of [auto, off, on]" + return core.NewValueError("format", value, usage, c.isFile) + } + return nil +} + +func (c *Config) ParseHeader(value string) error { + key, val, _ := cut(value, ":") + c.Headers = append(c.Headers, core.KeyVal{Key: key, Val: val}) + return nil + +} + +func (c *Config) ParseHTTP(value string) error { + switch value { + case "1": + c.HTTP = core.HTTP1 + case "2": + c.HTTP = core.HTTP2 + default: + const usage = "must be one of [1, 2]" + return core.NewValueError("http", value, usage, c.isFile) + } + return nil +} + +func (c *Config) ParseIgnoreStatus(value string) error { + v, err := strconv.ParseBool(value) + if err != nil { + return core.NewValueError("ignore-status", value, "must be a boolean", c.isFile) + } + c.IgnoreStatus = &v + return nil +} + +func (c *Config) ParseInsecure(value string) error { + v, err := strconv.ParseBool(value) + if err != nil { + return core.NewValueError("insecure", value, "must be a boolean", c.isFile) + } + c.Insecure = &v + return nil +} + +func (c *Config) ParseNoEncode(value string) error { + v, err := strconv.ParseBool(value) + if err != nil { + return core.NewValueError("no-encode", value, "must be a boolean", c.isFile) + } + c.NoEncode = &v + return nil +} + +func (c *Config) ParseNoPager(value string) error { + v, err := strconv.ParseBool(value) + if err != nil { + return core.NewValueError("no-pager", value, "must be a boolean", c.isFile) + } + c.NoPager = &v + return nil +} + +func (c *Config) ParseProxy(value string) error { + proxy, err := url.Parse(value) + if err != nil { + return core.NewValueError("proxy", value, err.Error(), c.isFile) + } + c.Proxy = proxy + return nil +} + +func (c *Config) ParseQuery(value string) error { + key, val, _ := cut(value, "=") + c.QueryParams = append(c.QueryParams, core.KeyVal{Key: key, Val: val}) + return nil +} + +func (c *Config) ParseRedirects(value string) error { + n, err := strconv.Atoi(value) + if err != nil || n < 0 { + const usage = "must be a positive integer" + return core.NewValueError("redirects", value, usage, c.isFile) + } + c.Redirects = &n + return nil +} + +func (c *Config) ParseSilent(value string) error { + v, err := strconv.ParseBool(value) + if err != nil { + return core.NewValueError("silent", value, "must be a boolean", c.isFile) + } + c.Silent = &v + return nil +} + +func (c *Config) ParseTimeout(value string) error { + secs, err := strconv.ParseFloat(value, 64) + if err != nil { + return core.NewValueError("timeout", value, "must be a valid number", c.isFile) + } + c.Timeout = core.PointerTo(time.Duration(float64(time.Second) * secs)) + return nil +} + +func (c *Config) ParseTLS(value string) error { + var version uint16 + switch value { + case "1.0": + version = tls.VersionTLS10 + case "1.1": + version = tls.VersionTLS11 + case "1.2": + version = tls.VersionTLS12 + case "1.3": + version = tls.VersionTLS13 + default: + const usage = "must be one of [1.0, 1.1, 1.2, 1.3]" + return core.NewValueError("tls", value, usage, c.isFile) + } + c.TLS = &version + return nil + +} + +func (c *Config) ParseVerbosity(value string) error { + v, err := strconv.Atoi(value) + if err != nil || v < 0 { + return core.NewValueError("verbosity", value, "must be a valid integer", c.isFile) + } + c.Verbosity = &v + return nil +} + +func cut(s, sep string) (string, string, bool) { + key, val, ok := strings.Cut(s, sep) + key, val = strings.TrimSpace(key), strings.TrimSpace(val) + return key, val, ok +} diff --git a/internal/config/file.go b/internal/config/file.go new file mode 100644 index 0000000..664f866 --- /dev/null +++ b/internal/config/file.go @@ -0,0 +1,168 @@ +package config + +import ( + "errors" + "fmt" + "iter" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + + "github.com/ryanfowler/fetch/internal/core" +) + +// File represents a configuration file. +type File struct { + Global *Config + Hosts map[string]*Config +} + +// GetFile returns a config File, or nil if one cannot be found. +func GetFile(path string) (*File, error) { + buf, err := getConfigFile(path) + if err != nil || buf == nil { + return nil, err + } + return parseFile(string(buf)) +} + +// getConfigFile searches for a local config file, returning the file contents +// if it exists. +func getConfigFile(path string) ([]byte, error) { + if path != "" { + // Direct config path was provided. + return os.ReadFile(path) + } + + if runtime.GOOS == "windows" { + appData := os.Getenv("AppData") + if appData == "" { + return nil, nil + } + d, err := os.ReadFile(filepath.Join(appData, "fetch", "config")) + if err != nil { + return nil, nil + } + return d, nil + } + + xdgHome := os.Getenv("XDG_CONFIG_HOME") + if xdgHome != "" { + f, err := os.ReadFile(xdgHome + "/fetch/config") + if err == nil { + return f, nil + } + } + + home := os.Getenv("HOME") + if home != "" { + f, err := os.ReadFile(home + "/.config/fetch/config") + if err == nil { + return f, nil + } + } + + return nil, nil +} + +// parseFile parses the provided File, returning any error encountered. +func parseFile(s string) (*File, error) { + f := File{Global: &Config{isFile: true}} + + config := f.Global + for num, line := range lines(s) { + line = strings.TrimSpace(line) + + if line == "" || line[0] == '#' { + // Skip empty lines and comments. + continue + } + + // Parse out a hostname. + if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { + hostStr := strings.TrimSpace(line[1 : len(line)-1]) + if hostStr == "" { + return nil, newFileError(num, errors.New("empty hostname")) + } + + config = &Config{isFile: true} + if f.Hosts == nil { + f.Hosts = make(map[string]*Config) + } + f.Hosts[hostStr] = config + continue + } + + // Pares a key and value pair. + key, val, ok := strings.Cut(line, "=") + if !ok { + return nil, newFileError(num, fmt.Errorf("invalid key/value pair: '%s'", line)) + } + key, val = strings.TrimSpace(key), strings.TrimSpace(val) + + err := config.Set(key, val) + if err != nil { + return nil, fileLineError{line: num, err: err} + } + } + + return &f, nil +} + +// lines returns an iterator over lines and line numbers. +func lines(s string) iter.Seq2[int, string] { + return func(yield func(int, string) bool) { + var num int + for len(s) > 0 { + num++ + + i := strings.IndexFunc(s, func(r rune) bool { + return r == '\n' || r == '\r' + }) + if i < 0 { + yield(num, s) + return + } + + if !yield(num, s[:i]) { + return + } + + n := 1 + if s[i] == '\r' && i+1 < len(s) && s[i+1] == '\n' { + n = 2 + } + s = s[i+n:] + } + } +} + +// fileLineError represents an error that prints a config file line with an err. +type fileLineError struct { + line int + err error +} + +func newFileError(line int, err error) fileLineError { + return fileLineError{line: line, err: err} +} + +func (err fileLineError) Error() string { + return fmt.Sprintf("config file: line %d: %s", err.line, err.err.Error()) +} + +func (err fileLineError) PrintTo(p *core.Printer) { + p.WriteString("config file: line ") + p.Set(core.Bold) + p.WriteString(strconv.Itoa(err.line)) + p.Reset() + p.WriteString(": ") + + if pt, ok := err.err.(core.PrinterTo); ok { + pt.PrintTo(p) + } else { + p.WriteString(err.err.Error()) + } +} diff --git a/internal/config/file_test.go b/internal/config/file_test.go new file mode 100644 index 0000000..0efbfa9 --- /dev/null +++ b/internal/config/file_test.go @@ -0,0 +1,90 @@ +package config + +import ( + "crypto/tls" + "reflect" + "strings" + "testing" + "time" + + "github.com/ryanfowler/fetch/internal/core" +) + +func TestParseFile(t *testing.T) { + tests := []struct { + name string + config string + expFile *File + expErr string + }{ + { + name: "successful parse", + config: ` + timeout = 10 + tls = 1.3`, + expFile: &File{ + Global: &Config{ + isFile: true, + Timeout: core.PointerTo(10 * time.Second), + TLS: core.PointerTo(uint16(tls.VersionTLS13)), + }, + }, + }, + { + name: "successful parse with hosts", + config: ` + # This is a comment + color = off + no-pager = true + + [example.com] + insecure = true + + [anotherhost.com] + ignore-status = true`, + expFile: &File{ + Global: &Config{ + isFile: true, + Color: core.ColorOff, + NoPager: core.PointerTo(true), + }, + Hosts: map[string]*Config{ + "example.com": { + isFile: true, + Insecure: core.PointerTo(true), + }, + "anotherhost.com": { + isFile: true, + IgnoreStatus: core.PointerTo(true), + }, + }, + }, + }, + { + name: "invalid key and value pair", + config: ` + color = off + invalidline`, + expErr: "line 3: invalid key/value pair: 'invalidline'", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f, err := parseFile(test.config) + if err != nil { + if test.expErr == "" { + t.Fatalf("unexpected error: %s", err.Error()) + } + if !strings.Contains(err.Error(), test.expErr) { + t.Fatalf("unexpected error: %s", err.Error()) + } + return + } + + if !reflect.DeepEqual(f, test.expFile) { + t.Fatalf("unexpected file: %+v\n", *f) + } + }) + } +} diff --git a/internal/core/core.go b/internal/core/core.go index 4e2e973..8cfb6c2 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -43,3 +43,8 @@ const ( type KeyVal struct { Key, Val string } + +// PointerTo returns a pointer to the value provided. +func PointerTo[T any](t T) *T { + return &t +} diff --git a/internal/core/errors.go b/internal/core/errors.go index 6990547..3fc5a59 100644 --- a/internal/core/errors.go +++ b/internal/core/errors.go @@ -20,3 +20,52 @@ type SignalError string func (err SignalError) Error() string { return fmt.Sprintf("received signal: %s", string(err)) } + +type ValueError struct { + isFile bool + option string + value string + usage string +} + +func NewValueError(option, value, usage string, isFile bool) *ValueError { + return &ValueError{ + isFile: isFile, + option: option, + value: value, + usage: usage, + } +} + +func (err *ValueError) Error() string { + option := err.option + if !err.isFile { + option = "--" + option + } + msg := fmt.Sprintf("invalid value '%s' for option '%s'", err.value, option) + if err.usage == "" { + msg = fmt.Sprintf("%s: %s", msg, err.usage) + } + return msg +} + +func (err *ValueError) PrintTo(p *Printer) { + p.WriteString("invalid value '") + p.Set(Yellow) + p.WriteString(err.value) + p.Reset() + + p.WriteString("' for option '") + p.Set(Bold) + if !err.isFile { + p.WriteString("--") + } + p.WriteString(err.option) + p.Reset() + p.WriteString("'") + + if err.usage != "" { + p.WriteString(": ") + p.WriteString(err.usage) + } +} diff --git a/internal/printer/printer.go b/internal/core/printer.go similarity index 89% rename from internal/printer/printer.go rename to internal/core/printer.go index 0a5654a..40c7194 100644 --- a/internal/printer/printer.go +++ b/internal/core/printer.go @@ -1,11 +1,9 @@ -package printer +package core import ( "bytes" "io" "os" - - "github.com/ryanfowler/fetch/internal/core" ) // Sequence represents an ANSI escape sequence. @@ -43,10 +41,10 @@ type Handle struct { } // NewHandle returns a new Handle given the provided color configuration. -func NewHandle(c core.Color) *Handle { +func NewHandle(c Color) *Handle { return &Handle{ - stderr: newPrinter(os.Stderr, core.IsStderrTerm, c), - stdout: newPrinter(os.Stdout, core.IsStdoutTerm, c), + stderr: newPrinter(os.Stderr, IsStderrTerm, c), + stdout: newPrinter(os.Stdout, IsStdoutTerm, c), } } @@ -68,12 +66,12 @@ type Printer struct { useColor bool } -func newPrinter(file *os.File, isTerm bool, c core.Color) *Printer { +func newPrinter(file *os.File, isTerm bool, c Color) *Printer { var useColor bool switch c { - case core.ColorOn: + case ColorOn: useColor = true - case core.ColorOff: + case ColorOff: useColor = false default: // By default, set color settings based on whether the file is diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go index 0b6d3b0..c31aca1 100644 --- a/internal/fetch/fetch.go +++ b/internal/fetch/fetch.go @@ -21,7 +21,6 @@ import ( "github.com/ryanfowler/fetch/internal/format" "github.com/ryanfowler/fetch/internal/image" "github.com/ryanfowler/fetch/internal/multipart" - "github.com/ryanfowler/fetch/internal/printer" ) type ContentType int @@ -46,7 +45,7 @@ type Request struct { NoEncode bool NoPager bool Output string - PrinterHandle *printer.Handle + PrinterHandle *core.Handle Redirects *int TLS uint16 Verbosity core.Verbosity @@ -74,8 +73,8 @@ func Fetch(ctx context.Context, r *Request) int { } p := r.PrinterHandle.Stderr() - p.Set(printer.Red) - p.Set(printer.Bold) + p.Set(core.Red) + p.Set(core.Bold) p.WriteString("error") p.Reset() p.WriteString(": ") @@ -211,7 +210,7 @@ func makeRequest(r *Request, c *client.Client, req *http.Request) (int, error) { return exitCode, nil } -func formatResponse(r *Request, resp *http.Response, p *printer.Printer) (io.Reader, error) { +func formatResponse(r *Request, resp *http.Response, p *core.Printer) (io.Reader, error) { if r.Output != "" && r.Output != "-" { f, err := os.Create(r.Output) if err != nil { @@ -303,7 +302,7 @@ func getContentType(headers http.Header) ContentType { return TypeUnknown } -func streamToStdout(r io.Reader, p *printer.Printer, forceOutput, noPager bool) error { +func streamToStdout(r io.Reader, p *core.Printer, forceOutput, noPager bool) error { // Check output to see if it's likely safe to print to stdout. if core.IsStdoutTerm && !forceOutput { var ok bool @@ -392,9 +391,9 @@ func isCertificateErr(err error) bool { return false } -func printInsecureMsg(p *printer.Printer) { +func printInsecureMsg(p *core.Printer) { p.WriteString("If you absolutely trust the server, try '") - p.Set(printer.Bold) + p.Set(core.Bold) p.WriteString("--insecure") p.Reset() p.WriteString("'.\n") diff --git a/internal/fetch/print.go b/internal/fetch/print.go index ebbde48..04cfad2 100644 --- a/internal/fetch/print.go +++ b/internal/fetch/print.go @@ -10,12 +10,11 @@ import ( "unicode/utf8" "github.com/ryanfowler/fetch/internal/core" - "github.com/ryanfowler/fetch/internal/printer" ) -func printRequestMetadata(p *printer.Printer, req *http.Request) { - p.Set(printer.Bold) - p.Set(printer.Yellow) +func printRequestMetadata(p *core.Printer, req *http.Request) { + p.Set(core.Bold) + p.Set(core.Yellow) p.WriteString(req.Method) p.Reset() @@ -25,22 +24,22 @@ func printRequestMetadata(p *printer.Printer, req *http.Request) { } p.WriteString(" ") - p.Set(printer.Bold) - p.Set(printer.Cyan) + p.Set(core.Bold) + p.Set(core.Cyan) p.WriteString(path) p.Reset() q := req.URL.RawQuery if req.URL.ForceQuery || q != "" { - p.Set(printer.Italic) - p.Set(printer.Cyan) + p.Set(core.Italic) + p.Set(core.Cyan) p.WriteString("?") p.WriteString(q) p.Reset() } p.WriteString(" ") - p.Set(printer.Dim) + p.Set(core.Dim) p.WriteString(req.Proto) p.Reset() @@ -52,8 +51,8 @@ func printRequestMetadata(p *printer.Printer, req *http.Request) { } for _, h := range headers { - p.Set(printer.Bold) - p.Set(printer.Blue) + p.Set(core.Bold) + p.Set(core.Blue) p.WriteString(h.Key) p.Reset() p.WriteString(": ") @@ -62,15 +61,15 @@ func printRequestMetadata(p *printer.Printer, req *http.Request) { } } -func printResponseMetadata(p *printer.Printer, v core.Verbosity, resp *http.Response) { - p.Set(printer.Dim) +func printResponseMetadata(p *core.Printer, v core.Verbosity, resp *http.Response) { + p.Set(core.Dim) p.WriteString(resp.Proto) p.Reset() p.WriteString(" ") statusColor := colorForStatus(resp.StatusCode) p.Set(statusColor) - p.Set(printer.Bold) + p.Set(core.Bold) p.WriteString(strconv.Itoa(resp.StatusCode)) text := http.StatusText(resp.StatusCode) @@ -91,7 +90,7 @@ func printResponseMetadata(p *printer.Printer, v core.Verbosity, resp *http.Resp p.WriteString("\n") } -func printResponseHeaders(p *printer.Printer, resp *http.Response) { +func printResponseHeaders(p *core.Printer, resp *http.Response) { headers := getHeaders(resp.Header) if resp.ContentLength >= 0 && resp.Header.Get("Content-Length") == "" { val := strconv.FormatInt(resp.ContentLength, 10) @@ -103,8 +102,8 @@ func printResponseHeaders(p *printer.Printer, resp *http.Response) { } for _, h := range headers { - p.Set(printer.Bold) - p.Set(printer.Cyan) + p.Set(core.Bold) + p.Set(core.Cyan) p.WriteString(h.Key) p.Reset() p.WriteString(": ") @@ -113,9 +112,9 @@ func printResponseHeaders(p *printer.Printer, resp *http.Response) { } } -func printBinaryWarning(p *printer.Printer) { - p.Set(printer.Bold) - p.Set(printer.Yellow) +func printBinaryWarning(p *core.Printer) { + p.Set(core.Bold) + p.Set(core.Yellow) p.WriteString("warning") p.Reset() p.WriteString(": the response body appears to be binary\n\n") @@ -123,14 +122,14 @@ func printBinaryWarning(p *printer.Printer) { p.Flush() } -func colorForStatus(code int) printer.Sequence { +func colorForStatus(code int) core.Sequence { switch { case code >= 200 && code < 300: - return printer.Green + return core.Green case code >= 300 && code < 400: - return printer.Yellow + return core.Yellow default: - return printer.Red + return core.Red } } diff --git a/internal/format/json.go b/internal/format/json.go index ba0e00d..bc61e1c 100644 --- a/internal/format/json.go +++ b/internal/format/json.go @@ -8,11 +8,11 @@ import ( "io" "strconv" - "github.com/ryanfowler/fetch/internal/printer" + "github.com/ryanfowler/fetch/internal/core" ) // FormatJSON formats the provided raw JSON data to the Printer. -func FormatJSON(buf []byte, p *printer.Printer) error { +func FormatJSON(buf []byte, p *core.Printer) error { err := formatJSON(bytes.NewReader(buf), p) if err != nil { p.Reset() @@ -20,7 +20,7 @@ func FormatJSON(buf []byte, p *printer.Printer) error { return err } -func formatJSON(r io.Reader, p *printer.Printer) error { +func formatJSON(r io.Reader, p *core.Printer) error { dec := json.NewDecoder(r) err := formatJSONValue(dec, p, 0) if err != nil { @@ -37,7 +37,7 @@ func formatJSON(r io.Reader, p *printer.Printer) error { return nil } -func formatJSONValue(dec *json.Decoder, p *printer.Printer, indent int) error { +func formatJSONValue(dec *json.Decoder, p *core.Printer, indent int) error { token, err := dec.Token() if err != nil { return err @@ -46,7 +46,7 @@ func formatJSONValue(dec *json.Decoder, p *printer.Printer, indent int) error { return formatJSONValueToken(dec, p, indent, token) } -func formatJSONValueToken(dec *json.Decoder, p *printer.Printer, indent int, token any) error { +func formatJSONValueToken(dec *json.Decoder, p *core.Printer, indent int, token any) error { switch t := token.(type) { case json.Delim: switch t { @@ -73,7 +73,7 @@ func formatJSONValueToken(dec *json.Decoder, p *printer.Printer, indent int, tok return nil } -func formatJSONObject(dec *json.Decoder, p *printer.Printer, indent int) error { +func formatJSONObject(dec *json.Decoder, p *core.Printer, indent int) error { p.WriteString("{") var hasFields bool @@ -113,7 +113,7 @@ func formatJSONObject(dec *json.Decoder, p *printer.Printer, indent int) error { } } -func formatJSONArray(dec *json.Decoder, p *printer.Printer, indent int) error { +func formatJSONArray(dec *json.Decoder, p *core.Printer, indent int) error { p.WriteString("[") var hasFields bool @@ -146,24 +146,24 @@ func formatJSONArray(dec *json.Decoder, p *printer.Printer, indent int) error { } } -func writeJSONKey(p *printer.Printer, s string) { +func writeJSONKey(p *core.Printer, s string) { p.WriteString("\"") - p.Set(printer.Blue) - p.Set(printer.Bold) + p.Set(core.Blue) + p.Set(core.Bold) escapeJSONString(p, s) p.Reset() p.WriteString("\": ") } -func writeJSONString(p *printer.Printer, s string) { +func writeJSONString(p *core.Printer, s string) { p.WriteString("\"") - p.Set(printer.Green) + p.Set(core.Green) escapeJSONString(p, s) p.Reset() p.WriteString("\"") } -func escapeJSONString(p *printer.Printer, s string) { +func escapeJSONString(p *core.Printer, s string) { for _, c := range s { switch c { case '\b': diff --git a/internal/format/ndjson.go b/internal/format/ndjson.go index 8118a9d..07e34f6 100644 --- a/internal/format/ndjson.go +++ b/internal/format/ndjson.go @@ -7,12 +7,12 @@ import ( "io" "strconv" - "github.com/ryanfowler/fetch/internal/printer" + "github.com/ryanfowler/fetch/internal/core" ) // FormatNDJSON streams the provided newline-delimited JSON to the Printer, // flushing every line. -func FormatNDJSON(r io.Reader, p *printer.Printer) error { +func FormatNDJSON(r io.Reader, p *core.Printer) error { dec := json.NewDecoder(r) for { err := formatNDJSONValue(dec, p) @@ -28,7 +28,7 @@ func FormatNDJSON(r io.Reader, p *printer.Printer) error { } } -func formatNDJSONValue(dec *json.Decoder, p *printer.Printer) error { +func formatNDJSONValue(dec *json.Decoder, p *core.Printer) error { token, err := dec.Token() if err != nil { return err @@ -37,7 +37,7 @@ func formatNDJSONValue(dec *json.Decoder, p *printer.Printer) error { return formatNDJSONValueToken(dec, p, token) } -func formatNDJSONValueToken(dec *json.Decoder, p *printer.Printer, token any) error { +func formatNDJSONValueToken(dec *json.Decoder, p *core.Printer, token any) error { switch t := token.(type) { case json.Delim: switch t { @@ -64,7 +64,7 @@ func formatNDJSONValueToken(dec *json.Decoder, p *printer.Printer, token any) er return nil } -func formatNDJSONObject(dec *json.Decoder, p *printer.Printer) error { +func formatNDJSONObject(dec *json.Decoder, p *core.Printer) error { p.WriteString("{") var hasFields bool @@ -102,7 +102,7 @@ func formatNDJSONObject(dec *json.Decoder, p *printer.Printer) error { } } -func formatNDJSONArray(dec *json.Decoder, p *printer.Printer) error { +func formatNDJSONArray(dec *json.Decoder, p *core.Printer) error { p.WriteString("[") var hasFields bool diff --git a/internal/format/sse.go b/internal/format/sse.go index 58cf55f..f8b2d72 100644 --- a/internal/format/sse.go +++ b/internal/format/sse.go @@ -9,12 +9,12 @@ import ( "iter" "strings" - "github.com/ryanfowler/fetch/internal/printer" + "github.com/ryanfowler/fetch/internal/core" ) // FormatEventStream formats the provided stream of server sent events to the // Printer, flushing after each event. -func FormatEventStream(r io.Reader, p *printer.Printer) error { +func FormatEventStream(r io.Reader, p *core.Printer) error { var written bool for ev, err := range streamEvents(r) { if err != nil { @@ -33,16 +33,16 @@ func FormatEventStream(r io.Reader, p *printer.Printer) error { return nil } -func writeEventStreamType(t string, p *printer.Printer) { +func writeEventStreamType(t string, p *core.Printer) { p.WriteString("[") - p.Set(printer.Bold) + p.Set(core.Bold) p.WriteString(t) p.Reset() p.WriteString("]\n") p.Flush() } -func writeEventStreamData(d string, p *printer.Printer) { +func writeEventStreamData(d string, p *core.Printer) { dec := json.NewDecoder(strings.NewReader(d)) if formatNDJSONValue(dec, p) == nil { // Ensure there are no more tokens in the event. diff --git a/internal/format/xml.go b/internal/format/xml.go index cde0c32..26db939 100644 --- a/internal/format/xml.go +++ b/internal/format/xml.go @@ -7,11 +7,11 @@ import ( "io" "unicode/utf8" - "github.com/ryanfowler/fetch/internal/printer" + "github.com/ryanfowler/fetch/internal/core" ) // FormatXML formats the provided XML to the Printer. -func FormatXML(buf []byte, w *printer.Printer) error { +func FormatXML(buf []byte, w *core.Printer) error { dec := xml.NewDecoder(bytes.NewReader(buf)) var stack []bool @@ -83,39 +83,39 @@ func FormatXML(buf []byte, w *printer.Printer) error { } } -func writeXMLTagName(p *printer.Printer, s string) { - p.Set(printer.Bold) - p.Set(printer.Blue) +func writeXMLTagName(p *core.Printer, s string) { + p.Set(core.Bold) + p.Set(core.Blue) escapeXMLString(p, s) p.Reset() } -func writeXMLAttrName(p *printer.Printer, s string) { - p.Set(printer.Cyan) +func writeXMLAttrName(p *core.Printer, s string) { + p.Set(core.Cyan) escapeXMLString(p, s) p.Reset() } -func writeXMLAttrVal(p *printer.Printer, s string) { - p.Set(printer.Green) +func writeXMLAttrVal(p *core.Printer, s string) { + p.Set(core.Green) escapeXMLString(p, s) p.Reset() } -func writeXMLText(p *printer.Printer, t []byte) { - p.Set(printer.Green) +func writeXMLText(p *core.Printer, t []byte) { + p.Set(core.Green) escapeXMLString(p, string(t)) p.Reset() } -func writeXMLDirective(p *printer.Printer, b []byte) { - p.Set(printer.Cyan) +func writeXMLDirective(p *core.Printer, b []byte) { + p.Set(core.Cyan) p.Write(b) p.Reset() } -func writeXMLComment(p *printer.Printer, b []byte) { - p.Set(printer.Dim) +func writeXMLComment(p *core.Printer, b []byte) { + p.Set(core.Dim) p.Write(b) p.Reset() } @@ -123,14 +123,14 @@ func writeXMLComment(p *printer.Printer, b []byte) { var equalChar = []byte("=") var quoteChar = []byte("\"") -func writeXMLProcInst(p *printer.Printer, inst []byte) { +func writeXMLProcInst(p *core.Printer, inst []byte) { // This isn't perfect, but should work in most cases. This will break // when a field contains whitespace. for pair := range bytes.FieldsSeq(inst) { p.WriteString(" ") key, val, ok := bytes.Cut(pair, equalChar) - p.Set(printer.Cyan) + p.Set(core.Cyan) p.Write(key) p.Reset() if !ok { @@ -143,14 +143,14 @@ func writeXMLProcInst(p *printer.Printer, inst []byte) { p.Write(quoteChar) val, ok = bytes.CutSuffix(val, quoteChar) if ok { - p.Set(printer.Green) + p.Set(core.Green) p.Write(val) p.Reset() p.Write(quoteChar) continue } } - p.Set(printer.Cyan) + p.Set(core.Cyan) p.Write(val) p.Reset() } @@ -158,7 +158,7 @@ func writeXMLProcInst(p *printer.Printer, inst []byte) { // Mostly taken from the Go encoding/xml package in the standard library: // https://cs.opensource.google/go/go/+/refs/tags/go1.24.0:src/encoding/xml/xml.go;l=1964-1999 -func escapeXMLString(p *printer.Printer, s string) { +func escapeXMLString(p *core.Printer, s string) { var esc string var last int for i := 0; i < len(s); { diff --git a/internal/update/update.go b/internal/update/update.go index 1741446..39b888c 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -16,19 +16,18 @@ import ( "github.com/ryanfowler/fetch/internal/client" "github.com/ryanfowler/fetch/internal/core" - "github.com/ryanfowler/fetch/internal/printer" ) // Update checks the API for the latest fetch version and upgrades the current // executable in-place, returning the exit code to use. -func Update(ctx context.Context, p *printer.Printer, timeout time.Duration, silent bool) bool { +func Update(ctx context.Context, p *core.Printer, timeout time.Duration, silent bool) bool { err := update(ctx, p, timeout, silent) if err == nil { return true } - p.Set(printer.Bold) - p.Set(printer.Red) + p.Set(core.Bold) + p.Set(core.Red) p.WriteString("error") p.Reset() p.WriteString(": ") @@ -38,7 +37,7 @@ func Update(ctx context.Context, p *printer.Printer, timeout time.Duration, sile return false } -func update(ctx context.Context, p *printer.Printer, timeout time.Duration, silent bool) error { +func update(ctx context.Context, p *core.Printer, timeout time.Duration, silent bool) error { c := client.NewClient(client.ClientConfig{}) if timeout > 0 { @@ -215,13 +214,13 @@ func getFetchFilename() string { return name } -func writeInfo(p *printer.Printer, silent bool, s string) { +func writeInfo(p *core.Printer, silent bool, s string) { if silent { return } - p.Set(printer.Bold) - p.Set(printer.Green) + p.Set(core.Bold) + p.Set(core.Green) p.WriteString("info") p.Reset() p.WriteString(": ") diff --git a/main.go b/main.go index b093c72..d1c8be8 100644 --- a/main.go +++ b/main.go @@ -9,11 +9,11 @@ import ( "syscall" "github.com/ryanfowler/fetch/internal/cli" + "github.com/ryanfowler/fetch/internal/config" "github.com/ryanfowler/fetch/internal/core" "github.com/ryanfowler/fetch/internal/fetch" "github.com/ryanfowler/fetch/internal/format" "github.com/ryanfowler/fetch/internal/multipart" - "github.com/ryanfowler/fetch/internal/printer" "github.com/ryanfowler/fetch/internal/update" ) @@ -30,12 +30,29 @@ func main() { // Parse the CLI args. app, err := cli.Parse(os.Args[1:]) if err != nil { - p := printer.NewHandle(app.Color).Stderr() + p := core.NewHandle(app.Cfg.Color).Stderr() writeCLIErr(p, err) os.Exit(1) } - printerHandle := printer.NewHandle(app.Color) + // Parse any config file, and merge with it. + file, err := config.GetFile(app.ConfigPath) + if err != nil { + p := core.NewHandle(app.Cfg.Color).Stderr() + writeCLIErr(p, err) + os.Exit(1) + } + if file != nil { + if app.URL != nil { + hostCfg, ok := file.Hosts[app.URL.Hostname()] + if ok { + app.Cfg.Merge(hostCfg) + } + } + app.Cfg.Merge(file.Global) + } + + printerHandle := core.NewHandle(app.Cfg.Color) verbosity := getVerbosity(app) // Print help to stdout. @@ -56,7 +73,7 @@ func main() { if app.BuildInfo { p := printerHandle.Stdout() info := core.GetBuildInfo() - if app.Format != core.FormatOff { + if app.Cfg.Format != core.FormatOff { format.FormatJSON(info, p) } else { p.Write(info) @@ -68,7 +85,8 @@ func main() { // Attempt to update the current executable. if app.Update { p := printerHandle.Stderr() - ok := update.Update(ctx, p, app.Timeout, verbosity == core.VSilent) + timeout := getValue(app.Cfg.Timeout) + ok := update.Update(ctx, p, timeout, verbosity == core.VSilent) if ok { os.Exit(0) } @@ -84,19 +102,19 @@ func main() { // Make the HTTP request using the parsed configuration. req := fetch.Request{ - DNSServer: app.DNSServer, + DNSServer: app.Cfg.DNSServer, DryRun: app.DryRun, Edit: app.Edit, - Format: app.Format, - HTTP: app.HTTP, - IgnoreStatus: app.IgnoreStatus, - Insecure: app.Insecure, - NoEncode: app.NoEncode, - NoPager: app.NoPager, + Format: app.Cfg.Format, + HTTP: app.Cfg.HTTP, + IgnoreStatus: getValue(app.Cfg.IgnoreStatus), + Insecure: getValue(app.Cfg.Insecure), + NoEncode: getValue(app.Cfg.NoEncode), + NoPager: getValue(app.Cfg.NoPager), Output: app.Output, PrinterHandle: printerHandle, - Redirects: app.Redirects, - TLS: app.TLS, + Redirects: app.Cfg.Redirects, + TLS: getValue(app.Cfg.TLS), Verbosity: verbosity, Method: app.Method, @@ -104,26 +122,34 @@ func main() { Body: app.Data, Form: app.Form, Multipart: multipart.NewMultipart(app.Multipart), - Headers: app.Headers, - QueryParams: app.QueryParams, + Headers: app.Cfg.Headers, + QueryParams: app.Cfg.QueryParams, AWSSigv4: app.AWSSigv4, Basic: app.Basic, Bearer: app.Bearer, JSON: app.JSON, XML: app.XML, - Proxy: app.Proxy, - Timeout: app.Timeout, + Proxy: app.Cfg.Proxy, + Timeout: getValue(app.Cfg.Timeout), } status := fetch.Fetch(ctx, &req) os.Exit(status) } +func getValue[T any](v *T) T { + if v == nil { + var t T + return t + } + return *v +} + // getVerbosity returns the Verbosity level based on the app configuration. func getVerbosity(app *cli.App) core.Verbosity { - if app.Silent { + if getValue(app.Cfg.Silent) { return core.VSilent } - switch app.Verbose { + switch getValue(app.Cfg.Verbosity) { case 0: return core.VNormal case 1: @@ -134,14 +160,14 @@ func getVerbosity(app *cli.App) core.Verbosity { } // writeCLIErr writes the provided CLI error to the Printer. -func writeCLIErr(p *printer.Printer, err error) { - p.Set(printer.Bold) - p.Set(printer.Red) +func writeCLIErr(p *core.Printer, err error) { + p.Set(core.Bold) + p.Set(core.Red) p.WriteString("error") p.Reset() p.WriteString(": ") - if pt, ok := err.(printer.PrinterTo); ok { + if pt, ok := err.(core.PrinterTo); ok { pt.PrintTo(p) } else { p.WriteString(err.Error()) @@ -149,7 +175,7 @@ func writeCLIErr(p *printer.Printer, err error) { p.WriteString("\n\nFor more information, try '") - p.Set(printer.Bold) + p.Set(core.Bold) p.WriteString("--help") p.Reset()