Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub struct ModelClient {
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
session_source: SessionSource,
prompt_cache_key: String,
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -101,6 +102,7 @@ impl ModelClient {
summary: ReasoningSummaryConfig,
conversation_id: ConversationId,
session_source: SessionSource,
prompt_cache_key: String,
) -> Self {
let client = create_client();

Expand All @@ -114,6 +116,7 @@ impl ModelClient {
effort,
summary,
session_source,
prompt_cache_key,
}
}

Expand Down Expand Up @@ -246,7 +249,7 @@ impl ModelClient {
store: azure_workaround,
stream: true,
include,
prompt_cache_key: Some(self.conversation_id.to_string()),
prompt_cache_key: Some(self.prompt_cache_key.clone()),
text,
};

Expand Down
24 changes: 18 additions & 6 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl Codex {
auth_manager: Arc<AuthManager>,
conversation_history: InitialHistory,
session_source: SessionSource,
prompt_cache_key: Option<String>,
) -> CodexResult<CodexSpawnOk> {
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
let (tx_event, rx_event) = async_channel::unbounded();
Expand Down Expand Up @@ -190,6 +191,7 @@ impl Codex {
tx_event.clone(),
conversation_history,
session_source_clone,
prompt_cache_key,
)
.await
.map_err(|e| {
Expand Down Expand Up @@ -253,6 +255,7 @@ pub(crate) struct Session {
pub(crate) active_turn: Mutex<Option<ActiveTurn>>,
pub(crate) services: SessionServices,
next_internal_sub_id: AtomicU64,
prompt_cache_key: String,
}

/// The context needed for a single turn of the conversation.
Expand Down Expand Up @@ -368,6 +371,7 @@ impl Session {
session_configuration: &SessionConfiguration,
conversation_id: ConversationId,
sub_id: String,
prompt_cache_key: String,
) -> TurnContext {
let config = session_configuration.original_config_do_not_use.clone();
let model_family = find_family_for_model(&session_configuration.model)
Expand Down Expand Up @@ -395,6 +399,7 @@ impl Session {
session_configuration.model_reasoning_summary,
conversation_id,
session_configuration.session_source.clone(),
prompt_cache_key,
);

let tools_config = ToolsConfig::new(&ToolsConfigParams {
Expand Down Expand Up @@ -425,6 +430,7 @@ impl Session {
tx_event: Sender<Event>,
initial_history: InitialHistory,
session_source: SessionSource,
prompt_cache_key: Option<String>,
) -> anyhow::Result<Arc<Self>> {
debug!(
"Configuring session: model={}; provider={:?}",
Expand Down Expand Up @@ -589,6 +595,7 @@ impl Session {
active_turn: Mutex::new(None),
services,
next_internal_sub_id: AtomicU64::new(0),
prompt_cache_key: prompt_cache_key.unwrap_or_else(|| conversation_id.to_string()),
});

// Dispatch the SessionConfiguredEvent first and then report any errors.
Expand Down Expand Up @@ -618,6 +625,10 @@ impl Session {
Ok(sess)
}

pub(crate) fn get_prompt_cache_key(&self) -> String {
self.prompt_cache_key.clone()
}

pub(crate) fn get_tx_event(&self) -> Sender<Event> {
self.tx_event.clone()
}
Expand Down Expand Up @@ -703,6 +714,7 @@ impl Session {
&session_configuration,
self.conversation_id,
sub_id,
self.prompt_cache_key.clone(),
);
if let Some(final_schema) = updates.final_output_json_schema {
turn_context.final_output_json_schema = final_schema;
Expand Down Expand Up @@ -788,16 +800,11 @@ impl Session {
failure_message: Option<&str>,
) -> Option<SandboxCommandAssessment> {
let config = turn_context.client.config();
let provider = turn_context.client.provider().clone();
let auth_manager = Arc::clone(&self.services.auth_manager);
let otel = self.services.otel_event_manager.clone();
crate::sandboxing::assessment::assess_command(
config,
provider,
auth_manager,
&otel,
self.conversation_id,
turn_context.client.get_session_source(),
turn_context.client.clone(),
call_id,
command,
&turn_context.sandbox_policy,
Expand Down Expand Up @@ -1656,6 +1663,7 @@ async fn spawn_review_thread(
per_turn_config.model_reasoning_summary,
sess.conversation_id,
parent_turn_context.client.get_session_source(),
sess.prompt_cache_key.clone(),
);

let review_turn_context = TurnContext {
Expand Down Expand Up @@ -2529,6 +2537,7 @@ mod tests {
&session_configuration,
conversation_id,
"turn_id".to_string(),
conversation_id.to_string(),
);

let session = Session {
Expand All @@ -2538,6 +2547,7 @@ mod tests {
active_turn: Mutex::new(None),
services,
next_internal_sub_id: AtomicU64::new(0),
prompt_cache_key: conversation_id.to_string(),
};

(session, turn_context)
Expand Down Expand Up @@ -2603,6 +2613,7 @@ mod tests {
&session_configuration,
conversation_id,
"turn_id".to_string(),
conversation_id.to_string(),
));

let session = Arc::new(Session {
Expand All @@ -2612,6 +2623,7 @@ mod tests {
active_turn: Mutex::new(None),
services,
next_internal_sub_id: AtomicU64::new(0),
prompt_cache_key: conversation_id.to_string(),
});

(session, turn_context, rx_event)
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/codex_delegate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub(crate) async fn run_codex_conversation_interactive(
auth_manager,
InitialHistory::New,
SessionSource::SubAgent(SubAgentSource::Review),
Some(parent_session.get_prompt_cache_key()),
)
.await?;
let codex = Arc::new(codex);
Expand Down
12 changes: 11 additions & 1 deletion codex-rs/core/src/conversation_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ impl ConversationManager {
auth_manager,
InitialHistory::New,
self.session_source.clone(),
None,
)
.await?;
self.finalize_spawn(codex, conversation_id).await
Expand Down Expand Up @@ -150,6 +151,7 @@ impl ConversationManager {
auth_manager,
initial_history,
self.session_source.clone(),
None,
)
.await?;
self.finalize_spawn(codex, conversation_id).await
Expand All @@ -175,6 +177,7 @@ impl ConversationManager {
nth_user_message: usize,
config: Config,
path: PathBuf,
conversation_id: ConversationId,
) -> CodexResult<NewConversation> {
// Compute the prefix up to the cut point.
let history = RolloutRecorder::get_rollout_history(&path).await?;
Expand All @@ -185,7 +188,14 @@ impl ConversationManager {
let CodexSpawnOk {
codex,
conversation_id,
} = Codex::spawn(config, auth_manager, history, self.session_source.clone()).await?;
} = Codex::spawn(
config,
auth_manager,
history,
self.session_source.clone(),
Some(conversation_id.to_string()),
)
Comment on lines +193 to +197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reuse ancestor prompt cache key when forking

Passing Some(conversation_id.to_string()) here only works for the very first fork of a base conversation, where the prompt cache key equals the conversation id. Once you fork a fork (or any conversation whose session was started with a custom prompt cache key), the current conversation's cache key is inherited from its ancestor via Session::get_prompt_cache_key(), not from its own id. By hard-coding the new fork to use its immediate conversation id as the key, we break cache affinity for nested forks and lose the optimization this change is trying to introduce. We should look up the existing conversation's prompt cache key and propagate that through Codex::spawn instead of defaulting to the conversation id.

Useful? React with 👍 / 👎.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. I need to think about it more tomorrow.

.await?;

self.finalize_spawn(codex, conversation_id).await
}
Expand Down
23 changes: 1 addition & 22 deletions codex-rs/core/src/sandboxing/assessment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@ use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;

use crate::AuthManager;
use crate::ModelProviderInfo;
use crate::client::ModelClient;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::config::Config;
use crate::protocol::SandboxPolicy;
use askama::Template;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::ConversationId;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SandboxCommandAssessment;
use codex_protocol::protocol::SessionSource;
use futures::StreamExt;
use serde_json::json;
use tokio::time::timeout;
Expand Down Expand Up @@ -50,11 +46,8 @@ struct SandboxAssessmentPromptTemplate<'a> {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn assess_command(
config: Arc<Config>,
provider: ModelProviderInfo,
auth_manager: Arc<AuthManager>,
parent_otel: &OtelEventManager,
conversation_id: ConversationId,
session_source: SessionSource,
client: ModelClient,
call_id: &str,
command: &[String],
sandbox_policy: &SandboxPolicy,
Expand Down Expand Up @@ -132,20 +125,6 @@ pub(crate) async fn assess_command(
output_schema: Some(sandbox_assessment_schema()),
};

let child_otel =
parent_otel.with_model(config.model.as_str(), config.model_family.slug.as_str());

let client = ModelClient::new(
Arc::clone(&config),
Some(auth_manager),
child_otel,
provider,
config.model_reasoning_effort,
config.model_reasoning_summary,
conversation_id,
session_source,
);

let start = Instant::now();
let assessment_result = timeout(SANDBOX_ASSESSMENT_TIMEOUT, async move {
let mut stream = client.stream(&prompt).await?;
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/tests/chat_completions_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
summary,
conversation_id,
codex_protocol::protocol::SessionSource::Exec,
conversation_id.to_string(),
);

let mut prompt = Prompt::default();
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/tests/chat_completions_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
summary,
conversation_id,
codex_protocol::protocol::SessionSource::Exec,
conversation_id.to_string(),
);

let mut prompt = Prompt::default();
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/tests/responses_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async fn responses_stream_includes_task_type_header() {
summary,
conversation_id,
SessionSource::Exec,
conversation_id.to_string(),
);

let mut prompt = Prompt::default();
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/tests/suite/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
summary,
conversation_id,
codex_protocol::protocol::SessionSource::Exec,
conversation_id.to_string(),
);

let mut prompt = Prompt::default();
Expand Down
30 changes: 22 additions & 8 deletions codex-rs/core/tests/suite/compact_resume_fork.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use codex_core::config::OPENAI_DEFAULT_MODEL;
use codex_core::protocol::EventMsg;
use codex_core::protocol::Op;
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_protocol::ConversationId;
use codex_protocol::user_input::UserInput;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::ev_assistant_message;
Expand Down Expand Up @@ -78,7 +79,8 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
mount_initial_flow(&server).await;

// 2. Start a new conversation and drive it through the compact/resume/fork steps.
let (_home, config, manager, base) = start_test_conversation(&server).await;
let (_home, config, manager, base, base_conversation_id) =
start_test_conversation(&server).await;

user_turn(&base, "hello world").await;
compact_conversation(&base).await;
Expand All @@ -97,7 +99,7 @@ async fn compact_resume_and_fork_preserve_model_history_view() {
"compact+resume test expects resumed path {resumed_path:?} to exist",
);

let forked = fork_conversation(&manager, &config, resumed_path, 2).await;
let forked = fork_conversation(&manager, &config, resumed_path, 2, base_conversation_id).await;
user_turn(&forked, "AFTER_FORK").await;

// 3. Capture the requests to the model and validate the history slices.
Expand Down Expand Up @@ -535,7 +537,8 @@ async fn compact_resume_after_second_compaction_preserves_history() {
mount_second_compact_flow(&server).await;

// 2. Drive the conversation through compact -> resume -> fork -> compact -> resume.
let (_home, config, manager, base) = start_test_conversation(&server).await;
let (_home, config, manager, base, base_conversation_id) =
start_test_conversation(&server).await;

user_turn(&base, "hello world").await;
compact_conversation(&base).await;
Expand All @@ -554,7 +557,7 @@ async fn compact_resume_after_second_compaction_preserves_history() {
"second compact test expects resumed path {resumed_path:?} to exist",
);

let forked = fork_conversation(&manager, &config, resumed_path, 3).await;
let forked = fork_conversation(&manager, &config, resumed_path, 3, base_conversation_id).await;
user_turn(&forked, "AFTER_FORK").await;

compact_conversation(&forked).await;
Expand Down Expand Up @@ -780,7 +783,13 @@ async fn mount_second_compact_flow(server: &MockServer) {

async fn start_test_conversation(
server: &MockServer,
) -> (TempDir, Config, ConversationManager, Arc<CodexConversation>) {
) -> (
TempDir,
Config,
ConversationManager,
Arc<CodexConversation>,
ConversationId,
) {
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
..built_in_model_providers()["openai"].clone()
Expand All @@ -790,12 +799,16 @@ async fn start_test_conversation(
config.model_provider = model_provider;

let manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
let NewConversation { conversation, .. } = manager
let NewConversation {
conversation,
conversation_id,
..
} = manager
.new_conversation(config.clone())
.await
.expect("create conversation");

(home, config, manager, conversation)
(home, config, manager, conversation, conversation_id)
}

async fn user_turn(conversation: &Arc<CodexConversation>, text: &str) {
Expand Down Expand Up @@ -840,9 +853,10 @@ async fn fork_conversation(
config: &Config,
path: std::path::PathBuf,
nth_user_message: usize,
conversation_id: ConversationId,
) -> Arc<CodexConversation> {
let NewConversation { conversation, .. } = manager
.fork_conversation(nth_user_message, config.clone(), path)
.fork_conversation(nth_user_message, config.clone(), path, conversation_id)
.await
.expect("fork conversation");
conversation
Expand Down
Loading
Loading