# Copyright (c) 2011-2012 Denis Bilenko. See LICENSE for details. cimport cares import sys from python cimport * from _socket import gaierror __all__ = ['channel'] cdef object string_types cdef object text_type if sys.version_info[0] >= 3: string_types = str, text_type = str else: string_types = __builtins__.basestring, text_type = __builtins__.unicode TIMEOUT = 1 DEF EV_READ = 1 DEF EV_WRITE = 2 cdef extern from "dnshelper.c": int AF_INET int AF_INET6 struct hostent: char* h_name int h_addrtype struct sockaddr_t "sockaddr": pass struct ares_channeldata: pass object parse_h_name(hostent*) object parse_h_aliases(hostent*) object parse_h_addr_list(hostent*) void* create_object_from_hostent(void*) # this imports _socket lazily object PyUnicode_FromString(char*) int PyTuple_Check(object) int PyArg_ParseTuple(object, char*, ...) except 0 struct sockaddr_in6: pass int gevent_make_sockaddr(char* hostp, int port, int flowinfo, int scope_id, sockaddr_in6* sa6) void* malloc(int) void free(void*) void memset(void*, int, int) ARES_SUCCESS = cares.ARES_SUCCESS ARES_ENODATA = cares.ARES_ENODATA ARES_EFORMERR = cares.ARES_EFORMERR ARES_ESERVFAIL = cares.ARES_ESERVFAIL ARES_ENOTFOUND = cares.ARES_ENOTFOUND ARES_ENOTIMP = cares.ARES_ENOTIMP ARES_EREFUSED = cares.ARES_EREFUSED ARES_EBADQUERY = cares.ARES_EBADQUERY ARES_EBADNAME = cares.ARES_EBADNAME ARES_EBADFAMILY = cares.ARES_EBADFAMILY ARES_EBADRESP = cares.ARES_EBADRESP ARES_ECONNREFUSED = cares.ARES_ECONNREFUSED ARES_ETIMEOUT = cares.ARES_ETIMEOUT ARES_EOF = cares.ARES_EOF ARES_EFILE = cares.ARES_EFILE ARES_ENOMEM = cares.ARES_ENOMEM ARES_EDESTRUCTION = cares.ARES_EDESTRUCTION ARES_EBADSTR = cares.ARES_EBADSTR ARES_EBADFLAGS = cares.ARES_EBADFLAGS ARES_ENONAME = cares.ARES_ENONAME ARES_EBADHINTS = cares.ARES_EBADHINTS ARES_ENOTINITIALIZED = cares.ARES_ENOTINITIALIZED ARES_ELOADIPHLPAPI = cares.ARES_ELOADIPHLPAPI ARES_EADDRGETNETWORKPARAMS = cares.ARES_EADDRGETNETWORKPARAMS ARES_ECANCELLED = cares.ARES_ECANCELLED ARES_FLAG_USEVC = cares.ARES_FLAG_USEVC ARES_FLAG_PRIMARY = cares.ARES_FLAG_PRIMARY ARES_FLAG_IGNTC = cares.ARES_FLAG_IGNTC ARES_FLAG_NORECURSE = cares.ARES_FLAG_NORECURSE ARES_FLAG_STAYOPEN = cares.ARES_FLAG_STAYOPEN ARES_FLAG_NOSEARCH = cares.ARES_FLAG_NOSEARCH ARES_FLAG_NOALIASES = cares.ARES_FLAG_NOALIASES ARES_FLAG_NOCHECKRESP = cares.ARES_FLAG_NOCHECKRESP _ares_errors = dict([ (cares.ARES_SUCCESS, 'ARES_SUCCESS'), (cares.ARES_ENODATA, 'ARES_ENODATA'), (cares.ARES_EFORMERR, 'ARES_EFORMERR'), (cares.ARES_ESERVFAIL, 'ARES_ESERVFAIL'), (cares.ARES_ENOTFOUND, 'ARES_ENOTFOUND'), (cares.ARES_ENOTIMP, 'ARES_ENOTIMP'), (cares.ARES_EREFUSED, 'ARES_EREFUSED'), (cares.ARES_EBADQUERY, 'ARES_EBADQUERY'), (cares.ARES_EBADNAME, 'ARES_EBADNAME'), (cares.ARES_EBADFAMILY, 'ARES_EBADFAMILY'), (cares.ARES_EBADRESP, 'ARES_EBADRESP'), (cares.ARES_ECONNREFUSED, 'ARES_ECONNREFUSED'), (cares.ARES_ETIMEOUT, 'ARES_ETIMEOUT'), (cares.ARES_EOF, 'ARES_EOF'), (cares.ARES_EFILE, 'ARES_EFILE'), (cares.ARES_ENOMEM, 'ARES_ENOMEM'), (cares.ARES_EDESTRUCTION, 'ARES_EDESTRUCTION'), (cares.ARES_EBADSTR, 'ARES_EBADSTR'), (cares.ARES_EBADFLAGS, 'ARES_EBADFLAGS'), (cares.ARES_ENONAME, 'ARES_ENONAME'), (cares.ARES_EBADHINTS, 'ARES_EBADHINTS'), (cares.ARES_ENOTINITIALIZED, 'ARES_ENOTINITIALIZED'), (cares.ARES_ELOADIPHLPAPI, 'ARES_ELOADIPHLPAPI'), (cares.ARES_EADDRGETNETWORKPARAMS, 'ARES_EADDRGETNETWORKPARAMS'), (cares.ARES_ECANCELLED, 'ARES_ECANCELLED')]) # maps c-ares flag to _socket module flag _cares_flag_map = None cdef _prepare_cares_flag_map(): global _cares_flag_map import _socket _cares_flag_map = [ (getattr(_socket, 'NI_NUMERICHOST', 1), cares.ARES_NI_NUMERICHOST), (getattr(_socket, 'NI_NUMERICSERV', 2), cares.ARES_NI_NUMERICSERV), (getattr(_socket, 'NI_NOFQDN', 4), cares.ARES_NI_NOFQDN), (getattr(_socket, 'NI_NAMEREQD', 8), cares.ARES_NI_NAMEREQD), (getattr(_socket, 'NI_DGRAM', 16), cares.ARES_NI_DGRAM)] cpdef _convert_cares_flags(int flags, int default=cares.ARES_NI_LOOKUPHOST|cares.ARES_NI_LOOKUPSERVICE): if _cares_flag_map is None: _prepare_cares_flag_map() for socket_flag, cares_flag in _cares_flag_map: if socket_flag & flags: default |= cares_flag flags &= ~socket_flag if not flags: return default raise gaierror(-1, "Bad value for ai_flags: 0x%x" % flags) cpdef strerror(code): return '%s: %s' % (_ares_errors.get(code) or code, cares.ares_strerror(code)) class InvalidIP(ValueError): pass cdef void gevent_sock_state_callback(void *data, int s, int read, int write): if not data: return cdef channel ch = data ch._sock_state_callback(s, read, write) cdef class result: cdef public object value cdef public object exception def __init__(self, object value=None, object exception=None): self.value = value self.exception = exception def __repr__(self): if self.exception is None: return '%s(%r)' % (self.__class__.__name__, self.value) elif self.value is None: return '%s(exception=%r)' % (self.__class__.__name__, self.exception) else: return '%s(value=%r, exception=%r)' % (self.__class__.__name__, self.value, self.exception) # add repr_recursive precaution def successful(self): return self.exception is None def get(self): if self.exception is not None: raise self.exception return self.value class ares_host_result(tuple): def __new__(cls, family, iterable): cdef object self = tuple.__new__(cls, iterable) self.family = family return self def __getnewargs__(self): return (self.family, tuple(self)) cdef void gevent_ares_host_callback(void *arg, int status, int timeouts, hostent* host): cdef channel channel cdef object callback channel, callback = arg Py_DECREF(arg) cdef object host_result try: if status or not host: callback(result(None, gaierror(status, strerror(status)))) else: try: host_result = ares_host_result(host.h_addrtype, (parse_h_name(host), parse_h_aliases(host), parse_h_addr_list(host))) except: callback(result(None, sys.exc_info()[1])) else: callback(result(host_result)) except: channel.loop.handle_error(callback, *sys.exc_info()) cdef void gevent_ares_nameinfo_callback(void *arg, int status, int timeouts, char *c_node, char *c_service): cdef channel channel cdef object callback channel, callback = arg Py_DECREF(arg) cdef object node cdef object service try: if status: callback(result(None, gaierror(status, strerror(status)))) else: if c_node: node = PyUnicode_FromString(c_node) else: node = None if c_service: service = PyUnicode_FromString(c_service) else: service = None callback(result((node, service))) except: channel.loop.handle_error(callback, *sys.exc_info()) cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresChannel_Type]: cdef public object loop cdef ares_channeldata* channel cdef public dict _watchers cdef public object _timer def __init__(self, object loop, flags=None, timeout=None, tries=None, ndots=None, udp_port=None, tcp_port=None, servers=None): cdef ares_channeldata* channel = NULL cdef cares.ares_options options memset(&options, 0, sizeof(cares.ares_options)) cdef int optmask = cares.ARES_OPT_SOCK_STATE_CB options.sock_state_cb = gevent_sock_state_callback options.sock_state_cb_data = self if flags is not None: options.flags = int(flags) optmask |= cares.ARES_OPT_FLAGS if timeout is not None: options.timeout = int(float(timeout) * 1000) optmask |= cares.ARES_OPT_TIMEOUTMS if tries is not None: options.tries = int(tries) optmask |= cares.ARES_OPT_TRIES if ndots is not None: options.ndots = int(ndots) optmask |= cares.ARES_OPT_NDOTS if udp_port is not None: options.udp_port = int(udp_port) optmask |= cares.ARES_OPT_UDP_PORT if tcp_port is not None: options.tcp_port = int(tcp_port) optmask |= cares.ARES_OPT_TCP_PORT cdef int result = cares.ares_library_init(cares.ARES_LIB_INIT_ALL) # ARES_LIB_INIT_WIN32 -DUSE_WINSOCK? if result: raise gaierror(result, strerror(result)) result = cares.ares_init_options(&channel, &options, optmask) if result: raise gaierror(result, strerror(result)) self._timer = loop.timer(TIMEOUT, TIMEOUT) self._watchers = {} self.channel = channel try: if servers is not None: self.set_servers(servers) self.loop = loop except: self.destroy() raise def __repr__(self): args = (self.__class__.__name__, id(self), self._timer, len(self._watchers)) return '<%s at 0x%x _timer=%r _watchers[%s]>' % args def destroy(self): if self.channel: # XXX ares_library_cleanup? cares.ares_destroy(self.channel) self.channel = NULL self._watchers.clear() self._timer.stop() self.loop = None def __dealloc__(self): if self.channel: # XXX ares_library_cleanup? cares.ares_destroy(self.channel) self.channel = NULL def set_servers(self, servers=None): if not self.channel: raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed') if not servers: servers = [] if isinstance(servers, string_types): servers = servers.split(',') cdef int length = len(servers) cdef int result, index cdef char* string cdef cares.ares_addr_node* c_servers if length <= 0: result = cares.ares_set_servers(self.channel, NULL) else: c_servers = malloc(sizeof(cares.ares_addr_node) * length) if not c_servers: raise MemoryError try: index = 0 for server in servers: if isinstance(server, unicode): server = server.encode('ascii') string = server if cares.ares_inet_pton(AF_INET, string, &c_servers[index].addr) > 0: c_servers[index].family = AF_INET elif cares.ares_inet_pton(AF_INET6, string, &c_servers[index].addr) > 0: c_servers[index].family = AF_INET6 else: raise InvalidIP(repr(string)) c_servers[index].next = &c_servers[index] + 1 index += 1 if index >= length: break c_servers[length - 1].next = NULL index = cares.ares_set_servers(self.channel, c_servers) if index: raise ValueError(strerror(index)) finally: free(c_servers) # this crashes c-ares #def cancel(self): # cares.ares_cancel(self.channel) cdef _sock_state_callback(self, int socket, int read, int write): if not self.channel: return cdef object watcher = self._watchers.get(socket) cdef int events = 0 if read: events |= EV_READ if write: events |= EV_WRITE if watcher is None: if not events: return watcher = self.loop.io(socket, events) self._watchers[socket] = watcher elif events: if watcher.events == events: return watcher.stop() watcher.events = events else: watcher.stop() self._watchers.pop(socket, None) if not self._watchers: self._timer.stop() return watcher.start(self._process_fd, watcher, pass_events=True) self._timer.again(self._on_timer) def _on_timer(self): cares.ares_process_fd(self.channel, cares.ARES_SOCKET_BAD, cares.ARES_SOCKET_BAD) def _process_fd(self, int events, object watcher): if not self.channel: return cdef int read_fd = watcher.fd cdef int write_fd = read_fd if not (events & EV_READ): read_fd = cares.ARES_SOCKET_BAD if not (events & EV_WRITE): write_fd = cares.ARES_SOCKET_BAD cares.ares_process_fd(self.channel, read_fd, write_fd) def gethostbyname(self, object callback, char* name, int family=AF_INET): if not self.channel: raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed') # note that for file lookups still AF_INET can be returned for AF_INET6 request cdef object arg = (self, callback) Py_INCREF(arg) cares.ares_gethostbyname(self.channel, name, family, gevent_ares_host_callback, arg) def gethostbyaddr(self, object callback, char* addr): if not self.channel: raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed') # will guess the family cdef char addr_packed[16] cdef int family cdef int length if cares.ares_inet_pton(AF_INET, addr, addr_packed) > 0: family = AF_INET length = 4 elif cares.ares_inet_pton(AF_INET6, addr, addr_packed) > 0: family = AF_INET6 length = 16 else: raise InvalidIP(repr(addr)) cdef object arg = (self, callback) Py_INCREF(arg) cares.ares_gethostbyaddr(self.channel, addr_packed, length, family, gevent_ares_host_callback, arg) cpdef _getnameinfo(self, object callback, tuple sockaddr, int flags): if not self.channel: raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed') cdef char* hostp = NULL cdef int port = 0 cdef int flowinfo = 0 cdef int scope_id = 0 cdef sockaddr_in6 sa6 if not PyTuple_Check(sockaddr): raise TypeError('expected a tuple, got %r' % (sockaddr, )) PyArg_ParseTuple(sockaddr, "si|ii", &hostp, &port, &flowinfo, &scope_id) if port < 0 or port > 65535: raise gaierror(-8, 'Invalid value for port: %r' % port) cdef int length = gevent_make_sockaddr(hostp, port, flowinfo, scope_id, &sa6) if length <= 0: raise InvalidIP(repr(hostp)) cdef object arg = (self, callback) Py_INCREF(arg) cdef sockaddr_t* x = &sa6 cares.ares_getnameinfo(self.channel, x, length, flags, gevent_ares_nameinfo_callback, arg) def getnameinfo(self, object callback, tuple sockaddr, int flags): try: flags = _convert_cares_flags(flags) except gaierror: # The stdlib just ignores bad flags flags = 0 return self._getnameinfo(callback, sockaddr, flags)