diff --git a/common/image/image.go b/common/image/image.go index beebd0c66a..1c0fc8e828 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -3,12 +3,15 @@ package image import ( "bytes" "encoding/base64" + "fmt" "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/network" "image" _ "image/gif" _ "image/jpeg" _ "image/png" - "net/http" + "net" + "net/url" "regexp" "strings" "sync" @@ -19,6 +22,34 @@ import ( // Regex to match data URL pattern var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) +// validateImageUrl checks that a URL does not resolve to a private/reserved IP +// to prevent Server-Side Request Forgery (SSRF) attacks. +func validateImageUrl(rawUrl string) error { + parsedUrl, err := url.Parse(rawUrl) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + // Only allow http and https schemes + if parsedUrl.Scheme != "http" && parsedUrl.Scheme != "https" { + return fmt.Errorf("unsupported URL scheme: %s", parsedUrl.Scheme) + } + host := parsedUrl.Hostname() + if host == "" { + return fmt.Errorf("empty host in URL") + } + // Resolve hostname and check against private IP ranges + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("failed to resolve host: %w", err) + } + for _, ip := range ips { + if network.IsPrivateIP(ip) { + return fmt.Errorf("URL resolves to a private/reserved IP address: %s", ip) + } + } + return nil +} + func IsImageUrl(url string) (bool, error) { resp, err := client.UserContentRequestHTTPClient.Head(url) if err != nil { @@ -57,11 +88,16 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { return } + // Validate URL to prevent SSRF attacks + if err = validateImageUrl(url); err != nil { + return + } + isImage, err := IsImageUrl(url) if !isImage { return } - resp, err := http.Get(url) + resp, err := client.UserContentRequestHTTPClient.Get(url) if err != nil { return } diff --git a/common/network/ip.go b/common/network/ip.go index 0fbe5e6f63..511f9bcb76 100644 --- a/common/network/ip.go +++ b/common/network/ip.go @@ -50,3 +50,37 @@ func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { } return false } + +// IsPrivateIP checks if an IP address is in a private or reserved range +// (loopback, link-local, RFC 1918, IPv6 unique local, etc.) +func IsPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + if ip.IsUnspecified() { + return true + } + // IPv4 private ranges (RFC 1918) + privateRanges := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "127.0.0.0/8", + "169.254.0.0/16", + "0.0.0.0/8", + } + for _, cidr := range privateRanges { + _, network, _ := net.ParseCIDR(cidr) + if network.Contains(ip) { + return true + } + } + // IPv6 unique local and loopback + if ip.To4() == nil { + _, network, _ := net.ParseCIDR("fc00::/7") + if network.Contains(ip) { + return true + } + } + return false +}