-
Notifications
You must be signed in to change notification settings - Fork 1
Hotpath fixes #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hotpath fixes #15
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| from .events import EventBus, ObservabilityEventType | ||
| from .logging import logger | ||
| from .retry import RetryManager | ||
| from .state import append_token, create_state, mark_completed, update_checkpoint | ||
| from .state import append_token, create_state, flush_content, mark_completed, update_checkpoint | ||
| from .types import ( | ||
| AwaitableStreamFactory, | ||
| CheckIntervals, | ||
|
|
@@ -482,6 +482,7 @@ async def run_stream() -> AsyncIterator[Event]: | |
| # Checkpoint invalid - start fresh | ||
| logger.debug("Checkpoint validation failed, starting fresh") | ||
| state.content = "" | ||
| state._content_buffer.clear() | ||
| state.token_count = 0 | ||
| pending_checkpoint = None | ||
|
|
||
|
|
@@ -718,82 +719,94 @@ async def emit_buffered_tool_calls() -> AsyncIterator[Event]: | |
| state.token_count, | ||
| ) | ||
|
|
||
| # Fire on_event callback for token events | ||
| _fire_callback(cb.on_event, event) | ||
|
|
||
| # Fire on_token callback | ||
| _fire_callback(cb.on_token, token_text) | ||
| # Fire per-token callbacks (skip function call overhead when None) | ||
| if cb.on_event is not None: | ||
| _fire_callback(cb.on_event, event) | ||
| if cb.on_token is not None: | ||
| _fire_callback(cb.on_token, token_text) | ||
|
|
||
| # Check guardrails periodically | ||
| if ( | ||
| state.token_count % guardrail_interval == 0 | ||
| and guardrails | ||
| ): | ||
| phase_start_time = time.perf_counter() | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_PHASE_START, | ||
| phase="post", | ||
| ruleCount=len(guardrails), | ||
| ) | ||
| flush_content(state) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2: Flushing the full content string inside periodic checks makes long streams O(n²) again. Prompt for AI agents |
||
| _has_obs = event_bus._handler is not None | ||
|
|
||
| all_violations = [] | ||
| for idx, rule in enumerate(guardrails): | ||
| callback_id = _next_callback_id() | ||
| rule_start_time = time.perf_counter() | ||
| if _has_obs: | ||
| phase_start_time = time.perf_counter() | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_RULE_START, | ||
| index=idx, | ||
| ruleId=rule.name, | ||
| callbackId=callback_id, | ||
| ObservabilityEventType.GUARDRAIL_PHASE_START, | ||
| phase="post", | ||
| ruleCount=len(guardrails), | ||
| ) | ||
|
|
||
| all_violations = [] | ||
| for idx, rule in enumerate(guardrails): | ||
| if _has_obs: | ||
| callback_id = _next_callback_id() | ||
| rule_start_time = time.perf_counter() | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_RULE_START, | ||
| index=idx, | ||
| ruleId=rule.name, | ||
| callbackId=callback_id, | ||
| ) | ||
|
|
||
| rule_violations = rule.check(state) | ||
| passed = len(rule_violations) == 0 | ||
| rule_duration_ms = int( | ||
| (time.perf_counter() - rule_start_time) * 1000 | ||
| ) | ||
| # Emit result for each rule | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_RULE_RESULT, | ||
| index=idx, | ||
| ruleId=rule.name, | ||
| passed=passed, | ||
| violation=rule_violations[0].__dict__ | ||
| if rule_violations | ||
| else None, | ||
| ) | ||
|
|
||
| if _has_obs: | ||
| passed = len(rule_violations) == 0 | ||
| rule_duration_ms = int( | ||
| (time.perf_counter() - rule_start_time) * 1000 | ||
| ) | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_RULE_RESULT, | ||
| index=idx, | ||
| ruleId=rule.name, | ||
| passed=passed, | ||
| violation=rule_violations[0].__dict__ | ||
| if rule_violations | ||
| else None, | ||
| ) | ||
|
|
||
| if rule_violations: | ||
| all_violations.extend(rule_violations) | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_RULE_END, | ||
| index=idx, | ||
| ruleId=rule.name, | ||
| passed=passed, | ||
| callbackId=callback_id, | ||
| durationMs=rule_duration_ms, | ||
| ) | ||
|
|
||
| if _has_obs: | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_RULE_END, | ||
| index=idx, | ||
| ruleId=rule.name, | ||
| passed=passed, | ||
| callbackId=callback_id, | ||
| durationMs=rule_duration_ms, | ||
| ) | ||
|
|
||
| if all_violations: | ||
| state.violations.extend(all_violations) | ||
| # Fire on_violation callback for each violation | ||
| for v in all_violations: | ||
| _fire_callback(cb.on_violation, v) | ||
|
|
||
| phase_duration_ms = int( | ||
| (time.perf_counter() - phase_start_time) * 1000 | ||
| ) | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_PHASE_END, | ||
| phase="post", | ||
| passed=len(all_violations) == 0, | ||
| violations=[v.__dict__ for v in all_violations], | ||
| durationMs=phase_duration_ms, | ||
| ) | ||
| if _has_obs: | ||
| phase_duration_ms = int( | ||
| (time.perf_counter() - phase_start_time) * 1000 | ||
| ) | ||
| event_bus.emit( | ||
| ObservabilityEventType.GUARDRAIL_PHASE_END, | ||
| phase="post", | ||
| passed=len(all_violations) == 0, | ||
| violations=[v.__dict__ for v in all_violations], | ||
| durationMs=phase_duration_ms, | ||
| ) | ||
|
|
||
| # Check drift periodically | ||
| if ( | ||
| drift_detector is not None | ||
| and state.token_count % drift_interval == 0 | ||
| ): | ||
| flush_content(state) | ||
| drift_result = drift_detector.check( | ||
| state.content, token_text | ||
| ) | ||
|
|
@@ -1189,6 +1202,7 @@ async def emit_buffered_tool_calls() -> AsyncIterator[Event]: | |
| else: | ||
| # Reset state for fresh retry (no continuation) | ||
| state.content = "" | ||
| state._content_buffer.clear() | ||
| state.token_count = 0 | ||
| state.checkpoint = "" | ||
| state.completed = False | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,23 +12,33 @@ def create_state() -> State: | |
| return State() | ||
|
|
||
|
|
||
| def flush_content(state: State) -> None: | ||
| """Materialize content buffer into content string. Call before reading state.content.""" | ||
| buf = state._content_buffer | ||
| if buf: | ||
| state.content = state.content + "".join(buf) | ||
| buf.clear() | ||
|
|
||
|
|
||
| def update_checkpoint(state: State) -> None: | ||
| """Save current content as checkpoint.""" | ||
| flush_content(state) | ||
| state.checkpoint = state.content | ||
|
|
||
|
|
||
| def append_token(state: State, token: str) -> None: | ||
| """Append token to content and update timing.""" | ||
| """Append token to content buffer and update timing.""" | ||
| now = time.time() | ||
| if state.first_token_at is None: | ||
| state.first_token_at = now | ||
| state.last_token_at = now | ||
| state.content += token | ||
| state._content_buffer.append(token) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P1: Buffering tokens here leaves Prompt for AI agents |
||
| state.token_count += 1 | ||
|
|
||
|
|
||
| def mark_completed(state: State) -> None: | ||
| """Mark stream as completed and calculate duration.""" | ||
| flush_content(state) | ||
| state.completed = True | ||
| if state.first_token_at is not None: | ||
| state.duration = (state.last_token_at or time.time()) - state.first_token_at | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P2:
get_history()["last_content"]now returns only the sliding window, which silently truncates long outputs behind the existing API name.Prompt for AI agents