Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/invoker-impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ restate-errors = { workspace = true }
restate-futures-util = { workspace = true }
restate-invoker-api = { workspace = true }
restate-queue = { workspace = true }
restate-serde-util = { workspace = true }
restate-service-client = { workspace = true }
restate-service-protocol = { workspace = true, features = ["message", "codec"] }
restate-service-protocol-v4 = { workspace = true, features = ["message-codec", "entry-codec"] }
Expand Down
211 changes: 200 additions & 11 deletions crates/invoker-impl/src/invocation_task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ use std::task::{Context, Poll, ready};
use std::time::{Duration, Instant};

use bytes::Bytes;
use futures::{FutureExt, Stream};
use futures::{FutureExt, Stream, StreamExt};
use http::response::Parts as ResponseParts;
use http::{HeaderName, HeaderValue, Response};
use http_body::{Body, Frame};
use metrics::histogram;
use metrics::{counter, histogram};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::task::AbortOnDropHandle;
use tracing::instrument;
use tracing::{debug, instrument};

use restate_invoker_api::invocation_reader::{InvocationReader, InvocationReaderTransaction};
use restate_invoker_api::invocation_reader::{
EagerState, InvocationReader, InvocationReaderTransaction,
};
use restate_invoker_api::{EntryEnricher, InvokeInputJournal};
use restate_serde_util::ByteCount;
use restate_service_client::{Request, ResponseBody, ServiceClient, ServiceClientError};
use restate_types::deployment::PinnedDeployment;
use restate_types::identifiers::{InvocationId, PartitionLeaderEpoch};
Expand All @@ -51,7 +54,7 @@ use restate_types::service_protocol::ServiceProtocolVersion;
use crate::TokenBucket;
use crate::error::InvokerError;
use crate::invocation_task::service_protocol_runner::ServiceProtocolRunner;
use crate::metric_definitions::{ID_LOOKUP, INVOKER_TASK_DURATION};
use crate::metric_definitions::{ID_LOOKUP, INVOKER_EAGER_STATE_TRUNCATED, INVOKER_TASK_DURATION};

// Clippy false positive, might be caused by Bytes contained within HeaderValue.
// https://github.com/rust-lang/rust/issues/40543#issuecomment-1212981256
Expand Down Expand Up @@ -82,6 +85,56 @@ const SERVICE_PROTOCOL_VERSION_V6: HeaderValue =
#[allow(clippy::declare_interior_mutable_const)]
const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server");

/// Collects state entries from an [`EagerState`] stream, respecting a size limit.
///
/// Returns a tuple of `(is_partial, entries)` where:
/// - `is_partial` is true if the state was already partial or if collection stopped due to size limit
/// - `entries` contains the collected and mapped key-value bytes
///
/// If the first entry already exceeds the size limit, then an empty entries [`Vec`] is returned.
async fn collect_eager_state<S, E, T>(
state: Option<EagerState<S>>,
size_limit: usize,
mut mapper: impl FnMut((Bytes, Bytes)) -> T,
) -> Result<(bool, Vec<T>), InvokerError>
where
S: Stream<Item = Result<(Bytes, Bytes), E>> + Send,
E: std::error::Error + Send + Sync + 'static,
{
let Some(state) = state else {
return Ok((true, Vec::new()));
};

let mut is_partial = state.is_partial();
let mut entries = Vec::new();
let mut total_size: usize = 0;

let mut stream = std::pin::pin!(state.into_inner());
while let Some(result) = stream.next().await {
let (key, value) = result.map_err(|e| InvokerError::StateReader(e.into()))?;
let entry_size = key.len() + value.len();

// Check if adding this entry would exceed the limit
if total_size.saturating_add(entry_size) > size_limit {
debug!(
"Eager state size limit reached ({}, limit: {}), \
sending partial state with {} entries",
ByteCount::from(total_size),
ByteCount::from(size_limit),
entries.len()
);
counter!(INVOKER_EAGER_STATE_TRUNCATED).increment(1);
is_partial = true;
break;
}

total_size = total_size.saturating_add(entry_size);
entries.push(mapper((key, value)));
}

Ok((is_partial, entries))
}

