diff --git a/allways/chain_providers/base.py b/allways/chain_providers/base.py index 1f72c23..0f12bed 100644 --- a/allways/chain_providers/base.py +++ b/allways/chain_providers/base.py @@ -121,6 +121,11 @@ def verify_transaction( return tx_info + @abstractmethod + def get_current_block_height(self) -> Optional[int]: + """Chain tip block height. None on transient backend failure.""" + ... + @abstractmethod def get_balance(self, address: str) -> int: ... diff --git a/allways/chain_providers/bitcoin.py b/allways/chain_providers/bitcoin.py index 84dc396..07689da 100644 --- a/allways/chain_providers/bitcoin.py +++ b/allways/chain_providers/bitcoin.py @@ -322,6 +322,25 @@ def api_verify_transaction( bt.logging.error(f'Esplora tx lookup failed for {tx_hash}: {e}') return None + def get_current_block_height(self) -> Optional[int]: + """Bitcoin chain tip via RPC with Esplora fallback. None on failure.""" + if self.mode == 'node': + result = self.rpc_call('getblockcount', []) + if result is not None: + try: + return int(result) + except (TypeError, ValueError): + pass + try: + resp = self.btc_api_get('/blocks/tip/height', timeout=10) + if resp.ok: + return int(resp.text.strip()) + except (requests.ConnectionError, requests.Timeout) as e: + bt.logging.debug(f'BTC get_current_block_height: Esplora unreachable ({e})') + except Exception as e: + bt.logging.debug(f'BTC get_current_block_height failed: {e}') + return None + def get_balance(self, address: str) -> int: """Get balance for a Bitcoin address in satoshis via RPC with Esplora fallback.""" result = self.rpc_call('getreceivedbyaddress', [address, 0]) diff --git a/allways/chain_providers/subtensor.py b/allways/chain_providers/subtensor.py index 7814a9e..aa915a1 100644 --- a/allways/chain_providers/subtensor.py +++ b/allways/chain_providers/subtensor.py @@ -272,6 +272,13 @@ def match_transfer(ext, tx_hash: str, is_raw: bool) -> Optional[Tuple[str, int, return dest, amount, sender + def get_current_block_height(self) -> Optional[int]: + try: + return int(self.subtensor.get_current_block()) + except Exception as e: + bt.logging.debug(f'TAO get_current_block_height failed: {e}') + return None + def get_balance(self, address: str) -> int: """Get balance for a TAO address in rao.""" try: diff --git a/allways/validator/chain_verification.py b/allways/validator/chain_verification.py index 4555aad..0b205fe 100644 --- a/allways/validator/chain_verification.py +++ b/allways/validator/chain_verification.py @@ -17,6 +17,10 @@ class SwapVerifier: Rate and miner source address are stored on the swap struct at initiation, so verification is self-contained — no commitment lookup needed. + + Dest-tx replay defense: snapshots the dest chain's tip on first sighting + of a swap and rejects later dest txs whose block predates the snapshot — + a validator-side stand-in for a contract-level ``used_to_tx`` mirror. """ def __init__( @@ -30,6 +34,35 @@ def __init__( self.metagraph = metagraph self.last_logged_confs: Dict[str, int] = {} # swap_id:chain -> confs self.source_verified_ids: Set[int] = set() # source tx is final once confirmed + self.dest_tip_at_init: Dict[int, int] = {} # swap_id -> dest tip at first sighting (non-TAO only) + + def observe_initiation(self, swap: Swap) -> None: + """Snapshot the dest chain's tip on first sighting of a non-TAO swap. + Idempotent; fails open with a one-time warning on RPC error.""" + if swap.to_chain == 'tao' or swap.id in self.dest_tip_at_init: + return + provider = self.providers.get(swap.to_chain) + if provider is None: + return + # Broad except (vs verify_tx's re-raise of ProviderUnreachableError): + # this runs inside a forward-loop iteration and must not break it. + try: + tip = provider.get_current_block_height() + except Exception: + tip = None + if tip and tip > 0: + self.dest_tip_at_init[swap.id] = tip + else: + log_on_change( + f'snapshot_unavailable:{swap.id}', + True, + f'{self._label(swap)}: dest-tip snapshot failed on {swap.to_chain} — replay defense off until retry', + ) + + def prune_to_active(self, active_ids: Set[int]) -> None: + """Drop per-swap state for swaps no longer being tracked.""" + self.dest_tip_at_init = {sid: v for sid, v in self.dest_tip_at_init.items() if sid in active_ids} + self.source_verified_ids &= active_ids def _label(self, swap: Swap) -> str: return _swap_label(swap, self.metagraph) @@ -43,7 +76,7 @@ def verify_tx( expected_amount: int, block_hint: int = 0, expected_sender: str = '', - ) -> bool: + ) -> Optional[TransactionInfo]: """Verify a confirmed transaction on a specific chain. Defers tx lookup, amount, and sender checks to the provider's @@ -54,11 +87,11 @@ def verify_tx( provider = self.providers.get(chain) if not provider: bt.logging.warning(f'{self._label(swap)}: no provider for chain {chain}') - return False + return None if not tx_hash: bt.logging.debug(f'{self._label(swap)}: empty tx_hash for {chain}, skipping verification') - return False + return None try: tx_info = provider.verify_transaction( @@ -73,16 +106,31 @@ def verify_tx( f'{self._label(swap)}: verify_transaction returned None on {chain} ' f'(tx={tx_hash[:16]}... block_hint={block_hint})' ) - return False + return None if not tx_info.confirmed: self.log_confs_progress(swap, chain, tx_hash, tx_info, expected_recipient, expected_amount) - return False - return True + return None + return tx_info except ProviderUnreachableError: raise except Exception as e: bt.logging.error(f'{self._label(swap)}: verification error on {chain}: {e}') + return None + + def is_dest_tx_fresh(self, swap: Swap, dest_info: TransactionInfo) -> bool: + """Reject a dest tx mined before the swap was initiated (replay defense).""" + if dest_info.block_number is None: + return True + lower = swap.initiated_block if swap.to_chain == 'tao' else self.dest_tip_at_init.get(swap.id) + if lower is None: + return True # fail-open; observe_initiation already logged + if dest_info.block_number < lower: + bt.logging.warning( + f'{self._label(swap)}: dest tx at block {dest_info.block_number} < initiated {lower} — ' + f'rejecting as replay (tx={swap.to_tx_hash[:16]}...)' + ) return False + return True def log_confs_progress( self, @@ -124,7 +172,7 @@ async def verify_miner_fulfillment(self, swap: Swap) -> bool: if swap.id in self.source_verified_ids: source_ok = True else: - source_ok = await asyncio.to_thread( + source_info = await asyncio.to_thread( self.verify_tx, swap, swap.from_chain, @@ -133,10 +181,11 @@ async def verify_miner_fulfillment(self, swap: Swap) -> bool: swap.from_amount, swap.from_tx_block, ) + source_ok = source_info is not None if source_ok: self.source_verified_ids.add(swap.id) - dest_ok = await asyncio.to_thread( + dest_info = await asyncio.to_thread( self.verify_tx, swap, swap.to_chain, @@ -146,6 +195,7 @@ async def verify_miner_fulfillment(self, swap: Swap) -> bool: swap.to_tx_block, swap.miner_to_address, ) + dest_ok = dest_info is not None and self.is_dest_tx_fresh(swap, dest_info) if source_ok != dest_ok: log_on_change( diff --git a/allways/validator/forward.py b/allways/validator/forward.py index 1fc7f47..8794574 100644 --- a/allways/validator/forward.py +++ b/allways/validator/forward.py @@ -86,6 +86,11 @@ async def forward(self: Validator) -> None: await tracker.poll() bt.logging.info('forward: tracker polled') + # Snapshot dest-chain tip on first sighting for the dest-tx replay defense. + for swap in tracker.active.values(): + verifier.observe_initiation(swap) + verifier.prune_to_active(set(tracker.active.keys())) + # Verify FULFILLED swaps end-to-end and vote confirm_swap. The returned # set is swap IDs where the provider was unreachable this cycle, so the # timeout phase knows to skip them (transient outage shouldn't slash). diff --git a/tests/test_chain_verification.py b/tests/test_chain_verification.py new file mode 100644 index 0000000..100de2b --- /dev/null +++ b/tests/test_chain_verification.py @@ -0,0 +1,142 @@ +"""Recency-based replay defense for miner-supplied dest tx hashes. + +Locks down the dest-chain tip snapshot taken at swap observation and the +single comparison used to reject a dest tx mined before its swap was +initiated. Closes the gap left by the contract enforcing ``used_from_tx`` +only on the source side. +""" + +from unittest.mock import MagicMock + +from allways.chain_providers.base import TransactionInfo +from allways.classes import Swap, SwapStatus +from allways.validator.chain_verification import SwapVerifier + + +def make_swap(swap_id: int = 1, to_chain: str = 'btc', initiated_block: int = 100) -> Swap: + return Swap( + id=swap_id, + user_hotkey='user', + miner_hotkey='miner', + from_chain='tao' if to_chain == 'btc' else 'btc', + to_chain=to_chain, + from_amount=1, + to_amount=1, + tao_amount=1, + user_from_address='from', + user_to_address='to', + miner_from_address='miner-from', + miner_to_address='miner-to', + rate='100', + to_tx_hash='dest-hash', + status=SwapStatus.FULFILLED, + initiated_block=initiated_block, + ) + + +def tx_at(block_number) -> TransactionInfo: + return TransactionInfo( + tx_hash='dest-hash', + confirmed=True, + sender='miner-to', + recipient='to', + amount=1, + block_number=block_number, + confirmations=10, + ) + + +class TestObserveInitiation: + def test_snapshots_observed_tip(self): + btc = MagicMock() + btc.get_current_block_height.return_value = 850_000 + v = SwapVerifier(chain_providers={'btc': btc}) + + v.observe_initiation(make_swap(swap_id=1, to_chain='btc')) + + assert v.dest_tip_at_init[1] == 850_000 + + def test_idempotent(self): + btc = MagicMock() + btc.get_current_block_height.return_value = 850_000 + v = SwapVerifier(chain_providers={'btc': btc}) + + v.observe_initiation(make_swap(swap_id=1, to_chain='btc')) + v.observe_initiation(make_swap(swap_id=1, to_chain='btc')) + + btc.get_current_block_height.assert_called_once() + + def test_tao_dest_is_noop(self): + btc = MagicMock() + v = SwapVerifier(chain_providers={'btc': btc}) + + v.observe_initiation(make_swap(swap_id=5, to_chain='tao')) + + assert 5 not in v.dest_tip_at_init + btc.get_current_block_height.assert_not_called() + + def test_failed_snapshot_leaves_no_entry_so_retry_is_possible(self): + btc = MagicMock() + btc.get_current_block_height.return_value = None + v = SwapVerifier(chain_providers={'btc': btc}) + + v.observe_initiation(make_swap(swap_id=7, to_chain='btc')) + + assert 7 not in v.dest_tip_at_init + + # Next forward step the RPC recovers — snapshot is captured. + btc.get_current_block_height.return_value = 850_500 + v.observe_initiation(make_swap(swap_id=7, to_chain='btc')) + + assert v.dest_tip_at_init[7] == 850_500 + + def test_rpc_raises_treated_as_failure(self): + btc = MagicMock() + btc.get_current_block_height.side_effect = RuntimeError('boom') + v = SwapVerifier(chain_providers={'btc': btc}) + + v.observe_initiation(make_swap(swap_id=9, to_chain='btc')) + + assert 9 not in v.dest_tip_at_init + + +class TestIsDestTxFresh: + def test_tao_accepts_initiation_block_and_rejects_earlier(self): + v = SwapVerifier(chain_providers={}) + swap = make_swap(to_chain='tao', initiated_block=100) + assert v.is_dest_tx_fresh(swap, tx_at(100)) is True + assert v.is_dest_tx_fresh(swap, tx_at(99)) is False + + def test_btc_accepts_at_snapshot_rejects_older_replay(self): + v = SwapVerifier(chain_providers={}) + v.dest_tip_at_init[1] = 850_000 + swap = make_swap(swap_id=1, to_chain='btc') + + assert v.is_dest_tx_fresh(swap, tx_at(850_000)) is True + assert v.is_dest_tx_fresh(swap, tx_at(849_500)) is False + + def test_failopen_when_no_snapshot(self): + v = SwapVerifier(chain_providers={}) + swap = make_swap(swap_id=1, to_chain='btc') + # Even an obviously old tx is accepted — defense disabled for this swap. + assert v.is_dest_tx_fresh(swap, tx_at(1)) is True + + def test_missing_block_number_passes(self): + v = SwapVerifier(chain_providers={}) + v.dest_tip_at_init[1] = 850_000 + swap = make_swap(swap_id=1, to_chain='btc') + info = tx_at(850_000) + info.block_number = None + assert v.is_dest_tx_fresh(swap, info) is True + + +class TestPruneToActive: + def test_drops_inactive_swaps(self): + v = SwapVerifier(chain_providers={}) + v.dest_tip_at_init = {1: 100, 2: 200, 3: 300} + v.source_verified_ids = {1, 2, 3} + + v.prune_to_active({2}) + + assert v.dest_tip_at_init == {2: 200} + assert v.source_verified_ids == {2}