Skip to content
81 changes: 79 additions & 2 deletions crates/rpc-client/src/poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use serde::Serialize;
use serde_json::value::RawValue;
use std::{
borrow::Cow,
fmt::Debug,
marker::PhantomData,
ops::{Deref, DerefMut},
pin::Pin,
Expand All @@ -22,6 +23,18 @@ use wasmtimer::tokio::{sleep, Sleep};
#[cfg(not(target_family = "wasm"))]
use tokio::time::{sleep, Sleep};

/// A function that creates new parameters when reconnection is needed
type ReconnectFn = Box<
dyn Fn(
WeakClient,
) -> alloy_transport::Pbf<
'static,
Box<RawValue>,
alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
> + Send
+ Sync,
>;

/// A poller task builder.
///
/// This builder is used to create a poller task that repeatedly polls a method on a client and
Expand Down Expand Up @@ -58,7 +71,6 @@ use tokio::time::{sleep, Sleep};
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
#[must_use = "this builder does nothing unless you call `spawn` or `into_stream`"]
pub struct PollerBuilder<Params, Resp> {
/// The client to poll with.
Expand All @@ -67,6 +79,7 @@ pub struct PollerBuilder<Params, Resp> {
/// Request Method
method: Cow<'static, str>,
params: Params,
reconnect_fn: Option<ReconnectFn>,

// config options
channel_size: usize,
Expand All @@ -76,6 +89,14 @@ pub struct PollerBuilder<Params, Resp> {
_pd: PhantomData<fn() -> Resp>,
}

impl<Params, Resp> std::fmt::Debug for PollerBuilder<Params, Resp> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PollerBuilder")
.field("reconnect_factory", &self.reconnect_fn.as_ref().map(|_| "<function>"))
.finish_non_exhaustive()
}
}

impl<Params, Resp> PollerBuilder<Params, Resp>
where
Params: RpcSend + 'static,
Expand All @@ -87,6 +108,7 @@ where
client.upgrade().map_or_else(|| Duration::from_secs(7), |c| c.poll_interval());
Self {
client,
reconnect_fn: None,
method: method.into(),
params,
channel_size: 16,
Expand All @@ -112,6 +134,30 @@ where
self
}

/// Sets the reconnect function which updates the poller params and runs when filter drops.
pub fn with_reconnect<F, Fut>(mut self, reconnect_fn: F) -> Self
where
F: Fn(WeakClient) -> Fut + Send + Sync + 'static,
Fut: Future<
Output = Result<
Box<RawValue>,
alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
>,
> + Send
+ 'static,
{
let boxed_reconnect_fn: ReconnectFn = Box::new(move |client| {
Box::pin(reconnect_fn(client))
as alloy_transport::Pbf<
'static,
Box<RawValue>,
alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
>
});
self.reconnect_fn = Some(boxed_reconnect_fn);
self
}

/// Returns the limit on the number of successful polls.
pub const fn limit(&self) -> usize {
self.limit
Expand Down Expand Up @@ -172,6 +218,7 @@ where
/// Note that this does not spawn the poller on a separate task, thus all responses will be
/// polled on the current thread once this stream is polled.
pub fn into_stream(self) -> PollerStream<Resp> {
println!("STtrEAMINg STARting");
PollerStream::new(self)
}

Expand All @@ -195,6 +242,14 @@ enum PollState<Resp> {
alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
>,
),
/// Attempting to reconnect (re-establish filter)
Reconnecting(
alloy_transport::Pbf<
'static,
Box<RawValue>,
alloy_transport::RpcError<alloy_transport::TransportErrorKind>,
>,
),
/// Sleeping between polls.
Sleeping(Pin<Box<Sleep>>),

Expand Down Expand Up @@ -232,6 +287,7 @@ pub struct PollerStream<Resp, Output = Resp, Map = fn(Resp) -> Output> {
method: Cow<'static, str>,
params: Box<RawValue>,
poll_interval: Duration,
reconnect_fn: Option<ReconnectFn>,
limit: usize,
poll_count: usize,
state: PollState<Resp>,
Expand All @@ -247,12 +303,13 @@ impl<Resp, Output, Map> std::fmt::Debug for PollerStream<Resp, Output, Map> {
.field("poll_interval", &self.poll_interval)
.field("limit", &self.limit)
.field("poll_count", &self.poll_count)
.field("reconnect_factory", &self.reconnect_fn.as_ref().map(|_| "<function>"))
.finish_non_exhaustive()
}
}

impl<Resp> PollerStream<Resp> {
fn new<Params: Serialize>(builder: PollerBuilder<Params, Resp>) -> Self {
fn new<Params: Serialize + RpcSend + 'static>(builder: PollerBuilder<Params, Resp>) -> Self {
let span = debug_span!("poller", method = %builder.method);

// Serialize params once
Expand All @@ -267,6 +324,7 @@ impl<Resp> PollerStream<Resp> {
method: builder.method,
params,
poll_interval: builder.poll_interval,
reconnect_fn: builder.reconnect_fn,
limit: builder.limit,
poll_count: 0,
state: PollState::Waiting,
Expand Down Expand Up @@ -316,6 +374,7 @@ where
poll_count: self.poll_count,
state: self.state,
span: self.span,
reconnect_fn: self.reconnect_fn,
map,
_pd: PhantomData,
}
Expand All @@ -335,6 +394,16 @@ where

loop {
match &mut this.state {
PollState::Reconnecting(fut) => match ready!(fut.poll_unpin(cx)) {
Ok(resp) => {
this.params = resp;
this.state = PollState::Waiting;
}
Err(err) => {
error!("reconnect failed: {}", err);
this.state = PollState::Finished;
}
},
PollState::Paused => return Poll::Pending,
PollState::Waiting => {
// Check if we've reached the limit
Expand Down Expand Up @@ -375,9 +444,17 @@ where
// the poller. Error codes are not consistent
// across reth/geth/nethermind, so we check
// just the message.
// The rpc call that returns this is eth_getFilterChanges
if let Some(resp) = err.as_error_resp() {
if resp.message.contains("filter not found") {
warn!("server has dropped the filter, stopping poller");
if let Some(ref reconnect_fn) = this.reconnect_fn {
warn!("server has dropped the filter, attempting to reconnect");

let reconnect_fut = reconnect_fn(this.client.clone());
this.state = PollState::Reconnecting(reconnect_fut);
continue;
}
this.state = PollState::Finished;
continue;
}
Expand Down