diff --git a/src/aleph/sdk/chains/remote.py b/src/aleph/sdk/chains/remote.py index 917cf39b..931b68f3 100644 --- a/src/aleph/sdk/chains/remote.py +++ b/src/aleph/sdk/chains/remote.py @@ -52,7 +52,7 @@ async def from_crypto_host( session = aiohttp.ClientSession(connector=connector) async with session.get(f"{host}/properties") as response: - await response.raise_for_status() + response.raise_for_status() data = await response.json() properties = AccountProperties(**data) @@ -75,7 +75,7 @@ def private_key(self): async def sign_message(self, message: Dict) -> Dict: """Sign a message inplace.""" async with self._session.post(f"{self._host}/sign", json=message) as response: - await response.raise_for_status() + response.raise_for_status() return await response.json() async def sign_raw(self, buffer: bytes) -> bytes: diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index 2975e112..ae4b6b04 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -38,6 +38,7 @@ from ..utils import extended_json_encoder, make_instance_content, make_program_content from .abstract import AuthenticatedAlephClient from .http import AlephHttpClient +from .services.authenticated_port_forwarder import AuthenticatedPortForwarder logger = logging.getLogger(__name__) @@ -81,6 +82,13 @@ def __init__( ) self.account = account + async def __aenter__(self): + await super().__aenter__() + # Override services with authenticated versions + self.port_forwarder = AuthenticatedPortForwarder(self) + + return self + async def ipfs_push(self, content: Mapping) -> str: """ Push arbitrary content as JSON to the IPFS service. @@ -392,7 +400,7 @@ async def create_store( if extra_fields is not None: values.update(extra_fields) - content = StoreContent.parse_obj(values) + content = StoreContent.model_validate(values) message, status, _ = await self.submit( content=content.model_dump(exclude_none=True), diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index bd3090e3..a433e48d 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -33,6 +33,12 @@ from aleph_message.status import MessageStatus from pydantic import ValidationError +from aleph.sdk.client.services.crn import Crn +from aleph.sdk.client.services.dns import DNS +from aleph.sdk.client.services.instance import Instance +from aleph.sdk.client.services.port_forwarder import PortForwarder +from aleph.sdk.client.services.scheduler import Scheduler + from ..conf import settings from ..exceptions import ( FileTooLarge, @@ -123,6 +129,13 @@ async def __aenter__(self): ) ) + # Initialize default services + self.dns = DNS(self) + self.port_forwarder = PortForwarder(self) + self.crn = Crn(self) + self.scheduler = Scheduler(self) + self.instance = Instance(self) + return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -139,7 +152,8 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: resp.raise_for_status() result = await resp.json() data = result.get("data", dict()) - return data.get(key) + final_result = data.get(key) + return final_result async def fetch_aggregates( self, address: str, keys: Optional[Iterable[str]] = None diff --git a/src/aleph/sdk/client/services/__init__.py b/src/aleph/sdk/client/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/client/services/authenticated_port_forwarder.py b/src/aleph/sdk/client/services/authenticated_port_forwarder.py new file mode 100644 index 00000000..765ac2f1 --- /dev/null +++ b/src/aleph/sdk/client/services/authenticated_port_forwarder.py @@ -0,0 +1,190 @@ +from typing import TYPE_CHECKING, Optional, Tuple + +from aleph_message.models import AggregateMessage, ItemHash +from aleph_message.status import MessageStatus + +from aleph.sdk.client.services.base import AggregateConfig +from aleph.sdk.client.services.port_forwarder import PortForwarder +from aleph.sdk.exceptions import MessageNotProcessed, NotAuthorize +from aleph.sdk.types import AllForwarders, Ports +from aleph.sdk.utils import safe_getattr + +if TYPE_CHECKING: + from aleph.sdk.client.abstract import AuthenticatedAlephClient + + +class AuthenticatedPortForwarder(PortForwarder): + """ + Authenticated Port Forwarder services with create and update capabilities + """ + + def __init__(self, client: "AuthenticatedAlephClient"): + super().__init__(client) + + async def _verify_status_processed_and_ownership( + self, item_hash: ItemHash + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Verify that the message is well processed (and not rejected / pending), + This also verify the ownership of the message + """ + message: AggregateMessage + status: MessageStatus + message, status = await self._client.get_message( + item_hash=item_hash, + with_status=True, + ) + + # We ensure message is not Rejected (Might not be processed yet) + if status not in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + raise MessageNotProcessed(item_hash=item_hash, status=status) + + message_content = safe_getattr(message, "content") + address = safe_getattr(message_content, "address") + + if ( + not hasattr(self._client, "account") + or address != self._client.account.get_address() + ): + current_address = ( + self._client.account.get_address() + if hasattr(self._client, "account") + else "unknown" + ) + raise NotAuthorize( + item_hash=item_hash, + target_address=address, + current_address=current_address, + ) + return message, status + + async def get_address_ports( + self, address: Optional[str] = None + ) -> AggregateConfig[AllForwarders]: + """ + Get all port forwarding configurations for an address + + Args: + address: The address to fetch configurations for. + If None, uses the authenticated client's account address. + + Returns: + Port forwarding configurations + """ + if address is None: + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("No account provided and client is not authenticated") + address = self._client.account.get_address() + + return await super().get_address_ports(address=address) + + async def get_ports( + self, item_hash: ItemHash = None, address: Optional[str] = None + ) -> Optional[Ports]: + """ + Get port forwarding configuration for a specific item hash + + Args: + address: The address to fetch configurations for. + If None, uses the authenticated client's account address. + item_hash: The hash of the item to get configuration for + + Returns: + Port configuration if found, otherwise empty Ports object + """ + if address is None: + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("No account provided and client is not authenticated") + address = self._client.account.get_address() + + if item_hash is None: + raise ValueError("item_hash must be provided") + + return await super().get_ports(address=address, item_hash=item_hash) + + async def create_ports( + self, item_hash: ItemHash, ports: Ports + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Create a new port forwarding configuration for an item hash + + Args: + item_hash: The hash of the item (instance/program/IPFS website) + ports: Dictionary mapping port numbers to PortFlags + + Returns: + Dictionary with the result of the operation + """ + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("An account is required for this operation") + + # Pre Check + # _, _ = await self._verify_status_processed_and_ownership(item_hash=item_hash) + + content = {str(item_hash): ports.model_dump()} + + # Check if create_aggregate exists on the client + return await self._client.create_aggregate( # type: ignore + key=self.aggregate_key, content=content + ) + + async def update_ports( + self, item_hash: ItemHash, ports: Ports + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Update an existing port forwarding configuration for an item hash + + Args: + item_hash: The hash of the item (instance/program/IPFS website) + ports: Dictionary mapping port numbers to PortFlags + + Returns: + Dictionary with the result of the operation + """ + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("An account is required for this operation") + + # Pre Check + # _, _ = await self._verify_status_processed_and_ownership(item_hash=item_hash) + + content = {} + + content[str(item_hash)] = ports.model_dump() + + message, status = await self._client.create_aggregate( # type: ignore + key=self.aggregate_key, content=content + ) + + return message, status + + async def delete_ports( + self, item_hash: ItemHash + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Delete port forwarding configuration for an item hash + + Args: + item_hash: The hash of the item (instance/program/IPFS website) to delete configuration for + + Returns: + Dictionary with the result of the operation + """ + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("An account is required for this operation") + + # Pre Check + # _, _ = await self._verify_status_processed_and_ownership(item_hash=item_hash) + + # Get the Port Config of the item_hash + port: Optional[Ports] = await self.get_ports(item_hash=item_hash) + if not port: + raise + + content = {} + content[str(item_hash)] = port.model_dump() + + # Create a new aggregate with the updated content + message, status = await self._client.create_aggregate( # type: ignore + key=self.aggregate_key, content=content + ) + return message, status diff --git a/src/aleph/sdk/client/services/base.py b/src/aleph/sdk/client/services/base.py new file mode 100644 index 00000000..7459d7f6 --- /dev/null +++ b/src/aleph/sdk/client/services/base.py @@ -0,0 +1,42 @@ +from abc import ABC +from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar + +from pydantic import BaseModel + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +T = TypeVar("T", bound=BaseModel) + + +class AggregateConfig(BaseModel, Generic[T]): + """ + A generic container for "aggregate" data of type T. + - `data` will be either None or a list of T-instances. + """ + + data: Optional[List[T]] = None + + +class BaseService(ABC, Generic[T]): + aggregate_key: str + model_cls: Type[T] + + def __init__(self, client: "AlephHttpClient"): + self._client = client + self.model_cls: Type[T] + + async def get_config(self, address: str): + + aggregate_data = await self._client.fetch_aggregate( + address=address, key=self.aggregate_key + ) + + if aggregate_data: + model_instance = self.model_cls.model_validate(aggregate_data) + config = AggregateConfig[T](data=[model_instance]) + else: + config = AggregateConfig[T](data=None) + + return config diff --git a/src/aleph/sdk/client/services/crn.py b/src/aleph/sdk/client/services/crn.py new file mode 100644 index 00000000..3317644a --- /dev/null +++ b/src/aleph/sdk/client/services/crn.py @@ -0,0 +1,138 @@ +from typing import TYPE_CHECKING, Dict, Optional, Union + +import aiohttp +from aiohttp.client_exceptions import ClientResponseError +from aleph_message.models import ItemHash + +from aleph.sdk.conf import settings +from aleph.sdk.exceptions import MethodNotAvailableOnCRN, VmNotFoundOnHost +from aleph.sdk.types import CrnExecutionV1, CrnExecutionV2, CrnV1List, CrnV2List +from aleph.sdk.utils import sanitize_url + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +class Crn: + """ + This services allow interact with CRNS API + TODO: ADD + /about/executions/details + /about/executions/records + /about/usage/system + /about/certificates + /about/capability + /about/config + /status/check/fastapi + /status/check/fastapi/legacy + /status/check/host + /status/check/version + /status/check/ipv6 + /status/config + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_last_crn_version(self): + """ + Fetch Last version tag from aleph-vm github repo + """ + # Create a new session for external domain requests + async with aiohttp.ClientSession() as session: + async with session.get(settings.CRN_VERSION_URL) as resp: + resp.raise_for_status() + data = await resp.json() + return data.get("tag_name") + + async def get_crns_list(self, only_active: bool = True) -> dict: + """ + Query a persistent VM running on aleph.im to retrieve list of CRNs: + https://crns-list.aleph.sh/crns.json + + Parameters + ---------- + only_active : bool + If True (the default), only return active CRNs (i.e. `filter_inactive=false`). + If False, return all CRNs (i.e. `filter_inactive=true`). + + Returns + ------- + dict + The parsed JSON response from /crns.json. + """ + # We want filter_inactive = (not only_active) + # Convert bool to string for the query parameter + filter_inactive_str = str(not only_active).lower() + params = {"filter_inactive": filter_inactive_str} + + # Create a new session for external domain requests + async with aiohttp.ClientSession() as session: + async with session.get( + sanitize_url(settings.CRN_LIST_URL), params=params + ) as resp: + resp.raise_for_status() + return await resp.json() + + async def get_active_vms_v2(self, crn_address: str) -> CrnV2List: + endpoint = "/v2/about/executions/list" + + full_url = sanitize_url(crn_address + endpoint) + + async with aiohttp.ClientSession() as session: + async with session.get(full_url) as resp: + resp.raise_for_status() + raw = await resp.json() + vm_mmap = CrnV2List.model_validate(raw) + return vm_mmap + + async def get_active_vms_v1(self, crn_address: str) -> CrnV1List: + endpoint = "/about/executions/list" + + full_url = sanitize_url(crn_address + endpoint) + + async with aiohttp.ClientSession() as session: + async with session.get(full_url) as resp: + resp.raise_for_status() + raw = await resp.json() + vm_map = CrnV1List.model_validate(raw) + return vm_map + + async def get_active_vms(self, crn_address: str) -> Union[CrnV2List, CrnV1List]: + try: + return await self.get_active_vms_v2(crn_address) + except ClientResponseError as e: + if e.status == 404: + return await self.get_active_vms_v1(crn_address) + raise + + async def get_vm( + self, crn_address: str, item_hash: ItemHash + ) -> Optional[Union[CrnExecutionV1, CrnExecutionV2]]: + vms = await self.get_active_vms(crn_address) + + vm_map: Dict[ItemHash, Union[CrnExecutionV1, CrnExecutionV2]] = vms.root + + if item_hash not in vm_map: + return None + + return vm_map[item_hash] + + async def update_instance_config(self, crn_address: str, item_hash: ItemHash): + vm = await self.get_vm(crn_address, item_hash) + + if not vm: + raise VmNotFoundOnHost(crn_url=crn_address, item_hash=item_hash) + + # CRN have two week to upgrade their node, + # So if the CRN does not have the update + # We can't update config + if isinstance(vm, CrnExecutionV1): + raise MethodNotAvailableOnCRN() + + full_url = sanitize_url(crn_address + f"/control/{item_hash}/update") + + async with aiohttp.ClientSession() as session: + async with session.post(full_url) as resp: + resp.raise_for_status() + return await resp.json() diff --git a/src/aleph/sdk/client/services/dns.py b/src/aleph/sdk/client/services/dns.py new file mode 100644 index 00000000..95132390 --- /dev/null +++ b/src/aleph/sdk/client/services/dns.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING, List, Optional + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.conf import settings +from aleph.sdk.types import Dns, DnsListAdapter +from aleph.sdk.utils import sanitize_url + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +class DNS: + """ + This Service mostly made to get active dns for instance: + `https://api.dns.public.aleph.sh/instances/list` + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_public_dns(self) -> List[Dns]: + """ + Get all the public dns ha + """ + async with aiohttp.ClientSession() as session: + async with session.get(sanitize_url(settings.DNS_API)) as resp: + resp.raise_for_status() + raw = await resp.json() + + return DnsListAdapter.validate_json(raw) + + async def get_public_dns_by_host(self, crn_hostname): + """ + Get all the public dns with filter on crn_url + """ + async with aiohttp.ClientSession() as session: + async with session.get( + sanitize_url(settings.DNS_API), params={"crn_url": crn_hostname} + ) as resp: + resp.raise_for_status() + raw = await resp.json() + + return DnsListAdapter.validate_json(raw) + + async def get_dns_for_instance(self, vm_hash: ItemHash) -> Optional[List[Dns]]: + async with aiohttp.ClientSession() as session: + async with session.get( + sanitize_url(settings.DNS_API), params={"item_hash": vm_hash} + ) as resp: + resp.raise_for_status() + raw = await resp.json() + return DnsListAdapter.validate_json(raw) diff --git a/src/aleph/sdk/client/services/instance.py b/src/aleph/sdk/client/services/instance.py new file mode 100644 index 00000000..1636cb62 --- /dev/null +++ b/src/aleph/sdk/client/services/instance.py @@ -0,0 +1,146 @@ +import asyncio +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +from aleph_message.models import InstanceMessage, ItemHash, MessageType, PaymentType +from aleph_message.status import MessageStatus + +from aleph.sdk.query.filters import MessageFilter +from aleph.sdk.query.responses import MessagesResponse + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + +from aleph.sdk.types import ( + CrnExecutionV1, + CrnExecutionV2, + InstanceAllocationsInfo, + InstanceManual, + InstancesExecutionList, + InstanceWithScheduler, +) +from aleph.sdk.utils import safe_getattr, sanitize_url + + +class Instance: + """ + This is utils functions that used multiple Service + exemple getting info about Allocations / exeuction of any instances (hold or not) + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_name_of_executable(self, item_hash: ItemHash) -> Optional[str]: + try: + message: Any = await self._client.get_message(item_hash=item_hash) + if hasattr(message, "content") and hasattr(message.content, "metadata"): + return message.content.metadata.get("name") + elif isinstance(message, dict): + # Handle dictionary response format + if "content" in message and isinstance(message["content"], dict): + if "metadata" in message["content"] and isinstance( + message["content"]["metadata"], dict + ): + return message["content"]["metadata"].get("name") + return None + except Exception: + return None + + async def get_instance_allocation_info( + self, msg: InstanceMessage, crn_list: dict + ) -> Tuple[InstanceMessage, Union[InstanceManual, InstanceWithScheduler]]: + vm_hash = msg.item_hash + payment_type = safe_getattr(msg, "content.payment.type.value") + firmware = safe_getattr(msg, "content.environment.trusted_execution.firmware") + has_gpu = safe_getattr(msg, "content.requirements.gpu") + + is_hold = payment_type == PaymentType.hold.value + is_conf = bool(firmware and len(firmware) == 64) + + if is_hold and not is_conf and not has_gpu: + alloc = await self._client.scheduler.get_allocation(vm_hash) + info = InstanceWithScheduler(source="scheduler", allocations=alloc) + else: + crn_hash = safe_getattr(msg, "content.requirements.node.node_hash") + if isinstance(crn_list, list): + node = next((n for n in crn_list if n.get("hash") == crn_hash), None) + url = sanitize_url(node.get("address")) if node else "" + else: + node = crn_list.get(crn_hash) + url = sanitize_url(node.get("address")) if node else "" + + info = InstanceManual(source="manual", crn_url=url) + return msg, info + + async def get_instances(self, address: str) -> List[InstanceMessage]: + resp: MessagesResponse = await self._client.get_messages( + message_filter=MessageFilter( + message_types=[MessageType.instance], + addresses=[address], + ), + page_size=100, + ) + return resp.messages + + async def get_instances_allocations(self, messages_list, only_processed=True): + crn_list_response = await self._client.crn.get_crns_list() + crn_list = crn_list_response.get("crns", {}) + + tasks = [] + for msg in messages_list: + if only_processed: + status = await self._client.get_message_status(msg.item_hash) + if ( + status != MessageStatus.PROCESSED + and status != MessageStatus.REMOVING + ): + continue + tasks.append(self.get_instance_allocation_info(msg, crn_list)) + + results = await asyncio.gather(*tasks) + + mapping = {ItemHash(msg.item_hash): info for msg, info in results} + + return InstanceAllocationsInfo.model_validate(mapping) + + async def get_instance_executions_info( + self, instances: InstanceAllocationsInfo + ) -> InstancesExecutionList: + async def _fetch( + item_hash: ItemHash, + alloc: Union[InstanceManual, InstanceWithScheduler], + ) -> tuple[str, Optional[Union[CrnExecutionV1, CrnExecutionV2]]]: + """Retrieve the execution record for an item hash.""" + if isinstance(alloc, InstanceManual): + crn_url = sanitize_url(alloc.crn_url) + else: + crn_url = sanitize_url(alloc.allocations.node.url) + + if not crn_url: + return str(item_hash), None + + try: + execution = await self._client.crn.get_vm( + item_hash=item_hash, + crn_address=crn_url, + ) + return str(item_hash), execution + except Exception: + return str(item_hash), None + + fetch_tasks = [] + msg_hash_map = {} + + for item_hash, alloc in instances.root.items(): + fetch_tasks.append(_fetch(item_hash, alloc)) + msg_hash_map[str(item_hash)] = item_hash + + results = await asyncio.gather(*fetch_tasks) + + mapping = { + ItemHash(msg_hash): exec_info + for msg_hash, exec_info in results + if msg_hash is not None and exec_info is not None + } + + return InstancesExecutionList.model_validate(mapping) diff --git a/src/aleph/sdk/client/services/port_forwarder.py b/src/aleph/sdk/client/services/port_forwarder.py new file mode 100644 index 00000000..923d0931 --- /dev/null +++ b/src/aleph/sdk/client/services/port_forwarder.py @@ -0,0 +1,44 @@ +from typing import TYPE_CHECKING, Optional + +from aleph_message.models import ItemHash + +from aleph.sdk.client.services.base import AggregateConfig, BaseService +from aleph.sdk.types import AllForwarders, Ports + +if TYPE_CHECKING: + pass + + +class PortForwarder(BaseService[AllForwarders]): + """ + Ports Forwarder Logic + """ + + aggregate_key = "port-forwarding" + model_cls = AllForwarders + + def __init__(self, client): + super().__init__(client=client) + + async def get_address_ports(self, address: str) -> AggregateConfig[AllForwarders]: + result = await self.get_config(address=address) + return result + + async def get_ports(self, item_hash: ItemHash, address: str) -> Optional[Ports]: + """ + Get Ports Forwarder of Instance / Program / IPFS website from aggregate + """ + ports_config: AggregateConfig[AllForwarders] = await self.get_address_ports( + address=address + ) + + if ports_config.data is None: + return Ports(ports={}) + + for forwarder in ports_config.data: + ports_map = forwarder.root + + if str(item_hash) in ports_map: + return ports_map[str(item_hash)] + + return Ports(ports={}) diff --git a/src/aleph/sdk/client/services/scheduler.py b/src/aleph/sdk/client/services/scheduler.py new file mode 100644 index 00000000..765ee2bd --- /dev/null +++ b/src/aleph/sdk/client/services/scheduler.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.conf import settings +from aleph.sdk.types import AllocationItem, SchedulerNodes, SchedulerPlan +from aleph.sdk.utils import sanitize_url + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +class Scheduler: + """ + This Service is made to interact with scheduler API: + `https://scheduler.api.aleph.cloud/` + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_plan(self) -> SchedulerPlan: + url = f"{sanitize_url(settings.SCHEDULER_URL)}/api/v0/plan" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + raw = await resp.json() + + return SchedulerPlan.model_validate(raw) + + async def get_nodes(self) -> SchedulerNodes: + url = f"{sanitize_url(settings.SCHEDULER_URL)}/api/v0/nodes" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + raw = await resp.json() + + return SchedulerNodes.model_validate(raw) + + async def get_allocation(self, vm_hash: ItemHash) -> AllocationItem: + """ + Fetch allocation information for a given VM hash. + """ + url = f"{sanitize_url(settings.SCHEDULER_URL)}/api/v0/allocation/{vm_hash}" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + payload = await resp.json() + + return AllocationItem.model_validate(payload) diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 50b38182..fc852417 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -84,6 +84,14 @@ class Settings(BaseSettings): IPFS_GATEWAY: ClassVar[str] = "https://ipfs.aleph.cloud/ipfs/" CRN_URL_FOR_PROGRAMS: ClassVar[str] = "https://dchq.staging.aleph.sh/" + DNS_API: ClassVar[str] = "https://api.dns.public.aleph.sh/instances/list" + CRN_URL_UPDATE: ClassVar[str] = "{crn_url}/control/machine/{vm_hash}/update" + CRN_LIST_URL: ClassVar[str] = "https://crns-list.aleph.sh/crns.json" + CRN_VERSION_URL: ClassVar[str] = ( + "https://api.github.com/repos/aleph-im/aleph-vm/releases/latest" + ) + SCHEDULER_URL: ClassVar[str] = "https://scheduler.api.aleph.cloud/" + # Web3Provider settings TOKEN_DECIMALS: ClassVar[int] = 18 TX_TIMEOUT: ClassVar[int] = 60 * 3 diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index ae0f634a..c960f5a8 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -1,5 +1,7 @@ from abc import ABC +from aleph_message.status import MessageStatus + from .types import TokenType from .utils import displayable_amount @@ -22,6 +24,74 @@ class MultipleMessagesError(QueryError): pass +class MessageNotProcessed(Exception): + """ + The resources that you arte trying to interact is not processed + """ + + item_hash: str + status: MessageStatus + + def __init__( + self, + item_hash: str, + status: MessageStatus, + ): + self.item_hash = item_hash + self.status = status + super().__init__( + f"Resources {item_hash} is not processed : {self.status.value}" + ) + + +class NotAuthorize(Exception): + """ + Request not authorize, this could happens for exemple in Ports Forwarding + if u try to setup ports for a vm who is not yours + """ + + item_hash: str + target_address: str + current_address: str + + def __init__(self, item_hash: str, target_address, current_address): + self.item_hash = item_hash + self.target_address = target_address + self.current_address = current_address + super().__init__( + f"Operations not authorize on resources {self.item_hash} \nTarget address : {self.target_address} \nCurrent address : {self.current_address}" + ) + + +class VmNotFoundOnHost(Exception): + """ + The VM not found on the host, + The Might might not be processed yet / wrong CRN_URL + """ + + item_hash: str + crn_url: str + + def __init__( + self, + item_hash: str, + crn_url, + ): + self.item_hash = item_hash + self.crn_url = crn_url + + super().__init__(f"Vm : {self.item_hash} not found on crn : {self.crn_url}") + + +class MethodNotAvailableOnCRN(Exception): + """ + If this error appears that means CRN you trying to interact is outdated and does + not handle this feature + """ + + pass + + class BroadcastError(Exception): """ Data could not be broadcast to the aleph.im network. diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index cf23f19d..6c1ae561 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -1,8 +1,10 @@ from abc import abstractmethod +from datetime import datetime from enum import Enum -from typing import Dict, Optional, Protocol, TypeVar +from typing import Any, Dict, List, Literal, Optional, Protocol, TypeVar, Union -from pydantic import BaseModel, Field +from aleph_message.models import ItemHash +from pydantic import BaseModel, Field, RootModel, TypeAdapter, field_validator __all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage") @@ -100,3 +102,190 @@ class TokenType(str, Enum): GAS = "GAS" ALEPH = "ALEPH" + + +# Scheduler +class Period(BaseModel): + start_timestamp: datetime + duration_seconds: float + + +class PlanItem(BaseModel): + persistent_vms: List[ItemHash] = Field(default_factory=list) + instances: List[ItemHash] = Field(default_factory=list) + on_demand_vms: List[ItemHash] = Field(default_factory=list) + jobs: List[str] = Field(default_factory=list) # adjust type if needed + + @field_validator( + "persistent_vms", "instances", "on_demand_vms", "jobs", mode="before" + ) + @classmethod + def coerce_to_list(cls, v: Any) -> List[Any]: + # Treat None or empty dict as empty list + if v is None or (isinstance(v, dict) and not v): + return [] + return v + + +class SchedulerPlan(BaseModel): + period: Period + plan: Dict[str, PlanItem] + + model_config = { + "populate_by_name": True, + } + + +class NodeItem(BaseModel): + node_id: str + url: str + ipv6: Optional[str] = None + supports_ipv6: bool + + +class SchedulerNodes(BaseModel): + nodes: List[NodeItem] + + model_config = { + "populate_by_name": True, + } + + def get_url(self, node_id: str) -> Optional[str]: + """ + Return the URL for the given node_id, or None if not found. + """ + for node in self.nodes: + if node.node_id == node_id: + return node + return None + + +class AllocationItem(BaseModel): + vm_hash: ItemHash + vm_type: str + vm_ipv6: Optional[str] = None + period: Period + node: NodeItem + + model_config = { + "populate_by_name": True, + } + + +class InstanceWithScheduler(BaseModel): + source: Literal["scheduler"] + allocations: AllocationItem # Case Scheduler + + +class InstanceManual(BaseModel): + source: Literal["manual"] + crn_url: str # Case + + +class InstanceAllocationsInfo( + RootModel[Dict[ItemHash, Union[InstanceManual, InstanceWithScheduler]]] +): + """ + RootModel holding mapping ItemHash to its Allocations. + Uses item_hash as the key instead of InstanceMessage objects to avoid hashability issues. + """ + + pass + + +# CRN Executions + + +class Networking(BaseModel): + ipv4: str + ipv6: str + + +class CrnExecutionV1(BaseModel): + networking: Networking + + +class PortMapping(BaseModel): + host: int + tcp: bool + udp: bool + + +class NetworkingV2(BaseModel): + ipv4_network: str + host_ipv4: str + ipv6_network: str + ipv6_ip: str + mapped_ports: Dict[str, PortMapping] + + +class VmStatus(BaseModel): + defined_at: Optional[datetime] + preparing_at: Optional[datetime] + prepared_at: Optional[datetime] + starting_at: Optional[datetime] + started_at: Optional[datetime] + stopping_at: Optional[datetime] + stopped_at: Optional[datetime] + + +class CrnExecutionV2(BaseModel): + networking: NetworkingV2 + status: VmStatus + running: bool + + +class CrnV1List(RootModel[Dict[ItemHash, CrnExecutionV1]]): + """ + V1: a dict whose keys are ItemHash (strings) + and whose values are VmItemV1 (just `networking`). + """ + + pass + + +class CrnV2List(RootModel[Dict[ItemHash, CrnExecutionV2]]): + """ + A RootModel whose root is a dict mapping each item‐hash (string) + to a CrnExecutionV2, exactly matching your JSON structure. + """ + + pass + + +class InstancesExecutionList( + RootModel[Dict[ItemHash, Union[CrnExecutionV1, CrnExecutionV2]]] +): + """ + A Root Model representing Instances Message hashes and their Executions. + Uses ItemHash as keys to avoid hashability issues with InstanceMessage objects. + """ + + pass + + +class IPV4(BaseModel): + public: str + local: str + + +class Dns(BaseModel): + name: str + item_hash: ItemHash + ipv4: Optional[IPV4] + ipv6: str + + +DnsListAdapter = TypeAdapter(list[Dns]) + + +class PortFlags(BaseModel): + tcp: bool + udp: bool + + +class Ports(BaseModel): + ports: Dict[int, PortFlags] + + +AllForwarders = RootModel[Dict[ItemHash, Ports]] diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 31b2be8d..19a3aa57 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -25,6 +25,7 @@ Union, get_args, ) +from urllib.parse import urlparse from uuid import UUID from zipfile import BadZipFile, ZipFile @@ -591,3 +592,24 @@ def make_program_content( authorized_keys=[], payment=payment, ) + + +def sanitize_url(url: str) -> str: + """ + Sanitize a URL by removing the trailing slash and ensuring it's properly formatted. + + Args: + url: The URL to sanitize + + Returns: + The sanitized URL + """ + # Remove trailing slash if present + url = url.rstrip("/") + + # Ensure URL has a proper scheme + parsed = urlparse(url) + if not parsed.scheme: + url = f"https://{url}" + + return url diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 385d2836..3ad0a4ad 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -166,7 +166,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): ... - async def raise_for_status(self): ... + def raise_for_status(self): ... @property def status(self): diff --git a/tests/unit/services/__init__.py b/tests/unit/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/services/mocks.py b/tests/unit/services/mocks.py new file mode 100644 index 00000000..86f473b7 --- /dev/null +++ b/tests/unit/services/mocks.py @@ -0,0 +1,345 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ..conftest import make_custom_mock_response + +FAKE_CRN_GPU_HASH = "abcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabca" +FAKE_CRN_GPU_ADDRESS = "0xBCABCABCABCABCABCABCABCABCABCABCABCABCAB" +FAKE_CRN_GPU_URL = "https://test.gpu.crn.com" + +FAKE_CRN_CONF_HASH = "defdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefd" +FAKE_CRN_CONF_ADDRESS = "0xDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEf" +FAKE_CRN_CONF_URL = "https://test.conf.crn" + +FAKE_CRN_BASIC_HASH = "aaaabbbbccccddddeeeeffff1111222233334444555566667777888899990000" +FAKE_CRN_BASIC_ADDRESS = "0xAAAABBBBCCCCDDDDEEEEFFFF1111222233334444" +FAKE_CRN_BASIC_URL = "https://test.basic.crn.com" + + +@pytest.fixture +def vm_status_v2(): + return { + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef": { + "networking": { + "ipv4_network": "192.168.0.0/24", + "host_ipv4": "192.168.0.1", + "ipv6_network": "2001:db8::/64", + "ipv6_ip": "2001:db8::1", + "mapped_ports": {}, + }, + "status": { + "defined_at": "2023-01-01T00:00:00Z", + "started_at": "2023-01-01T00:00:00Z", + "preparing_at": "2023-01-01T00:00:00Z", + "prepared_at": "2023-01-01T00:00:00Z", + "starting_at": "2023-01-01T00:00:00Z", + "stopping_at": "2023-01-01T00:00:00Z", + "stopped_at": "2023-01-01T00:00:00Z", + }, + "running": True, + } + } + + +@pytest.fixture +def vm_status_v1(): + return { + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef": { + "networking": {"ipv4": "192.168.0.1", "ipv6": "2001:db8::1"} + } + } + + +@pytest.fixture +def mock_crn_list(): + """Create a mock CRN list for testing.""" + return [ + { + "hash": FAKE_CRN_GPU_HASH, + "name": "Test GPU Instance", + "time": 1739525120.505, + "type": "compute", + "owner": FAKE_CRN_GPU_ADDRESS, + "score": 0.964502797686815, + "banner": "", + "locked": True, + "parent": FAKE_CRN_GPU_HASH, + "reward": FAKE_CRN_GPU_ADDRESS, + "status": "linked", + "address": FAKE_CRN_GPU_URL, + "manager": "", + "picture": "", + "authorized": "", + "description": "", + "performance": 0, + "multiaddress": "", + "score_updated": True, + "stream_reward": FAKE_CRN_GPU_ADDRESS, + "inactive_since": None, + "decentralization": 0.852680607762069, + "registration_url": "", + "terms_and_conditions": "", + "config_from_crn": True, + "debug_config_from_crn_at": "2025-06-18T12:09:03.843059+00:00", + "debug_config_from_crn_error": "None", + "debug_usage_from_crn_at": "2025-06-18T12:09:03.843059+00:00", + "usage_from_crn_error": "None", + "version": "1.6.0-rc1", + "payment_receiver_address": FAKE_CRN_GPU_ADDRESS, + "gpu_support": True, + "confidential_support": False, + "qemu_support": True, + "system_usage": { + "cpu": { + "count": 20, + "load_average": { + "load1": 0.357421875, + "load5": 0.31982421875, + "load15": 0.34912109375, + }, + "core_frequencies": {"min": 800, "max": 4280}, + }, + "mem": {"total_kB": 67219530, "available_kB": 61972037}, + "disk": {"total_kB": 1853812338, "available_kB": 1320664518}, + "period": { + "start_timestamp": "2025-06-18T12:09:00Z", + "duration_seconds": 60, + }, + "properties": { + "cpu": { + "architecture": "x86_64", + "vendor": "GenuineIntel", + "features": [], + } + }, + "gpu": { + "devices": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + "available_devices": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + }, + "active": True, + }, + "compatible_gpus": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + "compatible_available_gpus": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + "ipv6_check": {"host": True, "vm": True}, + }, + { + "hash": FAKE_CRN_CONF_HASH, + "name": "Test Conf CRN", + "time": 1739296606.021, + "type": "compute", + "owner": FAKE_CRN_CONF_ADDRESS, + "score": 0.964334395009276, + "banner": "", + "locked": False, + "parent": FAKE_CRN_CONF_HASH, + "reward": FAKE_CRN_CONF_ADDRESS, + "status": "linked", + "address": FAKE_CRN_CONF_URL, + "manager": "", + "picture": "", + "authorized": "", + "description": "", + "performance": 0, + "multiaddress": "", + "score_updated": False, + "stream_reward": FAKE_CRN_CONF_ADDRESS, + "inactive_since": None, + "decentralization": 0.994724704221032, + "registration_url": "", + "terms_and_conditions": "", + "config_from_crn": False, + "debug_config_from_crn_at": "2025-06-18T12:09:03.951298+00:00", + "debug_config_from_crn_error": "None", + "debug_usage_from_crn_at": "2025-06-18T12:09:03.951298+00:00", + "usage_from_crn_error": "None", + "version": "1.5.1", + "payment_receiver_address": FAKE_CRN_CONF_ADDRESS, + "gpu_support": False, + "confidential_support": True, + "qemu_support": True, + "system_usage": { + "cpu": { + "count": 224, + "load_average": { + "load1": 3.8466796875, + "load5": 3.9228515625, + "load15": 3.82080078125, + }, + "core_frequencies": {"min": 1500, "max": 2200}, + }, + "mem": {"total_kB": 807728145, "available_kB": 630166945}, + "disk": {"total_kB": 14971880235, "available_kB": 152975388}, + "period": { + "start_timestamp": "2025-06-18T12:09:00Z", + "duration_seconds": 60, + }, + "properties": { + "cpu": { + "architecture": "x86_64", + "vendor": "AuthenticAMD", + "features": ["sev", "sev_es"], + } + }, + "gpu": {"devices": [], "available_devices": []}, + "active": True, + }, + "compatible_gpus": [], + "compatible_available_gpus": [], + "ipv6_check": {"host": True, "vm": True}, + }, + { + "hash": FAKE_CRN_BASIC_HASH, + "name": "Test Basic CRN", + "time": 1687179700.242, + "type": "compute", + "owner": FAKE_CRN_BASIC_ADDRESS, + "score": 0.979808976368904, + "banner": FAKE_CRN_BASIC_HASH, + "locked": False, + "parent": FAKE_CRN_BASIC_HASH, + "reward": FAKE_CRN_BASIC_ADDRESS, + "status": "linked", + "address": FAKE_CRN_BASIC_URL, + "manager": FAKE_CRN_BASIC_ADDRESS, + "picture": FAKE_CRN_BASIC_HASH, + "authorized": "", + "description": "", + "performance": 0, + "multiaddress": "", + "score_updated": True, + "stream_reward": FAKE_CRN_BASIC_ADDRESS, + "inactive_since": None, + "decentralization": 0.93953628188216, + "registration_url": "", + "terms_and_conditions": "", + "config_from_crn": True, + "debug_config_from_crn_at": "2025-06-18T12:08:59.599676+00:00", + "debug_config_from_crn_error": "None", + "debug_usage_from_crn_at": "2025-06-18T12:08:59.599676+00:00", + "usage_from_crn_error": "None", + "version": "1.5.1", + "payment_receiver_address": FAKE_CRN_BASIC_ADDRESS, + "gpu_support": False, + "confidential_support": False, + "qemu_support": True, + "system_usage": { + "cpu": { + "count": 32, + "load_average": {"load1": 0, "load5": 0.01513671875, "load15": 0}, + "core_frequencies": {"min": 1200, "max": 3400}, + }, + "mem": {"total_kB": 270358832, "available_kB": 266152607}, + "disk": {"total_kB": 1005067972, "available_kB": 919488466}, + "period": { + "start_timestamp": "2025-06-18T12:09:00Z", + "duration_seconds": 60, + }, + "properties": { + "cpu": { + "architecture": "x86_64", + "vendor": "GenuineIntel", + "features": [], + } + }, + "gpu": {"devices": [], "available_devices": []}, + "active": True, + }, + "compatible_gpus": [], + "compatible_available_gpus": [], + "ipv6_check": {"host": True, "vm": False}, + }, + ] + + +def make_mock_aiohttp_session(mocked_json_response): + mock_response = AsyncMock() + mock_response.json.return_value = mocked_json_response + mock_response.raise_for_status.return_value = None + + session = MagicMock() + + session_cm = AsyncMock() + session_cm.__aenter__.return_value = session + + get_cm = AsyncMock() + get_cm.__aenter__.return_value = mock_response + + post_cm = AsyncMock() + post_cm.__aenter__.return_value = mock_response + + session.get = MagicMock(return_value=get_cm) + session.post = MagicMock(return_value=post_cm) + + return session_cm + + +def make_mock_get_active_vms_parametrized(v2_fails, expected_payload): + session = MagicMock() + + def get(url, *args, **kwargs): + mock_resp = None + if "/v2/about/executions/list" in url and v2_fails: + mock_resp = make_custom_mock_response(expected_payload, 404) + else: + mock_resp = make_custom_mock_response(expected_payload) + + mock_ctx = AsyncMock() + mock_ctx.__aenter__.return_value = mock_resp + return mock_ctx + + def post(url, *args, **kwargs): + if "/update" in url: + return make_custom_mock_response( + {"status": "ok", "msg": "VM not starting yet"}, 200 + ) + return None + + session.get = MagicMock(side_effect=get) + + session.post = MagicMock(side_effect=post) + + session_cm = AsyncMock() + session_cm.__aenter__.return_value = session + + return session_cm diff --git a/tests/unit/services/test_base_service.py b/tests/unit/services/test_base_service.py new file mode 100644 index 00000000..6c07dd50 --- /dev/null +++ b/tests/unit/services/test_base_service.py @@ -0,0 +1,46 @@ +from typing import Optional +from unittest.mock import AsyncMock + +import pytest +from pydantic import BaseModel + +from aleph.sdk.client.services.base import AggregateConfig, BaseService + + +class DummyModel(BaseModel): + foo: str + bar: Optional[int] + + +class DummyService(BaseService[DummyModel]): + aggregate_key = "dummy_key" + model_cls = DummyModel + + +@pytest.mark.asyncio +async def test_get_config_with_data(): + mock_client = AsyncMock() + mock_data = {"foo": "hello", "bar": 123} + mock_client.fetch_aggregate.return_value = mock_data + + service = DummyService(mock_client) + + result = await service.get_config("0xSOME_ADDRESS") + + assert isinstance(result, AggregateConfig) + assert result.data is not None + assert isinstance(result.data[0], DummyModel) + assert result.data[0].foo == "hello" + assert result.data[0].bar == 123 + + +@pytest.mark.asyncio +async def test_get_config_with_no_data(): + mock_client = AsyncMock() + mock_client.fetch_aggregate.return_value = None + + service = DummyService(mock_client) + result = await service.get_config("0xSOME_ADDRESS") + + assert isinstance(result, AggregateConfig) + assert result.data is None diff --git a/tests/unit/test_services.py b/tests/unit/test_services.py new file mode 100644 index 00000000..762fceea --- /dev/null +++ b/tests/unit/test_services.py @@ -0,0 +1,445 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient +from aleph.sdk.client.services.authenticated_port_forwarder import ( + AuthenticatedPortForwarder, + PortForwarder, +) +from aleph.sdk.client.services.crn import Crn +from aleph.sdk.client.services.dns import DNS +from aleph.sdk.client.services.instance import Instance +from aleph.sdk.client.services.scheduler import Scheduler +from aleph.sdk.types import ( + IPV4, + AllocationItem, + Dns, + PortFlags, + Ports, + SchedulerNodes, + SchedulerPlan, +) + + +@pytest.mark.asyncio +async def test_aleph_http_client_services_loading(): + """Test that services are properly loaded in AlephHttpClient's __aenter__""" + with patch("aiohttp.ClientSession") as mock_session: + mock_session_instance = AsyncMock() + mock_session.return_value = mock_session_instance + + client = AlephHttpClient(api_server="http://localhost") + + async def mocked_aenter(): + client._http_session = mock_session_instance + client.dns = DNS(client) + client.port_forwarder = PortForwarder(client) + client.crn = Crn(client) + client.scheduler = Scheduler(client) + client.instance = Instance(client) + return client + + with patch.object(client, "__aenter__", mocked_aenter), patch.object( + client, "__aexit__", AsyncMock() + ): + async with client: + assert isinstance(client.dns, DNS) + assert isinstance(client.port_forwarder, PortForwarder) + assert isinstance(client.crn, Crn) + assert isinstance(client.scheduler, Scheduler) + assert isinstance(client.instance, Instance) + + assert client.dns._client == client + assert client.port_forwarder._client == client + assert client.crn._client == client + assert client.scheduler._client == client + assert client.instance._client == client + + +@pytest.mark.asyncio +async def test_authenticated_http_client_services_loading(ethereum_account): + """Test that authenticated services are properly loaded in AuthenticatedAlephHttpClient's __aenter__""" + with patch("aiohttp.ClientSession") as mock_session: + mock_session_instance = AsyncMock() + mock_session.return_value = mock_session_instance + + client = AuthenticatedAlephHttpClient( + account=ethereum_account, api_server="http://localhost" + ) + + async def mocked_aenter(): + client._http_session = mock_session_instance + client.dns = DNS(client) + client.port_forwarder = AuthenticatedPortForwarder(client) + client.crn = Crn(client) + client.scheduler = Scheduler(client) + client.instance = Instance(client) + return client + + with patch.object(client, "__aenter__", mocked_aenter), patch.object( + client, "__aexit__", AsyncMock() + ): + async with client: + assert isinstance(client.dns, DNS) + assert isinstance(client.port_forwarder, AuthenticatedPortForwarder) + assert isinstance(client.crn, Crn) + assert isinstance(client.scheduler, Scheduler) + assert isinstance(client.instance, Instance) + + assert client.dns._client == client + assert client.port_forwarder._client == client + assert client.crn._client == client + assert client.scheduler._client == client + assert client.instance._client == client + + +def mock_aiohttp_session(response_data, raise_error=False, error_status=404): + """ + Creates a mock for aiohttp.ClientSession that properly handles async context managers. + + Args: + response_data: The data to return from the response's json() method + raise_error: Whether to raise an aiohttp.ClientResponseError + error_status: The HTTP status code to use if raising an error + + Returns: + A tuple of (patch_target, mock_session_context, mock_session, mock_response) + """ + # Mock the response object + mock_response = MagicMock() + + if raise_error: + # Set up raise_for_status to raise an exception + error = aiohttp.ClientResponseError( + request_info=MagicMock(), + history=tuple(), + status=error_status, + message="Not Found" if error_status == 404 else "Error", + ) + mock_response.raise_for_status = MagicMock(side_effect=error) + else: + # Normal case - just return the data + mock_response.raise_for_status = MagicMock() + mock_response.json = AsyncMock(return_value=response_data) + + # Mock the context manager for session.get + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_response) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + + # Mock the session's get method to return our context manager + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_context_manager) + mock_session.post = MagicMock(return_value=mock_context_manager) + + # Mock the ClientSession context manager + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + return "aiohttp.ClientSession", mock_session_context, mock_session, mock_response + + +@pytest.mark.asyncio +async def test_authenticated_port_forwarder_create_port_forward(ethereum_account): + """Test the create_port method in AuthenticatedPortForwarder""" + mock_client = MagicMock() + mock_client.http_session = AsyncMock() + mock_client.account = ethereum_account + + auth_port_forwarder = AuthenticatedPortForwarder(mock_client) + + ports = Ports(ports={80: PortFlags(tcp=True, udp=False)}) + + mock_message = MagicMock() + mock_status = MagicMock() + + # Setup the mock for create_aggregate + mock_client.create_aggregate = AsyncMock(return_value=(mock_message, mock_status)) + + # Mock the _verify_status_processed_and_ownership method + with patch.object( + auth_port_forwarder, + "_verify_status_processed_and_ownership", + AsyncMock(return_value=(mock_message, mock_status)), + ): + # Call the actual method + result_message, result_status = await auth_port_forwarder.create_ports( + item_hash="test_hash", ports=ports + ) + + # Verify create_aggregate was called + mock_client.create_aggregate.assert_called_once() + + # Check the parameters passed to create_aggregate + call_args = mock_client.create_aggregate.call_args + assert call_args[1]["key"] == "port-forwarding" + assert "test_hash" in call_args[1]["content"] + + # Verify the method returns what create_aggregate returns + assert result_message == mock_message + assert result_status == mock_status + + +@pytest.mark.asyncio +async def test_authenticated_port_forwarder_update_port(ethereum_account): + """Test the update_port method in AuthenticatedPortForwarder""" + mock_client = MagicMock() + mock_client.http_session = AsyncMock() + mock_client.account = ethereum_account + + auth_port_forwarder = AuthenticatedPortForwarder(mock_client) + + ports = Ports(ports={80: PortFlags(tcp=True, udp=False)}) + + mock_message = MagicMock() + mock_status = MagicMock() + + # Setup the mock for create_aggregate + mock_client.create_aggregate = AsyncMock(return_value=(mock_message, mock_status)) + + # Mock the _verify_status_processed_and_ownership method + with patch.object( + auth_port_forwarder, + "_verify_status_processed_and_ownership", + AsyncMock(return_value=(mock_message, mock_status)), + ): + # Call the actual method + result_message, result_status = await auth_port_forwarder.update_ports( + item_hash="test_hash", ports=ports + ) + + # Verify create_aggregate was called + mock_client.create_aggregate.assert_called_once() + + # Check the parameters passed to create_aggregate + call_args = mock_client.create_aggregate.call_args + assert call_args[1]["key"] == "port-forwarding" + assert "test_hash" in call_args[1]["content"] + + # Verify the method returns what create_aggregate returns + assert result_message == mock_message + assert result_status == mock_status + + +@pytest.mark.asyncio +async def test_dns_service_get_public_dns(): + """Test the DNSService get_public_dns method""" + mock_client = MagicMock() + dns_service = DNS(mock_client) + + # Mock the DnsListAdapter with a valid 64-character hash for ItemHash + mock_dns_list = [ + Dns( + name="test.aleph.sh", + item_hash="b236db23bf5ad005ad7f5d82eed08a68a925020f0755b2a59c03f784499198eb", + ipv6="2001:db8::1", + ipv4=IPV4(public="192.0.2.1", local="10.0.0.1"), + ) + ] + + # Patch DnsListAdapter.validate_json to return our mock DNS list + with patch( + "aleph.sdk.types.DnsListAdapter.validate_json", return_value=mock_dns_list + ): + # Set up mock for aiohttp.ClientSession to return a string (which is what validate_json expects) + patch_target, mock_session_context, _, _ = mock_aiohttp_session( + '["dummy json string"]' + ) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await dns_service.get_public_dns() + + assert len(result) == 1 + assert result[0].name == "test.aleph.sh" + assert ( + result[0].item_hash + == "b236db23bf5ad005ad7f5d82eed08a68a925020f0755b2a59c03f784499198eb" + ) + assert result[0].ipv6 == "2001:db8::1" + assert result[0].ipv4 is not None and result[0].ipv4.public == "192.0.2.1" + + +@pytest.mark.asyncio +async def test_crn_service_get_last_crn_version(): + """Test the CrnService get_last_crn_version method""" + mock_client = MagicMock() + crn_service = Crn(mock_client) + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session( + {"tag_name": "v1.2.3"} + ) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await crn_service.get_last_crn_version() + assert result == "v1.2.3" + + +@pytest.mark.asyncio +async def test_scheduler_service_get_plan(): + """Test the SchedulerService get_plan method""" + mock_client = MagicMock() + scheduler_service = Scheduler(mock_client) + + mock_plan_data = { + "period": {"start_timestamp": "2023-01-01T00:00:00Z", "duration_seconds": 3600}, + "plan": { + "node1": { + "persistent_vms": [ + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210", + ], + "instances": [], + "on_demand_vms": [], + "jobs": [], + } + }, + } + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session(mock_plan_data) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await scheduler_service.get_plan() + assert isinstance(result, SchedulerPlan) + assert "node1" in result.plan + assert ( + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + in result.plan["node1"].persistent_vms + ) + + +@pytest.mark.asyncio +async def test_scheduler_service_get_scheduler_node(): + """Test the SchedulerService get_scheduler_node method""" + mock_client = MagicMock() + scheduler_service = Scheduler(mock_client) + + mock_nodes_data = { + "nodes": [ + { + "node_id": "node1", + "url": "https://node1.aleph.im", + "ipv6": "2001:db8::1", + "supports_ipv6": True, + }, + { + "node_id": "node2", + "url": "https://node2.aleph.im", + "ipv6": None, + "supports_ipv6": False, + }, + ] + } + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session(mock_nodes_data) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await scheduler_service.get_nodes() + assert isinstance(result, SchedulerNodes) + assert len(result.nodes) == 2 + assert result.nodes[0].node_id == "node1" + assert result.nodes[1].url == "https://node2.aleph.im" + + +@pytest.mark.asyncio +async def test_scheduler_service_get_allocation(): + """Test the SchedulerService get_allocation method""" + mock_client = MagicMock() + scheduler_service = Scheduler(mock_client) + + mock_allocation_data = { + "vm_hash": "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + "vm_type": "instance", + "vm_ipv6": "2001:db8::1", + "period": {"start_timestamp": "2023-01-01T00:00:00Z", "duration_seconds": 3600}, + "node": { + "node_id": "node1", + "url": "https://node1.aleph.im", + "ipv6": "2001:db8::1", + "supports_ipv6": True, + }, + } + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session( + mock_allocation_data + ) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await scheduler_service.get_allocation( + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + ) + assert isinstance(result, AllocationItem) + assert ( + result.vm_hash + == "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + ) + assert result.node.node_id == "node1" + + +@pytest.mark.asyncio +async def test_utils_service_get_name_of_executable(): + """Test the UtilsService get_name_of_executable method""" + mock_client = MagicMock() + utils_service = Instance(mock_client) + + # Mock a message with metadata.name + mock_message = MagicMock() + mock_message.content.metadata = {"name": "test-executable"} + + # Set up the client mock to return the message + mock_client.get_message = AsyncMock(return_value=mock_message) + + # Test successful case + result = await utils_service.get_name_of_executable("hash1") + assert result == "test-executable" + + # Test with dict response + mock_client.get_message = AsyncMock( + return_value={"content": {"metadata": {"name": "dict-executable"}}} + ) + + result = await utils_service.get_name_of_executable("hash2") + assert result == "dict-executable" + + # Test with exception + mock_client.get_message = AsyncMock(side_effect=Exception("Test exception")) + + result = await utils_service.get_name_of_executable("hash3") + assert result is None + + +@pytest.mark.asyncio +async def test_utils_service_get_instances(): + """Test the UtilsService get_instances method""" + mock_client = MagicMock() + utils_service = Instance(mock_client) + + # Mock messages response + mock_messages = [MagicMock(), MagicMock()] + mock_response = MagicMock() + mock_response.messages = mock_messages + + # Set up the client mock + mock_client.get_messages = AsyncMock(return_value=mock_response) + + result = await utils_service.get_instances("0xaddress") + + # Check that get_messages was called with correct parameters + mock_client.get_messages.assert_called_once() + call_args = mock_client.get_messages.call_args[1] + assert call_args["page_size"] == 100 + assert call_args["message_filter"].addresses == ["0xaddress"] + + # Check result + assert result == mock_messages