diff options
Diffstat (limited to 'python/gevent/_ssl3.py')
-rw-r--r-- | python/gevent/_ssl3.py | 135 |
1 files changed, 93 insertions, 42 deletions
diff --git a/python/gevent/_ssl3.py b/python/gevent/_ssl3.py index 81b709c..71e2f76 100644 --- a/python/gevent/_ssl3.py +++ b/python/gevent/_ssl3.py @@ -8,11 +8,12 @@ This module implements cooperative SSL socket wrappers. """ # Our import magic sadly makes this warning useless # pylint: disable=undefined-variable +# pylint:disable=no-member from __future__ import absolute_import import ssl as __ssl__ -_ssl = __ssl__._ssl # pylint:disable=no-member +_ssl = __ssl__._ssl import errno from gevent.socket import socket, timeout_default @@ -44,20 +45,25 @@ orig_SSLContext = __ssl__.SSLContext # pylint:disable=no-member class SSLContext(orig_SSLContext): + + # Added in Python 3.7 + sslsocket_class = None # SSLSocket is assigned later + def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, suppress_ragged_eofs=True, server_hostname=None, session=None): - # pylint:disable=arguments-differ + # pylint:disable=arguments-differ,not-callable # (3.6 adds session) # Sadly, using *args and **kwargs doesn't work - return SSLSocket(sock=sock, server_side=server_side, - do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs, - server_hostname=server_hostname, - _context=self, - _session=session) + return self.sslsocket_class( + sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + server_hostname=server_hostname, + _context=self, + _session=session) if not hasattr(orig_SSLContext, 'check_hostname'): # Python 3.3 lacks this @@ -82,6 +88,16 @@ class SSLContext(orig_SSLContext): def verify_mode(self, value): super(orig_SSLContext, orig_SSLContext).verify_mode.__set__(self, value) + if hasattr(orig_SSLContext, 'minimum_version'): + # Like the above, added in 3.7 + @orig_SSLContext.minimum_version.setter + def minimum_version(self, value): + super(orig_SSLContext, orig_SSLContext).minimum_version.__set__(self, value) + + @orig_SSLContext.maximum_version.setter + def maximum_version(self, value): + super(orig_SSLContext, orig_SSLContext).maximum_version.__set__(self, value) + class _contextawaresock(socket._gevent_sock_class): # Python 2: pylint:disable=slots-on-old-class # We have to pass the raw stdlib socket to SSLContext.wrap_socket. @@ -122,6 +138,23 @@ class _contextawaresock(socket._gevent_sock_class): # Python 2: pylint:disable=s pass raise AttributeError(name) +try: + _SSLObject_factory = SSLObject +except NameError: + # 3.4 and below do not have SSLObject, something + # we magically import through copy_globals + pass +else: + if hasattr(SSLObject, '_create'): + # 3.7 is making thing difficult and won't let you + # actually construct an object + def _SSLObject_factory(sslobj, owner=None, session=None): + s = SSLObject.__new__(SSLObject) + s._sslobj = sslobj + s._sslobj.owner = owner or s + if session is not None: + s._sslobj.session = session + return s class SSLSocket(socket): """ @@ -142,6 +175,7 @@ class SSLSocket(socket): server_hostname=None, _session=None, # 3.6 _context=None): + # pylint:disable=too-many-locals,too-many-statements,too-many-branches if _context: self._context = _context @@ -218,8 +252,9 @@ class SSLSocket(socket): try: self._sslobj = self._context._wrap_socket(self._sock, server_side, server_hostname) - if _session is not None: # 3.6 - self._sslobj = SSLObject(self._sslobj, owner=self, session=self._session) + if _session is not None: # 3.6+ + self._sslobj = _SSLObject_factory(self._sslobj, owner=self, + session=self._session) if do_handshake_on_connect: timeout = self.gettimeout() if timeout == 0.0: @@ -305,8 +340,7 @@ class SSLSocket(socket): if buffer is None: return b'' return 0 - else: - raise + raise def write(self, data): """Write DATA to the underlying SSL channel. Returns @@ -455,8 +489,7 @@ class SSLSocket(socket): # Python #23804 return b'' return self.read(buflen) - else: - return socket.recv(self, buflen, flags) + return socket.recv(self, buflen, flags) def recv_into(self, buffer, nbytes=None, flags=0): self._checkClosed() @@ -468,8 +501,7 @@ class SSLSocket(socket): if flags != 0: raise ValueError("non-zero flags not allowed in calls to recv_into() on %s" % self.__class__) return self.read(nbytes, buffer) - else: - return socket.recv_into(self, buffer, nbytes, flags) + return socket.recv_into(self, buffer, nbytes, flags) def recvfrom(self, buflen=1024, flags=0): self._checkClosed() @@ -507,31 +539,41 @@ class SSLSocket(socket): socket.shutdown(self, how) def unwrap(self): - if self._sslobj: - while True: - try: - s = self._sslobj.shutdown() - break - except SSLWantReadError: - if self.timeout == 0.0: - return 0 - self._wait(self._read_event) - except SSLWantWriteError: - if self.timeout == 0.0: - return 0 - self._wait(self._write_event) - - self._sslobj = None - # The return value of shutting down the SSLObject is the - # original wrapped socket, i.e., _contextawaresock. But that - # object doesn't have the gevent wrapper around it so it can't - # be used. We have to wrap it back up with a gevent wrapper. - sock = socket(family=s.family, type=s.type, proto=s.proto, fileno=s.fileno()) - s.detach() - return sock - else: + if not self._sslobj: raise ValueError("No SSL wrapper around " + str(self)) + while True: + try: + s = self._sslobj.shutdown() + break + except SSLWantReadError: + # Callers of this method expect to get a socket + # back, so we can't simply return 0, we have + # to let these be raised + if self.timeout == 0.0: + raise + self._wait(self._read_event) + except SSLWantWriteError: + if self.timeout == 0.0: + raise + self._wait(self._write_event) + + self._sslobj = None + + # The return value of shutting down the SSLObject is the + # original wrapped socket passed to _wrap_socket, i.e., + # _contextawaresock. But that object doesn't have the + # gevent wrapper around it so it can't be used. We have to + # wrap it back up with a gevent wrapper. + assert s is self._sock + # In the stdlib, SSLSocket subclasses socket.socket and passes itself + # to _wrap_socket, so it gets itself back. We can't do that, we have to + # pass our subclass of _socket.socket, _contextawaresock. + # So ultimately we should return ourself. + + # See test_ftplib.py:TestTLS_FTPClass.test_ccc + return self + def _real_close(self): self._sslobj = None # self._closed = True @@ -553,7 +595,9 @@ class SSLSocket(socket): raise self._wait(self._write_event, timeout_exc=_SSLErrorHandshakeTimeout) - if self._context.check_hostname: + if sys.version_info[:2] < (3, 7) and self._context.check_hostname: + # In Python 3.7, the underlying OpenSSL name matching is used. + # The version implemented in Python doesn't understand IDNA encoding. if not self.server_hostname: raise ValueError("check_hostname needs server_hostname " "argument") @@ -567,8 +611,8 @@ class SSLSocket(socket): if self._connected: raise ValueError("attempt to connect already-connected SSLSocket!") self._sslobj = self._context._wrap_socket(self._sock, False, self.server_hostname) - if self._session is not None: # 3.6 - self._sslobj = SSLObject(self._sslobj, owner=self, session=self._session) + if self._session is not None: # 3.6+ + self._sslobj = _SSLObject_factory(self._sslobj, owner=self, session=self._session) try: if connect_ex: rc = socket.connect_ex(self, addr) @@ -600,6 +644,7 @@ class SSLSocket(socket): SSL channel, and the address of the remote client.""" newsock, addr = socket.accept(self) + newsock._drop_events() newsock = self._context.wrap_socket(newsock, do_handshake_on_connect=self.do_handshake_on_connect, suppress_ragged_eofs=self.suppress_ragged_eofs, @@ -611,6 +656,9 @@ class SSLSocket(socket): if the requested `cb_type` is not supported. Return bytes of the data or None if the data is not available (e.g. before the handshake). """ + if hasattr(self._sslobj, 'get_channel_binding'): + # 3.7+, and sslobj is not None + return self._sslobj.get_channel_binding(cb_type) if cb_type not in CHANNEL_BINDING_TYPES: raise ValueError("Unsupported channel binding type") if cb_type != "tls-unique": @@ -620,6 +668,9 @@ class SSLSocket(socket): return self._sslobj.tls_unique_cb() +# Python does not support forward declaration of types +SSLContext.sslsocket_class = SSLSocket + # Python 3.2 onwards raise normal timeout errors, not SSLError. # See https://bugs.python.org/issue10272 _SSLErrorReadTimeout = _socket_timeout('The read operation timed out') |