Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions common/image/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package image
import (
"bytes"
"encoding/base64"
"fmt"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/network"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"net/http"
"net"
"net/url"
"regexp"
"strings"
"sync"
Expand All @@ -19,6 +22,34 @@ import (
// Regex to match data URL pattern
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)

// validateImageUrl checks that a URL does not resolve to a private/reserved IP
// to prevent Server-Side Request Forgery (SSRF) attacks.
func validateImageUrl(rawUrl string) error {
parsedUrl, err := url.Parse(rawUrl)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
// Only allow http and https schemes
if parsedUrl.Scheme != "http" && parsedUrl.Scheme != "https" {
return fmt.Errorf("unsupported URL scheme: %s", parsedUrl.Scheme)
}
host := parsedUrl.Hostname()
if host == "" {
return fmt.Errorf("empty host in URL")
}
// Resolve hostname and check against private IP ranges
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("failed to resolve host: %w", err)
}
for _, ip := range ips {
if network.IsPrivateIP(ip) {
return fmt.Errorf("URL resolves to a private/reserved IP address: %s", ip)
}
}
return nil
}

func IsImageUrl(url string) (bool, error) {
resp, err := client.UserContentRequestHTTPClient.Head(url)
if err != nil {
Expand Down Expand Up @@ -57,11 +88,16 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
return
}

// Validate URL to prevent SSRF attacks
if err = validateImageUrl(url); err != nil {
return
}

isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
resp, err := client.UserContentRequestHTTPClient.Get(url)
if err != nil {
return
}
Expand Down
34 changes: 34 additions & 0 deletions common/network/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,37 @@ func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
}
return false
}

// IsPrivateIP checks if an IP address is in a private or reserved range
// (loopback, link-local, RFC 1918, IPv6 unique local, etc.)
func IsPrivateIP(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
if ip.IsUnspecified() {
return true
}
// IPv4 private ranges (RFC 1918)
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"169.254.0.0/16",
"0.0.0.0/8",
}
for _, cidr := range privateRanges {
_, network, _ := net.ParseCIDR(cidr)
if network.Contains(ip) {
return true
}
}
// IPv6 unique local and loopback
if ip.To4() == nil {
_, network, _ := net.ParseCIDR("fc00::/7")
if network.Contains(ip) {
return true
}
}
return false
}