Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions allways/chain_providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def verify_transaction(

return tx_info

@abstractmethod
def get_current_block_height(self) -> Optional[int]:
"""Chain tip block height. None on transient backend failure."""
...

@abstractmethod
def get_balance(self, address: str) -> int: ...

Expand Down
19 changes: 19 additions & 0 deletions allways/chain_providers/bitcoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ def api_verify_transaction(
bt.logging.error(f'Esplora tx lookup failed for {tx_hash}: {e}')
return None

def get_current_block_height(self) -> Optional[int]:
"""Bitcoin chain tip via RPC with Esplora fallback. None on failure."""
if self.mode == 'node':
result = self.rpc_call('getblockcount', [])
if result is not None:
try:
return int(result)
except (TypeError, ValueError):
pass
try:
resp = self.btc_api_get('/blocks/tip/height', timeout=10)
if resp.ok:
return int(resp.text.strip())
except (requests.ConnectionError, requests.Timeout) as e:
bt.logging.debug(f'BTC get_current_block_height: Esplora unreachable ({e})')
except Exception as e:
bt.logging.debug(f'BTC get_current_block_height failed: {e}')
return None

def get_balance(self, address: str) -> int:
"""Get balance for a Bitcoin address in satoshis via RPC with Esplora fallback."""
result = self.rpc_call('getreceivedbyaddress', [address, 0])
Expand Down
7 changes: 7 additions & 0 deletions allways/chain_providers/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ def match_transfer(ext, tx_hash: str, is_raw: bool) -> Optional[Tuple[str, int,

return dest, amount, sender

def get_current_block_height(self) -> Optional[int]:
try:
return int(self.subtensor.get_current_block())
except Exception as e:
bt.logging.debug(f'TAO get_current_block_height failed: {e}')
return None

def get_balance(self, address: str) -> int:
"""Get balance for a TAO address in rao."""
try:
Expand Down
66 changes: 58 additions & 8 deletions allways/validator/chain_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class SwapVerifier:

Rate and miner source address are stored on the swap struct at initiation,
so verification is self-contained — no commitment lookup needed.

Dest-tx replay defense: snapshots the dest chain's tip on first sighting
of a swap and rejects later dest txs whose block predates the snapshot —
a validator-side stand-in for a contract-level ``used_to_tx`` mirror.
"""

def __init__(
Expand All @@ -30,6 +34,35 @@ def __init__(
self.metagraph = metagraph
self.last_logged_confs: Dict[str, int] = {} # swap_id:chain -> confs
self.source_verified_ids: Set[int] = set() # source tx is final once confirmed
self.dest_tip_at_init: Dict[int, int] = {} # swap_id -> dest tip at first sighting (non-TAO only)

def observe_initiation(self, swap: Swap) -> None:
"""Snapshot the dest chain's tip on first sighting of a non-TAO swap.
Idempotent; fails open with a one-time warning on RPC error."""
if swap.to_chain == 'tao' or swap.id in self.dest_tip_at_init:
return
provider = self.providers.get(swap.to_chain)
if provider is None:
return
# Broad except (vs verify_tx's re-raise of ProviderUnreachableError):
# this runs inside a forward-loop iteration and must not break it.
try:
tip = provider.get_current_block_height()
except Exception:
tip = None
if tip and tip > 0:
self.dest_tip_at_init[swap.id] = tip
else:
log_on_change(
f'snapshot_unavailable:{swap.id}',
True,
f'{self._label(swap)}: dest-tip snapshot failed on {swap.to_chain} — replay defense off until retry',
)

def prune_to_active(self, active_ids: Set[int]) -> None:
"""Drop per-swap state for swaps no longer being tracked."""
self.dest_tip_at_init = {sid: v for sid, v in self.dest_tip_at_init.items() if sid in active_ids}
self.source_verified_ids &= active_ids

def _label(self, swap: Swap) -> str:
return _swap_label(swap, self.metagraph)
Expand All @@ -43,7 +76,7 @@ def verify_tx(
expected_amount: int,
block_hint: int = 0,
expected_sender: str = '',
) -> bool:
) -> Optional[TransactionInfo]:
"""Verify a confirmed transaction on a specific chain.

Defers tx lookup, amount, and sender checks to the provider's
Expand All @@ -54,11 +87,11 @@ def verify_tx(
provider = self.providers.get(chain)
if not provider:
bt.logging.warning(f'{self._label(swap)}: no provider for chain {chain}')
return False
return None

if not tx_hash:
bt.logging.debug(f'{self._label(swap)}: empty tx_hash for {chain}, skipping verification')
return False
return None

try:
tx_info = provider.verify_transaction(
Expand All @@ -73,16 +106,31 @@ def verify_tx(
f'{self._label(swap)}: verify_transaction returned None on {chain} '
f'(tx={tx_hash[:16]}... block_hint={block_hint})'
)
return False
return None
if not tx_info.confirmed:
self.log_confs_progress(swap, chain, tx_hash, tx_info, expected_recipient, expected_amount)
return False
return True
return None
return tx_info
except ProviderUnreachableError:
raise
except Exception as e:
bt.logging.error(f'{self._label(swap)}: verification error on {chain}: {e}')
return None

def is_dest_tx_fresh(self, swap: Swap, dest_info: TransactionInfo) -> bool:
"""Reject a dest tx mined before the swap was initiated (replay defense)."""
if dest_info.block_number is None:
return True
lower = swap.initiated_block if swap.to_chain == 'tao' else self.dest_tip_at_init.get(swap.id)
if lower is None:
return True # fail-open; observe_initiation already logged
if dest_info.block_number < lower:
bt.logging.warning(
f'{self._label(swap)}: dest tx at block {dest_info.block_number} < initiated {lower} — '
f'rejecting as replay (tx={swap.to_tx_hash[:16]}...)'
)
return False
return True

def log_confs_progress(
self,
Expand Down Expand Up @@ -124,7 +172,7 @@ async def verify_miner_fulfillment(self, swap: Swap) -> bool:
if swap.id in self.source_verified_ids:
source_ok = True
else:
source_ok = await asyncio.to_thread(
source_info = await asyncio.to_thread(
self.verify_tx,
swap,
swap.from_chain,
Expand All @@ -133,10 +181,11 @@ async def verify_miner_fulfillment(self, swap: Swap) -> bool:
swap.from_amount,
swap.from_tx_block,
)
source_ok = source_info is not None
if source_ok:
self.source_verified_ids.add(swap.id)

dest_ok = await asyncio.to_thread(
dest_info = await asyncio.to_thread(
self.verify_tx,
swap,
swap.to_chain,
Expand All @@ -146,6 +195,7 @@ async def verify_miner_fulfillment(self, swap: Swap) -> bool:
swap.to_tx_block,
swap.miner_to_address,
)
dest_ok = dest_info is not None and self.is_dest_tx_fresh(swap, dest_info)

if source_ok != dest_ok:
log_on_change(
Expand Down
5 changes: 5 additions & 0 deletions allways/validator/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ async def forward(self: Validator) -> None:
await tracker.poll()
bt.logging.info('forward: tracker polled')

# Snapshot dest-chain tip on first sighting for the dest-tx replay defense.
for swap in tracker.active.values():
verifier.observe_initiation(swap)
verifier.prune_to_active(set(tracker.active.keys()))

# Verify FULFILLED swaps end-to-end and vote confirm_swap. The returned
# set is swap IDs where the provider was unreachable this cycle, so the
# timeout phase knows to skip them (transient outage shouldn't slash).
Expand Down
142 changes: 142 additions & 0 deletions tests/test_chain_verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Recency-based replay defense for miner-supplied dest tx hashes.

Locks down the dest-chain tip snapshot taken at swap observation and the
single comparison used to reject a dest tx mined before its swap was
initiated. Closes the gap left by the contract enforcing ``used_from_tx``
only on the source side.
"""

from unittest.mock import MagicMock

from allways.chain_providers.base import TransactionInfo
from allways.classes import Swap, SwapStatus
from allways.validator.chain_verification import SwapVerifier


def make_swap(swap_id: int = 1, to_chain: str = 'btc', initiated_block: int = 100) -> Swap:
return Swap(
id=swap_id,
user_hotkey='user',
miner_hotkey='miner',
from_chain='tao' if to_chain == 'btc' else 'btc',
to_chain=to_chain,
from_amount=1,
to_amount=1,
tao_amount=1,
user_from_address='from',
user_to_address='to',
miner_from_address='miner-from',
miner_to_address='miner-to',
rate='100',
to_tx_hash='dest-hash',
status=SwapStatus.FULFILLED,
initiated_block=initiated_block,
)


def tx_at(block_number) -> TransactionInfo:
return TransactionInfo(
tx_hash='dest-hash',
confirmed=True,
sender='miner-to',
recipient='to',
amount=1,
block_number=block_number,
confirmations=10,
)


class TestObserveInitiation:
def test_snapshots_observed_tip(self):
btc = MagicMock()
btc.get_current_block_height.return_value = 850_000
v = SwapVerifier(chain_providers={'btc': btc})

v.observe_initiation(make_swap(swap_id=1, to_chain='btc'))

assert v.dest_tip_at_init[1] == 850_000

def test_idempotent(self):
btc = MagicMock()
btc.get_current_block_height.return_value = 850_000
v = SwapVerifier(chain_providers={'btc': btc})

v.observe_initiation(make_swap(swap_id=1, to_chain='btc'))
v.observe_initiation(make_swap(swap_id=1, to_chain='btc'))

btc.get_current_block_height.assert_called_once()

def test_tao_dest_is_noop(self):
btc = MagicMock()
v = SwapVerifier(chain_providers={'btc': btc})

v.observe_initiation(make_swap(swap_id=5, to_chain='tao'))

assert 5 not in v.dest_tip_at_init
btc.get_current_block_height.assert_not_called()

def test_failed_snapshot_leaves_no_entry_so_retry_is_possible(self):
btc = MagicMock()
btc.get_current_block_height.return_value = None
v = SwapVerifier(chain_providers={'btc': btc})

v.observe_initiation(make_swap(swap_id=7, to_chain='btc'))

assert 7 not in v.dest_tip_at_init

# Next forward step the RPC recovers — snapshot is captured.
btc.get_current_block_height.return_value = 850_500
v.observe_initiation(make_swap(swap_id=7, to_chain='btc'))

assert v.dest_tip_at_init[7] == 850_500

def test_rpc_raises_treated_as_failure(self):
btc = MagicMock()
btc.get_current_block_height.side_effect = RuntimeError('boom')
v = SwapVerifier(chain_providers={'btc': btc})

v.observe_initiation(make_swap(swap_id=9, to_chain='btc'))

assert 9 not in v.dest_tip_at_init


class TestIsDestTxFresh:
def test_tao_accepts_initiation_block_and_rejects_earlier(self):
v = SwapVerifier(chain_providers={})
swap = make_swap(to_chain='tao', initiated_block=100)
assert v.is_dest_tx_fresh(swap, tx_at(100)) is True
assert v.is_dest_tx_fresh(swap, tx_at(99)) is False

def test_btc_accepts_at_snapshot_rejects_older_replay(self):
v = SwapVerifier(chain_providers={})
v.dest_tip_at_init[1] = 850_000
swap = make_swap(swap_id=1, to_chain='btc')

assert v.is_dest_tx_fresh(swap, tx_at(850_000)) is True
assert v.is_dest_tx_fresh(swap, tx_at(849_500)) is False

def test_failopen_when_no_snapshot(self):
v = SwapVerifier(chain_providers={})
swap = make_swap(swap_id=1, to_chain='btc')
# Even an obviously old tx is accepted — defense disabled for this swap.
assert v.is_dest_tx_fresh(swap, tx_at(1)) is True

def test_missing_block_number_passes(self):
v = SwapVerifier(chain_providers={})
v.dest_tip_at_init[1] = 850_000
swap = make_swap(swap_id=1, to_chain='btc')
info = tx_at(850_000)
info.block_number = None
assert v.is_dest_tx_fresh(swap, info) is True


class TestPruneToActive:
def test_drops_inactive_swaps(self):
v = SwapVerifier(chain_providers={})
v.dest_tip_at_init = {1: 100, 2: 200, 3: 300}
v.source_verified_ids = {1, 2, 3}

v.prune_to_active({2})

assert v.dest_tip_at_init == {2: 200}
assert v.source_verified_ids == {2}
Loading