pub(super) struct InvocationTaskOutput {
pub(super) partition: PartitionLeaderEpoch,
pub(super) invocation_id: InvocationId,
Expand Down Expand Up @@ -143,7 +196,7 @@ pub(super) struct InvocationTask<EE, DMR> {
invocation_target: InvocationTarget,
inactivity_timeout: Duration,
abort_timeout: Duration,
disable_eager_state: bool,
eager_state_size_limit: usize,
message_size_warning: NonZeroUsize,
message_size_limit: NonZeroUsize,
retry_count_since_last_stored_entry: u32,
Expand Down Expand Up @@ -213,7 +266,7 @@ where
invocation_target: InvocationTarget,
default_inactivity_timeout: Duration,
default_abort_timeout: Duration,
disable_eager_state: bool,
eager_state_size_limit: usize,
message_size_warning: NonZeroUsize,
message_size_limit: NonZeroUsize,
retry_count_since_last_stored_entry: u32,
Expand All @@ -230,7 +283,7 @@ where
invocation_target,
inactivity_timeout: default_inactivity_timeout,
abort_timeout: default_abort_timeout,
disable_eager_state,
eager_state_size_limit,
entry_enricher,
schemas: deployment_metadata_resolver,
invoker_tx,
Expand Down Expand Up @@ -389,10 +442,17 @@ where
)));
}

// Determine if we need to read state
// Resolve the effective eager state size limit:
// Per-handler/service override takes precedence over server-level config.
// 0 means "disable eager state", non-zero values are clamped to the message size limit.
if let Some(limit) = invocation_attempt_options.eager_state_size_limit {
let limit = limit.as_usize();
self.eager_state_size_limit = limit.min(self.message_size_limit.get());
}

// Determine if we need to read state (0 means lazy state / no eager state)
let keyed_service_id = if self.invocation_target.as_keyed_service_id().is_some()
&& invocation_attempt_options.enable_lazy_state != Some(true)
&& !self.disable_eager_state
&& self.eager_state_size_limit > 0
{
self.invocation_target.as_keyed_service_id()
} else {
Expand Down Expand Up @@ -548,3 +608,132 @@ impl Stream for ResponseStream {
}
}
}

#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures::stream;

use restate_invoker_api::invocation_reader::EagerState;

use super::collect_eager_state;

type StateResult = Result<(Bytes, Bytes), std::io::Error>;

// Helper to create a (Bytes, Bytes) pair of known sizes
fn entry(key_size: usize, value_size: usize) -> StateResult {
Ok((
Bytes::from(vec![b'k'; key_size]),
Bytes::from(vec![b'v'; value_size]),
))
}

#[tokio::test]
async fn collect_eager_state_no_state_returns_partial() {
let (is_partial, entries) = collect_eager_state::<
stream::Empty<Result<(Bytes, Bytes), std::io::Error>>,
_,
_,
>(None, 1024, std::convert::identity)
.await
.unwrap();

assert!(is_partial, "no state should be reported as partial");
assert!(entries.is_empty());
}

#[tokio::test]
async fn collect_eager_state_complete_within_limit() {
let items = vec![entry(10, 20), entry(5, 15)];
let state = EagerState::new_complete(stream::iter(items));

let (is_partial, entries) = collect_eager_state(Some(state), 1024, std::convert::identity)
.await
.unwrap();

assert!(!is_partial, "all entries fit within limit");
assert_eq!(entries.len(), 2);
}

#[tokio::test]
async fn collect_eager_state_preserves_partial_flag() {
// Stream is pre-flagged as partial even though all entries fit
let items = vec![entry(10, 10)];
let state = EagerState::new_partial(stream::iter(items));

let (is_partial, entries) = collect_eager_state(Some(state), 1024, std::convert::identity)
.await
.unwrap();

assert!(is_partial, "partial flag should be preserved from source");
assert_eq!(entries.len(), 1);
}

#[tokio::test]
async fn collect_eager_state_truncates_at_limit() {
// 3 entries of 50 bytes each, limit of 120 bytes => should fit 2
let items = vec![entry(25, 25), entry(25, 25), entry(25, 25)];
let state = EagerState::new_complete(stream::iter(items));

let (is_partial, entries) = collect_eager_state(Some(state), 120, std::convert::identity)
.await
.unwrap();

assert!(is_partial, "should be partial after truncation");
assert_eq!(entries.len(), 2, "only 2 entries should fit (100 bytes)");
}

#[tokio::test]
async fn collect_eager_state_first_entry_always_included() {
// Single entry larger than the limit — should return empty entries
let items = vec![entry(100, 101)];
let state = EagerState::new_complete(stream::iter(items));

let (is_partial, entries) = collect_eager_state(Some(state), 200, std::convert::identity)
.await
.unwrap();

assert!(is_partial, "first entry exceeded limit so partial state");
assert!(
entries.is_empty(),
"first entry exceeded limit so empty entries"
);
}

#[tokio::test]
async fn collect_eager_state_stream_error_propagated() {
let items: Vec<StateResult> = vec![Err(std::io::Error::other("boom"))];
let state = EagerState::new_complete(stream::iter(items));

let result = collect_eager_state(Some(state), 1024, std::convert::identity).await;
assert!(result.is_err(), "stream error should be propagated");
}

