From 12a21653b2dabe0243fb7cb7e4b3141961d8c9f5 Mon Sep 17 00:00:00 2001 From: Arondondon Date: Thu, 26 Sep 2024 19:11:52 +0300 Subject: [PATCH 1/2] Implemented channels caching mechanism. --- snet/sdk/__init__.py | 6 +- snet/sdk/mpe/mpe_contract.py | 3 - snet/sdk/mpe/payment_channel_provider.py | 130 +++++++++++++++++------ snet/sdk/service_client.py | 16 +-- 4 files changed, 108 insertions(+), 47 deletions(-) diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index fd52bf2..218c22f 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -7,6 +7,8 @@ import google.protobuf.internal.api_implementation +from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider + with warnings.catch_warnings(): # Suppress the eth-typing package`s warnings related to some new networks warnings.filterwarnings("ignore", "Network .* does not have a valid ChainId. eth-typing should be " @@ -76,6 +78,8 @@ def __init__(self, sdk_config: Config, metadata_provider=None): self.registry_contract = get_contract_object(self.web3, "Registry", _registry_contract_address) self.account = Account(self.web3, sdk_config, self.mpe_contract) + self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract) + self.payment_channel_provider.update_cache() def create_service_client(self, org_id: str, service_id: str, group_name=None, payment_channel_management_strategy=None, @@ -122,7 +126,7 @@ def create_service_client(self, org_id: str, service_id: str, group_name=None, pb2_module = self.get_module_by_keyword(org_id, service_id, keyword="pb2.py") service_client = ServiceClient(org_id, service_id, service_metadata, group, service_stub, strategy, - options, self.mpe_contract, self.account, self.web3, pb2_module) + options, self.mpe_contract, self.account, self.web3, pb2_module, self.payment_channel_provider) return service_client def get_service_stub(self, org_id: str, service_id: str) -> ServiceStub: diff --git a/snet/sdk/mpe/mpe_contract.py b/snet/sdk/mpe/mpe_contract.py index e740c9a..90a76ae 100644 --- a/snet/sdk/mpe/mpe_contract.py +++ b/snet/sdk/mpe/mpe_contract.py @@ -8,9 +8,6 @@ def __init__(self, w3, address=None): self.contract = get_contract_object(self.web3, "MultiPartyEscrow") else: self.contract = get_contract_object(self.web3, "MultiPartyEscrow", address) - self.event_topics = [self.web3.keccak( - text="ChannelOpen(uint256,uint256,address,address,address,bytes32,uint256,uint256)").hex()] - self.deployment_block = get_contract_deployment_block(self.web3, "MultiPartyEscrow") def balance(self, address): return self.contract.functions.balances(address).call() diff --git a/snet/sdk/mpe/payment_channel_provider.py b/snet/sdk/mpe/payment_channel_provider.py index 7ad8b90..9412faa 100644 --- a/snet/sdk/mpe/payment_channel_provider.py +++ b/snet/sdk/mpe/payment_channel_provider.py @@ -1,66 +1,126 @@ +from pathlib import Path + from web3._utils.events import get_event_data from eth_abi.codec import ABICodec +import pickle from snet.sdk.mpe.payment_channel import PaymentChannel from snet.contracts import get_contract_deployment_block BLOCKS_PER_BATCH = 5000 +CHANNELS_DIR = Path.home().joinpath(".snet", "channels") class PaymentChannelProvider(object): - def __init__(self, w3, payment_channel_state_service_client, mpe_contract): + def __init__(self, w3, mpe_contract): self.web3 = w3 self.mpe_contract = mpe_contract self.event_topics = [self.web3.keccak( text="ChannelOpen(uint256,uint256,address,address,address,bytes32,uint256,uint256)").hex()] self.deployment_block = get_contract_deployment_block(self.web3, "MultiPartyEscrow") - self.payment_channel_state_service_client = payment_channel_state_service_client - - def get_past_open_channels(self, account, payment_address, group_id, starting_block_number=0, to_block_number=None): - if to_block_number is None: - to_block_number = self.web3.eth.block_number - - if starting_block_number == 0: - starting_block_number = self.deployment_block - + self.mpe_address = mpe_contract.contract.address + self.channels_file = CHANNELS_DIR.joinpath(str(self.mpe_address), "channels.pickle") + + def update_cache(self): + channels = [] + last_read_block = self.deployment_block + + if not self.channels_file.exists(): + self.channels_file.parent.mkdir(parents=True, exist_ok=True) + with open(self.channels_file, "wb") as f: + empty_dict = { + "last_read_block": last_read_block, + "channels": channels + } + pickle.dump(empty_dict, f) + else: + with open(self.channels_file, "rb") as f: + load_dict = pickle.load(f) + last_read_block = load_dict["last_read_block"] + channels = load_dict["channels"] + + current_block_number = self.web3.eth.block_number + + if last_read_block < current_block_number: + new_channels = self._get_all_channels_from_blockchain_logs_to_dicts(last_read_block, current_block_number) + channels = channels + new_channels + last_read_block = current_block_number + + with open(self.channels_file, "wb") as f: + dict_to_save = { + "last_read_block": last_read_block, + "channels": channels + } + pickle.dump(dict_to_save, f) + + def _event_data_args_to_dict(self, event_data): + return { + "channel_id": event_data["channelId"], + "sender": event_data["sender"], + "signer": event_data["signer"], + "recipient": event_data["recipient"], + "group_id": event_data["groupId"], + } + + def _get_all_channels_from_blockchain_logs_to_dicts(self, starting_block_number, to_block_number): codec: ABICodec = self.web3.codec logs = [] from_block = starting_block_number while from_block <= to_block_number: to_block = min(from_block + BLOCKS_PER_BATCH, to_block_number) - logs = logs + self.web3.eth.get_logs({"fromBlock": from_block, "toBlock": to_block, - "address": self.mpe_contract.contract.address, - "topics": self.event_topics}) + logs = logs + self.web3.eth.get_logs({"fromBlock": from_block, + "toBlock": to_block, + "address": self.mpe_address, + "topics": self.event_topics}) from_block = to_block + 1 event_abi = self.mpe_contract.contract._find_matching_event_abi(event_name="ChannelOpen") - channels_opened = list(filter( - lambda - channel: (channel.sender == account.address or channel.signer == account.signer_address) and channel.recipient == - payment_address and channel.groupId == group_id, - - [get_event_data(codec, event_abi, l)["args"] for l in logs] - )) - return list(map(lambda channel: PaymentChannel(channel["channelId"], self.web3, account, - self.payment_channel_state_service_client, self.mpe_contract), - channels_opened)) - def open_channel(self, account, amount, expiration, payment_address, group_id): + event_data_list = [get_event_data(codec, event_abi, l)["args"] for l in logs] + channels_opened = list(map(self._event_data_args_to_dict, event_data_list)) - receipt = self.mpe_contract.open_channel(account, payment_address, group_id, amount, expiration) - return self._get_newly_opened_channel(receipt, account, payment_address, group_id) + return channels_opened + + def _get_channels_from_cache(self): + self.update_cache() + with open(self.channels_file, "rb") as f: + load_dict = pickle.load(f) + return load_dict["channels"] - def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id): - receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount, - expiration) - return self._get_newly_opened_channel(receipt, account, payment_address, group_id) + def get_past_open_channels(self, account, payment_address, group_id, payment_channel_state_service_client): - def _get_newly_opened_channel(self, receipt,account, payment_address, group_id): - open_channels = self.get_past_open_channels(account, payment_address, group_id, receipt["blockNumber"], - receipt["blockNumber"]) - if len(open_channels) == 0: + dict_channels = self._get_channels_from_cache() + + channels_opened = list(filter(lambda channel: (channel["sender"] == account.address + or channel["signer"] == account.signer_address) + and channel["recipient"] == payment_address + and channel["group_id"] == group_id, + dict_channels)) + + return list(map(lambda channel: PaymentChannel(channel["channel_id"], + self.web3, + account, + payment_channel_state_service_client, + self.mpe_contract), + channels_opened)) + + def open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client): + receipt = self.mpe_contract.open_channel(account, payment_address, group_id, amount, expiration) + self.update_cache() + return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client) + + def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client): + receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount, expiration) + self.update_cache() + return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client) + + def _get_newly_opened_channel(self, account, payment_address, group_id, receipt, payment_channel_state_service_client): + open_channels = self.get_past_open_channels(account, payment_address, group_id, payment_channel_state_service_client) + n_channels = len(open_channels) + if n_channels == 0: raise Exception(f"Error while opening channel, please check transaction {receipt.transactionHash.hex()} ") - return open_channels[0] + return open_channels[n_channels - 1] + diff --git a/snet/sdk/service_client.py b/snet/sdk/service_client.py index 7a4b59f..1423266 100644 --- a/snet/sdk/service_client.py +++ b/snet/sdk/service_client.py @@ -13,7 +13,6 @@ from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path, find_file_by_keyword import snet.sdk.generic_client_interceptor as generic_client_interceptor -from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider class _ClientCallDetails( @@ -26,7 +25,7 @@ class _ClientCallDetails( class ServiceClient: def __init__(self, org_id, service_id, service_metadata, group, service_stub, payment_strategy, - options, mpe_contract, account, sdk_web3, pb2_module): + options, mpe_contract, account, sdk_web3, pb2_module, payment_channel_provider): self.org_id = org_id self.service_id = service_id self.options = options @@ -38,9 +37,8 @@ def __init__(self, org_id, service_id, service_metadata, group, service_stub, pa self.__base_grpc_channel = self._get_grpc_channel() self.grpc_channel = grpc.intercept_channel(self.__base_grpc_channel, generic_client_interceptor.create(self._intercept_call)) - self.payment_channel_provider = PaymentChannelProvider(sdk_web3, - self._generate_payment_channel_state_service_client(), - mpe_contract) + self.payment_channel_provider = payment_channel_provider + self.payment_channel_state_service_client = self._generate_payment_channel_state_service_client() self.service = self._generate_grpc_stub(service_stub) self.pb2_module = importlib.import_module(pb2_module) if isinstance(pb2_module, str) else pb2_module self.payment_channels = [] @@ -122,7 +120,8 @@ def load_open_channels(self): payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) new_payment_channels = self.payment_channel_provider.get_past_open_channels(self.account, payment_address, - group_id, self.last_read_block) + group_id, + self.payment_channel_state_service_client) self.payment_channels = self.payment_channels + \ self._filter_existing_channels_from_new_payment_channels(new_payment_channels) self.last_read_block = current_block_number @@ -150,13 +149,14 @@ def open_channel(self, amount, expiration): payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) return self.payment_channel_provider.open_channel(self.account, amount, expiration, payment_address, - group_id) + group_id, self.payment_channel_state_service_client) def deposit_and_open_channel(self, amount, expiration): payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) return self.payment_channel_provider.deposit_and_open_channel(self.account, amount, expiration, - payment_address, group_id) + payment_address, group_id, + self.payment_channel_state_service_client) def get_price(self): return self.group["pricing"][0]["price_in_cogs"] From 4a749c649e55c22fc439593967c668dd5b56a0f9 Mon Sep 17 00:00:00 2001 From: Arondondon Date: Fri, 27 Sep 2024 15:19:41 +0300 Subject: [PATCH 2/2] Changed cache directory. Also, now cache is updated only when accessing channels. --- snet/sdk/__init__.py | 1 - snet/sdk/mpe/payment_channel_provider.py | 10 ++++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 218c22f..e07bcc1 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -79,7 +79,6 @@ def __init__(self, sdk_config: Config, metadata_provider=None): self.account = Account(self.web3, sdk_config, self.mpe_contract) self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract) - self.payment_channel_provider.update_cache() def create_service_client(self, org_id: str, service_id: str, group_name=None, payment_channel_management_strategy=None, diff --git a/snet/sdk/mpe/payment_channel_provider.py b/snet/sdk/mpe/payment_channel_provider.py index 9412faa..48ff5e6 100644 --- a/snet/sdk/mpe/payment_channel_provider.py +++ b/snet/sdk/mpe/payment_channel_provider.py @@ -9,7 +9,7 @@ BLOCKS_PER_BATCH = 5000 -CHANNELS_DIR = Path.home().joinpath(".snet", "channels") +CHANNELS_DIR = Path.home().joinpath(".snet", "cache", "mpe") class PaymentChannelProvider(object): @@ -28,6 +28,7 @@ def update_cache(self): last_read_block = self.deployment_block if not self.channels_file.exists(): + print(f"Channels cache is empty. Caching may take some time when first accessing channels.\nCaching in progress...") self.channels_file.parent.mkdir(parents=True, exist_ok=True) with open(self.channels_file, "wb") as f: empty_dict = { @@ -109,18 +110,15 @@ def get_past_open_channels(self, account, payment_address, group_id, payment_cha def open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client): receipt = self.mpe_contract.open_channel(account, payment_address, group_id, amount, expiration) - self.update_cache() return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client) def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client): receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount, expiration) - self.update_cache() return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client) def _get_newly_opened_channel(self, account, payment_address, group_id, receipt, payment_channel_state_service_client): open_channels = self.get_past_open_channels(account, payment_address, group_id, payment_channel_state_service_client) - n_channels = len(open_channels) - if n_channels == 0: + if not open_channels: raise Exception(f"Error while opening channel, please check transaction {receipt.transactionHash.hex()} ") - return open_channels[n_channels - 1] + return open_channels[-1]