Skip to content

Commit bb06f93

Browse files
committed
fix: anthopic compression
1 parent b585e81 commit bb06f93

File tree

3 files changed

+131
-29
lines changed

3 files changed

+131
-29
lines changed

internal/domain/services/chat_service.go

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,13 @@ func (s *chatService) SendMessage(ctx context.Context, id string, message *entit
329329
Role: "system",
330330
Content: agent.FullSystemPrompt(),
331331
}
332-
systemTokens := estimateTokens(systemMessage)
332+
333+
// Use provider-specific token estimation for system message
334+
systemEstimateFunc := estimateTokens
335+
if provider.Type == entities.ProviderAnthropic {
336+
systemEstimateFunc = estimateAnthropicTokens
337+
}
338+
systemTokens := systemEstimateFunc(systemMessage)
333339
if systemTokens > tokenLimit {
334340
return nil, errors.InternalErrorf("system prompt too large for the context window")
335341
}
@@ -348,8 +354,15 @@ func (s *chatService) SendMessage(ctx context.Context, id string, message *entit
348354

349355
// Check if we need to compress messages
350356
totalMessageTokens := systemTokens
357+
// Use provider-specific token estimation
358+
tokenEstimator := estimateTokens
359+
if provider.Type == entities.ProviderAnthropic {
360+
tokenEstimator = estimateAnthropicTokens
361+
s.logger.Debug("Using Anthropic-specific token estimation")
362+
}
363+
351364
for i := range chat.Messages {
352-
totalMessageTokens += estimateTokens(&chat.Messages[i])
365+
totalMessageTokens += tokenEstimator(&chat.Messages[i])
353366
}
354367

355368
s.logger.Debug("Total message tokens: ", zap.Float64("total_message_tokens", float64(totalMessageTokens)), zap.Float64("compression_threshold", compressionThreshold))
@@ -436,25 +449,35 @@ func (s *chatService) SendMessage(ctx context.Context, id string, message *entit
436449
return nil, errors.InternalErrorf("no messages to send")
437450
}
438451

452+
// Use provider-specific token estimation for pre-flight check
453+
var estimateFunc func(*entities.Message) int = estimateTokens
454+
if provider.Type == entities.ProviderAnthropic {
455+
estimateFunc = estimateAnthropicTokens
456+
}
457+
439458
totalTokens := 0
440459
for _, msg := range messagesToSend {
441460
if msg == nil {
442461
s.logger.Error("Nil message found in messagesToSend")
443462
continue
444463
}
445-
totalTokens += estimateTokens(msg)
464+
totalTokens += estimateFunc(msg)
446465
}
447466

448467
// More aggressive pre-flight compression at 75% to prevent API errors
449468
preFlightLimit := int(float64(tokenLimit) * 0.75)
450469
if totalTokens > preFlightLimit {
451-
s.logger.Warn("Messages exceed pre-flight limit, attempting compression", zap.Int("total_tokens", totalTokens), zap.Int("pre_flight_limit", preFlightLimit))
470+
s.logger.Warn("Messages exceed pre-flight limit, attempting compression",
471+
zap.Int("total_tokens", totalTokens),
472+
zap.Int("pre_flight_limit", preFlightLimit),
473+
zap.Int("token_limit", tokenLimit))
452474

453-
// Try compression first
454-
compressedMessages, originalMessagesReplaced, err := s.compressMessages(ctx, chat, model, provider, resolvedAPIKey, tokenLimit)
475+
// Try compression with the pre-flight limit as target
476+
compressedMessages, originalMessagesReplaced, err := s.compressMessages(ctx, chat, model, provider, resolvedAPIKey, preFlightLimit)
455477
if err != nil {
456478
s.logger.Warn("Pre-flight compression failed, falling back to trimming", zap.Error(err))
457-
messagesToSend = s.trimMessagesToLimit(messagesToSend, preFlightLimit)
479+
messagesToSend = s.trimMessagesToLimit(messagesToSend, preFlightLimit, provider.Type)
480+
s.logger.Info("Pre-flight trimming applied", zap.Int("original_count", len(messagesToSend)), zap.Int("trimmed_count", len(messagesToSend)))
458481
} else {
459482
if originalMessagesReplaced {
460483
if err := s.chatRepo.UpdateChat(ctx, chat); err != nil {
@@ -463,7 +486,10 @@ func (s *chatService) SendMessage(ctx context.Context, id string, message *entit
463486
}
464487
// Replace messagesToSend with compressed version
465488
messagesToSend = append([]*entities.Message{systemMessage}, compressedMessages...)
466-
s.logger.Info("Pre-flight compression successful", zap.Int("original_count", len(chat.Messages)), zap.Int("compressed_count", len(compressedMessages)))
489+
s.logger.Info("Pre-flight compression successful",
490+
zap.Int("original_count", len(chat.Messages)),
491+
zap.Int("compressed_count", len(compressedMessages)),
492+
zap.Int("target_tokens", preFlightLimit))
467493
}
468494
}
469495

@@ -518,7 +544,7 @@ func (s *chatService) SendMessage(ctx context.Context, id string, message *entit
518544
compressedMessages, originalMessagesReplaced, err := s.compressMessages(ctx, chat, model, provider, resolvedAPIKey, compressionTarget)
519545
if err != nil {
520546
s.logger.Warn("Failed progressive compression, using fallback trimming", zap.Error(err), zap.Int("target_tokens", compressionTarget))
521-
compressedMessages = s.trimMessagesToLimit(messagesToSend, compressionTarget)
547+
compressedMessages = s.trimMessagesToLimit(messagesToSend, compressionTarget, provider.Type)
522548
originalMessagesReplaced = false
523549
} else if originalMessagesReplaced {
524550
if err := s.chatRepo.UpdateChat(ctx, chat); err != nil {
@@ -673,6 +699,28 @@ func estimateTokens(msg *entities.Message) int {
673699
return len(tokens)
674700
}
675701

702+
// estimateAnthropicTokens provides a rough token estimate for Anthropic models
703+
// Anthropic uses different tokenization than OpenAI, approximately 4 chars per token
704+
func estimateAnthropicTokens(msg *entities.Message) int {
705+
if msg == nil {
706+
return 0
707+
}
708+
709+
// Rough approximation: ~4 characters per token for English text
710+
charCount := len(msg.Content)
711+
tokenEstimate := charCount / 4
712+
713+
// Add some padding for safety and to account for tokenization differences
714+
tokenEstimate = int(float64(tokenEstimate) * 1.1)
715+
716+
// Minimum of 1 token
717+
if tokenEstimate < 1 {
718+
tokenEstimate = 1
719+
}
720+
721+
return tokenEstimate
722+
}
723+
676724
// isContextError checks if an error is related to context window limits
677725
func isContextError(err error) bool {
678726
if err == nil {
@@ -1132,12 +1180,37 @@ func (s *chatService) compressMessages(
11321180
apiKey string,
11331181
tokenLimit int,
11341182
) ([]*entities.Message, bool, error) {
1135-
// Calculate how many messages to summarize (approx 50% of older messages)
1136-
numMessagesToKeep := int(float64(len(chat.Messages)) * 0.5)
1183+
// Use provider-specific token estimation
1184+
var estimateFunc func(*entities.Message) int = estimateTokens
1185+
if provider.Type == entities.ProviderAnthropic {
1186+
estimateFunc = estimateAnthropicTokens
1187+
}
1188+
// Calculate current total tokens to determine compression aggressiveness
1189+
currentTokens := 0
1190+
for _, msg := range chat.Messages {
1191+
currentTokens += estimateFunc(&msg)
1192+
}
1193+
1194+
// If we're way over the limit, be more aggressive with compression
1195+
compressionRatio := 0.5 // Default: keep 50% of messages
1196+
if currentTokens > tokenLimit*2 {
1197+
compressionRatio = 0.3 // If 2x over limit, keep only 30%
1198+
} else if currentTokens > int(float64(tokenLimit)*1.5) {
1199+
compressionRatio = 0.4 // If 1.5x over limit, keep 40%
1200+
}
1201+
1202+
numMessagesToKeep := int(float64(len(chat.Messages)) * compressionRatio)
11371203
if numMessagesToKeep < 1 {
11381204
numMessagesToKeep = 1 // Always keep at least the most recent message
11391205
}
11401206

1207+
s.logger.Debug("Compression calculation",
1208+
zap.Int("current_tokens", currentTokens),
1209+
zap.Int("token_limit", tokenLimit),
1210+
zap.Float64("compression_ratio", compressionRatio),
1211+
zap.Int("messages_total", len(chat.Messages)),
1212+
zap.Int("messages_to_keep", numMessagesToKeep))
1213+
11411214
// Tentative split point
11421215
summarizeEndIdx := len(chat.Messages) - numMessagesToKeep
11431216
if summarizeEndIdx < 1 {
@@ -1235,13 +1308,13 @@ func (s *chatService) compressMessages(
12351308
chat.Messages = append([]entities.Message{*summaryMsg}, recentMessagesToKeep...)
12361309

12371310
// Verify we're not exceeding token limit
1238-
currentTokens := estimateTokens(summaryMsg)
1311+
currentTokens = estimateFunc(summaryMsg)
12391312
var finalMessages []*entities.Message
12401313
finalMessages = append(finalMessages, summaryMsg)
12411314

12421315
// Add as many of the recent messages as possible within token limit
12431316
for i := range recentMessagesToKeep {
1244-
msgTokens := estimateTokens(&recentMessagesToKeep[i])
1317+
msgTokens := estimateFunc(&recentMessagesToKeep[i])
12451318
if currentTokens+msgTokens > tokenLimit {
12461319
break
12471320
}
@@ -1253,7 +1326,7 @@ func (s *chatService) compressMessages(
12531326
}
12541327

12551328
// trimMessagesToLimit removes oldest messages until under token limit
1256-
func (s *chatService) trimMessagesToLimit(messages []*entities.Message, maxTokens int) []*entities.Message {
1329+
func (s *chatService) trimMessagesToLimit(messages []*entities.Message, maxTokens int, providerType entities.ProviderType) []*entities.Message {
12571330
if messages == nil || len(messages) <= 1 {
12581331
return messages // Always keep at least system message
12591332
}
@@ -1263,8 +1336,14 @@ func (s *chatService) trimMessagesToLimit(messages []*entities.Message, maxToken
12631336
return messages
12641337
}
12651338

1339+
// Use provider-specific token estimation
1340+
estimateFunc := estimateTokens
1341+
if providerType == entities.ProviderAnthropic {
1342+
estimateFunc = estimateAnthropicTokens
1343+
}
1344+
12661345
var result []*entities.Message
1267-
totalTokens := estimateTokens(messages[0]) // Always include system message
1346+
totalTokens := estimateFunc(messages[0]) // Always include system message
12681347

12691348
result = append(result, messages[0])
12701349

@@ -1274,7 +1353,7 @@ func (s *chatService) trimMessagesToLimit(messages []*entities.Message, maxToken
12741353
s.logger.Warn("Skipping nil message in trimMessagesToLimit", zap.Int("index", i))
12751354
continue
12761355
}
1277-
msgTokens := estimateTokens(messages[i])
1356+
msgTokens := estimateFunc(messages[i])
12781357
if totalTokens+msgTokens > maxTokens {
12791358
break
12801359
}

internal/domain/services/chat_service_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestTrimMessagesToLimit(t *testing.T) {
2727
}
2828

2929
// Test with a very low token limit to force trimming
30-
result := cs.trimMessagesToLimit(messages, 10) // Very low limit
30+
result := cs.trimMessagesToLimit(messages, 10, entities.ProviderOpenAI) // Very low limit
3131

3232
// Should always keep at least the system message
3333
if len(result) == 0 {
@@ -52,14 +52,14 @@ func TestTrimMessagesToLimitNilInput(t *testing.T) {
5252
cs := &chatService{logger: logger}
5353

5454
// Test with nil input
55-
result := cs.trimMessagesToLimit(nil, 100)
55+
result := cs.trimMessagesToLimit(nil, 100, entities.ProviderOpenAI)
5656
if result != nil {
5757
t.Error("trimMessagesToLimit should return nil for nil input")
5858
}
5959

6060
// Test with empty slice
6161
empty := []*entities.Message{}
62-
result = cs.trimMessagesToLimit(empty, 100)
62+
result = cs.trimMessagesToLimit(empty, 100, entities.ProviderOpenAI)
6363
if len(result) != 0 {
6464
t.Error("trimMessagesToLimit should return empty slice for empty input")
6565
}

internal/impl/integrations/anthropic.go

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,9 @@ func (m *AnthropicIntegration) GenerateResponse(ctx context.Context, messages []
207207
zap.Int("status_code", resp.StatusCode),
208208
zap.String("body", string(body)))
209209

210-
// Check for context window errors
211-
if resp.StatusCode == http.StatusBadRequest {
212-
if contextErr := m.parseAnthropicContextError(body); contextErr != nil {
213-
return nil, contextErr
214-
}
210+
// Check for context window errors on any error status
211+
if contextErr := m.parseAnthropicContextError(body); contextErr != nil {
212+
return nil, contextErr
215213
}
216214

217215
return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body))
@@ -498,6 +496,7 @@ func (m *AnthropicIntegration) GetLastUsage() (*entities.Usage, error) {
498496

499497
// parseAnthropicContextError checks if the error response is related to context window limits
500498
func (m *AnthropicIntegration) parseAnthropicContextError(body []byte) error {
499+
// Try to parse as structured error first
501500
var errorResp struct {
502501
Type string `json:"type"`
503502
Error struct {
@@ -506,18 +505,42 @@ func (m *AnthropicIntegration) parseAnthropicContextError(body []byte) error {
506505
} `json:"error"`
507506
}
508507

509-
if err := json.Unmarshal(body, &errorResp); err != nil {
510-
return nil // Not a valid JSON error response, return nil to let caller handle
511-
}
508+
if err := json.Unmarshal(body, &errorResp); err == nil {
509+
m.logger.Debug("Parsed structured Anthropic error", zap.String("type", errorResp.Type), zap.String("error_type", errorResp.Error.Type), zap.String("message", errorResp.Error.Message))
512510

513-
if errorResp.Type == "error" && errorResp.Error.Type == "invalid_request_error" {
511+
// Check for context-related errors regardless of error type
514512
errMsg := strings.ToLower(errorResp.Error.Message)
515513
if strings.Contains(errMsg, "too long") ||
516514
strings.Contains(errMsg, "token limit") ||
517515
strings.Contains(errMsg, "context") ||
518-
strings.Contains(errMsg, "maximum length") {
516+
strings.Contains(errMsg, "maximum length") ||
517+
strings.Contains(errMsg, "context_length_exceeded") ||
518+
strings.Contains(errMsg, "prompt is too long") ||
519+
strings.Contains(errMsg, "input too long") {
519520
return errors.ContextWindowErrorf("Anthropic context window exceeded: %s", errorResp.Error.Message)
520521
}
522+
523+
// Also check for system errors that might indicate context issues
524+
if errorResp.Type == "error" && (errorResp.Error.Type == "system_error" || errorResp.Error.Type == "internal_error") {
525+
if strings.Contains(errMsg, "context") || strings.Contains(errMsg, "token") || strings.Contains(errMsg, "length") {
526+
return errors.ContextWindowErrorf("Anthropic system error (likely context): %s", errorResp.Error.Message)
527+
}
528+
}
529+
} else {
530+
// If not structured JSON, check if it's a raw error message that contains context-related text
531+
bodyStr := strings.ToLower(string(body))
532+
m.logger.Debug("Checking raw Anthropic error for context issues", zap.String("body", bodyStr))
533+
534+
if strings.Contains(bodyStr, "too long") ||
535+
strings.Contains(bodyStr, "token limit") ||
536+
strings.Contains(bodyStr, "context") ||
537+
strings.Contains(bodyStr, "maximum length") ||
538+
strings.Contains(bodyStr, "context_length_exceeded") ||
539+
strings.Contains(bodyStr, "prompt is too long") ||
540+
strings.Contains(bodyStr, "input too long") ||
541+
strings.Contains(bodyStr, "context window") {
542+
return errors.ContextWindowErrorf("Anthropic context window exceeded (raw error): %s", string(body))
543+
}
521544
}
522545

523546
return nil

0 commit comments

Comments
 (0)