#[tokio::test]
async fn collect_eager_state_exact_boundary() {
// 2 entries of exactly 50 bytes each, limit of 100 => both should fit
let items = vec![entry(25, 25), entry(25, 25)];
let state = EagerState::new_complete(stream::iter(items));

let (is_partial, entries) = collect_eager_state(Some(state), 100, std::convert::identity)
.await
.unwrap();

assert!(!is_partial, "entries exactly at limit should fit");
assert_eq!(entries.len(), 2);
}

#[tokio::test]
async fn collect_eager_state_one_byte_over_limit() {
// 2 entries of 50 bytes each, limit of 99 => only first should fit
let items = vec![entry(25, 25), entry(25, 25)];
let state = EagerState::new_complete(stream::iter(items));

let (is_partial, entries) = collect_eager_state(Some(state), 99, std::convert::identity)
.await
.unwrap();

assert!(is_partial, "should be partial when 1 byte over");
assert_eq!(entries.len(), 1, "only first entry should fit");
}
}
24 changes: 9 additions & 15 deletions crates/invoker-impl/src/invocation_task/service_protocol_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::convert::Infallible;
use std::time::Duration;

use bytes::Bytes;
use futures::{Stream, StreamExt, TryStreamExt, stream};
use futures::{Stream, StreamExt, stream};
use http::uri::PathAndQuery;
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use http_body::Frame;
Expand Down Expand Up @@ -48,7 +48,7 @@ use crate::Notification;
use crate::error::{InvocationErrorRelatedEntry, InvokerError, SdkInvocationError};
use crate::invocation_task::{
InvocationTask, InvocationTaskOutputInner, InvokerBodyStream, InvokerRequestStreamSender,
ResponseChunk, ResponseStream, TerminalLoopState, X_RESTATE_SERVER,
ResponseChunk, ResponseStream, TerminalLoopState, X_RESTATE_SERVER, collect_eager_state,
invocation_id_to_header_value, service_protocol_version_to_header_value,
};

Expand Down Expand Up @@ -484,19 +484,13 @@ where
S: Stream<Item = Result<(Bytes, Bytes), E>> + Send,
E: std::error::Error + Send + Sync + 'static,
{
// Collect state if present, mapping to StateEntry while collecting
let (partial_state, state_map) = if let Some(state) = state {
let is_partial = state.is_partial();
let entries: Vec<StateEntry> = state
.into_inner()
.map_ok(|(key, value)| StateEntry { key, value })
.try_collect()
.await
.map_err(|e| InvokerError::StateReader(e.into()))?;
(is_partial, entries)
} else {
(true, Vec::new())
};
// Collect state entries with size limit
let (partial_state, state_map) = collect_eager_state(
state,
self.invocation_task.eager_state_size_limit,
|(key, value)| StateEntry { key, value },
)
.await?;

// Send the invoke frame
self.write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::time::Duration;

use bytes::Bytes;
use bytestring::ByteString;
use futures::{Stream, StreamExt, TryStreamExt, stream};
use futures::{Stream, StreamExt, stream};
use gardal::futures::StreamExt as GardalStreamExt;
use http::uri::PathAndQuery;
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
Expand Down Expand Up @@ -64,7 +64,7 @@ use crate::error::{
};
use crate::invocation_task::{
InvocationTask, InvocationTaskOutputInner, InvokerBodyStream, InvokerRequestStreamSender,
ResponseChunk, ResponseStream, TerminalLoopState, X_RESTATE_SERVER,
ResponseChunk, ResponseStream, TerminalLoopState, X_RESTATE_SERVER, collect_eager_state,
invocation_id_to_header_value, retry_after, service_protocol_version_to_header_value,
};

Expand Down Expand Up @@ -599,19 +599,13 @@ where
S: Stream<Item = Result<(Bytes, Bytes), E>> + Send,
E: std::error::Error + Send + Sync + 'static,
{
// Collect state if present, mapping to StateEntry while collecting
let (partial_state, state_map) = if let Some(state) = state {
let is_partial = state.is_partial();
let entries: Vec<StateEntry> = state
.into_inner()
.map_ok(|(key, value)| StateEntry { key, value })
.try_collect()
.await
.map_err(|e| InvokerError::StateReader(e.into()))?;
(is_partial, entries)
} else {
(true, Vec::new())
};
// Collect state entries with size limit
let (partial_state, state_map) = collect_eager_state(
state,
self.invocation_task.eager_state_size_limit,
|(key, value)| StateEntry { key, value },
)
.await?;

// Send the invoke frame
self.write(
Expand Down
2 changes: 1 addition & 1 deletion crates/invoker-impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ where
invocation_target,
opts.inactivity_timeout.into(),
opts.abort_timeout.into(),
opts.disable_eager_state,
opts.eager_state_size_limit(),
opts.message_size_warning.as_non_zero_usize(),
opts.message_size_limit(),
retry_count_since_last_stored_entry,
Expand Down
Loading
Loading