diff --git a/docs/channel.rst b/docs/channel.rst index a629c85..88c2f36 100644 --- a/docs/channel.rst +++ b/docs/channel.rst @@ -247,7 +247,7 @@ :param callable callback: Callback to be called with the result of the query. - Tis function does the same as :py:meth:`query` but it will honor the ``domain`` and ``search`` directives in + This function does the same as :py:meth:`query` but it will honor the ``domain`` and ``search`` directives in ``resolv.conf``. .. py:method:: cancel() @@ -293,6 +293,18 @@ Process the given file descriptors for read and/or write events. + + .. py:method:: process_read_fd(read_fd) + + :param int read_fd: File descriptor ready to read from. + + Processes the given file file descriptors for read events + + .. py:method:: process_write_fd(write_fd) + :param int write_fd: File descriptor ready to write to. + + Processes the given file file descriptors for write events + .. py:method:: getsock() Return a tuple containing 2 lists with the file descriptors ready to read and write. @@ -305,15 +317,15 @@ If the ``max_timeout`` parameter is specified, it is stored on the channel and the appropriate value is then returned. - .. py:method:: set_local_ip(local_ip) + .. py:method:: set_local_ip(ip) - :param str local_ip: IP address. + :param str ip: IP address. Set the local IPv4 or IPv6 address from which the queries will be sent. - .. py:method:: set_local_dev(local_dev) + .. py:method:: set_local_dev(dev) - :param str local_dev: Network device name. + :param str dev: Network device name. Set the local ethernet device from which the queries will be sent. diff --git a/src/pycares/__init__.py b/src/pycares/__init__.py index dc4177d..7ed3bae 100644 --- a/src/pycares/__init__.py +++ b/src/pycares/__init__.py @@ -1,9 +1,9 @@ - from ._cares import ffi as _ffi, lib as _lib import _cffi_backend # hint for bundler tools + if _lib.ARES_SUCCESS != _lib.ares_library_init(_lib.ARES_LIB_INIT_ALL) or _ffi is None: - raise RuntimeError('Could not initialize c-ares') + raise RuntimeError("Could not initialize c-ares") from . import errno from .utils import ascii_bytes, maybe_str, parse_name @@ -13,9 +13,10 @@ import socket import threading from collections.abc import Callable, Iterable -from typing import Any, Callable, Final, Optional, Dict, Union +from typing import Any, Callable, Final, Optional, Dict, Union, Literal, overload from queue import SimpleQueue + IP4 = tuple[str, int] IP6 = tuple[str, int, int, int] @@ -49,28 +50,28 @@ ARES_NI_IDN_USE_STD3_ASCII_RULES = _lib.ARES_NI_IDN_USE_STD3_ASCII_RULES # Bad socket -ARES_SOCKET_BAD = _lib.ARES_SOCKET_BAD +ARES_SOCKET_BAD: int = _lib.ARES_SOCKET_BAD # Query types -QUERY_TYPE_A = _lib.T_A -QUERY_TYPE_AAAA = _lib.T_AAAA -QUERY_TYPE_ANY = _lib.T_ANY -QUERY_TYPE_CAA = _lib.T_CAA -QUERY_TYPE_CNAME = _lib.T_CNAME -QUERY_TYPE_MX = _lib.T_MX -QUERY_TYPE_NAPTR = _lib.T_NAPTR -QUERY_TYPE_NS = _lib.T_NS -QUERY_TYPE_PTR = _lib.T_PTR -QUERY_TYPE_SOA = _lib.T_SOA -QUERY_TYPE_SRV = _lib.T_SRV -QUERY_TYPE_TXT = _lib.T_TXT +QUERY_TYPE_A: Literal[1] = _lib.T_A +QUERY_TYPE_AAAA: Literal[28] = _lib.T_AAAA +QUERY_TYPE_ANY: Literal[255] = _lib.T_ANY +QUERY_TYPE_CAA: Literal[257] = _lib.T_CAA +QUERY_TYPE_CNAME: Literal[5] = _lib.T_CNAME +QUERY_TYPE_MX: Literal[15] = _lib.T_MX +QUERY_TYPE_NAPTR: Literal[35] = _lib.T_NAPTR +QUERY_TYPE_NS: Literal[2] = _lib.T_NS +QUERY_TYPE_PTR: Literal[12] = _lib.T_PTR +QUERY_TYPE_SOA: Literal[6] = _lib.T_SOA +QUERY_TYPE_SRV: Literal[33] = _lib.T_SRV +QUERY_TYPE_TXT: Literal[16] = _lib.T_TXT # Query classes -QUERY_CLASS_IN = _lib.C_IN -QUERY_CLASS_CHAOS = _lib.C_CHAOS -QUERY_CLASS_HS = _lib.C_HS -QUERY_CLASS_NONE = _lib.C_NONE -QUERY_CLASS_ANY = _lib.C_ANY +QUERY_CLASS_IN: int = _lib.C_IN +QUERY_CLASS_CHAOS: int = _lib.C_CHAOS +QUERY_CLASS_HS: int = _lib.C_HS +QUERY_CLASS_NONE: int = _lib.C_NONE +QUERY_CLASS_ANY: int = _lib.C_ANY ARES_VERSION = maybe_str(_ffi.string(_lib.ares_version(_ffi.NULL))) PYCARES_ADDRTTL_SIZE = 256 @@ -82,7 +83,9 @@ class AresError(Exception): # callback helpers -_handle_to_channel: Dict[Any, "Channel"] = {} # Maps handle to channel to prevent use-after-free +_handle_to_channel: Dict[ + Any, "Channel" +] = {} # Maps handle to channel to prevent use-after-free @_ffi.def_extern() @@ -92,6 +95,7 @@ def _sock_state_cb(data, socket_fd, readable, writable): sock_state_cb = _ffi.from_handle(data) sock_state_cb(socket_fd, readable, writable) + @_ffi.def_extern() def _host_cb(arg, status, timeouts, hostent): # Get callback data without removing the reference yet @@ -109,6 +113,7 @@ def _host_cb(arg, status, timeouts, hostent): callback(result, status) _handle_to_channel.pop(arg, None) + @_ffi.def_extern() def _nameinfo_cb(arg, status, timeouts, node, service): # Get callback data without removing the reference yet @@ -126,6 +131,7 @@ def _nameinfo_cb(arg, status, timeouts, node, service): callback(result, status) _handle_to_channel.pop(arg, None) + @_ffi.def_extern() def _query_cb(arg, status, timeouts, abuf, alen): # Get callback data without removing the reference yet @@ -137,7 +143,19 @@ def _query_cb(arg, status, timeouts, abuf, alen): if status == _lib.ARES_SUCCESS: if query_type == _lib.T_ANY: result = [] - for qtype in (_lib.T_A, _lib.T_AAAA, _lib.T_CAA, _lib.T_CNAME, _lib.T_MX, _lib.T_NAPTR, _lib.T_NS, _lib.T_PTR, _lib.T_SOA, _lib.T_SRV, _lib.T_TXT): + for qtype in ( + _lib.T_A, + _lib.T_AAAA, + _lib.T_CAA, + _lib.T_CNAME, + _lib.T_MX, + _lib.T_NAPTR, + _lib.T_NS, + _lib.T_PTR, + _lib.T_SOA, + _lib.T_SRV, + _lib.T_TXT, + ): r, status = parse_result(qtype, abuf, alen) if status not in (None, _lib.ARES_ENODATA, _lib.ARES_EBADRESP): result = None @@ -157,6 +175,7 @@ def _query_cb(arg, status, timeouts, abuf, alen): callback(result, status) _handle_to_channel.pop(arg, None) + @_ffi.def_extern() def _addrinfo_cb(arg, status, timeouts, res): # Get callback data without removing the reference yet @@ -174,11 +193,14 @@ def _addrinfo_cb(arg, status, timeouts, res): callback(result, status) _handle_to_channel.pop(arg, None) + def parse_result(query_type, abuf, alen): if query_type == _lib.T_A: addrttls = _ffi.new("struct ares_addrttl[]", PYCARES_ADDRTTL_SIZE) naddrttls = _ffi.new("int*", PYCARES_ADDRTTL_SIZE) - parse_status = _lib.ares_parse_a_reply(abuf, alen, _ffi.NULL, addrttls, naddrttls) + parse_status = _lib.ares_parse_a_reply( + abuf, alen, _ffi.NULL, addrttls, naddrttls + ) if parse_status != _lib.ARES_SUCCESS: result = None status = parse_status @@ -188,7 +210,9 @@ def parse_result(query_type, abuf, alen): elif query_type == _lib.T_AAAA: addrttls = _ffi.new("struct ares_addr6ttl[]", PYCARES_ADDRTTL_SIZE) naddrttls = _ffi.new("int*", PYCARES_ADDRTTL_SIZE) - parse_status = _lib.ares_parse_aaaa_reply(abuf, alen, _ffi.NULL, addrttls, naddrttls) + parse_status = _lib.ares_parse_aaaa_reply( + abuf, alen, _ffi.NULL, addrttls, naddrttls + ) if parse_status != _lib.ARES_SUCCESS: result = None status = parse_status @@ -264,7 +288,9 @@ def parse_result(query_type, abuf, alen): status = None elif query_type == _lib.T_PTR: hostent = _ffi.new("struct hostent **") - parse_status = _lib.ares_parse_ptr_reply(abuf, alen, _ffi.NULL, 0, socket.AF_UNSPEC, hostent) + parse_status = _lib.ares_parse_ptr_reply( + abuf, alen, _ffi.NULL, 0, socket.AF_UNSPEC, hostent + ) if parse_status != _lib.ARES_SUCCESS: result = None status = parse_status @@ -365,7 +391,9 @@ def start(self) -> None: if self._thread is not None: # Started by another thread while waiting for the lock return - self._thread = threading.Thread(target=self._run_safe_shutdown_loop, daemon=True) + self._thread = threading.Thread( + target=self._run_safe_shutdown_loop, daemon=True + ) self._thread.start() def destroy_channel(self, channel) -> None: @@ -385,27 +413,105 @@ def destroy_channel(self, channel) -> None: class Channel: - __qtypes__ = (_lib.T_A, _lib.T_AAAA, _lib.T_ANY, _lib.T_CAA, _lib.T_CNAME, _lib.T_MX, _lib.T_NAPTR, _lib.T_NS, _lib.T_PTR, _lib.T_SOA, _lib.T_SRV, _lib.T_TXT) + """ + The c-ares ``Channel`` provides asynchronous DNS operations. + + The Channel object is designed to handle an unlimited number of DNS queries efficiently. + Creating and destroying resolver instances repeatedly is resource-intensive and not + recommended. Instead, create a single resolver instance and reuse it throughout your + application's lifetime. + """ + + __qtypes__ = ( + _lib.T_A, + _lib.T_AAAA, + _lib.T_ANY, + _lib.T_CAA, + _lib.T_CNAME, + _lib.T_MX, + _lib.T_NAPTR, + _lib.T_NS, + _lib.T_PTR, + _lib.T_SOA, + _lib.T_SRV, + _lib.T_TXT, + ) __qclasses__ = (_lib.C_IN, _lib.C_CHAOS, _lib.C_HS, _lib.C_NONE, _lib.C_ANY) - def __init__(self, - flags: Optional[int] = None, - timeout: Optional[float] = None, - tries: Optional[int] = None, - ndots: Optional[int] = None, - tcp_port: Optional[int] = None, - udp_port: Optional[int] = None, - servers: Optional[Iterable[Union[str, bytes]]] = None, - domains: Optional[Iterable[Union[str, bytes]]] = None, - lookups: Union[str, bytes, None] = None, - sock_state_cb: Optional[Callable[[int, bool, bool], None]] = None, - socket_send_buffer_size: Optional[int] = None, - socket_receive_buffer_size: Optional[int] = None, - rotate: bool = False, - local_ip: Union[str, bytes, None] = None, - local_dev: Optional[str] = None, - resolvconf_path: Union[str, bytes, None] = None, - event_thread: bool = False) -> None: + def __init__( + self, + flags: Optional[int] = None, + timeout: Optional[float] = None, + tries: Optional[int] = None, + ndots: Optional[int] = None, + tcp_port: Optional[int] = None, + udp_port: Optional[int] = None, + servers: Optional[Iterable[Union[str, bytes]]] = None, + domains: Optional[Iterable[Union[str, bytes]]] = None, + lookups: Union[str, bytes, None] = None, + sock_state_cb: Optional[Callable[[int, bool, bool], None]] = None, + socket_send_buffer_size: Optional[int] = None, + socket_receive_buffer_size: Optional[int] = None, + rotate: bool = False, + local_ip: Union[str, bytes, None] = None, + local_dev: Optional[str] = None, + resolvconf_path: Union[str, bytes, None] = None, + event_thread: bool = False, + ) -> None: + """ + Args: + flags: Flags controlling the behavior of the resolver. + See ``constants`` for available values. + + timeout: The number of seconds each name server is given to respond to + a query on the first try. The default is five seconds. + + tries: The number of tries the resolver will try contacting each name + server before giving up. The default is four tries. + + ndots: The number of dots which must be present in a domain name for it + to be queried for "as is" prior to querying for it with the default domain + extensions appended. The default value is 1 unless set otherwise by resolv.conf + or the RES_OPTIONS environment variable. + + tcp_port: The (TCP) port to use for queries. The default is 53. + + udp_port: The (UDP) port to use for queries. The default is 53. + + servers: List of nameservers to be used to do the lookups. + + domains: The domains to search, instead of the domains specified + in resolv.conf or the domain derived from the kernel hostname variable. + + lookups: The lookups to perform for host queries. lookups should + be set to a string of the characters "b" or "f", where "b" indicates a \ + DNS lookup and "f" indicates a lookup in the hosts file. + + sock_state_cb: A callback function to be invoked when a + socket changes state. Callback signature: ``sock_state_cb(self, fd, readable, writable)`` + + This option is mutually exclusive with the ``event_thread`` option. + + event_thread: If set to True, c-ares will use its own thread + to process events. This is the recommended way to use c-ares, as it + allows for automatic reinitialization of the channel when the + system resolver configuration changes. Verify that c-ares was + compiled with thread-safety by calling `ares_threadsafety` + before using this option. This option is mutually exclusive with the + ``sock_state_cb`` option. + + socket_send_buffer_size: Size for the created socket's send buffer. + + socket_receive_buffer_size: Size for the created socket's receive buffer. + + rotate: If set to True, the nameservers are rotated when doing queries. + + local_ip: Sets the local IP address for DNS operations. + + local_dev: Sets the local network adapter to use for DNS operations. Linux only. + + resolvconf_path: Path to resolv.conf, defaults to /etc/resolv.conf. Unix only. + """ # Initialize _channel to None first to ensure __del__ doesn't fail self._channel = None @@ -424,33 +530,35 @@ def __init__(self, if tries is not None: options.tries = tries - optmask = optmask | _lib.ARES_OPT_TRIES + optmask = optmask | _lib.ARES_OPT_TRIES if ndots is not None: options.ndots = ndots - optmask = optmask | _lib.ARES_OPT_NDOTS + optmask = optmask | _lib.ARES_OPT_NDOTS if tcp_port is not None: options.tcp_port = tcp_port - optmask = optmask | _lib.ARES_OPT_TCP_PORT + optmask = optmask | _lib.ARES_OPT_TCP_PORT if udp_port is not None: options.udp_port = udp_port - optmask = optmask | _lib.ARES_OPT_UDP_PORT + optmask = optmask | _lib.ARES_OPT_UDP_PORT if socket_send_buffer_size is not None: options.socket_send_buffer_size = socket_send_buffer_size - optmask = optmask | _lib.ARES_OPT_SOCK_SNDBUF + optmask = optmask | _lib.ARES_OPT_SOCK_SNDBUF if socket_receive_buffer_size is not None: options.socket_receive_buffer_size = socket_receive_buffer_size - optmask = optmask | _lib.ARES_OPT_SOCK_RCVBUF + optmask = optmask | _lib.ARES_OPT_SOCK_RCVBUF if sock_state_cb: if not callable(sock_state_cb): raise TypeError("sock_state_cb is not callable") if event_thread: - raise RuntimeError("sock_state_cb and event_thread cannot be used together") + raise RuntimeError( + "sock_state_cb and event_thread cannot be used together" + ) userdata = _ffi.new_handle(sock_state_cb) @@ -459,40 +567,42 @@ def __init__(self, options.sock_state_cb = _lib._sock_state_cb options.sock_state_cb_data = userdata - optmask = optmask | _lib.ARES_OPT_SOCK_STATE_CB + optmask = optmask | _lib.ARES_OPT_SOCK_STATE_CB if event_thread: if not ares_threadsafety(): raise RuntimeError("c-ares is not built with thread safety") if sock_state_cb: - raise RuntimeError("sock_state_cb and event_thread cannot be used together") - optmask = optmask | _lib.ARES_OPT_EVENT_THREAD + raise RuntimeError( + "sock_state_cb and event_thread cannot be used together" + ) + optmask = optmask | _lib.ARES_OPT_EVENT_THREAD options.evsys = _lib.ARES_EVSYS_DEFAULT if lookups: - options.lookups = _ffi.new('char[]', ascii_bytes(lookups)) - optmask = optmask | _lib.ARES_OPT_LOOKUPS + options.lookups = _ffi.new("char[]", ascii_bytes(lookups)) + optmask = optmask | _lib.ARES_OPT_LOOKUPS if domains: strs = [_ffi.new("char[]", ascii_bytes(i)) for i in domains] c = _ffi.new("char *[%d]" % (len(domains) + 1)) for i in range(len(domains)): - c[i] = strs[i] + c[i] = strs[i] options.domains = c options.ndomains = len(domains) - optmask = optmask | _lib.ARES_OPT_DOMAINS + optmask = optmask | _lib.ARES_OPT_DOMAINS if rotate: - optmask = optmask | _lib.ARES_OPT_ROTATE + optmask = optmask | _lib.ARES_OPT_ROTATE if resolvconf_path is not None: - optmask = optmask | _lib.ARES_OPT_RESOLVCONF - options.resolvconf_path = _ffi.new('char[]', ascii_bytes(resolvconf_path)) + optmask = optmask | _lib.ARES_OPT_RESOLVCONF + options.resolvconf_path = _ffi.new("char[]", ascii_bytes(resolvconf_path)) r = _lib.ares_init_options(channel, options, optmask) if r != _lib.ARES_SUCCESS: - raise AresError('Failed to initialize c-ares channel') + raise AresError("Failed to initialize c-ares channel") # Initialize all attributes for consistency self._event_thread = event_thread @@ -513,7 +623,9 @@ def __del__(self) -> None: """Ensure the channel is destroyed when the object is deleted.""" self.close() - def _create_callback_handle(self, callback_data): + def _create_callback_handle( + self, callback_data: Union[Callable[..., None], tuple[Any, ...]] + ): """ Create a callback handle and register it for tracking. @@ -539,15 +651,33 @@ def _create_callback_handle(self, callback_data): return userdata def cancel(self) -> None: + """Cancel any pending query on this channel. All pending callbacks will be called with ARES_ECANCELLED errorno.""" _lib.ares_cancel(self._channel[0]) def reinit(self) -> None: + """ + Reinitialize the channel. + + For more details, see the `ares_reinit documentation `_. + + Raises: + AresError: If ``ares_reinit`` was unsuccessful + """ r = _lib.ares_reinit(self._channel[0]) if r != _lib.ARES_SUCCESS: raise AresError(r, errno.strerror(r)) @property def servers(self) -> list[str]: + """ + Obtains a list of current servers being used by this channel + Raises: + AresError: if C function `ares_get_servers `_ + was unsuccessful + + ValueError: When Setting new servers for this property if an + invalid IPV4 or IPV6 Address was given + """ servers = _ffi.new("struct ares_addr_node **") r = _lib.ares_get_servers(self._channel[0], servers) @@ -562,7 +692,9 @@ def servers(self) -> list[str]: ip = _ffi.new("char []", _lib.INET6_ADDRSTRLEN) s = server[0] - if _ffi.NULL != _lib.ares_inet_ntop(s.family, _ffi.addressof(s.addr), ip, _lib.INET6_ADDRSTRLEN): + if _ffi.NULL != _lib.ares_inet_ntop( + s.family, _ffi.addressof(s.addr), ip, _lib.INET6_ADDRSTRLEN + ): server_list.append(maybe_str(_ffi.string(ip, _lib.INET6_ADDRSTRLEN))) server = s.next @@ -573,9 +705,21 @@ def servers(self) -> list[str]: def servers(self, servers: Iterable[Union[str, bytes]]) -> None: c = _ffi.new("struct ares_addr_node[%d]" % len(servers)) for i, server in enumerate(servers): - if _lib.ares_inet_pton(socket.AF_INET, ascii_bytes(server), _ffi.addressof(c[i].addr.addr4)) == 1: + if ( + _lib.ares_inet_pton( + socket.AF_INET, ascii_bytes(server), _ffi.addressof(c[i].addr.addr4) + ) + == 1 + ): c[i].family = socket.AF_INET - elif _lib.ares_inet_pton(socket.AF_INET6, ascii_bytes(server), _ffi.addressof(c[i].addr.addr6)) == 1: + elif ( + _lib.ares_inet_pton( + socket.AF_INET6, + ascii_bytes(server), + _ffi.addressof(c[i].addr.addr6), + ) + == 1 + ): c[i].family = socket.AF_INET6 else: raise ValueError("invalid IP address") @@ -587,7 +731,11 @@ def servers(self, servers: Iterable[Union[str, bytes]]) -> None: if r != _lib.ARES_SUCCESS: raise AresError(r, errno.strerror(r)) - def getsock(self): + def getsock(self) -> tuple[list[int], list[int]]: + """ + Return a tuple containing 2 lists with the file descriptors + ready to read and write. + """ rfds = [] wfds = [] socks = _ffi.new("ares_socket_t [%d]" % _lib.ARES_GETSOCK_MAXNUM) @@ -601,23 +749,57 @@ def getsock(self): return rfds, wfds def process_fd(self, read_fd: int, write_fd: int) -> None: - _lib.ares_process_fd(self._channel[0], _ffi.cast("ares_socket_t", read_fd), _ffi.cast("ares_socket_t", write_fd)) + """Process the given file descriptors for read and/or write events. + Args: + read_fd: File descriptor ready to read from. - def process_read_fd(self, read_fd:int) -> None: - _lib.ares_process_fd(self._channel[0], _ffi.cast("ares_socket_t", read_fd), _ffi.cast("ares_socket_t", ARES_SOCKET_BAD)) + write_fd: File descriptor ready to write to. + """ + _lib.ares_process_fd( + self._channel[0], + _ffi.cast("ares_socket_t", read_fd), + _ffi.cast("ares_socket_t", write_fd), + ) + + def process_read_fd(self, read_fd: int) -> None: + """Processes the given file file descriptors for read events + Args: + read_fd: File descriptor ready to read from. + """ + _lib.ares_process_fd( + self._channel[0], + _ffi.cast("ares_socket_t", read_fd), + _ffi.cast("ares_socket_t", ARES_SOCKET_BAD), + ) + + def process_write_fd(self, write_fd: int) -> None: + """Processes the given file file descriptors for write events + Args: + write_fd: File descriptor ready to write to. + """ + _lib.ares_process_fd( + self._channel[0], + _ffi.cast("ares_socket_t", ARES_SOCKET_BAD), + _ffi.cast("ares_socket_t", write_fd), + ) - def process_write_fd(self, write_fd:int) -> None: - _lib.ares_process_fd(self._channel[0], _ffi.cast("ares_socket_t", ARES_SOCKET_BAD), _ffi.cast("ares_socket_t", write_fd)) + def timeout(self, max_timeout: Optional[float] = None) -> float: + """ + Determines the maximum time for which the caller should wait before invoking ``process_fd`` to process timeouts. + If the ``max_timeout`` parameter is specified, it is stored on the channel and the appropriate value is then + returned. - def timeout(self, t = None): + Args: + max_timeout: Maximum timeout. + """ maxtv = _ffi.NULL tv = _ffi.new("struct timeval*") - if t is not None: - if t >= 0.0: + if max_timeout is not None: + if max_timeout >= 0.0: maxtv = _ffi.new("struct timeval*") - maxtv.tv_sec = int(math.floor(t)) - maxtv.tv_usec = int(math.fmod(t, 1.0) * 1000000) + maxtv.tv_sec = int(math.floor(max_timeout)) + maxtv.tv_usec = int(math.fmod(max_timeout, 1.0) * 1000000) else: raise ValueError("timeout needs to be a positive number or None") @@ -626,9 +808,27 @@ def timeout(self, t = None): if tv == _ffi.NULL: return 0.0 - return (tv.tv_sec + tv.tv_usec / 1000000.0) + return tv.tv_sec + tv.tv_usec / 1000000.0 + + def gethostbyaddr( + self, + addr: str, + callback: Callable[[Optional["ares_nameinfo_result"], int], None], + ) -> None: + """ + Retrieves the host information corresponding to a network address. + + Args: + name: Name to query. + + callback: Callback to be called with the result of the query. + Retrieves the host information corresponding to a network address. + Callback signature: ``callback(result, errorno)`` + + Raises: + TypeError: if callback is not callable or IP address is invalid + """ - def gethostbyaddr(self, addr: str, callback: Callable[[Any, int], None]) -> None: if not callable(callback): raise TypeError("a callable is required") @@ -644,67 +844,421 @@ def gethostbyaddr(self, addr: str, callback: Callable[[Any, int], None]) -> None raise ValueError("invalid IP address") userdata = self._create_callback_handle(callback) - _lib.ares_gethostbyaddr(self._channel[0], address, _ffi.sizeof(address[0]), family, _lib._host_cb, userdata) + _lib.ares_gethostbyaddr( + self._channel[0], + address, + _ffi.sizeof(address[0]), + family, + _lib._host_cb, + userdata, + ) + + def gethostbyname( + self, + name: str, + family: socket.AddressFamily, + callback: Callable[[Optional["ares_nameinfo_result"], int], None], + ) -> None: + """ + Retrieves host information corresponding to a host name from a host database. + Callback signature: ``callback(result, errorno)`` + + Args: + name: Name to query. + + family: Socket family. + + callback: Callback to be called with the result of the query. + + Raises: + TypeError: if callback is not callable + """ - def gethostbyname(self, name: str, family: socket.AddressFamily, callback: Callable[[Any, int], None]) -> None: if not callable(callback): raise TypeError("a callable is required") userdata = self._create_callback_handle(callback) - _lib.ares_gethostbyname(self._channel[0], parse_name(name), family, _lib._host_cb, userdata) + _lib.ares_gethostbyname( + self._channel[0], parse_name(name), family, _lib._host_cb, userdata + ) def getaddrinfo( self, host: str, port: Optional[int], - callback: Callable[[Any, int], None], - family: socket.AddressFamily = 0, + callback: Callable[[Optional["ares_addrinfo_result"], int], None], + family: Union[socket.AddressFamily, Literal[0]] = 0, type: int = 0, proto: int = 0, - flags: int = 0 + flags: int = 0, ) -> None: + """ + The ``family``, ``type`` and ``proto`` arguments can be optionally specified in order to narrow the list of + addresses returned. Passing zero as a value for each of these arguments selects the full range of results. + The ``flags`` argument can be one or several of the ``AI_*`` constants, and will influence how results are + computed and returned. For example, ``AI_NUMERICHOST`` will disable domain name resolution. + + Translate the host/port argument into a sequence of 5-tuples that contain all the necessary arguments for + creating a socket connected to that service. + + Callback signature: ``callback(result, errorno)`` + + Args: + address: address tuple to get info about. + + flags: Query flags, see the NI flags section. + + callback: Callback to be called with the result of the query. + + Raises: + TypeError: if callable is not callable + + """ + if not callable(callback): raise TypeError("a callable is required") if port is None: service = _ffi.NULL elif isinstance(port, int): - service = str(port).encode('ascii') + service = str(port).encode("ascii") else: service = ascii_bytes(port) userdata = self._create_callback_handle(callback) - hints = _ffi.new('struct ares_addrinfo_hints*') + hints = _ffi.new("struct ares_addrinfo_hints*") hints.ai_flags = flags hints.ai_family = family hints.ai_socktype = type hints.ai_protocol = proto - _lib.ares_getaddrinfo(self._channel[0], parse_name(host), service, hints, _lib._addrinfo_cb, userdata) + _lib.ares_getaddrinfo( + self._channel[0], + parse_name(host), + service, + hints, + _lib._addrinfo_cb, + userdata, + ) + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[1], + callback: Callable[[list["ares_query_a_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[28], + callback: Callable[[list["ares_query_aaaa_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[255], + callback: Callable[["AresResult", int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[257], + callback: Callable[[list["ares_query_caa_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[5], + callback: Callable[["ares_query_cname_result", int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[15], + callback: Callable[[list["ares_query_mx_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[35], + callback: Callable[[list["ares_query_naptr_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[2], + callback: Callable[[list["ares_query_ns_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[12], + callback: Callable[[list["ares_query_ptr_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[6], + callback: Callable[["ares_query_soa_result", int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[33], + callback: Callable[[list["ares_query_srv_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def query( + self, + name: Union[str, bytes], + query_type: Literal[16], + callback: Callable[[list["ares_query_txt_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + def query( + self, + name: Union[str, bytes], + query_type: int, + callback: Callable[["AresResult", int], None], + query_class: Optional[int] = None, + ) -> None: + """ + Do a DNS query of the specified type. + + Available types: + - ``QUERY_TYPE_A`` + - ``QUERY_TYPE_AAAA`` + - ``QUERY_TYPE_ANY`` + - ``QUERY_TYPE_CAA`` + - ``QUERY_TYPE_CNAME`` + - ``QUERY_TYPE_MX`` + - ``QUERY_TYPE_NAPTR`` + - ``QUERY_TYPE_NS`` + - ``QUERY_TYPE_PTR`` + - ``QUERY_TYPE_SOA`` + - ``QUERY_TYPE_SRV`` + - ``QUERY_TYPE_TXT`` + + Args: + name: Name to query. + + query_type: Type of query to perform. + + callback: Callback to be called with the result of the query. + + + Raises: + TypeError: if callaback is not callable + + ValueError: if invalid query type or class was specified + + """ + self._do_query( + _lib.ares_query, name, query_type, callback, query_class=query_class + ) + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[1], + callback: Callable[[list["ares_query_a_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[28], + callback: Callable[[list["ares_query_aaaa_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[255], + callback: Callable[["AresResult", int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[257], + callback: Callable[[list["ares_query_caa_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[5], + callback: Callable[["ares_query_cname_result", int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[15], + callback: Callable[[list["ares_query_mx_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[35], + callback: Callable[[list["ares_query_naptr_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[2], + callback: Callable[[list["ares_query_ns_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[12], + callback: Callable[[list["ares_query_ptr_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[6], + callback: Callable[["ares_query_soa_result", int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[33], + callback: Callable[[list["ares_query_srv_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + @overload + def search( + self, + name: Union[str, bytes], + query_type: Literal[16], + callback: Callable[[list["ares_query_txt_result"], int], None], + query_class: Optional[int] = ..., + ) -> None: ... + + def search( + self, + name: Union[str, bytes], + query_type: int, + callback: Callable[["AresResult", int], None], + query_class: Optional[int] = None, + ): + """ + This function does the same as `query` but it will honor the ``domain`` and ``search`` directives in ``resolv.conf``. + Args: + name: Name to query. - def query(self, name: str, query_type: str, callback: Callable[[Any, int], None], query_class: Optional[str] = None) -> None: - self._do_query(_lib.ares_query, name, query_type, callback, query_class=query_class) + query_type: Type of query to perform. - def search(self, name, query_type, callback, query_class=None): - self._do_query(_lib.ares_search, name, query_type, callback, query_class=query_class) + callback: Callback to be called with the result of the query. - def _do_query(self, func, name, query_type, callback, query_class=None): + Raises: + TypeError: if callaback is not callable + + ValueError: if invalid query type or class was specified + """ + self._do_query( + _lib.ares_search, name, query_type, callback, query_class=query_class + ) + + def _do_query( + self, + func: Any, + name: Union[str, bytes], + query_type: int, + callback: Callable[[Optional["ares_addrinfo_result"], int], None], + query_class: Optional[Union[str, int]] = None, + ): if not callable(callback): - raise TypeError('a callable is required') + raise TypeError("a callable is required") if query_type not in self.__qtypes__: - raise ValueError('invalid query type specified') + raise ValueError("invalid query type specified") if query_class is None: query_class = _lib.C_IN if query_class not in self.__qclasses__: - raise ValueError('invalid query class specified') + raise ValueError("invalid query class specified") userdata = self._create_callback_handle((callback, query_type)) - func(self._channel[0], parse_name(name), query_class, query_type, _lib._query_cb, userdata) - - def set_local_ip(self, ip): + func( + self._channel[0], + parse_name(name), + query_class, + query_type, + _lib._query_cb, + userdata, + ) + + def set_local_ip(self, ip: Union[str, bytes]) -> None: + """Set the local IPv4 or IPv6 address from which the queries will be sent. + Args: + ip: IP Address. + """ addr4 = _ffi.new("struct in_addr*") addr6 = _ffi.new("struct ares_in6_addr*") if _lib.ares_inet_pton(socket.AF_INET, ascii_bytes(ip), addr4) == 1: @@ -714,14 +1268,44 @@ def set_local_ip(self, ip): else: raise ValueError("invalid IP address") - def getnameinfo(self, address: Union[IP4, IP6], flags: int, callback: Callable[[Any, int], None]) -> None: + def getnameinfo( + self, + address: Union[IP4, IP6], + flags: int, + callback: Callable[["ares_nameinfo_result", int], None], + ) -> None: + """ + Provides protocol-independent name resolution from an address to a host name and + from a port number to the service name. + + ``address`` must be a 2-item tuple for IPv4 or a 4-item tuple for IPv6. Format of + fields is the same as one returned by `getaddrinfo()`. + + Callback signature: ``callback(result, errorno)`` + + Args: + address: address tuple to get info about. + + flags: Query flags, see the NI flags section. + + callback: Callback to be called with the result of the query. + + Raises: + TypeError: if callback is not callable or address or address's host field is invalid + """ + if not callable(callback): raise TypeError("a callable is required") if len(address) == 2: (ip, port) = address sa4 = _ffi.new("struct sockaddr_in*") - if _lib.ares_inet_pton(socket.AF_INET, ascii_bytes(ip), _ffi.addressof(sa4.sin_addr)) != 1: + if ( + _lib.ares_inet_pton( + socket.AF_INET, ascii_bytes(ip), _ffi.addressof(sa4.sin_addr) + ) + != 1 + ): raise ValueError("Invalid IPv4 address %r" % ip) sa4.sin_family = socket.AF_INET sa4.sin_port = socket.htons(port) @@ -729,20 +1313,39 @@ def getnameinfo(self, address: Union[IP4, IP6], flags: int, callback: Callable[[ elif len(address) == 4: (ip, port, flowinfo, scope_id) = address sa6 = _ffi.new("struct sockaddr_in6*") - if _lib.ares_inet_pton(socket.AF_INET6, ascii_bytes(ip), _ffi.addressof(sa6.sin6_addr)) != 1: + if ( + _lib.ares_inet_pton( + socket.AF_INET6, ascii_bytes(ip), _ffi.addressof(sa6.sin6_addr) + ) + != 1 + ): raise ValueError("Invalid IPv6 address %r" % ip) sa6.sin6_family = socket.AF_INET6 sa6.sin6_port = socket.htons(port) - sa6.sin6_flowinfo = socket.htonl(flowinfo) # I'm unsure about byteorder here. - sa6.sin6_scope_id = scope_id # Yes, without htonl. + sa6.sin6_flowinfo = socket.htonl( + flowinfo + ) # I'm unsure about byteorder here. + sa6.sin6_scope_id = scope_id # Yes, without htonl. sa = sa6 else: raise ValueError("Invalid address argument") userdata = self._create_callback_handle(callback) - _lib.ares_getnameinfo(self._channel[0], _ffi.cast("struct sockaddr*", sa), _ffi.sizeof(sa[0]), flags, _lib._nameinfo_cb, userdata) - - def set_local_dev(self, dev): + _lib.ares_getnameinfo( + self._channel[0], + _ffi.cast("struct sockaddr*", sa), + _ffi.sizeof(sa[0]), + flags, + _lib._nameinfo_cb, + userdata, + ) + + def set_local_dev(self, dev: str): + """ + Set the local ethernet device from which the queries will be sent. + Args: + dev: Network device name. + """ _lib.ares_set_local_dev(self._channel[0], dev) def close(self) -> None: @@ -773,38 +1376,49 @@ class AresResult: __slots__ = () def __repr__(self): - attrs = ['%s=%s' % (a, getattr(self, a)) for a in self.__slots__] - return '<%s> %s' % (self.__class__.__name__, ', '.join(attrs)) + attrs = ["%s=%s" % (a, getattr(self, a)) for a in self.__slots__] + return "<%s> %s" % (self.__class__.__name__, ", ".join(attrs)) # DNS query result types # + class ares_query_a_result(AresResult): - __slots__ = ('host', 'ttl') - type: Final = 'A' + __slots__ = ("host", "ttl") + type: Final = "A" def __init__(self, ares_addrttl): buf = _ffi.new("char[]", _lib.INET6_ADDRSTRLEN) - _lib.ares_inet_ntop(socket.AF_INET, _ffi.addressof(ares_addrttl.ipaddr), buf, _lib.INET6_ADDRSTRLEN) + _lib.ares_inet_ntop( + socket.AF_INET, + _ffi.addressof(ares_addrttl.ipaddr), + buf, + _lib.INET6_ADDRSTRLEN, + ) self.host = maybe_str(_ffi.string(buf, _lib.INET6_ADDRSTRLEN)) self.ttl = ares_addrttl.ttl class ares_query_aaaa_result(AresResult): - __slots__ = ('host', 'ttl') - type: Final = 'AAAA' + __slots__ = ("host", "ttl") + type: Final = "AAAA" def __init__(self, ares_addrttl): buf = _ffi.new("char[]", _lib.INET6_ADDRSTRLEN) - _lib.ares_inet_ntop(socket.AF_INET6, _ffi.addressof(ares_addrttl.ip6addr), buf, _lib.INET6_ADDRSTRLEN) + _lib.ares_inet_ntop( + socket.AF_INET6, + _ffi.addressof(ares_addrttl.ip6addr), + buf, + _lib.INET6_ADDRSTRLEN, + ) self.host = maybe_str(_ffi.string(buf, _lib.INET6_ADDRSTRLEN)) self.ttl = ares_addrttl.ttl -class ares_query_caa_result(AresResult): - __slots__ = ('critical', 'property', 'value', 'ttl') - type: Final = 'CAA' +class ares_query_caa_result(AresResult): + __slots__ = ("critical", "property", "value", "ttl") + type: Final = "CAA" def __init__(self, caa): self.critical = caa.critical @@ -814,8 +1428,8 @@ def __init__(self, caa): class ares_query_cname_result(AresResult): - __slots__ = ('cname', 'ttl') - type: Final = 'CNAME' + __slots__ = ("cname", "ttl") + type: Final = "CNAME" def __init__(self, host): self.cname = maybe_str(_ffi.string(host.h_name)) @@ -823,8 +1437,8 @@ def __init__(self, host): class ares_query_mx_result(AresResult): - __slots__ = ('host', 'priority', 'ttl') - type: Final = 'MX' + __slots__ = ("host", "priority", "ttl") + type: Final = "MX" def __init__(self, mx): self.host = maybe_str(_ffi.string(mx.host)) @@ -833,8 +1447,16 @@ def __init__(self, mx): class ares_query_naptr_result(AresResult): - __slots__ = ('order', 'preference', 'flags', 'service', 'regex', 'replacement', 'ttl') - type: Final = 'NAPTR' + __slots__ = ( + "order", + "preference", + "flags", + "service", + "regex", + "replacement", + "ttl", + ) + type: Final = "NAPTR" def __init__(self, naptr): self.order = naptr.order @@ -847,8 +1469,8 @@ def __init__(self, naptr): class ares_query_ns_result(AresResult): - __slots__ = ('host', 'ttl') - type: Final = 'NS' + __slots__ = ("host", "ttl") + type: Final = "NS" def __init__(self, ns): self.host = maybe_str(_ffi.string(ns)) @@ -856,8 +1478,8 @@ def __init__(self, ns): class ares_query_ptr_result(AresResult): - __slots__ = ('name', 'ttl', 'aliases') - type: Final = 'PTR' + __slots__ = ("name", "ttl", "aliases") + type: Final = "PTR" def __init__(self, hostent, aliases): self.name = maybe_str(_ffi.string(hostent.h_name)) @@ -866,8 +1488,17 @@ def __init__(self, hostent, aliases): class ares_query_soa_result(AresResult): - __slots__ = ('nsname', 'hostmaster', 'serial', 'refresh', 'retry', 'expires', 'minttl', 'ttl') - type: Final = 'SOA' + __slots__ = ( + "nsname", + "hostmaster", + "serial", + "refresh", + "retry", + "expires", + "minttl", + "ttl", + ) + type: Final = "SOA" def __init__(self, soa): self.nsname = maybe_str(_ffi.string(soa.nsname)) @@ -880,9 +1511,9 @@ def __init__(self, soa): self.ttl = -1 -class ares_query_srv_result(AresResult): - __slots__ = ('host', 'port', 'priority', 'weight', 'ttl') - type: Final = 'SRV' +class ares_query_srv_result(AresResult): + __slots__ = ("host", "port", "priority", "weight", "ttl") + type: Final = "SRV" def __init__(self, srv): self.host = maybe_str(_ffi.string(srv.host)) @@ -893,8 +1524,8 @@ def __init__(self, srv): class ares_query_txt_result(AresResult): - __slots__ = ('text', 'ttl') - type: Final = 'TXT' + __slots__ = ("text", "ttl") + type: Final = "TXT" def __init__(self, txt_chunk): self.text = maybe_str(txt_chunk.text) @@ -902,8 +1533,8 @@ def __init__(self, txt_chunk): class ares_query_txt_result_chunk(AresResult): - __slots__ = ('text', 'ttl') - type: Final = 'TXT' + __slots__ = ("text", "ttl") + type: Final = "TXT" def __init__(self, txt): self.text = _ffi.string(txt.txt) @@ -913,8 +1544,9 @@ def __init__(self, txt): # Other result types # + class ares_host_result(AresResult): - __slots__ = ('name', 'aliases', 'addresses') + __slots__ = ("name", "aliases", "addresses") def __init__(self, hostent): self.name = maybe_str(_ffi.string(hostent.h_name)) @@ -928,13 +1560,17 @@ def __init__(self, hostent): i = 0 while hostent.h_addr_list[i] != _ffi.NULL: buf = _ffi.new("char[]", _lib.INET6_ADDRSTRLEN) - if _ffi.NULL != _lib.ares_inet_ntop(hostent.h_addrtype, hostent.h_addr_list[i], buf, _lib.INET6_ADDRSTRLEN): - self.addresses.append(maybe_str(_ffi.string(buf, _lib.INET6_ADDRSTRLEN))) + if _ffi.NULL != _lib.ares_inet_ntop( + hostent.h_addrtype, hostent.h_addr_list[i], buf, _lib.INET6_ADDRSTRLEN + ): + self.addresses.append( + maybe_str(_ffi.string(buf, _lib.INET6_ADDRSTRLEN)) + ) i += 1 class ares_nameinfo_result(AresResult): - __slots__ = ('node', 'service') + __slots__ = ("node", "service") def __init__(self, node, service): self.node = maybe_str(_ffi.string(node)) @@ -942,7 +1578,7 @@ def __init__(self, node, service): class ares_addrinfo_node_result(AresResult): - __slots__ = ('ttl', 'flags', 'family', 'socktype', 'protocol', 'addr') + __slots__ = ("ttl", "flags", "family", "socktype", "protocol", "addr") def __init__(self, ares_node): self.ttl = ares_node.ai_ttl @@ -956,21 +1592,33 @@ def __init__(self, ares_node): if addr.sa_family == socket.AF_INET: self.family = socket.AF_INET s = _ffi.cast("struct sockaddr_in*", addr) - if _ffi.NULL != _lib.ares_inet_ntop(s.sin_family, _ffi.addressof(s.sin_addr), ip, _lib.INET6_ADDRSTRLEN): + if _ffi.NULL != _lib.ares_inet_ntop( + s.sin_family, _ffi.addressof(s.sin_addr), ip, _lib.INET6_ADDRSTRLEN + ): # (address, port) 2-tuple for AF_INET - self.addr = (_ffi.string(ip, _lib.INET6_ADDRSTRLEN), socket.ntohs(s.sin_port)) + self.addr = ( + _ffi.string(ip, _lib.INET6_ADDRSTRLEN), + socket.ntohs(s.sin_port), + ) elif addr.sa_family == socket.AF_INET6: self.family = socket.AF_INET6 s = _ffi.cast("struct sockaddr_in6*", addr) - if _ffi.NULL != _lib.ares_inet_ntop(s.sin6_family, _ffi.addressof(s.sin6_addr), ip, _lib.INET6_ADDRSTRLEN): + if _ffi.NULL != _lib.ares_inet_ntop( + s.sin6_family, _ffi.addressof(s.sin6_addr), ip, _lib.INET6_ADDRSTRLEN + ): # (address, port, flow info, scope id) 4-tuple for AF_INET6 - self.addr = (_ffi.string(ip, _lib.INET6_ADDRSTRLEN), socket.ntohs(s.sin6_port), s.sin6_flowinfo, s.sin6_scope_id) + self.addr = ( + _ffi.string(ip, _lib.INET6_ADDRSTRLEN), + socket.ntohs(s.sin6_port), + s.sin6_flowinfo, + s.sin6_scope_id, + ) else: raise ValueError("invalid sockaddr family") class ares_addrinfo_cname_result(AresResult): - __slots__ = ('ttl', 'alias', 'name') + __slots__ = ("ttl", "alias", "name") def __init__(self, ares_cname): self.ttl = ares_cname.ttl @@ -979,7 +1627,7 @@ def __init__(self, ares_cname): class ares_addrinfo_result(AresResult): - __slots__ = ('cnames', 'nodes') + __slots__ = ("cnames", "nodes") def __init__(self, ares_addrinfo): self.cnames = [] @@ -1004,6 +1652,7 @@ def ares_threadsafety() -> bool: """ return bool(_lib.ares_threadsafety()) + __all__ = ( "ARES_FLAG_USEVC", "ARES_FLAG_PRIMARY", @@ -1015,7 +1664,6 @@ def ares_threadsafety() -> bool: "ARES_FLAG_NOCHECKRESP", "ARES_FLAG_EDNS", "ARES_FLAG_NO_DFLT_SVR", - # Nameinfo flag values "ARES_NI_NOFQDN", "ARES_NI_NUMERICHOST", @@ -1032,11 +1680,8 @@ def ares_threadsafety() -> bool: "ARES_NI_IDN", "ARES_NI_IDN_ALLOW_UNASSIGNED", "ARES_NI_IDN_USE_STD3_ASCII_RULES", - # Bad socket "ARES_SOCKET_BAD", - - # Query types "QUERY_TYPE_A", "QUERY_TYPE_AAAA", @@ -1050,19 +1695,16 @@ def ares_threadsafety() -> bool: "QUERY_TYPE_SOA", "QUERY_TYPE_SRV", "QUERY_TYPE_TXT", - # Query classes "QUERY_CLASS_IN", "QUERY_CLASS_CHAOS", "QUERY_CLASS_HS", "QUERY_CLASS_NONE", "QUERY_CLASS_ANY", - - "ARES_VERSION", "AresError", "Channel", "ares_threadsafety", "errno", - "__version__" + "__version__", ) diff --git a/src/pycares/__main__.py b/src/pycares/__main__.py index e2adce8..fc919fa 100644 --- a/src/pycares/__main__.py +++ b/src/pycares/__main__.py @@ -1,4 +1,3 @@ - import collections.abc import pycares import select @@ -24,60 +23,85 @@ def wait_channel(channel): def cb(result, error): if error is not None: - print('Error: (%d) %s' % (error, pycares.errno.strerror(error))) + print("Error: (%d) %s" % (error, pycares.errno.strerror(error))) else: parts = [ - ';; QUESTION SECTION:', - ';%s\t\t\tIN\t%s' % (hostname, qtype.upper()), - '', - ';; ANSWER SECTION:' + ";; QUESTION SECTION:", + ";%s\t\t\tIN\t%s" % (hostname, qtype.upper()), + "", + ";; ANSWER SECTION:", ] if not isinstance(result, collections.abc.Iterable): result = [result] for r in result: - txt = '%s\t\t%d\tIN\t%s' % (hostname, r.ttl, r.type) - if r.type in ('A', 'AAAA'): - parts.append('%s\t%s' % (txt, r.host)) - elif r.type == 'CAA': + txt = "%s\t\t%d\tIN\t%s" % (hostname, r.ttl, r.type) + if r.type in ("A", "AAAA"): + parts.append("%s\t%s" % (txt, r.host)) + elif r.type == "CAA": parts.append('%s\t%d %s "%s"' % (txt, r.critical, r.property, r.value)) - elif r.type == 'CNAME': - parts.append('%s\t%s' % (txt, r.cname)) - elif r.type == 'MX': - parts.append('%s\t%d %s' % (txt, r.priority, r.host)) - elif r.type == 'NAPTR': - parts.append('%s\t%d %d "%s" "%s" "%s" %s' % (txt, r.order, r.preference, r.flags, r.service, r.regex, r.replacement)) - elif r.type == 'NS': - parts.append('%s\t%s' % (txt, r.host)) - elif r.type == 'PTR': - parts.append('%s\t%s' % (txt, r.name)) - elif r.type == 'SOA': - parts.append('%s\t%s %s %d %d %d %d %d' % (txt, r.nsname, r.hostmaster, r.serial, r.refresh, r.retry, r.expires, r.minttl)) - elif r.type == 'SRV': - parts.append('%s\t%d %d %d %s' % (txt, r.priority, r.weight, r.port, r.host)) - elif r.type == 'TXT': + elif r.type == "CNAME": + parts.append("%s\t%s" % (txt, r.cname)) + elif r.type == "MX": + parts.append("%s\t%d %s" % (txt, r.priority, r.host)) + elif r.type == "NAPTR": + parts.append( + '%s\t%d %d "%s" "%s" "%s" %s' + % ( + txt, + r.order, + r.preference, + r.flags, + r.service, + r.regex, + r.replacement, + ) + ) + elif r.type == "NS": + parts.append("%s\t%s" % (txt, r.host)) + elif r.type == "PTR": + parts.append("%s\t%s" % (txt, r.name)) + elif r.type == "SOA": + parts.append( + "%s\t%s %s %d %d %d %d %d" + % ( + txt, + r.nsname, + r.hostmaster, + r.serial, + r.refresh, + r.retry, + r.expires, + r.minttl, + ) + ) + elif r.type == "SRV": + parts.append( + "%s\t%d %d %d %s" % (txt, r.priority, r.weight, r.port, r.host) + ) + elif r.type == "TXT": parts.append('%s\t"%s"' % (txt, r.text)) - print('\n'.join(parts)) + print("\n".join(parts)) channel = pycares.Channel() if len(sys.argv) not in (2, 3): - print('Invalid arguments! Usage: python -m pycares [query_type] hostname') + print("Invalid arguments! Usage: python -m pycares [query_type] hostname") sys.exit(1) if len(sys.argv) == 2: _, hostname = sys.argv - qtype = 'A' + qtype = "A" else: _, qtype, hostname = sys.argv try: - query_type = getattr(pycares, 'QUERY_TYPE_%s' % qtype.upper()) + query_type = getattr(pycares, "QUERY_TYPE_%s" % qtype.upper()) except Exception: - print('Invalid query type: %s' % qtype) + print("Invalid query type: %s" % qtype) sys.exit(1) channel.query(hostname, query_type, cb) diff --git a/src/pycares/_version.py b/src/pycares/_version.py index 9b002d3..897e6be 100644 --- a/src/pycares/_version.py +++ b/src/pycares/_version.py @@ -1,2 +1 @@ - -__version__ = '4.10.0' +__version__ = "4.10.0" diff --git a/src/pycares/utils.py b/src/pycares/utils.py index b843bb6..bb37a05 100644 --- a/src/pycares/utils.py +++ b/src/pycares/utils.py @@ -1,4 +1,3 @@ - from typing import Union try: @@ -7,51 +6,50 @@ idna2008 = None -def ascii_bytes(data): +def ascii_bytes(data: Union[str, bytes]) -> bytes: if isinstance(data, str): - return data.encode('ascii') + return data.encode("ascii") if isinstance(data, bytes): return data - raise TypeError('only str (ascii encoding) and bytes are supported') + raise TypeError("only str (ascii encoding) and bytes are supported") -def maybe_str(data): +def maybe_str(data: Union[str, bytes]) -> Union[str, bytes]: if isinstance(data, str): return data if isinstance(data, bytes): try: - return data.decode('ascii') + return data.decode("ascii") except UnicodeDecodeError: return data - raise TypeError('only str (ascii encoding) and bytes are supported') + raise TypeError("only str (ascii encoding) and bytes are supported") -def parse_name_idna2008(name: str) -> str: - parts = name.split('.') +def parse_name_idna2008(name: str) -> bytes: + parts = name.split(".") r = [] for part in parts: if part.isascii(): - r.append(part.encode('ascii')) + r.append(part.encode("ascii")) elif len(part) > 253: raise RuntimeError( f"domains can only be less than 253 characters in length not {len(name)}" ) else: r.append(idna2008.encode(part)) - return b'.'.join(r) + return b".".join(r) -def parse_name(name: Union[str, bytes]) -> bytes: +def parse_name(name: Union[str, bytes]) -> Union[bytes, str]: if isinstance(name, str): if name.isascii(): - return name.encode('ascii') + return name.encode("ascii") if idna2008 is not None: return parse_name_idna2008(name) - return name.encode('idna') + return name.encode("idna") if isinstance(name, bytes): return name - raise TypeError('only str and bytes are supported') - + raise TypeError("only str and bytes are supported") -__all__ = ['ascii_bytes', 'maybe_str', 'parse_name'] +__all__ = ["ascii_bytes", "maybe_str", "parse_name"]