Skip to content

Commit a36b7e7

Browse files
authored
Batch request and spec conformance (#14)
* feat: batch req and spec conformance * fix: check if resps is empty too * lint: clippy * fix: make tests not serial * fix: remove ipc from shared.rs * fix: remove import
1 parent 5415e94 commit a36b7e7

File tree

18 files changed

+654
-275
lines changed

18 files changed

+654
-275
lines changed

Cargo.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ tokio-util = { version = "0.7.13", optional = true, features = ["io"] }
3737
tokio-tungstenite = { version = "0.26.1", features = ["rustls-tls-webpki-roots"], optional = true }
3838
futures-util = { version = "0.3.31", optional = true }
3939

40+
[dev-dependencies]
41+
tempfile = "3.15.0"
42+
tracing-subscriber = "0.3.19"
43+
4044
[features]
4145
default = ["axum", "ws", "ipc"]
4246
axum = ["dep:axum"]
@@ -66,7 +70,3 @@ inherits = "dev"
6670
strip = true
6771
debug = false
6872
incremental = false
69-
70-
[dev-dependencies]
71-
tempfile = "3.15.0"
72-
tracing-subscriber = "0.3.19"

src/axum.rs

+13-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use crate::{
2-
types::{Request, Response},
3-
HandlerArgs,
4-
};
1+
use crate::types::{InboundData, Response};
52
use axum::{extract::FromRequest, response::IntoResponse};
63
use bytes::Bytes;
74
use std::{future::Future, pin::Pin};
@@ -18,20 +15,19 @@ where
1815
return Box::<str>::from(Response::parse_error()).into_response();
1916
};
2017

21-
let Ok(req) = Request::try_from(bytes) else {
22-
return Box::<str>::from(Response::parse_error()).into_response();
23-
};
24-
25-
let args = HandlerArgs {
26-
ctx: Default::default(),
27-
req,
28-
};
29-
30-
// Default handler ctx does not allow for notifications, which is
31-
// what we want over HTTP.
32-
let response = unwrap_infallible!(self.call_with_state(args, state).await);
18+
// If the inbound data is not currently parsable, we
19+
// send an empty one it to the router, as the router enforces
20+
// the specification.
21+
let req = InboundData::try_from(bytes).unwrap_or_default();
3322

34-
Box::<str>::from(response).into_response()
23+
if let Some(response) = self
24+
.call_batch_with_state(Default::default(), req, state)
25+
.await
26+
{
27+
Box::<str>::from(response).into_response()
28+
} else {
29+
().into_response()
30+
}
3531
})
3632
}
3733
}

src/lib.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ pub mod pubsub;
144144
pub use pubsub::ReadJsonStream;
145145

146146
mod routes;
147+
pub use routes::{BatchFuture, Handler, HandlerArgs, HandlerCtx, NotifyError, RouteFuture};
147148
pub(crate) use routes::{BoxedIntoRoute, ErasedIntoRoute, Method, Route};
148-
pub use routes::{Handler, HandlerArgs, HandlerCtx, NotifyError, RouteFuture};
149149

150150
mod router;
151151
pub use router::Router;
@@ -208,7 +208,8 @@ mod test {
208208
(),
209209
)
210210
.await
211-
.expect("infallible");
211+
.expect("infallible")
212+
.expect("request had ID, is not a notification");
212213

