diff options
Diffstat (limited to 'python/gevent/_sslgte279.py')
-rw-r--r-- | python/gevent/_sslgte279.py | 714 |
1 files changed, 714 insertions, 0 deletions
diff --git a/python/gevent/_sslgte279.py b/python/gevent/_sslgte279.py new file mode 100644 index 0000000..92280ad --- /dev/null +++ b/python/gevent/_sslgte279.py @@ -0,0 +1,714 @@ +# Wrapper module for _ssl. Written by Bill Janssen. +# Ported to gevent by Denis Bilenko. +"""SSL wrapper for socket objects on Python 2.7.9 and above. + +For the documentation, refer to :mod:`ssl` module manual. + +This module implements cooperative SSL socket wrappers. +""" + +from __future__ import absolute_import +# Our import magic sadly makes this warning useless +# pylint: disable=undefined-variable +# pylint: disable=too-many-instance-attributes,too-many-locals,too-many-statements,too-many-branches +# pylint: disable=arguments-differ,too-many-public-methods + +import ssl as __ssl__ + +_ssl = __ssl__._ssl # pylint:disable=no-member + +import errno +from gevent._socket2 import socket +from gevent.socket import timeout_default +from gevent.socket import create_connection +from gevent.socket import error as socket_error +from gevent.socket import timeout as _socket_timeout +from gevent._compat import PYPY +from gevent._util import copy_globals + +__implements__ = [ + 'SSLContext', + 'SSLSocket', + 'wrap_socket', + 'get_server_certificate', + 'create_default_context', + '_create_unverified_context', + '_create_default_https_context', + '_create_stdlib_context', +] + +# Import all symbols from Python's ssl.py, except those that we are implementing +# and "private" symbols. +__imports__ = copy_globals(__ssl__, globals(), + # SSLSocket *must* subclass gevent.socket.socket; see issue 597 and 801 + names_to_ignore=__implements__ + ['socket', 'create_connection'], + dunder_names_to_keep=()) + +try: + _delegate_methods +except NameError: # PyPy doesn't expose this detail + _delegate_methods = ('recv', 'recvfrom', 'recv_into', 'recvfrom_into', 'send', 'sendto') + +__all__ = __implements__ + __imports__ +if 'namedtuple' in __all__: + __all__.remove('namedtuple') + +orig_SSLContext = __ssl__.SSLContext # pylint: disable=no-member + + +class SSLContext(orig_SSLContext): + def wrap_socket(self, sock, server_side=False, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=None): + return SSLSocket(sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + server_hostname=server_hostname, + _context=self) + + +def create_default_context(purpose=Purpose.SERVER_AUTH, cafile=None, + capath=None, cadata=None): + """Create a SSLContext object with default settings. + + NOTE: The protocol and settings may change anytime without prior + deprecation. The values represent a fair balance between maximum + compatibility and security. + """ + if not isinstance(purpose, _ASN1Object): + raise TypeError(purpose) + + context = SSLContext(PROTOCOL_SSLv23) + + # SSLv2 considered harmful. + context.options |= OP_NO_SSLv2 + + # SSLv3 has problematic security and is only required for really old + # clients such as IE6 on Windows XP + context.options |= OP_NO_SSLv3 + + # disable compression to prevent CRIME attacks (OpenSSL 1.0+) + context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0) + + if purpose == Purpose.SERVER_AUTH: + # verify certs and host name in client mode + context.verify_mode = CERT_REQUIRED + context.check_hostname = True # pylint: disable=attribute-defined-outside-init + elif purpose == Purpose.CLIENT_AUTH: + # Prefer the server's ciphers by default so that we get stronger + # encryption + context.options |= getattr(_ssl, "OP_CIPHER_SERVER_PREFERENCE", 0) + + # Use single use keys in order to improve forward secrecy + context.options |= getattr(_ssl, "OP_SINGLE_DH_USE", 0) + context.options |= getattr(_ssl, "OP_SINGLE_ECDH_USE", 0) + + # disallow ciphers with known vulnerabilities + context.set_ciphers(_RESTRICTED_SERVER_CIPHERS) + + if cafile or capath or cadata: + context.load_verify_locations(cafile, capath, cadata) + elif context.verify_mode != CERT_NONE: + # no explicit cafile, capath or cadata but the verify mode is + # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system + # root CA certificates for the given purpose. This may fail silently. + context.load_default_certs(purpose) + return context + +def _create_unverified_context(protocol=PROTOCOL_SSLv23, cert_reqs=None, + check_hostname=False, purpose=Purpose.SERVER_AUTH, + certfile=None, keyfile=None, + cafile=None, capath=None, cadata=None): + """Create a SSLContext object for Python stdlib modules + + All Python stdlib modules shall use this function to create SSLContext + objects in order to keep common settings in one place. The configuration + is less restrict than create_default_context()'s to increase backward + compatibility. + """ + if not isinstance(purpose, _ASN1Object): + raise TypeError(purpose) + + context = SSLContext(protocol) + # SSLv2 considered harmful. + context.options |= OP_NO_SSLv2 + # SSLv3 has problematic security and is only required for really old + # clients such as IE6 on Windows XP + context.options |= OP_NO_SSLv3 + + if cert_reqs is not None: + context.verify_mode = cert_reqs + context.check_hostname = check_hostname # pylint: disable=attribute-defined-outside-init + + if keyfile and not certfile: + raise ValueError("certfile must be specified") + if certfile or keyfile: + context.load_cert_chain(certfile, keyfile) + + # load CA root certs + if cafile or capath or cadata: + context.load_verify_locations(cafile, capath, cadata) + elif context.verify_mode != CERT_NONE: + # no explicit cafile, capath or cadata but the verify mode is + # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system + # root CA certificates for the given purpose. This may fail silently. + context.load_default_certs(purpose) + + return context + +# Used by http.client if no context is explicitly passed. +_create_default_https_context = create_default_context + + +# Backwards compatibility alias, even though it's not a public name. +_create_stdlib_context = _create_unverified_context + +class SSLSocket(socket): + """ + gevent `ssl.SSLSocket <https://docs.python.org/2/library/ssl.html#ssl-sockets>`_ + for Pythons >= 2.7.9 but less than 3. + """ + + def __init__(self, sock=None, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, + suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, + server_hostname=None, + _context=None): + # fileno is ignored + # pylint: disable=unused-argument + if _context: + self._context = _context + else: + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side " + "operations") + if keyfile and not certfile: + raise ValueError("certfile must be specified") + if certfile and not keyfile: + keyfile = certfile + self._context = SSLContext(ssl_version) + self._context.verify_mode = cert_reqs + if ca_certs: + self._context.load_verify_locations(ca_certs) + if certfile: + self._context.load_cert_chain(certfile, keyfile) + if npn_protocols: + self._context.set_npn_protocols(npn_protocols) + if ciphers: + self._context.set_ciphers(ciphers) + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get + # mixed in. + if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM: + raise NotImplementedError("only stream sockets are supported") + + if PYPY: + socket.__init__(self, _sock=sock) + sock._drop() + else: + # CPython: XXX: Must pass the underlying socket, not our + # potential wrapper; test___example_servers fails the SSL test + # with a client-side EOF error. (Why?) + socket.__init__(self, _sock=sock._sock) + + # The initializer for socket overrides the methods send(), recv(), etc. + # in the instance, which we don't need -- but we want to provide the + # methods defined in SSLSocket. + for attr in _delegate_methods: + try: + delattr(self, attr) + except AttributeError: + pass + if server_side and server_hostname: + raise ValueError("server_hostname can only be specified " + "in client mode") + if self._context.check_hostname and not server_hostname: + raise ValueError("check_hostname requires server_hostname") + self.server_side = server_side + self.server_hostname = server_hostname + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self.settimeout(sock.gettimeout()) + + # See if we are connected + try: + self.getpeername() + except socket_error as e: + if e.errno != errno.ENOTCONN: + raise + connected = False + else: + connected = True + + self._makefile_refs = 0 + self._closed = False + self._sslobj = None + self._connected = connected + if connected: + # create the SSL object + try: + self._sslobj = self._context._wrap_socket(self._sock, server_side, + server_hostname, ssl_sock=self) + if do_handshake_on_connect: + timeout = self.gettimeout() + if timeout == 0.0: + # non-blocking + raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") + self.do_handshake() + + except socket_error as x: + self.close() + raise x + + + @property + def context(self): + return self._context + + @context.setter + def context(self, ctx): + self._context = ctx + self._sslobj.context = ctx + + def dup(self): + raise NotImplementedError("Can't dup() %s instances" % + self.__class__.__name__) + + def _checkClosed(self, msg=None): + # raise an exception here if you wish to check for spurious closes + pass + + def _check_connected(self): + if not self._connected: + # getpeername() will raise ENOTCONN if the socket is really + # not connected; note that we can be connected even without + # _connected being set, e.g. if connect() first returned + # EAGAIN. + self.getpeername() + + def read(self, len=1024, buffer=None): + """Read up to LEN bytes and return them. + Return zero-length string on EOF.""" + self._checkClosed() + + while 1: + if not self._sslobj: + raise ValueError("Read on closed or unwrapped SSL socket.") + if len == 0: + return b'' if buffer is None else 0 + if len < 0 and buffer is None: + # This is handled natively in python 2.7.12+ + raise ValueError("Negative read length") + try: + if buffer is not None: + return self._sslobj.read(len, buffer) + return self._sslobj.read(len or 1024) + except SSLWantReadError: + if self.timeout == 0.0: + raise + self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout) + except SSLWantWriteError: + if self.timeout == 0.0: + raise + # note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional + self._wait(self._write_event, timeout_exc=_SSLErrorReadTimeout) + except SSLError as ex: + if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: + if buffer is not None: + return 0 + return b'' + else: + raise + + def write(self, data): + """Write DATA to the underlying SSL channel. Returns + number of bytes of DATA actually transmitted.""" + self._checkClosed() + + while 1: + if not self._sslobj: + raise ValueError("Write on closed or unwrapped SSL socket.") + + try: + return self._sslobj.write(data) + except SSLError as ex: + if ex.args[0] == SSL_ERROR_WANT_READ: + if self.timeout == 0.0: + raise + self._wait(self._read_event, timeout_exc=_SSLErrorWriteTimeout) + elif ex.args[0] == SSL_ERROR_WANT_WRITE: + if self.timeout == 0.0: + raise + self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout) + else: + raise + + def getpeercert(self, binary_form=False): + """Returns a formatted version of the data in the + certificate provided by the other end of the SSL channel. + Return None if no certificate was provided, {} if a + certificate was provided, but not validated.""" + + self._checkClosed() + self._check_connected() + return self._sslobj.peer_certificate(binary_form) + + def selected_npn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_NPN: + return None + return self._sslobj.selected_npn_protocol() + + if hasattr(_ssl, 'HAS_ALPN'): + # 2.7.10+ + def selected_alpn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_ALPN: # pylint:disable=no-member + return None + return self._sslobj.selected_alpn_protocol() + + def cipher(self): + self._checkClosed() + if not self._sslobj: + return None + return self._sslobj.cipher() + + def compression(self): + self._checkClosed() + if not self._sslobj: + return None + return self._sslobj.compression() + + def __check_flags(self, meth, flags): + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to %s on %s" % + (meth, self.__class__)) + + def send(self, data, flags=0, timeout=timeout_default): + self._checkClosed() + self.__check_flags('send', flags) + + if timeout is timeout_default: + timeout = self.timeout + + if not self._sslobj: + return socket.send(self, data, flags, timeout) + + while True: + try: + return self._sslobj.write(data) + except SSLWantReadError: + if self.timeout == 0.0: + return 0 + self._wait(self._read_event) + except SSLWantWriteError: + if self.timeout == 0.0: + return 0 + self._wait(self._write_event) + + def sendto(self, data, flags_or_addr, addr=None): + self._checkClosed() + if self._sslobj: + raise ValueError("sendto not allowed on instances of %s" % + self.__class__) + elif addr is None: + return socket.sendto(self, data, flags_or_addr) + else: + return socket.sendto(self, data, flags_or_addr, addr) + + def sendmsg(self, *args, **kwargs): + # Ensure programs don't send data unencrypted if they try to + # use this method. + raise NotImplementedError("sendmsg not allowed on instances of %s" % + self.__class__) + + def sendall(self, data, flags=0): + self._checkClosed() + self.__check_flags('sendall', flags) + + try: + socket.sendall(self, data) + except _socket_timeout as ex: + if self.timeout == 0.0: + # Python 2 simply *hangs* in this case, which is bad, but + # Python 3 raises SSLWantWriteError. We do the same. + raise SSLWantWriteError("The operation did not complete (write)") + # Convert the socket.timeout back to the sslerror + raise SSLError(*ex.args) + + def recv(self, buflen=1024, flags=0): + self._checkClosed() + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv() on %s" % + self.__class__) + if buflen == 0: + return b'' + return self.read(buflen) + else: + return socket.recv(self, buflen, flags) + + def recv_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() + if buffer is not None and (nbytes is None): + # Fix for python bug #23804: bool(bytearray()) is False, + # but we should read 0 bytes. + nbytes = len(buffer) + elif nbytes is None: + nbytes = 1024 + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv_into() on %s" % + self.__class__) + return self.read(nbytes, buffer) + else: + return socket.recv_into(self, buffer, nbytes, flags) + + def recvfrom(self, buflen=1024, flags=0): + self._checkClosed() + if self._sslobj: + raise ValueError("recvfrom not allowed on instances of %s" % + self.__class__) + else: + return socket.recvfrom(self, buflen, flags) + + def recvfrom_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() + if self._sslobj: + raise ValueError("recvfrom_into not allowed on instances of %s" % + self.__class__) + else: + return socket.recvfrom_into(self, buffer, nbytes, flags) + + def recvmsg(self, *args, **kwargs): + raise NotImplementedError("recvmsg not allowed on instances of %s" % + self.__class__) + + def recvmsg_into(self, *args, **kwargs): + raise NotImplementedError("recvmsg_into not allowed on instances of " + "%s" % self.__class__) + + def pending(self): + self._checkClosed() + if self._sslobj: + return self._sslobj.pending() + return 0 + + def shutdown(self, how): + self._checkClosed() + self._sslobj = None + socket.shutdown(self, how) + + def close(self): + if self._makefile_refs < 1: + self._sslobj = None + socket.close(self) + else: + self._makefile_refs -= 1 + + if PYPY: + + def _reuse(self): + self._makefile_refs += 1 + + def _drop(self): + if self._makefile_refs < 1: + self.close() + else: + self._makefile_refs -= 1 + + def _sslobj_shutdown(self): + while True: + try: + return self._sslobj.shutdown() + except SSLError as ex: + if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: + return '' + elif ex.args[0] == SSL_ERROR_WANT_READ: + if self.timeout == 0.0: + raise + sys.exc_clear() + self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout) + elif ex.args[0] == SSL_ERROR_WANT_WRITE: + if self.timeout == 0.0: + raise + sys.exc_clear() + self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout) + else: + raise + + def unwrap(self): + if self._sslobj: + s = self._sslobj_shutdown() + self._sslobj = None + return socket(_sock=s) # match _ssl2; critical to drop/reuse here on PyPy + else: + raise ValueError("No SSL wrapper around " + str(self)) + + def _real_close(self): + self._sslobj = None + socket._real_close(self) # pylint: disable=no-member + + def do_handshake(self): + """Perform a TLS/SSL handshake.""" + self._check_connected() + while True: + try: + self._sslobj.do_handshake() + break + except SSLWantReadError: + if self.timeout == 0.0: + raise + self._wait(self._read_event, timeout_exc=_SSLErrorHandshakeTimeout) + except SSLWantWriteError: + if self.timeout == 0.0: + raise + self._wait(self._write_event, timeout_exc=_SSLErrorHandshakeTimeout) + + if self._context.check_hostname: + if not self.server_hostname: + raise ValueError("check_hostname needs server_hostname " + "argument") + match_hostname(self.getpeercert(), self.server_hostname) + + def _real_connect(self, addr, connect_ex): + if self.server_side: + raise ValueError("can't connect in server-side mode") + # Here we assume that the socket is client-side, and not + # connected at the time of the call. We connect it, then wrap it. + if self._connected: + raise ValueError("attempt to connect already-connected SSLSocket!") + self._sslobj = self._context._wrap_socket(self._sock, False, self.server_hostname, ssl_sock=self) + try: + if connect_ex: + rc = socket.connect_ex(self, addr) + else: + rc = None + socket.connect(self, addr) + if not rc: + self._connected = True + if self.do_handshake_on_connect: + self.do_handshake() + return rc + except socket_error: + self._sslobj = None + raise + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + self._real_connect(addr, False) + + def connect_ex(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + return self._real_connect(addr, True) + + def accept(self): + """Accepts a new connection from a remote client, and returns + a tuple containing that new connection wrapped with a server-side + SSL channel, and the address of the remote client.""" + + newsock, addr = socket.accept(self) + newsock = self._context.wrap_socket(newsock, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs, + server_side=True) + return newsock, addr + + def makefile(self, mode='r', bufsize=-1): + + """Make and return a file-like object that + works with the SSL connection. Just use the code + from the socket module.""" + if not PYPY: + self._makefile_refs += 1 + # close=True so as to decrement the reference count when done with + # the file-like object. + return _fileobject(self, mode, bufsize, close=True) + + def get_channel_binding(self, cb_type="tls-unique"): + """Get channel binding data for current connection. Raise ValueError + if the requested `cb_type` is not supported. Return bytes of the data + or None if the data is not available (e.g. before the handshake). + """ + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError("Unsupported channel binding type") + if cb_type != "tls-unique": + raise NotImplementedError( + "{0} channel binding type not implemented" + .format(cb_type)) + if self._sslobj is None: + return None + return self._sslobj.tls_unique_cb() + + def version(self): + """ + Return a string identifying the protocol version used by the + current SSL channel, or None if there is no established channel. + """ + if self._sslobj is None: + return None + return self._sslobj.version() + +if PYPY or not hasattr(SSLSocket, 'timeout'): + # PyPy (and certain versions of CPython) doesn't have a direct + # 'timeout' property on raw sockets, because that's not part of + # the documented specification. We may wind up wrapping a raw + # socket (when ssl is used with PyWSGI) or a gevent socket, which + # does have a read/write timeout property as an alias for + # get/settimeout, so make sure that's always the case because + # pywsgi can depend on that. + SSLSocket.timeout = property(lambda self: self.gettimeout(), + lambda self, value: self.settimeout(value)) + + + +_SSLErrorReadTimeout = SSLError('The read operation timed out') +_SSLErrorWriteTimeout = SSLError('The write operation timed out') +_SSLErrorHandshakeTimeout = SSLError('The handshake operation timed out') + +def wrap_socket(sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + ciphers=None): + + return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, + server_side=server_side, cert_reqs=cert_reqs, + ssl_version=ssl_version, ca_certs=ca_certs, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers) + +def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): + """Retrieve the certificate from the server at the specified address, + and return it as a PEM-encoded string. + If 'ca_certs' is specified, validate the server cert against it. + If 'ssl_version' is specified, use it in the connection attempt.""" + + _, _ = addr + if ca_certs is not None: + cert_reqs = CERT_REQUIRED + else: + cert_reqs = CERT_NONE + context = _create_stdlib_context(ssl_version, + cert_reqs=cert_reqs, + cafile=ca_certs) + with closing(create_connection(addr)) as sock: + with closing(context.wrap_socket(sock)) as sslsock: + dercert = sslsock.getpeercert(True) + return DER_cert_to_PEM_cert(dercert) |