diff --git a/internal/proxy/circuit_breaker.go b/internal/proxy/circuit_breaker.go new file mode 100644 index 00000000..6aa953ae --- /dev/null +++ b/internal/proxy/circuit_breaker.go @@ -0,0 +1,140 @@ +package proxy + +import ( + "fmt" + "log/slog" + "sync" + "time" +) + +const ( + // defaultBreakerThreshold is the number of consecutive fetch failures that + // opens the circuit for a provider. + defaultBreakerThreshold = 5 + + // defaultBreakerResetTimeout is how long the circuit stays open before + // allowing a single probe request through. + // + // 60 seconds is intentionally longer than negativeCacheTTL (30 seconds). + // The negative cache provides short-term backoff after individual failures. + // The circuit breaker provides medium-term protection after repeated failures. + // If both timeouts were equal, the negCache entry would expire and re-trigger + // a new fetch attempt just as the circuit was probing, interleaving two + // independent backoff mechanisms in a confusing way. + defaultBreakerResetTimeout = 60 * time.Second +) + +type cbState int + +const ( + cbClosed cbState = iota // normal; requests flow through + cbOpen // failing fast; no upstream calls + cbHalfOpen // one probe allowed; waiting for outcome +) + +func (s cbState) String() string { + switch s { + case cbClosed: + return "closed" + case cbOpen: + return "open" + case cbHalfOpen: + return "half-open" + default: + return fmt.Sprintf("cbState(%d)", int(s)) + } +} + +// circuitBreaker tracks consecutive attestation fetch failures for a single +// provider and opens the circuit after a configurable threshold. While open, +// allow() returns false immediately so no upstream HTTP call is made. After +// resetTimeout the circuit enters the half-open state and allows one probe +// request through. A successful probe closes the circuit; a failed probe +// reopens it. +// +// The breaker is keyed per-provider, not per-provider+model. A provider's +// attestation endpoint is shared across all its models; if it starts timing +// out it does so for every model, not selectively. Keying by (provider, model) +// would require threshold failures per model before any protection engaged, +// meaning tens of failures for a provider with many models before the first +// request was blocked. +// +// Caller contract: a caller that receives true from allow must call exactly +// one of success or failure when the upstream call completes. Failure to do so +// while the circuit is half-open will leave it permanently half-open. +type circuitBreaker struct { + mu sync.Mutex + state cbState + // failures counts fetch failures since the last success() call. It is only + // incremented in cbClosed (counting toward the threshold) and cbHalfOpen + // (counting probe failures). It is not incremented while cbOpen because the + // circuit is already tripped and the count is irrelevant until it resets. + failures int + openedAt time.Time + threshold int + resetTimeout time.Duration + now func() time.Time // injectable for tests; production uses time.Now +} + +// allow reports whether the caller should proceed with an upstream fetch. +// It is safe for concurrent use. +func (cb *circuitBreaker) allow() bool { + cb.mu.Lock() + defer cb.mu.Unlock() + switch cb.state { + case cbClosed: + return true + case cbOpen: + if cb.now().Sub(cb.openedAt) >= cb.resetTimeout { + cb.state = cbHalfOpen + return true // allow the single probe + } + return false + case cbHalfOpen: + return false // probe already in flight + default: + panic("proxy: unhandled cbState in allow()") + } +} + +// success closes the circuit and resets the failure counter. Call after a +// successful upstream fetch. +func (cb *circuitBreaker) success() { + cb.mu.Lock() + defer cb.mu.Unlock() + if cb.state != cbClosed { + slog.Warn("circuit breaker recovered", + "previous_state", cb.state, "failures_cleared", cb.failures) + } + cb.state = cbClosed + cb.failures = 0 +} + +// failure records a fetch failure. When the threshold is reached the circuit +// opens. A failure during the half-open probe reopens the circuit immediately. +// Failures while already open are discarded — the counter is only meaningful +// when counting toward the threshold or tracking probe outcomes. +func (cb *circuitBreaker) failure() { + cb.mu.Lock() + defer cb.mu.Unlock() + switch cb.state { + case cbClosed: + cb.failures++ + if cb.failures >= cb.threshold { + slog.Warn("circuit breaker open", + "failures", cb.failures, "threshold", cb.threshold) + cb.state = cbOpen + cb.openedAt = cb.now() + } + case cbHalfOpen: + cb.failures++ + slog.Warn("circuit breaker reopened after probe failure", + "failures", cb.failures) + cb.state = cbOpen + cb.openedAt = cb.now() + case cbOpen: + // Already open; discard. The counter is irrelevant until the circuit closes. + default: + panic("proxy: unhandled cbState in failure()") + } +} diff --git a/internal/proxy/circuit_breaker_test.go b/internal/proxy/circuit_breaker_test.go new file mode 100644 index 00000000..6d24f9a1 --- /dev/null +++ b/internal/proxy/circuit_breaker_test.go @@ -0,0 +1,119 @@ +package proxy + +import ( + "testing" + "time" +) + +// newTestBreaker creates a circuit breaker with controllable time. The returned +// advance function moves the clock forward by the given duration. +func newTestBreaker(threshold int) (cb *circuitBreaker, advance func(time.Duration)) { + now := time.Now() + cb = &circuitBreaker{ + threshold: threshold, + resetTimeout: time.Minute, + now: func() time.Time { return now }, + } + advance = func(d time.Duration) { now = now.Add(d) } + return +} + +func TestCircuitBreaker_InitialState(t *testing.T) { + cb, _ := newTestBreaker(3) + for range 5 { + if !cb.allow() { + t.Error("new circuit breaker should allow all requests") + } + } +} + +func TestCircuitBreaker_OpensAfterThreshold(t *testing.T) { + cb, _ := newTestBreaker(3) + + for i := range 2 { + cb.failure() + if !cb.allow() { + t.Errorf("circuit should still be closed after %d failure(s)", i+1) + } + } + + cb.failure() // 3rd — hits threshold + if cb.allow() { + t.Error("circuit should be open after threshold failures") + } +} + +func TestCircuitBreaker_OpenBlocksRequests(t *testing.T) { + cb, _ := newTestBreaker(1) + cb.failure() // open immediately (threshold = 1) + + for range 5 { + if cb.allow() { + t.Error("open circuit should block all requests") + } + } +} + +func TestCircuitBreaker_HalfOpenAfterReset(t *testing.T) { + cb, advance := newTestBreaker(1) + cb.failure() // open + + if cb.allow() { + t.Error("circuit should be blocked immediately after opening") + } + + advance(time.Minute + time.Second) // advance past resetTimeout + + if !cb.allow() { + t.Error("circuit should allow probe after reset timeout") + } + // Already half-open; second call should block (probe in flight). + if cb.allow() { + t.Error("only one probe should be allowed while half-open") + } +} + +func TestCircuitBreaker_RecoveryClosesCircuit(t *testing.T) { + cb, advance := newTestBreaker(1) + cb.failure() // open + advance(time.Minute + time.Second) // advance past reset + cb.allow() // transitions to half-open + cb.success() // probe succeeded → closed + + if !cb.allow() { + t.Error("circuit should be closed after successful probe") + } +} + +func TestCircuitBreaker_ProbeFailureReopens(t *testing.T) { + cb, advance := newTestBreaker(1) + cb.failure() // open + advance(time.Minute + time.Second) // advance past reset + cb.allow() // transitions to half-open + cb.failure() // probe failed → reopen + + if cb.allow() { + t.Error("circuit should be open after probe failure") + } +} + +func TestCircuitBreaker_SuccessResetsFailureCount(t *testing.T) { + cb, _ := newTestBreaker(3) + + for range 2 { + cb.failure() + } + cb.success() // reset + + // Need a full threshold of new failures to open again. + for i := range 2 { + cb.failure() + if !cb.allow() { + t.Errorf("circuit should still be closed after %d failure(s) post-reset", i+1) + } + } + cb.failure() // 3rd — threshold + if cb.allow() { + t.Error("circuit should open after new threshold failures post-reset") + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5f0743d0..49351599 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -281,12 +281,13 @@ type Server struct { rekorClient *attestation.RekorClient nvidiaVerifier *attestation.NVIDIAVerifier mux *http.ServeMux - attestClient *http.Client // for attestation fetches - collateral trust.HTTPSGetter // for Intel PCS collateral fetches - verifyQuote attestation.TDXVerifier // constructed from cfg.Offline + collateral - upstreamClient *http.Client // for chat completions forwards - sseConns atomic.Int64 // active SSE /events connections - e2eeFailed sync.Map // cacheKey → true; tracks provider+model pairs with E2EE decryption failures + attestClient *http.Client // for attestation fetches + collateral trust.HTTPSGetter // for Intel PCS collateral fetches + verifyQuote attestation.TDXVerifier // constructed from cfg.Offline + collateral + upstreamClient *http.Client // for chat completions forwards + sseConns atomic.Int64 // active SSE /events connections + e2eeFailed sync.Map // cacheKey → true; tracks provider+model pairs with E2EE decryption failures + breakers map[string]*circuitBreaker // provider name → circuit breaker; fixed at startup, no sync needed stats stats } @@ -349,6 +350,15 @@ func New(cfg *config.Config) (*Server, error) { return nil, errors.New("no providers configured") } + s.breakers = make(map[string]*circuitBreaker, len(s.providers)) + for name := range s.providers { + s.breakers[name] = &circuitBreaker{ + threshold: defaultBreakerThreshold, + resetTimeout: defaultBreakerResetTimeout, + now: time.Now, + } + } + s.mux.HandleFunc("GET /{$}", s.handleIndex) s.mux.HandleFunc("GET /events", s.handleEvents) s.mux.HandleFunc("POST /v1/chat/completions", s.handleEndpoint(&chatEndpoint)) @@ -600,6 +610,18 @@ func (s *Server) fetchAndVerify(ctx context.Context, prov *provider.Provider, up return nil, nil } + cb := s.breakers[prov.Name] + if !cb.allow() { + slog.WarnContext(ctx, "circuit breaker open; skipping attestation fetch", "provider", prov.Name) + // Also record in negCache so that the next request for this specific + // (provider, model) pair is rejected at the negCache.IsBlocked check + // before it even reaches fetchAndVerify. This avoids repeated mutex + // acquisitions on the circuit breaker for the same model while the + // circuit is open. + s.negCache.Record(prov.Name, upstreamModel) + return nil, nil + } + totalStart := time.Now() nonce := attestation.NewNonce() @@ -608,9 +630,27 @@ func (s *Server) fetchAndVerify(ctx context.Context, prov *provider.Provider, up raw, err := prov.Attester.FetchAttestation(ctx, upstreamModel, nonce) if err != nil { slog.ErrorContext(ctx, "attestation fetch failed", "provider", prov.Name, "model", upstreamModel, "err", err) + // Do not count client-driven context terminations as provider failures. + // context.Canceled (client disconnect) and context.DeadlineExceeded + // (client-imposed timeout) both say nothing about provider health — + // the provider may be perfectly healthy. Counting them would let an + // attacker trip the circuit with threshold fire-and-disconnect or + // short-deadline requests. Use ctx.Err() to catch all context-error + // variants rather than enumerating specific sentinel values. + if ctx.Err() == nil { + cb.failure() + } s.negCache.Record(prov.Name, upstreamModel) return nil, nil } + // success() is called on a clean network fetch, not contingent on + // verification passing. Verification failure (bad TDX quote, failed + // supply-chain check) means the TEE returned wrong content, not that + // the endpoint is unavailable. Requiring a non-blocked report before + // closing the circuit would keep it open indefinitely for a provider + // that consistently fails verification — which is a policy problem, + // not a reliability problem. + cb.success() fetchDur := time.Since(fetchStart) slog.DebugContext(ctx, "attestation fetch complete", "provider", prov.Name, "elapsed", fetchDur) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 74932315..98aa16e1 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -4346,3 +4346,69 @@ func TestDashboardEvents(t *testing.T) { } } } + +// TestCircuitBreaker_BlocksAfterThreshold verifies that after +// defaultBreakerThreshold consecutive attestation fetch failures the circuit +// opens and subsequent requests are rejected without contacting the attestation +// endpoint. +func TestCircuitBreaker_BlocksAfterThreshold(t *testing.T) { + // Must match defaultBreakerThreshold in proxy package. + const threshold = 5 + + var calls atomic.Int64 + attestSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v1/tee/attestation" { + calls.Add(1) + t.Logf("attestation server: call #%d (model=%s)", calls.Load(), r.URL.Query().Get("model")) + // Use 400 (not 5xx) so RetryTransport does not retry the request. + // A 5xx response would cause 3 HTTP calls per FetchAttestation invocation, + // making the assertion on total call count incorrect. + http.Error(w, "bad request", http.StatusBadRequest) + return + } + // Model listing and other endpoints return an empty list. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer attestSrv.Close() + + proxySrv := newProxyServer(t, buildConfig(attestSrv.URL, false)) + defer proxySrv.Close() + + // Fire `threshold` requests with distinct model names so each bypasses the + // negative cache (which is keyed per provider+model). Each failure + // increments the provider-level circuit breaker. + for i := range threshold { + model := fmt.Sprintf("cb-test-model-%d", i) + resp, err := postChat(t, proxySrv.URL, model, false) + if err != nil { + t.Fatalf("request %d: %v", i, err) + } + resp.Body.Close() + t.Logf("request %d (model=%s): status %d", i, model, resp.StatusCode) + if resp.StatusCode != http.StatusBadGateway { + t.Errorf("request %d: got status %d, want %d", i, resp.StatusCode, http.StatusBadGateway) + } + } + + if got := calls.Load(); got != int64(threshold) { + t.Errorf("expected %d attestation calls to open the circuit, got %d", threshold, got) + } + + // Send one more request with a brand-new model (no negCache entry). The + // circuit should be open, so the attestation endpoint must not be called. + newModel := fmt.Sprintf("cb-test-model-%d", threshold) + resp, err := postChat(t, proxySrv.URL, newModel, false) + if err != nil { + t.Fatalf("circuit-open request: %v", err) + } + resp.Body.Close() + t.Logf("circuit-open request (model=%s): status %d", newModel, resp.StatusCode) + if resp.StatusCode != http.StatusBadGateway { + t.Errorf("circuit-open request: got status %d, want %d", resp.StatusCode, http.StatusBadGateway) + } + + if got := calls.Load(); got != int64(threshold) { + t.Errorf("circuit open: attestation endpoint should not have received additional calls; got %d total, want %d", got, threshold) + } +}