Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions vertex/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/api/transport"
thttp "google.golang.org/api/transport/http"

"github.com/anthropics/anthropic-sdk-go/internal/requestconfig"
sdkoption "github.com/anthropics/anthropic-sdk-go/option"
Expand Down Expand Up @@ -39,10 +40,6 @@ func WithGoogleAuth(ctx context.Context, region string, projectID string, scopes
// WithCredentials returns a request option which uses the provided credentials for Google Vertex AI and registers middleware that
// intercepts request to the Messages API.
func WithCredentials(ctx context.Context, region string, projectID string, creds *google.Credentials) sdkoption.RequestOption {
client, _, err := transport.NewHTTPClient(ctx, option.WithTokenSource(creds.TokenSource))
if err != nil {
panic(fmt.Errorf("failed to create HTTP client: %v", err))
}
middleware := vertexMiddleware(region, projectID)

var baseURL string
Expand All @@ -53,6 +50,24 @@ func WithCredentials(ctx context.Context, region string, projectID string, creds
}

return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
getClient := func() (*http.Client, error) {
if rc.HTTPClient == nil || rc.HTTPClient.Transport == nil {
c, _, err := transport.NewHTTPClient(ctx, option.WithTokenSource(creds.TokenSource))
return c, err
}
transport, err := thttp.NewTransport(
ctx,
rc.HTTPClient.Transport,
option.WithTokenSource(creds.TokenSource),
)
return &http.Client{Transport: transport}, err
}

client, err := getClient()
if err != nil {
return fmt.Errorf("failed to create http client: %v", err)
}

return rc.Apply(
sdkoption.WithBaseURL(baseURL),
sdkoption.WithMiddleware(middleware),
Expand All @@ -68,13 +83,16 @@ func vertexMiddleware(region, projectID string) sdkoption.Middleware {
if err != nil {
return nil, err
}
r.Body.Close()
if err := r.Body.Close(); err != nil {
return nil, err
}

if !gjson.GetBytes(body, "anthropic_version").Exists() {
body, _ = sjson.SetBytes(body, "anthropic_version", DefaultVersion)
}

if r.URL.Path == "/v1/messages" && r.Method == http.MethodPost {
switch {
case r.URL.Path == "/v1/messages" && r.Method == http.MethodPost:
if projectID == "" {
return nil, fmt.Errorf("no projectId was given and it could not be resolved from credentials")
}
Expand All @@ -90,14 +108,16 @@ func vertexMiddleware(region, projectID string) sdkoption.Middleware {
}

r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", projectID, region, model, specifier)
}

if r.URL.Path == "/v1/messages/count_tokens" && r.Method == http.MethodPost {
case r.URL.Path == "/v1/messages/count_tokens" && r.Method == http.MethodPost:
if projectID == "" {
return nil, fmt.Errorf("no projectId was given and it could not be resolved from credentials")
}

r.URL.Path = fmt.Sprintf("/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", projectID, region)

default:
return nil, fmt.Errorf("vertex middleware does not support %s %s", r.Method, r.URL.Path)
Comment on lines +119 to +120
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like if we miss any path, we probably forgot something in the code. A good idea to error here?

}

reader := bytes.NewReader(body)
Expand Down