diff --git a/apps/argus/src/adapters/contract.rs b/apps/argus/src/adapters/contract.rs index f1645caf1b..0a26bd8f4a 100644 --- a/apps/argus/src/adapters/contract.rs +++ b/apps/argus/src/adapters/contract.rs @@ -10,21 +10,39 @@ use std::collections::HashMap; #[async_trait] pub trait GetChainPrices { - async fn get_price_unsafe( + async fn get_all_prices_for_subscription( &self, subscription_id: SubscriptionId, - feed_id: &PriceId, - ) -> Result>; + ) -> Result>; + + async fn get_prices_for_subscription( + &self, + subscription_id: SubscriptionId, + price_ids: &Vec, + ) -> Result>; } #[async_trait] impl GetChainPrices for PythPulse { - async fn get_price_unsafe( + async fn get_all_prices_for_subscription( &self, - _subscription_id: SubscriptionId, - _feed_id: &PriceId, - ) -> Result> { - todo!() + subscription_id: SubscriptionId, + ) -> Result> { + let price_ids = self.get_prices_unsafe(subscription_id, vec![]).await?; + Ok(price_ids.into_iter().map(From::from).collect()) + } + async fn get_prices_for_subscription( + &self, + subscription_id: SubscriptionId, + price_ids: &Vec, + ) -> Result> { + let price_ids = self + .get_prices_unsafe( + subscription_id, + price_ids.into_iter().map(|id| id.to_bytes()).collect(), + ) + .await?; + Ok(price_ids.into_iter().map(From::from).collect()) } } #[async_trait] diff --git a/apps/argus/src/adapters/types.rs b/apps/argus/src/adapters/types.rs index 20d634c706..8ead074cf9 100644 --- a/apps/argus/src/adapters/types.rs +++ b/apps/argus/src/adapters/types.rs @@ -3,3 +3,20 @@ use pyth_sdk::PriceIdentifier; pub type PriceId = PriceIdentifier; pub type SubscriptionId = U256; + +use crate::adapters::ethereum::pyth_pulse::Price as ContractPrice; // ABI-generated Price +use pyth_sdk::Price as SdkPrice; // pyth_sdk::Price + +impl From for SdkPrice { + fn from(contract_price: ContractPrice) -> Self { + SdkPrice { + price: contract_price.price, + conf: contract_price.conf, + expo: contract_price.expo, + publish_time: contract_price + .publish_time + .try_into() + .expect("Failed to convert publish_time from U256 to i64 (UnixTimestamp)"), + } + } +} diff --git a/apps/argus/src/command/run.rs b/apps/argus/src/command/run.rs index ef7900bd9b..348e1f037f 100644 --- a/apps/argus/src/command/run.rs +++ b/apps/argus/src/command/run.rs @@ -142,6 +142,7 @@ pub async fn run_keeper_for_chain( contract.clone(), config.keeper.chain_price_poll_interval, state.chain_price_state.clone(), + state.subscription_state.clone(), ); let price_pusher_service = PricePusherService::new( @@ -157,6 +158,7 @@ pub async fn run_keeper_for_chain( state.subscription_state.clone(), state.pyth_price_state.clone(), state.chain_price_state.clone(), + price_pusher_service.request_sender(), ); let services: Vec> = vec![ diff --git a/apps/argus/src/services/chain_price_service.rs b/apps/argus/src/services/chain_price_service.rs index 959c70ed6c..3aa178bcbb 100644 --- a/apps/argus/src/services/chain_price_service.rs +++ b/apps/argus/src/services/chain_price_service.rs @@ -7,6 +7,8 @@ use anyhow::Result; use async_trait::async_trait; +use pyth_sdk::Price; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::watch; @@ -14,9 +16,11 @@ use tokio::time; use tracing; use crate::adapters::contract::GetChainPrices; +use crate::adapters::types::PriceId; use crate::services::Service; use crate::state::ChainName; use crate::state::ChainPriceState; +use crate::state::SubscriptionState; pub struct ChainPriceService { chain_name: ChainName, @@ -24,6 +28,7 @@ pub struct ChainPriceService { contract: Arc, poll_interval: Duration, chain_price_state: Arc, + subscription_state: Arc, } impl ChainPriceService { @@ -32,6 +37,7 @@ impl ChainPriceService { contract: Arc, poll_interval: Duration, chain_price_state: Arc, + subscription_state: Arc, ) -> Self { Self { chain_name: chain_name.clone(), @@ -39,17 +45,62 @@ impl ChainPriceService { contract, poll_interval, chain_price_state, + subscription_state, } } - async fn poll_prices(&self, state: Arc) { - let feed_ids = state.get_feed_ids(); + #[tracing::instrument(skip_all, fields(task = self.name, chain_name = self.chain_name))] + async fn poll_prices(&self) -> Result<()> { + // Get all active subscriptions + let subscriptions = self.subscription_state.get_subscriptions(); - tracing::debug!( - service = self.name, - feed_count = feed_ids.len(), - "Polled for on-chain price updates" - ); + // For each subscription, query the chain for the price of each feed + for item in subscriptions.iter() { + let subscription_id = item.key().clone(); + let subscription_params = item.value().clone(); + + // TODO: do this in parallel using tokio tasks? + let price_ids = subscription_params + .price_ids + .into_iter() + .map(|id| PriceId::new(id)) + .collect::>(); + + match self + .contract + .get_prices_for_subscription(subscription_id, &price_ids) + .await + { + Ok(prices) => { + let prices_map: HashMap = price_ids + .clone() + .into_iter() + .zip(prices.into_iter()) + .collect(); + + tracing::debug!( + price_ids = ?price_ids, + subscription_id = %subscription_id, + "Got prices for subscription" + ); + + // Store the latest price feeds for the subscription + self.chain_price_state + .update_prices(subscription_id, prices_map); + } + Err(e) => { + // If we failed to get prices for a subscription, we'll retry on the next poll interval. + // Continue to the next subscription. + tracing::error!( + subscription_id = %subscription_id, + error = %e, + "Failed to get prices for subscription" + ); + continue; + } + } + } + Ok(()) } } @@ -64,7 +115,13 @@ impl Service for ChainPriceService { loop { tokio::select! { _ = interval.tick() => { - self.poll_prices(self.chain_price_state.clone()).await; + if let Err(e) = self.poll_prices().await { + tracing::error!( + service = self.name, + error = %e, + "Failed to poll chain prices" + ); + } } _ = stop_rx.changed() => { if *stop_rx.borrow() { diff --git a/apps/argus/src/services/controller_service.rs b/apps/argus/src/services/controller_service.rs index c35427738c..c6d9e68acb 100644 --- a/apps/argus/src/services/controller_service.rs +++ b/apps/argus/src/services/controller_service.rs @@ -9,22 +9,26 @@ use anyhow::Result; use async_trait::async_trait; use std::sync::Arc; use std::time::Duration; -use tokio::sync::watch; +use tokio::sync::{mpsc, watch}; use tokio::time; use tracing; +use crate::adapters::ethereum::UpdateCriteria; use crate::adapters::types::{PriceId, SubscriptionId}; use crate::services::types::PushRequest; use crate::services::Service; use crate::state::ChainName; use crate::state::{ChainPriceState, PythPriceState, SubscriptionState}; +use pyth_sdk::Price; pub struct ControllerService { name: String, + chain_name: ChainName, update_interval: Duration, subscription_state: Arc, pyth_price_state: Arc, chain_price_state: Arc, + price_pusher_tx: mpsc::Sender, } impl ControllerService { @@ -34,17 +38,21 @@ impl ControllerService { subscription_state: Arc, pyth_price_state: Arc, chain_price_state: Arc, + price_pusher_tx: mpsc::Sender, ) -> Self { Self { name: format!("ControllerService-{}", chain_name), + chain_name, update_interval, subscription_state, pyth_price_state, chain_price_state, + price_pusher_tx, } } - async fn perform_update(&self) { + #[tracing::instrument(skip_all, fields(task = self.name, chain_name = self.chain_name))] + async fn perform_update(&self) -> Result<()> { let subscriptions = self.subscription_state.get_subscriptions(); tracing::debug!( @@ -53,29 +61,47 @@ impl ControllerService { "Checking subscriptions for updates" ); - for (sub_id, params) in subscriptions { - let mut _needs_update = false; - let mut feed_ids: Vec = Vec::new(); + for item in subscriptions.iter() { + let sub_id = item.key().clone(); + let params = item.value().clone(); + let price_ids: Vec = params + .price_ids + .iter() + .map(|id| PriceId::new(*id)) + .collect(); - for feed_id in ¶ms.price_ids { - let feed_id = PriceId::new(*feed_id); - let pyth_price = self.pyth_price_state.get_price(&feed_id); - let chain_price = self.chain_price_state.get_price(&feed_id); + // Check each feed until we find one that needs updating + for feed_id in price_ids.iter() { + let pyth_price_opt = self.pyth_price_state.get_price(&feed_id); + let chain_price_opt = self.chain_price_state.get_price(&sub_id, &feed_id); - if pyth_price.is_none() || chain_price.is_none() { + if pyth_price_opt.is_none() { + tracing::warn!("No Pyth price found for feed, skipping"); continue; } - feed_ids.push(feed_id); - } + let pyth_price = pyth_price_opt.as_ref().unwrap(); - if _needs_update && !feed_ids.is_empty() { - self.trigger_update(sub_id, feed_ids).await; + if needs_update( + pyth_price, + chain_price_opt.as_ref(), + ¶ms.update_criteria, + ) { + // If any feed needs updating, trigger update for all feeds in this subscription + // and move on to next subscription + self.trigger_update(sub_id, price_ids.clone()).await?; + break; + } } } + Ok(()) } - async fn trigger_update(&self, subscription_id: SubscriptionId, price_ids: Vec) { + async fn trigger_update( + &self, + subscription_id: SubscriptionId, + price_ids: Vec, + ) -> Result<()> { tracing::info!( service = self.name, subscription_id = subscription_id.to_string(), @@ -83,7 +109,7 @@ impl ControllerService { "Triggering price update" ); - let _request = PushRequest { + let request = PushRequest { subscription_id, price_ids, }; @@ -93,6 +119,10 @@ impl ControllerService { "Would push update for subscription {}", subscription_id ); + + self.price_pusher_tx.send(request).await?; + + Ok(()) } } @@ -108,7 +138,13 @@ impl Service for ControllerService { loop { tokio::select! { _ = interval.tick() => { - self.perform_update().await; + if let Err(e) = self.perform_update().await { + tracing::error!( + service = self.name, + error = %e, + "Failed to perform price update" + ); + } } _ = stop_rx.changed() => { if *stop_rx.borrow() { @@ -125,3 +161,595 @@ impl Service for ControllerService { Ok(()) } } + +/// Determines if an on-chain price update is needed based on the latest Pyth price, +/// the current on-chain price (if available), and the subscription's update criteria. +#[tracing::instrument()] +fn needs_update( + pyth_price: &Price, + chain_price_opt: Option<&Price>, + update_criteria: &UpdateCriteria, +) -> bool { + // If there's no price currently on the chain for this feed, an update is always needed. + let chain_price = match chain_price_opt { + None => { + tracing::debug!("Update criteria met: No chain price available."); + return true; + } + Some(cp) => cp, + }; + + // 1. Heartbeat Check: + // Updates if `update_on_heartbeat` is enabled and the Pyth price is newer than or equal to + // the chain price plus `heartbeat_seconds`. + if update_criteria.update_on_heartbeat { + if pyth_price.publish_time + >= chain_price.publish_time + (update_criteria.heartbeat_seconds as i64) + { + tracing::debug!( + "Heartbeat criteria met: Pyth price is sufficiently newer or same age with met delta." + ); + return true; + } + } + + // 2. Deviation Check: + // If `update_on_deviation` is enabled, checks if the Pyth price has deviated from the chain price + // by more than `deviation_threshold_bps`. + // Example: If chain_price is 100 and deviation_threshold_bps is 50 (0.5%), + // then threshold_value = (100 * 50) / 10000 = 0.5 + // This means a price difference of more than 0.5 would trigger an update + if update_criteria.update_on_deviation { + // Critical assumption: The `expo` fields of `pyth_price` and `chain_price` are identical, + // since we directly compare the `price` fields. + if chain_price.price == 0 { + if pyth_price.price != 0 { + tracing::debug!( + "Deviation criteria met: Chain price is 0, Pyth price is non-zero." + ); + return true; + } + } else { + let price_diff = pyth_price.price.abs_diff(chain_price.price); + let threshold_val = (chain_price.price.abs() as u64 + * update_criteria.deviation_threshold_bps as u64) + / 10000; + + if price_diff > threshold_val { + tracing::debug!( + abs_price_diff = price_diff, + threshold = threshold_val, + "Deviation criteria met: Price difference exceeds threshold." + ); + return true; + } + } + } + + // If neither heartbeat nor deviation criteria were met. + tracing::debug!("No update criteria met."); + false +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::adapters::ethereum::{SubscriptionParams, UpdateCriteria}; + use crate::state::{ChainPriceState, PythPriceState, SubscriptionState}; + use ethers::types::U256; + use pyth_sdk::{Price, PriceIdentifier}; + use std::collections::HashMap; + use std::time::Duration; + use tokio::sync::mpsc; + + /// Helper to create a mock Price + fn mock_price(price: i64, conf: u64, expo: i32, publish_time: i64) -> Price { + Price { + price, + conf, + expo, + publish_time, + } + } + + /// Helper to create mock UpdateCriteria + fn mock_criteria( + update_on_heartbeat: bool, + heartbeat_seconds: u32, + update_on_deviation: bool, + deviation_threshold_bps: u32, + ) -> UpdateCriteria { + UpdateCriteria { + update_on_heartbeat, + heartbeat_seconds, + update_on_deviation, + deviation_threshold_bps, + } + } + + /// Helper function to create a default SubscriptionParams for tests + fn mock_subscription_params( + price_ids_bytes: Vec<[u8; 32]>, + update_criteria: UpdateCriteria, + ) -> SubscriptionParams { + SubscriptionParams { + price_ids: price_ids_bytes, + update_criteria, + // Initialize all fields of SubscriptionParams + reader_whitelist: vec![], // Default to empty list + whitelist_enabled: false, // Default to false + is_active: true, // Default to true, important for processing in tests + is_permanent: false, // Default to false + } + } + + struct TestControllerSetup { + controller: ControllerService, + pyth_state: Arc, + chain_state: Arc, + sub_state: Arc, + push_request_rx: mpsc::Receiver, + } + + fn setup_test_controller() -> TestControllerSetup { + let sub_state = Arc::new(SubscriptionState::new()); + let pyth_state = Arc::new(PythPriceState::new()); + let chain_state = Arc::new(ChainPriceState::new()); + let (pusher_tx, push_request_rx) = mpsc::channel(10); // Small buffer for tests + + let controller = ControllerService::new( + "test_chain".to_string(), + Duration::from_millis(100), // Interval doesn't really matter for perform_update direct call + sub_state.clone(), + pyth_state.clone(), + chain_state.clone(), + pusher_tx, + ); + + TestControllerSetup { + controller, + pyth_state, + chain_state, + sub_state, + push_request_rx, + } + } + + #[tokio::test] + async fn test_perform_update_no_subscriptions() { + let TestControllerSetup { + controller, + mut push_request_rx, + .. + } = setup_test_controller(); + + controller + .perform_update() + .await + .expect("perform_update should not fail"); + + // Expect no requests to be sent + assert!( + push_request_rx.try_recv().is_err(), + "Should be no push requests if no subscriptions" + ); + } + + #[tokio::test] + async fn test_perform_update_subscription_no_feed_ids() { + let TestControllerSetup { + controller, + sub_state, + mut push_request_rx, + .. + } = setup_test_controller(); + + let sub_id = U256::from(1); + let criteria = mock_criteria(true, 60, true, 100); + let params = mock_subscription_params(vec![], criteria); // Empty feed_ids + + let mut subs_map = HashMap::new(); + subs_map.insert(sub_id, params); + sub_state.update_subscriptions(subs_map); + + controller + .perform_update() + .await + .expect("perform_update should not fail"); + assert!( + push_request_rx.try_recv().is_err(), + "Should be no push requests if subscription has no feed IDs" + ); + } + + #[tokio::test] + async fn test_perform_update_single_sub_single_feed_update_needed_heartbeat() { + let TestControllerSetup { + controller, + sub_state, + pyth_state, + chain_state, + mut push_request_rx, + } = setup_test_controller(); + + let sub_id = U256::from(123); + let feed_id_bytes = [1u8; 32]; + let feed_id = PriceIdentifier::new(feed_id_bytes); + let criteria = mock_criteria(true, 60, false, 0); // Heartbeat only + let params = mock_subscription_params(vec![feed_id_bytes], criteria); + + sub_state.update_subscriptions(HashMap::from([(sub_id, params)])); + pyth_state.update_price(feed_id, mock_price(100, 10, -2, 1000)); // Pyth price @ t=1000 + chain_state.update_price(&sub_id, feed_id, mock_price(100, 10, -2, 900)); // Chain price @ t=900 (1000 >= 900 + 60) + + controller + .perform_update() + .await + .expect("perform_update failed"); + + let request = push_request_rx + .recv() + .await + .expect("Should receive a PushRequest"); + assert_eq!(request.subscription_id, sub_id); + assert_eq!(request.price_ids.len(), 1); + assert_eq!(request.price_ids[0], feed_id); + assert!( + push_request_rx.try_recv().is_err(), + "Should be no more requests" + ); + } + + #[tokio::test] + async fn test_perform_update_single_sub_single_feed_update_needed_deviation() { + let TestControllerSetup { + controller, + sub_state, + pyth_state, + chain_state, + mut push_request_rx, + } = setup_test_controller(); + + let sub_id = U256::from(456); + let feed_id_bytes = [2u8; 32]; + let feed_id = PriceIdentifier::new(feed_id_bytes); + let criteria = mock_criteria(false, 0, true, 100); // Deviation only, 100 bps = 1% + let params = mock_subscription_params(vec![feed_id_bytes], criteria); + + sub_state.update_subscriptions(HashMap::from([(sub_id, params)])); + pyth_state.update_price(feed_id, mock_price(102, 10, -2, 1000)); // Pyth price 102 + chain_state.update_price(&sub_id, feed_id, mock_price(100, 10, -2, 1000)); // Chain price 100. Diff 2 > (100*100)/10000=1 + + controller + .perform_update() + .await + .expect("perform_update failed"); + + let request = push_request_rx + .recv() + .await + .expect("Should receive a PushRequest"); + assert_eq!(request.subscription_id, sub_id); + assert_eq!(request.price_ids.len(), 1); + assert_eq!(request.price_ids[0], feed_id); + assert!( + push_request_rx.try_recv().is_err(), + "Should be no more requests" + ); + } + + #[tokio::test] + async fn test_perform_update_single_sub_single_feed_no_update_needed() { + let TestControllerSetup { + controller, + sub_state, + pyth_state, + chain_state, + mut push_request_rx, + } = setup_test_controller(); + + let sub_id = U256::from(789); + let feed_id_bytes = [3u8; 32]; + let feed_id = PriceIdentifier::new(feed_id_bytes); + // Criteria: heartbeat 60s, deviation 100bps (1%) + let criteria = mock_criteria(true, 60, true, 100); + let params = mock_subscription_params(vec![feed_id_bytes], criteria); + + sub_state.update_subscriptions(HashMap::from([(sub_id, params)])); + // Pyth price: t=950, val=100. Chain price: t=900, val=100 + // Heartbeat not met: 950 is not >= 900 + 60 (960) + // Deviation not met: 100 vs 100 is 0 diff. + pyth_state.update_price(feed_id, mock_price(100, 10, -2, 950)); + chain_state.update_price(&sub_id, feed_id, mock_price(100, 10, -2, 900)); + + controller + .perform_update() + .await + .expect("perform_update failed"); + assert!( + push_request_rx.try_recv().is_err(), + "Should be no push requests if no update needed" + ); + } + + #[tokio::test] + async fn test_perform_update_single_sub_multiple_feeds_mixed_updates() { + let TestControllerSetup { + controller, + sub_state, + pyth_state, + chain_state, + mut push_request_rx, + } = setup_test_controller(); + + let sub_id = U256::from(111); + let feed1_bytes = [11u8; 32]; + let feed1_id = PriceIdentifier::new(feed1_bytes); + let feed2_bytes = [22u8; 32]; + let feed2_id = PriceIdentifier::new(feed2_bytes); + let feed3_bytes = [33u8; 32]; + let feed3_id = PriceIdentifier::new(feed3_bytes); + + let criteria = mock_criteria(true, 60, true, 100); // Heartbeat 60s, Dev 1% + let params = + mock_subscription_params(vec![feed1_bytes, feed2_bytes, feed3_bytes], criteria); + sub_state.update_subscriptions(HashMap::from([(sub_id, params)])); + + // Feed 1: Needs update (heartbeat) + pyth_state.update_price(feed1_id, mock_price(100, 10, -2, 1000)); + chain_state.update_price(&sub_id, feed1_id, mock_price(100, 10, -2, 900)); // 1000 >= 900 + 60 + + // Feed 2: Needs update (deviation) + pyth_state.update_price(feed2_id, mock_price(102, 10, -2, 950)); + chain_state.update_price(&sub_id, feed2_id, mock_price(100, 10, -2, 950)); // Diff 2 > 1 (1% of 100) + + // Feed 3: No update needed + pyth_state.update_price(feed3_id, mock_price(100, 10, -2, 950)); + chain_state.update_price(&sub_id, feed3_id, mock_price(100, 10, -2, 900)); // No criteria met + + controller + .perform_update() + .await + .expect("perform_update failed"); + + let request = push_request_rx + .recv() + .await + .expect("Should receive a PushRequest"); + assert_eq!(request.subscription_id, sub_id); + assert_eq!( + request.price_ids.len(), + 3, + "Expected all 3 feeds to be in the request as one or more needed an update" + ); + assert!(request.price_ids.contains(&feed1_id)); + assert!(request.price_ids.contains(&feed2_id)); + assert!(request.price_ids.contains(&feed3_id)); // Feed3 should now be included + assert!( + push_request_rx.try_recv().is_err(), + "Should be no more requests" + ); + } + + #[tokio::test] + async fn test_perform_update_no_pyth_price_for_feed() { + let TestControllerSetup { + controller, + sub_state, + pyth_state, + chain_state, + mut push_request_rx, + } = setup_test_controller(); + + let sub_id = U256::from(222); + let feed1_bytes = [44u8; 32]; // Pyth price will be missing for this one + let feed1_id = PriceIdentifier::new(feed1_bytes); + let feed2_bytes = [55u8; 32]; // This one will have Pyth price and need update + let feed2_id = PriceIdentifier::new(feed2_bytes); + + let criteria = mock_criteria(true, 60, false, 0); + let params = mock_subscription_params(vec![feed1_bytes, feed2_bytes], criteria); + sub_state.update_subscriptions(HashMap::from([(sub_id, params)])); + + // No Pyth price for feed1_id + // Pyth price for feed2_id, needs update by heartbeat + pyth_state.update_price(feed2_id, mock_price(100, 10, -2, 1000)); + chain_state.update_price(&sub_id, feed2_id, mock_price(100, 10, -2, 900)); + // Optionally set chain price for feed1 too, though it won't matter without Pyth price + chain_state.update_price(&sub_id, feed1_id, mock_price(200, 10, -2, 900)); + + controller + .perform_update() + .await + .expect("perform_update failed"); + + let request = push_request_rx + .recv() + .await + .expect("Should receive one PushRequest for feed2"); + assert_eq!(request.subscription_id, sub_id); + assert_eq!( + request.price_ids.len(), + 2, // Expecting both feeds from the subscription + "Expected all feeds from the subscription to be in the request as one needed an update" + ); + assert!(request.price_ids.contains(&feed1_id)); // The one with no pyth price initially + assert!(request.price_ids.contains(&feed2_id)); // The one that triggered the update + assert!( + push_request_rx.try_recv().is_err(), + "Should be no more requests" + ); + } + + #[tokio::test] + async fn test_perform_update_error_on_send() { + let TestControllerSetup { + controller, + sub_state, + pyth_state, + chain_state, + mut push_request_rx, + } = setup_test_controller(); + + let sub_id = U256::from(333); + let feed_id_bytes = [66u8; 32]; + let feed_id = PriceIdentifier::new(feed_id_bytes); + let criteria = mock_criteria(true, 60, false, 0); + let params = mock_subscription_params(vec![feed_id_bytes], criteria); + + sub_state.update_subscriptions(HashMap::from([(sub_id, params)])); + pyth_state.update_price(feed_id, mock_price(100, 10, -2, 1000)); + chain_state.update_price(&sub_id, feed_id, mock_price(100, 10, -2, 900)); + + // Close the receiver end to simulate a send error + push_request_rx.close(); + + let result = controller.perform_update().await; + assert!( + result.is_err(), + "perform_update should return an error if send fails" + ); + // Further assertions on the specific error type could be added if desired + // e.g., assert_matches!(result.unwrap_err().downcast_ref::>(), Some(_)); + } + + // ================================ + // UNIT TESTS FOR `needs_update` + // ================================ + + #[test] + fn test_needs_update_no_chain_price() { + let pyth_price = mock_price(100, 10, -2, 1000); + let criteria = mock_criteria(true, 60, true, 100); + assert!(needs_update(&pyth_price, None, &criteria)); + } + + #[test] + fn test_needs_update_heartbeat_triggered() { + let pyth_price = mock_price(100, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(true, 60, false, 0); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_heartbeat_not_triggered_too_soon() { + let pyth_price = mock_price(100, 10, -2, 950); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(true, 60, false, 0); + assert!(!needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_heartbeat_exact_time_triggered() { + let pyth_price = mock_price(100, 10, -2, 960); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(true, 60, false, 0); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_heartbeat_disabled() { + let pyth_price = mock_price(100, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(false, 60, false, 0); + assert!(!needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_deviation_triggered_positive_diff() { + let pyth_price = mock_price(105, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 1000); + let criteria = mock_criteria(false, 0, true, 100); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_deviation_triggered_negative_diff() { + let pyth_price = mock_price(95, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 1000); + let criteria = mock_criteria(false, 0, true, 100); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_deviation_not_triggered_within_threshold() { + let pyth_price = mock_price(100, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 1000); + let criteria = mock_criteria(false, 0, true, 100); + assert!(!needs_update(&pyth_price, Some(&chain_price), &criteria)); + + let pyth_price_slight_dev = mock_price(1005, 10, -3, 1000); + let chain_price_slight_dev = mock_price(1000, 10, -3, 1000); + let criteria_5_percent = mock_criteria(false, 0, true, 500); + assert!(!needs_update( + &pyth_price_slight_dev, + Some(&chain_price_slight_dev), + &criteria_5_percent + )); + } + + #[test] + fn test_needs_update_deviation_exact_threshold_not_triggered() { + let pyth_price = mock_price(101, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 1000); + let criteria = mock_criteria(false, 0, true, 100); + assert!(!needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_deviation_disabled() { + let pyth_price = mock_price(150, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 1000); + let criteria = mock_criteria(false, 0, false, 100); + assert!(!needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_deviation_chain_price_zero_pyth_nonzero() { + let pyth_price = mock_price(10, 10, -2, 1000); + let chain_price = mock_price(0, 0, -2, 1000); + let criteria = mock_criteria(false, 0, true, 1000); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_deviation_both_prices_zero() { + let pyth_price = mock_price(0, 0, -2, 1000); + let chain_price = mock_price(0, 0, -2, 1000); + let criteria = mock_criteria(false, 0, true, 100); + assert!(!needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_heartbeat_and_deviation_triggered() { + let pyth_price = mock_price(105, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(true, 60, true, 100); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_heartbeat_triggered_deviation_not_due_to_time() { + let pyth_price = mock_price(100, 10, -2, 1000); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(true, 60, true, 100); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_deviation_triggered_heartbeat_not_due_to_price() { + let pyth_price = mock_price(105, 10, -2, 950); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(true, 60, true, 100); + assert!(needs_update(&pyth_price, Some(&chain_price), &criteria)); + } + + #[test] + fn test_needs_update_neither_triggered() { + let pyth_price = mock_price(100, 10, -2, 950); + let chain_price = mock_price(100, 10, -2, 900); + let criteria = mock_criteria(true, 60, true, 100); + assert!(!needs_update(&pyth_price, Some(&chain_price), &criteria)); + } +} diff --git a/apps/argus/src/services/price_pusher_service.rs b/apps/argus/src/services/price_pusher_service.rs index b116ea005d..9d2bfbd5f7 100644 --- a/apps/argus/src/services/price_pusher_service.rs +++ b/apps/argus/src/services/price_pusher_service.rs @@ -56,12 +56,12 @@ impl PricePusherService { #[tracing::instrument( skip(self), fields( - name = "handle_request", task = self.name, + chain_name = self.chain_name, subscription_id = request.subscription_id.to_string() ) )] - async fn handle_request(&self, request: PushRequest) { + async fn handle_push_request(&self, request: PushRequest) { let price_ids = request.price_ids.clone(); match self.pyth_price_client.get_latest_prices(&price_ids).await { @@ -73,16 +73,12 @@ impl PricePusherService { { Ok(tx_hash) => { tracing::info!( - service = self.name, - subscription_id = request.subscription_id.to_string(), tx_hash = tx_hash.to_string(), "Successfully pushed price updates" ); } Err(e) => { tracing::error!( - service = self.name, - subscription_id = request.subscription_id.to_string(), error = %e, "Failed to push price updates" ); @@ -91,8 +87,6 @@ impl PricePusherService { } Err(e) => { tracing::error!( - service = self.name, - subscription_id = request.subscription_id.to_string(), error = %e, "Failed to get Pyth price update data" ); @@ -118,7 +112,7 @@ impl Service for PricePusherService { loop { tokio::select! { Some(request) = receiver.recv() => { - self.handle_request(request).await; + self.handle_push_request(request).await; } _ = exit_rx.changed() => { if *exit_rx.borrow() { diff --git a/apps/argus/src/services/subscription_service.rs b/apps/argus/src/services/subscription_service.rs index 8211d1f398..33efcc6c61 100644 --- a/apps/argus/src/services/subscription_service.rs +++ b/apps/argus/src/services/subscription_service.rs @@ -60,7 +60,6 @@ impl SubscriptionService { let feed_ids = self.subscription_state.get_feed_ids(); self.pyth_price_state.update_feed_ids(feed_ids.clone()); - self.chain_price_state.update_feed_ids(feed_ids); Ok(()) } diff --git a/apps/argus/src/state.rs b/apps/argus/src/state.rs index 6ed2ca87e6..b477b757e2 100644 --- a/apps/argus/src/state.rs +++ b/apps/argus/src/state.rs @@ -41,12 +41,8 @@ impl SubscriptionState { subscriptions: DashMap::new(), } } - - pub fn get_subscriptions(&self) -> HashMap { - self.subscriptions - .iter() - .map(|r| (*r.key(), r.value().clone())) - .collect() + pub fn get_subscriptions(&self) -> &DashMap { + &self.subscriptions } pub fn get_subscription(&self, id: &SubscriptionId) -> Option { @@ -115,40 +111,46 @@ impl PythPriceState { /// Stores the latest on-chain prices for a given set of price feeds. /// Updated by the ChainPriceService. pub struct ChainPriceState { - prices: DashMap, - feed_ids: DashMap, + subscription_feed_prices: DashMap>, } impl ChainPriceState { pub fn new() -> Self { Self { - prices: DashMap::new(), - feed_ids: DashMap::new(), + subscription_feed_prices: DashMap::new(), } } - pub fn get_price(&self, feed_id: &PriceId) -> Option { - self.prices.get(feed_id).map(|r| r.value().clone()) - } - - pub fn update_price(&self, feed_id: PriceId, price: Price) { - self.prices.insert(feed_id, price); - } - - pub fn update_prices(&self, prices: HashMap) { + pub fn get_price(&self, subscription_id: &SubscriptionId, feed_id: &PriceId) -> Option { + Some( + self.subscription_feed_prices + .get(subscription_id)? + .get(feed_id)? + .value() + .clone(), + ) + } + pub fn update_price(&self, subscription_id: &SubscriptionId, feed_id: PriceId, price: Price) { + let subscription_feeds = self + .subscription_feed_prices + .entry(subscription_id.clone()) + .or_insert_with(DashMap::new); + subscription_feeds.insert(feed_id, price); + } + + pub fn update_prices(&self, subscription_id: SubscriptionId, prices: HashMap) { + let subscription_map = self + .subscription_feed_prices + .entry(subscription_id) + .or_insert_with(DashMap::new); for (feed_id, price) in prices { - self.prices.insert(feed_id, price); - } - } - - pub fn update_feed_ids(&self, feed_ids: HashSet) { - self.feed_ids.clear(); - for feed_id in feed_ids { - self.feed_ids.insert(feed_id, ()); + subscription_map.insert(feed_id, price); } } - pub fn get_feed_ids(&self) -> HashSet { - self.feed_ids.iter().map(|r| *r.key()).collect() + pub fn get_subscription_feed_prices( + &self, + ) -> &DashMap> { + &self.subscription_feed_prices } }