From 2177517e5886c43990f836a66ef5084c9d071103 Mon Sep 17 00:00:00 2001 From: Ahmed Farghal Date: Wed, 22 Apr 2026 12:49:25 +0100 Subject: [PATCH] Introduce restate-sharding crate Part of the effort to decompose the restate-types monolith into focused, composable utility crates. Add a new `restate-sharding` crate (in `crates/sharding/`) that defines the foundational sharding types: - `KeyRange`: a compact `Copy` newtype over `std::range::RangeInclusive`. 16 bytes (vs 24 for std::ops::RangeInclusive), wire-format compatible serde and bilrost encoding. Provides iter(), is_overlapping(), and split() helpers. - `PartitionKey`: type alias for u64 - `WithPartitionKey`: trait for types that carry a partition key The crate is re-exported from `restate-types::sharding` for ergonomic access. Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.lock | 14 + Cargo.toml | 1 + crates/encoding/Cargo.toml | 1 + crates/sharding/Cargo.toml | 26 ++ crates/sharding/src/key_range.rs | 441 +++++++++++++++++++++++++++++ crates/sharding/src/lib.rs | 43 +++ crates/types/Cargo.toml | 1 + crates/types/src/identifiers.rs | 24 +- crates/types/src/invocation/mod.rs | 24 ++ crates/types/src/lib.rs | 5 + crates/wal-protocol/src/lib.rs | 8 +- 11 files changed, 568 insertions(+), 20 deletions(-) create mode 100644 crates/sharding/Cargo.toml create mode 100644 crates/sharding/src/key_range.rs create mode 100644 crates/sharding/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 1b1da154fe..237f311f8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7180,6 +7180,7 @@ dependencies = [ "bytestring", "rand", "restate-encoding-derive", + "restate-sharding", "restate-workspace-hack", "static_assertions", ] @@ -8027,6 +8028,18 @@ dependencies = [ "typify", ] +[[package]] +name = "restate-sharding" +version = "1.6.3-dev" +dependencies = [ + "bilrost", + "flexbuffers", + "restate-sharding", + "restate-workspace-hack", + "serde", + "serde_json", +] + [[package]] name = "restate-storage-api" version = "1.6.3-dev" @@ -8239,6 +8252,7 @@ dependencies = [ "restate-errors", "restate-memory", "restate-serde-util", + "restate-sharding", "restate-test-util", "restate-time-util", "restate-types", diff --git a/Cargo.toml b/Cargo.toml index e587d4574a..c81f1d0835 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,6 +83,7 @@ restate-service-protocol = { path = "crates/service-protocol" } restate-service-protocol-v4 = { path = "crates/service-protocol-v4" } restate-storage-api = { path = "crates/storage-api" } restate-storage-query-datafusion = { path = "crates/storage-query-datafusion" } +restate-sharding = { path = "crates/sharding" } restate-util-string = { path = "util/string" } restate-test-util = { path = "crates/test-util" } restate-time-util = { path = "crates/time-util" } diff --git a/crates/encoding/Cargo.toml b/crates/encoding/Cargo.toml index f576007962..390344433e 100644 --- a/crates/encoding/Cargo.toml +++ b/crates/encoding/Cargo.toml @@ -18,4 +18,5 @@ bytestring = { workspace = true } [dev-dependencies] rand = { workspace = true } +restate-sharding = { workspace = true, features = ["bilrost"] } static_assertions = { workspace = true } \ No newline at end of file diff --git a/crates/sharding/Cargo.toml b/crates/sharding/Cargo.toml new file mode 100644 index 0000000000..061fb76413 --- /dev/null +++ b/crates/sharding/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "restate-sharding" +version.workspace = true +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +publish = false + +[features] +default = [] +serde = ["dep:serde"] +bilrost = ["dep:bilrost"] + +[dependencies] +restate-workspace-hack = { workspace = true } + +bilrost = { workspace = true, optional = true } +serde = { workspace = true, optional = true } + +[dev-dependencies] +restate-sharding = { path = ".", default-features = false, features = ["serde", "bilrost"] } + +flexbuffers = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } diff --git a/crates/sharding/src/key_range.rs b/crates/sharding/src/key_range.rs new file mode 100644 index 0000000000..336654b513 --- /dev/null +++ b/crates/sharding/src/key_range.rs @@ -0,0 +1,441 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::fmt; +use std::ops::{Bound, RangeBounds}; +use std::range::RangeInclusiveIter; + +/// An inclusive range of partition keys `[start, end]`. +/// +/// This is a compact, `Copy` representation of an inclusive key range backed by +/// [`std::range::RangeInclusive`]. Compared to [`std::ops::RangeInclusive`]: +/// +/// Wire-format compatible with `std::ops::RangeInclusive` for both serde and bilrost. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct KeyRange(std::range::RangeInclusive); + +impl KeyRange { + /// The full key space `[0, u64::MAX]`. + pub const FULL: Self = Self::new(u64::MIN, u64::MAX); + + /// Creates a new inclusive key range `[start, end]`. + pub const fn new(start: u64, end: u64) -> Self { + Self(std::range::RangeInclusive { start, last: end }) + } + + /// The inclusive lower bound. + #[inline] + pub const fn start(&self) -> u64 { + self.0.start + } + + /// The inclusive upper bound. + #[inline] + pub const fn end(&self) -> u64 { + self.0.last + } + + /// Returns an iterator over all keys in this range. + #[inline] + pub fn iter(&self) -> RangeInclusiveIter { + self.0.iter() + } + + /// Returns `true` if `self` and `other` share at least one key. + #[inline] + pub const fn is_overlapping(&self, other: &KeyRange) -> bool { + self.start() <= other.end() && other.start() <= self.end() + } + + /// Splits this range into two roughly equal halves at the midpoint. + /// + /// Returns `None` if the range contains a single key (cannot be split). + /// Otherwise returns `(left, right)` where `left.end() + 1 == right.start()`. + pub const fn split(&self) -> Option<(KeyRange, KeyRange)> { + if self.start() == self.end() { + return None; + } + let mid = self.start() + (self.end() - self.start()) / 2; + Some(( + KeyRange::new(self.start(), mid), + KeyRange::new(mid + 1, self.end()), + )) + } +} + +impl RangeBounds for KeyRange { + #[inline] + fn start_bound(&self) -> Bound<&u64> { + self.0.start_bound() + } + + #[inline] + fn end_bound(&self) -> Bound<&u64> { + self.0.end_bound() + } +} + +impl From> for KeyRange { + #[inline] + fn from(range: std::ops::RangeInclusive) -> Self { + Self::new(*range.start(), *range.end()) + } +} + +impl From for std::ops::RangeInclusive { + #[inline] + fn from(range: KeyRange) -> Self { + range.start()..=range.end() + } +} + +impl From> for KeyRange { + #[inline] + fn from(range: std::range::RangeInclusive) -> Self { + Self(range) + } +} + +impl From for std::range::RangeInclusive { + #[inline] + fn from(range: KeyRange) -> Self { + range.0 + } +} + +impl fmt::Display for KeyRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[{}, {}]", self.start(), self.end()) + } +} + +#[cfg(feature = "serde")] +mod serde_impl { + use super::KeyRange; + + use serde::ser::SerializeStruct; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + /// Serializes as `{"start": , "end": }` to match `std::ops::RangeInclusive`'s + /// serde format. + impl Serialize for KeyRange { + fn serialize(&self, serializer: S) -> Result { + let mut state = serializer.serialize_struct("KeyRange", 2)?; + state.serialize_field("start", &self.start())?; + state.serialize_field("end", &self.end())?; + state.end() + } + } + + /// Deserializes from `{"start": , "end": }` to match `std::ops::RangeInclusive`'s + /// serde format. + impl<'de> Deserialize<'de> for KeyRange { + fn deserialize>(deserializer: D) -> Result { + #[derive(Deserialize)] + #[serde(deny_unknown_fields)] + struct Raw { + start: u64, + end: u64, + } + let raw = Raw::deserialize(deserializer)?; + Ok(KeyRange::new(raw.start, raw.end)) + } + } +} + +#[cfg(feature = "bilrost")] +mod bilrost_impl { + use super::KeyRange; + + use bilrost::DecodeErrorKind; + use bilrost::encoding::{EmptyState, ForOverwrite, Proxiable}; + + impl Proxiable for KeyRange { + type Proxy = (u64, u64); + + fn encode_proxy(&self) -> Self::Proxy { + (self.start(), self.end()) + } + + fn decode_proxy(&mut self, proxy: Self::Proxy) -> Result<(), DecodeErrorKind> { + *self = KeyRange::new(proxy.0, proxy.1); + Ok(()) + } + } + + impl ForOverwrite<(), KeyRange> for () { + fn for_overwrite() -> KeyRange { + KeyRange::new(0, 0) + } + } + + impl EmptyState<(), KeyRange> for () { + fn empty() -> KeyRange { + KeyRange::new(0, 0) + } + + fn is_empty(val: &KeyRange) -> bool { + val.start() == 0 && val.end() == 0 + } + + fn clear(val: &mut KeyRange) { + *val = <() as EmptyState<(), KeyRange>>::empty(); + } + } + + bilrost::delegate_proxied_encoding!( + use encoding (bilrost::encoding::General) + to encode proxied type (KeyRange) + with general encodings + ); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basics() { + let r = KeyRange::new(10, 20); + assert_eq!(r.start(), 10); + assert_eq!(r.end(), 20); + assert!(r.contains(&10)); + assert!(r.contains(&15)); + assert!(r.contains(&20)); + assert!(!r.contains(&9)); + assert!(!r.contains(&21)); + + // Copy + let r2 = r; + assert_eq!(r, r2); + + // Display + assert_eq!(format!("{r}"), "[10, 20]"); + } + + #[test] + fn full_range() { + assert_eq!(KeyRange::FULL.start(), u64::MIN); + assert_eq!(KeyRange::FULL.end(), u64::MAX); + assert!(KeyRange::FULL.contains(&0)); + assert!(KeyRange::FULL.contains(&u64::MAX)); + } + + #[test] + fn from_ops_range_inclusive() { + let ops_range: std::ops::RangeInclusive = 5..=10; + let kr = KeyRange::from(ops_range.clone()); + assert_eq!(kr.start(), 5); + assert_eq!(kr.end(), 10); + + let back: std::ops::RangeInclusive = kr.into(); + assert_eq!(back, ops_range); + } + + #[test] + fn from_range_inclusive() { + let range = std::range::RangeInclusive { start: 3, last: 7 }; + let kr = KeyRange::from(range); + assert_eq!(kr.start(), 3); + assert_eq!(kr.end(), 7); + + let back: std::range::RangeInclusive = kr.into(); + assert_eq!(back, range); + } + + #[test] + fn size_and_copy() { + assert_eq!(size_of::(), 16); + assert_eq!( + size_of::(), + size_of::>() + ); + // Old type is larger due to exhausted flag + padding + assert!(size_of::>() > size_of::()); + } + + #[test] + fn bilrost_roundtrip() { + use bilrost::{Message, OwnedMessage}; + + #[derive(Debug, PartialEq, bilrost::Message)] + struct TestMsg { + #[bilrost(1)] + range: KeyRange, + } + + let msg = TestMsg { + range: KeyRange::new(100, 200), + }; + let encoded = msg.encode_to_vec(); + let decoded = TestMsg::decode(&*encoded).unwrap(); + assert_eq!(msg, decoded); + } + + #[test] + fn iter() { + let r = KeyRange::new(5, 9); + let keys: Vec = r.iter().collect(); + assert_eq!(keys, vec![5, 6, 7, 8, 9]); + + let single = KeyRange::new(42, 42); + let keys: Vec = single.iter().collect(); + assert_eq!(keys, vec![42]); + } + + #[test] + fn is_overlapping() { + let a = KeyRange::new(0, 10); + let b = KeyRange::new(5, 15); + assert!(a.is_overlapping(&b)); + assert!(b.is_overlapping(&a)); + + // Adjacent ranges don't overlap + let c = KeyRange::new(11, 20); + assert!(!a.is_overlapping(&c)); + + // Touching at boundary + let d = KeyRange::new(10, 20); + assert!(a.is_overlapping(&d)); + + // Contained + let e = KeyRange::new(2, 8); + assert!(a.is_overlapping(&e)); + + // Same range + assert!(a.is_overlapping(&a)); + + // Disjoint + let f = KeyRange::new(100, 200); + assert!(!a.is_overlapping(&f)); + } + + #[test] + fn split() { + // Even range [0, 9] → [0, 4] and [5, 9] + let r = KeyRange::new(0, 9); + let (left, right) = r.split().unwrap(); + assert_eq!(left, KeyRange::new(0, 4)); + assert_eq!(right, KeyRange::new(5, 9)); + + // Odd range [0, 10] → [0, 5] and [6, 10] + let r = KeyRange::new(0, 10); + let (left, right) = r.split().unwrap(); + assert_eq!(left, KeyRange::new(0, 5)); + assert_eq!(right, KeyRange::new(6, 10)); + + // Two-element range [5, 6] → [5, 5] and [6, 6] + let r = KeyRange::new(5, 6); + let (left, right) = r.split().unwrap(); + assert_eq!(left, KeyRange::new(5, 5)); + assert_eq!(right, KeyRange::new(6, 6)); + + // Single element cannot be split + assert!(KeyRange::new(42, 42).split().is_none()); + + // Full range + let (left, right) = KeyRange::FULL.split().unwrap(); + assert_eq!(left.start(), 0); + assert_eq!(right.end(), u64::MAX); + assert_eq!(left.end() + 1, right.start()); + } + + #[test] + fn json_wire_compat_with_ops_range_inclusive() { + // KeyRange must be wire-compatible with std::ops::RangeInclusive in JSON. + let kr = KeyRange::new(100, 200); + let ops = std::ops::RangeInclusive::new(100u64, 200u64); + + let kr_json = serde_json::to_string(&kr).unwrap(); + let ops_json = serde_json::to_string(&ops).unwrap(); + assert_eq!(kr_json, ops_json, "serialized forms must be identical"); + + // Cross-deserialize in both directions + let kr_from_ops: KeyRange = serde_json::from_str(&ops_json).unwrap(); + assert_eq!(kr_from_ops, kr); + + let ops_from_kr: std::ops::RangeInclusive = serde_json::from_str(&kr_json).unwrap(); + assert_eq!(ops_from_kr, ops); + } + + #[test] + fn json_wire_compat_boundary_values() { + // Verify wire compat at boundary values (0, u64::MAX) + for (start, end) in [(0u64, 0u64), (0, u64::MAX), (u64::MAX, u64::MAX)] { + let kr = KeyRange::new(start, end); + let ops = std::ops::RangeInclusive::new(start, end); + + let kr_json = serde_json::to_string(&kr).unwrap(); + let ops_json = serde_json::to_string(&ops).unwrap(); + assert_eq!(kr_json, ops_json, "mismatch at ({start}, {end})"); + + let roundtrip: KeyRange = serde_json::from_str(&ops_json).unwrap(); + assert_eq!(roundtrip, kr); + } + } + + #[test] + fn flexbuffers_wire_compat_with_ops_range_inclusive() { + // KeyRange must be wire-compatible with std::ops::RangeInclusive in flexbuffers. + let kr = KeyRange::new(100, 200); + let ops = std::ops::RangeInclusive::new(100u64, 200u64); + + let kr_buf = flexbuffers::to_vec(kr).unwrap(); + let ops_buf = flexbuffers::to_vec(&ops).unwrap(); + assert_eq!(kr_buf, ops_buf, "serialized forms must be identical"); + + // Cross-deserialize in both directions + let kr_from_ops: KeyRange = flexbuffers::from_slice(&ops_buf).unwrap(); + assert_eq!(kr_from_ops, kr); + + let ops_from_kr: std::ops::RangeInclusive = flexbuffers::from_slice(&kr_buf).unwrap(); + assert_eq!(ops_from_kr, ops); + } + + #[test] + fn flexbuffers_wire_compat_boundary_values() { + for (start, end) in [(0u64, 0u64), (0, u64::MAX), (u64::MAX, u64::MAX)] { + let kr = KeyRange::new(start, end); + let ops = std::ops::RangeInclusive::new(start, end); + + let kr_buf = flexbuffers::to_vec(kr).unwrap(); + let ops_buf = flexbuffers::to_vec(&ops).unwrap(); + assert_eq!(kr_buf, ops_buf, "mismatch at ({start}, {end})"); + + let roundtrip: KeyRange = flexbuffers::from_slice(&ops_buf).unwrap(); + assert_eq!(roundtrip, kr); + } + } + + #[test] + fn bilrost_compat_with_ops_range() { + use bilrost::{Message, OwnedMessage}; + + // The bilrost crate's existing encoding for RangeInclusive uses RestateEncoding + // which is defined in restate-encoding. We can't test cross-type bilrost compat here + // since RestateEncoding is not available in this crate. That test belongs in + // restate-encoding's test suite. + // + // Here we verify that KeyRange's bilrost encoding is self-consistent. + #[derive(Debug, PartialEq, bilrost::Message)] + struct Msg { + #[bilrost(1)] + range: KeyRange, + } + + let msg = Msg { + range: KeyRange::new(u64::MIN, u64::MAX), + }; + let encoded = msg.encode_to_vec(); + let decoded = Msg::decode(&*encoded).unwrap(); + assert_eq!(msg, decoded); + } +} diff --git a/crates/sharding/src/lib.rs b/crates/sharding/src/lib.rs new file mode 100644 index 0000000000..b2223b25f4 --- /dev/null +++ b/crates/sharding/src/lib.rs @@ -0,0 +1,43 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +mod key_range; + +use std::sync::Arc; + +pub use key_range::KeyRange; + +/// Identifying to which partition a key belongs. This is unlike the [`PartitionId`] +/// which identifies a consecutive range of partition keys. +pub type PartitionKey = u64; + +/// Trait for data structures that have a partition key +pub trait WithPartitionKey { + /// Returns the partition key + fn partition_key(&self) -> PartitionKey; +} + +impl WithPartitionKey for &T +where + T: WithPartitionKey, +{ + fn partition_key(&self) -> PartitionKey { + (*self).partition_key() + } +} + +impl WithPartitionKey for Arc +where + T: WithPartitionKey, +{ + fn partition_key(&self) -> PartitionKey { + self.as_ref().partition_key() + } +} diff --git a/crates/types/Cargo.toml b/crates/types/Cargo.toml index 9491e2a894..3496bfe0ae 100644 --- a/crates/types/Cargo.toml +++ b/crates/types/Cargo.toml @@ -30,6 +30,7 @@ restate-memory = { workspace = true } restate-serde-util = { workspace = true } restate-test-util = { workspace = true, optional = true } restate-time-util = { workspace = true, features = ["serde", "serde_with"] } +restate-sharding = { workspace = true, features = ["serde", "bilrost"] } restate-util-string = { workspace = true, features = ["serde", "bilrost"] } restate-utoipa = { workspace = true } diff --git a/crates/types/src/identifiers.rs b/crates/types/src/identifiers.rs index 591de1172a..5a145457ba 100644 --- a/crates/types/src/identifiers.rs +++ b/crates/types/src/identifiers.rs @@ -13,6 +13,7 @@ mod partitioned; pub use partitioned::PartitionedResourceId; +pub use restate_sharding::{PartitionKey, WithPartitionKey}; use std::cell::RefCell; use std::fmt::{self, Display, Formatter}; @@ -155,10 +156,6 @@ pub type PartitionLeaderEpoch = (PartitionId, LeaderEpoch); // Just an alias pub type EntryIndex = u32; -/// Identifying to which partition a key belongs. This is unlike the [`PartitionId`] -/// which identifies a consecutive range of partition keys. -pub type PartitionKey = u64; - /// Returns the partition key computed from either the service_key, or idempotency_key, if possible fn deterministic_partition_key( service_key: Option<&str>, @@ -203,13 +200,6 @@ fn random_unscoped_service_partition_key(service_name: &str) -> PartitionKey { let bucket = rand::rng().random_range(0..UNSCOPED_SERVICE_PARTITION_KEY_FANOUT); unscoped_service_partition_key(service_name, bucket) } - -/// Trait for data structures that have a partition key -pub trait WithPartitionKey { - /// Returns the partition key - fn partition_key(&self) -> PartitionKey; -} - /// A family of resource identifiers that tracks the timestamp of its creation. pub trait TimestampAwareId { /// The timestamp when this ID was created. @@ -656,12 +646,6 @@ impl WithPartitionKey for InvocationId { } } -impl WithPartitionKey for T { - fn partition_key(&self) -> PartitionKey { - self.invocation_id().partition_key - } -} - impl Display for InvocationId { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { // encode the id such that it is possible to do a string prefix search for a @@ -816,6 +800,12 @@ impl From<(InvocationId, EntryIndex)> for JournalEntryId { } } +impl WithPartitionKey for JournalEntryId { + fn partition_key(&self) -> PartitionKey { + self.invocation_id.partition_key() + } +} + impl WithInvocationId for JournalEntryId { fn invocation_id(&self) -> InvocationId { self.invocation_id diff --git a/crates/types/src/invocation/mod.rs b/crates/types/src/invocation/mod.rs index 84b528ea44..7b1252f6b2 100644 --- a/crates/types/src/invocation/mod.rs +++ b/crates/types/src/invocation/mod.rs @@ -427,6 +427,12 @@ impl InvocationRequest { } } +impl WithPartitionKey for InvocationRequest { + fn partition_key(&self) -> PartitionKey { + self.header.invocation_id().partition_key() + } +} + impl WithInvocationId for InvocationRequest { fn invocation_id(&self) -> InvocationId { self.header.invocation_id() @@ -590,6 +596,12 @@ impl JournalCompletionTarget { } } +impl WithPartitionKey for JournalCompletionTarget { + fn partition_key(&self) -> PartitionKey { + self.caller_id.partition_key() + } +} + impl WithInvocationId for JournalCompletionTarget { fn invocation_id(&self) -> InvocationId { self.caller_id @@ -607,6 +619,12 @@ pub struct InvocationResponse { pub result: ResponseResult, } +impl WithPartitionKey for InvocationResponse { + fn partition_key(&self) -> PartitionKey { + self.target.invocation_id().partition_key() + } +} + impl WithInvocationId for InvocationResponse { fn invocation_id(&self) -> InvocationId { self.target.invocation_id() @@ -1361,6 +1379,12 @@ pub struct NotifySignalRequest { pub signal: Signal, } +impl WithPartitionKey for NotifySignalRequest { + fn partition_key(&self) -> PartitionKey { + self.invocation_id.partition_key() + } +} + impl WithInvocationId for NotifySignalRequest { fn invocation_id(&self) -> InvocationId { self.invocation_id diff --git a/crates/types/src/lib.rs b/crates/types/src/lib.rs index 3b68d0354d..8df2b343a6 100644 --- a/crates/types/src/lib.rs +++ b/crates/types/src/lib.rs @@ -81,6 +81,11 @@ pub mod memory { pub use restate_memory::*; } +// Re-export restate-sharding crate for key range utilities. +pub mod sharding { + pub use restate_sharding::*; +} + // Re-export metrics' SharedString (Space-efficient Cow + RefCounted variant) pub type SharedString = metrics::SharedString; diff --git a/crates/wal-protocol/src/lib.rs b/crates/wal-protocol/src/lib.rs index 0754f5ab64..0d5ccf96d2 100644 --- a/crates/wal-protocol/src/lib.rs +++ b/crates/wal-protocol/src/lib.rs @@ -239,15 +239,17 @@ impl HasRecordKeys for Envelope { Command::TruncateOutbox(_) => Keys::Single(self.partition_key()), Command::ProxyThrough(_) => Keys::Single(self.partition_key()), Command::AttachInvocation(_) => Keys::Single(self.partition_key()), - Command::ResumeInvocation(req) => Keys::Single(req.partition_key()), - Command::RestartAsNewInvocation(req) => Keys::Single(req.partition_key()), + Command::ResumeInvocation(req) => Keys::Single(req.invocation_id.partition_key()), + Command::RestartAsNewInvocation(req) => Keys::Single(req.invocation_id.partition_key()), // todo: Handle journal entries that request cross-partition invocations Command::InvokerEffect(effect) => Keys::Single(effect.invocation_id.partition_key()), Command::Timer(timer) => Keys::Single(timer.invocation_id().partition_key()), Command::ScheduleTimer(timer) => Keys::Single(timer.invocation_id().partition_key()), Command::InvocationResponse(response) => Keys::Single(response.partition_key()), Command::NotifySignal(sig) => Keys::Single(sig.partition_key()), - Command::NotifyGetInvocationOutputResponse(res) => Keys::Single(res.partition_key()), + Command::NotifyGetInvocationOutputResponse(res) => { + Keys::Single(res.target.partition_key()) + } Command::UpsertSchema(schema) => schema.partition_key_range.clone(), Command::VQSchedulerDecisions(_) => Keys::Single(self.partition_key()), }