diff --git a/vertex/vertex.go b/vertex/vertex.go index fb40806a..f5706194 100644 --- a/vertex/vertex.go +++ b/vertex/vertex.go @@ -10,6 +10,7 @@ import ( "golang.org/x/oauth2/google" "google.golang.org/api/option" "google.golang.org/api/transport" + thttp "google.golang.org/api/transport/http" "github.com/anthropics/anthropic-sdk-go/internal/requestconfig" sdkoption "github.com/anthropics/anthropic-sdk-go/option" @@ -39,10 +40,6 @@ func WithGoogleAuth(ctx context.Context, region string, projectID string, scopes // WithCredentials returns a request option which uses the provided credentials for Google Vertex AI and registers middleware that // intercepts request to the Messages API. func WithCredentials(ctx context.Context, region string, projectID string, creds *google.Credentials) sdkoption.RequestOption { - client, _, err := transport.NewHTTPClient(ctx, option.WithTokenSource(creds.TokenSource)) - if err != nil { - panic(fmt.Errorf("failed to create HTTP client: %v", err)) - } middleware := vertexMiddleware(region, projectID) var baseURL string @@ -53,6 +50,24 @@ func WithCredentials(ctx context.Context, region string, projectID string, creds } return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error { + getClient := func() (*http.Client, error) { + if rc.HTTPClient == nil || rc.HTTPClient.Transport == nil { + c, _, err := transport.NewHTTPClient(ctx, option.WithTokenSource(creds.TokenSource)) + return c, err + } + transport, err := thttp.NewTransport( + ctx, + rc.HTTPClient.Transport, + option.WithTokenSource(creds.TokenSource), + ) + return &http.Client{Transport: transport}, err + } + + client, err := getClient() + if err != nil { + return fmt.Errorf("failed to create http client: %v", err) + } + return rc.Apply( sdkoption.WithBaseURL(baseURL), sdkoption.WithMiddleware(middleware), @@ -68,13 +83,16 @@ func vertexMiddleware(region, projectID string) sdkoption.Middleware { if err != nil { return nil, err } - r.Body.Close() + if err := r.Body.Close(); err != nil { + return nil, err + } if !gjson.GetBytes(body, "anthropic_version").Exists() { body, _ = sjson.SetBytes(body, "anthropic_version", DefaultVersion) } - if r.URL.Path == "/v1/messages" && r.Method == http.MethodPost { + switch { + case r.URL.Path == "/v1/messages" && r.Method == http.MethodPost: if projectID == "" { return nil, fmt.Errorf("no projectId was given and it could not be resolved from credentials") } @@ -90,14 +108,16 @@ func vertexMiddleware(region, projectID string) sdkoption.Middleware { } r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", projectID, region, model, specifier) - } - if r.URL.Path == "/v1/messages/count_tokens" && r.Method == http.MethodPost { + case r.URL.Path == "/v1/messages/count_tokens" && r.Method == http.MethodPost: if projectID == "" { return nil, fmt.Errorf("no projectId was given and it could not be resolved from credentials") } r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", projectID, region) + + default: + return nil, fmt.Errorf("vertex middleware does not support %s %s", r.Method, r.URL.Path) } reader := bytes.NewReader(body)