213214
assert_rv_eq(
214215
&res,
@@ -226,7 +227,8 @@ mod test {
226227
(),
227228
)
228229
.await
229-
.expect("infallible");
230+
.expect("infallible")
231+
.expect("request had ID, is not a notification");
230232

231233
assert_rv_eq(&res2, r#"{"jsonrpc":"2.0","id":1,"result":"{}"}"#);
232234
}

src/pubsub/shared.rs

+15-24
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ use core::fmt;
22

33
use crate::{
44
pubsub::{In, JsonSink, Listener, Out},
5-
types::Request,
6-
HandlerArgs,
5+
types::InboundData,
76
};
87
use serde_json::value::RawValue;
98
use tokio::{
@@ -193,31 +192,25 @@ where
193192
select! {
194193
biased;
195194
_ = write_task.closed() => {
196-
debug!("IpcWriteTask has gone away");
195+
debug!("WriteTask has gone away");
197196
break;
198197
}
199198
item = requests.next() => {
200199
let Some(item) = item else {
201-
trace!("IPC read stream has closed");
200+
trace!("inbound read stream has closed");
202201
break;
203202
};
204203

205-
let req = match Request::try_from(item) {
206-
Ok(req) => req,
207-
Err(err) => {
208-
tracing::warn!(%err, "inbound request is malformatted");
209-
continue
210-
}
211-
};
204+
// If the inbound data is not currently parsable, we
205+
// send an empty one it to the router, as the router
206+
// enforces the specification.
207+
let reqs = InboundData::try_from(item).unwrap_or_default();
212208

213-
let span = debug_span!("ipc request handling", id = req.id(), method = req.method());
209+
let span = debug_span!("pubsub request handling", reqs = reqs.len());
214210

215-
let args = HandlerArgs {
216-
ctx: write_task.clone().into(),
217-
req,
218-
};
211+
let ctx = write_task.clone().into();
219212

220-
let fut = router.handle_request(args);
213+
let fut = router.handle_request_batch(ctx, reqs);
221214
let write_task = write_task.clone();
222215

223216
// Acquiring the permit before spawning the task means that
@@ -232,16 +225,14 @@ where
232225
// Run the future in a new task.
233226
tokio::spawn(
234227
async move {
235-
// Run the request handler and serialize the
236-
// response.
237-
let rv = fut.await.expect("infallible");
238-
239228
// Send the response to the write task.
240229
// we don't care if the receiver has gone away,
241230
// as the task is done regardless.
242-
let _ = permit.send(
243-
rv
244-
);
231+
if let Some(rv) = fut.await {
232+
let _ = permit.send(
233+
rv
234+
);
235+
}
245236
}
246237
.instrument(span)
247238
);

src/router.rs

+48-22
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
//! JSON-RPC router.
22
33
use crate::{
4-
routes::{MakeErasedHandler, RouteFuture},
5-
BoxedIntoRoute, ErasedIntoRoute, Handler, HandlerArgs, Method, MethodId, RegistrationError,
6-
Route,
4+
routes::{BatchFuture, MakeErasedHandler, RouteFuture},
5+
types::InboundData,
6+
BoxedIntoRoute, ErasedIntoRoute, Handler, HandlerArgs, HandlerCtx, Method, MethodId,
7+
RegistrationError, Route,
78
};
89
use core::fmt;
910
use serde_json::value::RawValue;
@@ -193,7 +194,7 @@ where
193194
where
194195
T: Service<
195196
HandlerArgs,
196-
Response = Box<RawValue>,
197+
Response = Option<Box<RawValue>>,
197198
Error = Infallible,
198199
Future: Send + 'static,
199200
> + Clone
@@ -299,15 +300,35 @@ where
299300
/// This is a convenience method, primarily for testing. Use in production
300301
/// code is discouraged. Routers should not be left in incomplete states.
301302
pub fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
302-
let id = args.req.id_owned();
303-
let method = args.req.method();
303+
let id = args.req().id_owned();
304+
let method = args.req().method();
304305

305-
let span = debug_span!("Router::call_with_state", %method, %id);
306-
trace!(params = args.req.params());
306+
let span = debug_span!("Router::call_with_state", %method, ?id);
307+
trace!(params = args.req().params());
307308

308309
self.inner.call_with_state(args, state).with_span(span)
309310
}
310311

312+
/// Call a method on the router, without providing state.
313+
pub fn call_batch_with_state(
314+
&self,
315+
ctx: HandlerCtx,
316+
inbound: InboundData,
317+
state: S,
318+
) -> BatchFuture {
319+
let mut fut = BatchFuture::new_with_capacity(inbound.single(), inbound.len());
320+
// According to spec, non-parsable requests should still receive a
321+
// response.
322+
for req in inbound.iter() {
323+
let req = req.map(|req| {
324+
let args = HandlerArgs::new(ctx.clone(), req);
325+
self.call_with_state(args, state.clone())
326+
});
327+
fut.push_parse_result(req);
328+
}
329+
fut
330+
}
331+
311332
/// Nest this router into a new Axum router, with the specified path.
312333
#[cfg(feature = "axum")]
313334
pub fn into_axum(self, path: &str) -> axum::Router<S> {
@@ -316,22 +337,27 @@ where
316337
}
317338

318339
impl Router<()> {
319-
// /// Serve the router over a connection. This method returns a
320-
// /// [`ServerShutdown`], which will shut down the server when dropped.
321-
// ///
322-
// /// [`ServerShutdown`]: crate::pubsub::ServerShutdown
323-
// #[cfg(feature = "pubsub")]
324-
// pub async fn serve_pubsub<C: crate::pubsub::Connect>(
325-
// self,
326-
// connect: C,
327-
// ) -> Result<crate::pubsub::ServerShutdown, C::Error> {
328-
// connect.run(self).await
329-
// }
340+
/// Serve the router over a connection. This method returns a
341+
/// [`ServerShutdown`], which will shut down the server when dropped.
342+
///
343+
/// [`ServerShutdown`]: crate::pubsub::ServerShutdown
344+
#[cfg(feature = "pubsub")]
345+
pub async fn serve_pubsub<C: crate::pubsub::Connect>(
346+
self,
347+
connect: C,
348+
) -> Result<crate::pubsub::ServerShutdown, C::Error> {
349+
connect.serve(self).await
350+
}
330351

331352
/// Call a method on the router.
332353
pub fn handle_request(&self, args: HandlerArgs) -> RouteFuture {
333354
self.call_with_state(args, ())
334355
}
356+
357+
/// Call a batch of methods on the router.
358+
pub fn handle_request_batch(&self, ctx: HandlerCtx, batch: InboundData) -> BatchFuture {
359+
self.call_batch_with_state(ctx, batch, ())
360+
}
335361
}
336362

337363
impl<S> fmt::Debug for Router<S> {
@@ -341,7 +367,7 @@ impl<S> fmt::Debug for Router<S> {
341367
}
342368

343369
impl tower::Service<HandlerArgs> for Router<()> {
344-
type Response = Box<RawValue>;
370+
type Response = Option<Box<RawValue>>;
345371
type Error = Infallible;
346372
type Future = RouteFuture;
347373

@@ -355,7 +381,7 @@ impl tower::Service<HandlerArgs> for Router<()> {
355381
}
356382

357383
impl tower::Service<HandlerArgs> for &Router<()> {
358-
type Response = Box<RawValue>;
384+
type Response = Option<Box<RawValue>>;
359385
type Error = Infallible;
360386
type Future = RouteFuture;
361387

@@ -517,7 +543,7 @@ impl<S> RouterInner<S> {
517543
/// Call a method on the router, with the provided state.
518544
#[track_caller]
519545
pub(crate) fn call_with_state(&self, args: HandlerArgs, state: S) -> RouteFuture {
520-
let method = args.req.method();
546+
let method = args.req().method();
521547
self.method_by_name(method)
522548
.unwrap_or(&self.fallback)
523549
.call_with_state(args, state)

src/routes/ctx.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,24 @@ impl HandlerCtx {
7979
#[derive(Debug, Clone)]
8080
pub struct HandlerArgs {
8181
/// The handler context.
82-
pub ctx: HandlerCtx,
82+
pub(crate) ctx: HandlerCtx,
8383
/// The JSON-RPC request.
84-
pub req: Request,
84+
pub(crate) req: Request,
85+
}
86+
87+
impl HandlerArgs {
88+
/// Create new handler arguments.
89+
pub const fn new(ctx: HandlerCtx, req: Request) -> Self {
90+
Self { ctx, req }
91+
}
92+
93+
/// Get a reference to the handler context.
94+
pub const fn ctx(&self) -> &HandlerCtx {
95+
&self.ctx
96+
}
97+
98+
/// Get a reference to the JSON-RPC request.
99+
pub const fn req(&self) -> &Request {
100+
&self.req
101+
}
85102
}

0 commit comments

Comments
 (0)