diff options
| author | James Taylor <user234683@users.noreply.github.com> | 2019-06-01 23:23:18 -0700 | 
|---|---|---|
| committer | James Taylor <user234683@users.noreply.github.com> | 2019-06-02 02:25:39 -0700 | 
| commit | af9c4e0554c3475d959014e9e7cef78eff88afa5 (patch) | |
| tree | ced7a2ccd6d0ab8e9d251dcd61bba09f3bb87074 | |
| parent | 3905e7e64059b45479894ba1fdfb0ef9cef64475 (diff) | |
| parent | 9f93b9429c77e631972186049fbc7518e2cf5d4b (diff) | |
| download | yt-local-af9c4e0554c3475d959014e9e7cef78eff88afa5.tar.lz yt-local-af9c4e0554c3475d959014e9e7cef78eff88afa5.tar.xz yt-local-af9c4e0554c3475d959014e9e7cef78eff88afa5.zip | |
Bring up to date with master
54 files changed, 9934 insertions, 552 deletions
| @@ -1,7 +1,7 @@  # youtube-local   -youtube-local is a browser-based client written in Python for watching Youtube anonymously and without the lag of the javascript-heavy page used by Youtube. One of the primary features is that all requests are routed through Tor, except for the video file at googlevideo.com. This is analogous to what HookTube does, except that you do not have to trust a third-party to respect your privacy. The assumption here is that Google won't put the effort in to incorporate the video file requests into their survelliance systems, as it's not worth pursuing the incredibly small number of users who care about privacy. Using Tor is optional; when not routing through Tor, video pages load *faster* than they do with Youtube's laggy javascript page (for me atleast). +youtube-local is a browser-based client written in Python for watching Youtube anonymously and without the lag of the javascript-heavy page used by Youtube. One of the primary features is that all requests are routed through Tor, except for the video file at googlevideo.com. This is analogous to what HookTube does, except that you do not have to trust a third-party to respect your privacy. The assumption here is that Google won't put the effort in to incorporate the video file requests into their survelliance systems, as it's not worth pursuing the incredibly small number of users who care about privacy. Tor has high latency, so this will not be as fast as regular Youtube. However, using Tor is optional; when not routing through Tor, video pages may load faster than they do with Youtube's laggy javascript page depending on your browser.  The Youtube API is not used, so no keys or anything are needed. It uses the same requests as the Youtube webpage. No javascript is used either. @@ -15,9 +15,9 @@ Download the zip file under the Releases page. Unzip it anywhere you choose.  ### Linux/MacOS -Ensure you have python 3.5 or later installed. Then, install gevent, brotli, and PySocks by running +Ensure you have python 3.5 or later installed. Then, install gevent, brotli, PySocks, and urllib3 by running  ``` -pip3 install gevent brotli pysocks +pip3 install gevent brotli pysocks urllib3  ```  **Note**: If pip isn't installed, install it according to [this answer](https://unix.stackexchange.com/a/182467), but make sure you run `python3 get-pip.py` instead of `python get-pip.py` diff --git a/python/urllib3/__init__.py b/python/urllib3/__init__.py new file mode 100644 index 0000000..148a9c3 --- /dev/null +++ b/python/urllib3/__init__.py @@ -0,0 +1,92 @@ +""" +urllib3 - Thread-safe connection pooling and re-using. +""" + +from __future__ import absolute_import +import warnings + +from .connectionpool import ( +    HTTPConnectionPool, +    HTTPSConnectionPool, +    connection_from_url +) + +from . import exceptions +from .filepost import encode_multipart_formdata +from .poolmanager import PoolManager, ProxyManager, proxy_from_url +from .response import HTTPResponse +from .util.request import make_headers +from .util.url import get_host +from .util.timeout import Timeout +from .util.retry import Retry + + +# Set default logging handler to avoid "No handler found" warnings. +import logging +from logging import NullHandler + +__author__ = 'Andrey Petrov (andrey.petrov@shazow.net)' +__license__ = 'MIT' +__version__ = '1.24.1' + +__all__ = ( +    'HTTPConnectionPool', +    'HTTPSConnectionPool', +    'PoolManager', +    'ProxyManager', +    'HTTPResponse', +    'Retry', +    'Timeout', +    'add_stderr_logger', +    'connection_from_url', +    'disable_warnings', +    'encode_multipart_formdata', +    'get_host', +    'make_headers', +    'proxy_from_url', +) + +logging.getLogger(__name__).addHandler(NullHandler()) + + +def add_stderr_logger(level=logging.DEBUG): +    """ +    Helper for quickly adding a StreamHandler to the logger. Useful for +    debugging. + +    Returns the handler after adding it. +    """ +    # This method needs to be in this __init__.py to get the __name__ correct +    # even if urllib3 is vendored within another package. +    logger = logging.getLogger(__name__) +    handler = logging.StreamHandler() +    handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) +    logger.addHandler(handler) +    logger.setLevel(level) +    logger.debug('Added a stderr logging handler to logger: %s', __name__) +    return handler + + +# ... Clean up. +del NullHandler + + +# All warning filters *must* be appended unless you're really certain that they +# shouldn't be: otherwise, it's very hard for users to use most Python +# mechanisms to silence them. +# SecurityWarning's always go off by default. +warnings.simplefilter('always', exceptions.SecurityWarning, append=True) +# SubjectAltNameWarning's should go off once per host +warnings.simplefilter('default', exceptions.SubjectAltNameWarning, append=True) +# InsecurePlatformWarning's don't vary between requests, so we keep it default. +warnings.simplefilter('default', exceptions.InsecurePlatformWarning, +                      append=True) +# SNIMissingWarnings should go off only once. +warnings.simplefilter('default', exceptions.SNIMissingWarning, append=True) + + +def disable_warnings(category=exceptions.HTTPWarning): +    """ +    Helper for quickly disabling all urllib3 warnings. +    """ +    warnings.simplefilter('ignore', category) diff --git a/python/urllib3/_collections.py b/python/urllib3/_collections.py new file mode 100644 index 0000000..34f2381 --- /dev/null +++ b/python/urllib3/_collections.py @@ -0,0 +1,329 @@ +from __future__ import absolute_import +try: +    from collections.abc import Mapping, MutableMapping +except ImportError: +    from collections import Mapping, MutableMapping +try: +    from threading import RLock +except ImportError:  # Platform-specific: No threads available +    class RLock: +        def __enter__(self): +            pass + +        def __exit__(self, exc_type, exc_value, traceback): +            pass + + +from collections import OrderedDict +from .exceptions import InvalidHeader +from .packages.six import iterkeys, itervalues, PY3 + + +__all__ = ['RecentlyUsedContainer', 'HTTPHeaderDict'] + + +_Null = object() + + +class RecentlyUsedContainer(MutableMapping): +    """ +    Provides a thread-safe dict-like container which maintains up to +    ``maxsize`` keys while throwing away the least-recently-used keys beyond +    ``maxsize``. + +    :param maxsize: +        Maximum number of recent elements to retain. + +    :param dispose_func: +        Every time an item is evicted from the container, +        ``dispose_func(value)`` is called.  Callback which will get called +    """ + +    ContainerCls = OrderedDict + +    def __init__(self, maxsize=10, dispose_func=None): +        self._maxsize = maxsize +        self.dispose_func = dispose_func + +        self._container = self.ContainerCls() +        self.lock = RLock() + +    def __getitem__(self, key): +        # Re-insert the item, moving it to the end of the eviction line. +        with self.lock: +            item = self._container.pop(key) +            self._container[key] = item +            return item + +    def __setitem__(self, key, value): +        evicted_value = _Null +        with self.lock: +            # Possibly evict the existing value of 'key' +            evicted_value = self._container.get(key, _Null) +            self._container[key] = value + +            # If we didn't evict an existing value, we might have to evict the +            # least recently used item from the beginning of the container. +            if len(self._container) > self._maxsize: +                _key, evicted_value = self._container.popitem(last=False) + +        if self.dispose_func and evicted_value is not _Null: +            self.dispose_func(evicted_value) + +    def __delitem__(self, key): +        with self.lock: +            value = self._container.pop(key) + +        if self.dispose_func: +            self.dispose_func(value) + +    def __len__(self): +        with self.lock: +            return len(self._container) + +    def __iter__(self): +        raise NotImplementedError('Iteration over this class is unlikely to be threadsafe.') + +    def clear(self): +        with self.lock: +            # Copy pointers to all values, then wipe the mapping +            values = list(itervalues(self._container)) +            self._container.clear() + +        if self.dispose_func: +            for value in values: +                self.dispose_func(value) + +    def keys(self): +        with self.lock: +            return list(iterkeys(self._container)) + + +class HTTPHeaderDict(MutableMapping): +    """ +    :param headers: +        An iterable of field-value pairs. Must not contain multiple field names +        when compared case-insensitively. + +    :param kwargs: +        Additional field-value pairs to pass in to ``dict.update``. + +    A ``dict`` like container for storing HTTP Headers. + +    Field names are stored and compared case-insensitively in compliance with +    RFC 7230. Iteration provides the first case-sensitive key seen for each +    case-insensitive pair. + +    Using ``__setitem__`` syntax overwrites fields that compare equal +    case-insensitively in order to maintain ``dict``'s api. For fields that +    compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add`` +    in a loop. + +    If multiple fields that are equal case-insensitively are passed to the +    constructor or ``.update``, the behavior is undefined and some will be +    lost. + +    >>> headers = HTTPHeaderDict() +    >>> headers.add('Set-Cookie', 'foo=bar') +    >>> headers.add('set-cookie', 'baz=quxx') +    >>> headers['content-length'] = '7' +    >>> headers['SET-cookie'] +    'foo=bar, baz=quxx' +    >>> headers['Content-Length'] +    '7' +    """ + +    def __init__(self, headers=None, **kwargs): +        super(HTTPHeaderDict, self).__init__() +        self._container = OrderedDict() +        if headers is not None: +            if isinstance(headers, HTTPHeaderDict): +                self._copy_from(headers) +            else: +                self.extend(headers) +        if kwargs: +            self.extend(kwargs) + +    def __setitem__(self, key, val): +        self._container[key.lower()] = [key, val] +        return self._container[key.lower()] + +    def __getitem__(self, key): +        val = self._container[key.lower()] +        return ', '.join(val[1:]) + +    def __delitem__(self, key): +        del self._container[key.lower()] + +    def __contains__(self, key): +        return key.lower() in self._container + +    def __eq__(self, other): +        if not isinstance(other, Mapping) and not hasattr(other, 'keys'): +            return False +        if not isinstance(other, type(self)): +            other = type(self)(other) +        return (dict((k.lower(), v) for k, v in self.itermerged()) == +                dict((k.lower(), v) for k, v in other.itermerged())) + +    def __ne__(self, other): +        return not self.__eq__(other) + +    if not PY3:  # Python 2 +        iterkeys = MutableMapping.iterkeys +        itervalues = MutableMapping.itervalues + +    __marker = object() + +    def __len__(self): +        return len(self._container) + +    def __iter__(self): +        # Only provide the originally cased names +        for vals in self._container.values(): +            yield vals[0] + +    def pop(self, key, default=__marker): +        '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. +          If key is not found, d is returned if given, otherwise KeyError is raised. +        ''' +        # Using the MutableMapping function directly fails due to the private marker. +        # Using ordinary dict.pop would expose the internal structures. +        # So let's reinvent the wheel. +        try: +            value = self[key] +        except KeyError: +            if default is self.__marker: +                raise +            return default +        else: +            del self[key] +            return value + +    def discard(self, key): +        try: +            del self[key] +        except KeyError: +            pass + +    def add(self, key, val): +        """Adds a (name, value) pair, doesn't overwrite the value if it already +        exists. + +        >>> headers = HTTPHeaderDict(foo='bar') +        >>> headers.add('Foo', 'baz') +        >>> headers['foo'] +        'bar, baz' +        """ +        key_lower = key.lower() +        new_vals = [key, val] +        # Keep the common case aka no item present as fast as possible +        vals = self._container.setdefault(key_lower, new_vals) +        if new_vals is not vals: +            vals.append(val) + +    def extend(self, *args, **kwargs): +        """Generic import function for any type of header-like object. +        Adapted version of MutableMapping.update in order to insert items +        with self.add instead of self.__setitem__ +        """ +        if len(args) > 1: +            raise TypeError("extend() takes at most 1 positional " +                            "arguments ({0} given)".format(len(args))) +        other = args[0] if len(args) >= 1 else () + +        if isinstance(other, HTTPHeaderDict): +            for key, val in other.iteritems(): +                self.add(key, val) +        elif isinstance(other, Mapping): +            for key in other: +                self.add(key, other[key]) +        elif hasattr(other, "keys"): +            for key in other.keys(): +                self.add(key, other[key]) +        else: +            for key, value in other: +                self.add(key, value) + +        for key, value in kwargs.items(): +            self.add(key, value) + +    def getlist(self, key, default=__marker): +        """Returns a list of all the values for the named field. Returns an +        empty list if the key doesn't exist.""" +        try: +            vals = self._container[key.lower()] +        except KeyError: +            if default is self.__marker: +                return [] +            return default +        else: +            return vals[1:] + +    # Backwards compatibility for httplib +    getheaders = getlist +    getallmatchingheaders = getlist +    iget = getlist + +    # Backwards compatibility for http.cookiejar +    get_all = getlist + +    def __repr__(self): +        return "%s(%s)" % (type(self).__name__, dict(self.itermerged())) + +    def _copy_from(self, other): +        for key in other: +            val = other.getlist(key) +            if isinstance(val, list): +                # Don't need to convert tuples +                val = list(val) +            self._container[key.lower()] = [key] + val + +    def copy(self): +        clone = type(self)() +        clone._copy_from(self) +        return clone + +    def iteritems(self): +        """Iterate over all header lines, including duplicate ones.""" +        for key in self: +            vals = self._container[key.lower()] +            for val in vals[1:]: +                yield vals[0], val + +    def itermerged(self): +        """Iterate over all headers, merging duplicate ones together.""" +        for key in self: +            val = self._container[key.lower()] +            yield val[0], ', '.join(val[1:]) + +    def items(self): +        return list(self.iteritems()) + +    @classmethod +    def from_httplib(cls, message):  # Python 2 +        """Read headers from a Python 2 httplib message object.""" +        # python2.7 does not expose a proper API for exporting multiheaders +        # efficiently. This function re-reads raw lines from the message +        # object and extracts the multiheaders properly. +        obs_fold_continued_leaders = (' ', '\t') +        headers = [] + +        for line in message.headers: +            if line.startswith(obs_fold_continued_leaders): +                if not headers: +                    # We received a header line that starts with OWS as described +                    # in RFC-7230 S3.2.4. This indicates a multiline header, but +                    # there exists no previous header to which we can attach it. +                    raise InvalidHeader( +                        'Header continuation with no previous header: %s' % line +                    ) +                else: +                    key, value = headers[-1] +                    headers[-1] = (key, value + ' ' + line.strip()) +                    continue + +            key, value = line.split(':', 1) +            headers.append((key, value.strip())) + +        return cls(headers) diff --git a/python/urllib3/connection.py b/python/urllib3/connection.py new file mode 100644 index 0000000..02b3665 --- /dev/null +++ b/python/urllib3/connection.py @@ -0,0 +1,391 @@ +from __future__ import absolute_import +import datetime +import logging +import os +import socket +from socket import error as SocketError, timeout as SocketTimeout +import warnings +from .packages import six +from .packages.six.moves.http_client import HTTPConnection as _HTTPConnection +from .packages.six.moves.http_client import HTTPException  # noqa: F401 + +try:  # Compiled with SSL? +    import ssl +    BaseSSLError = ssl.SSLError +except (ImportError, AttributeError):  # Platform-specific: No SSL. +    ssl = None + +    class BaseSSLError(BaseException): +        pass + + +try:  # Python 3: +    # Not a no-op, we're adding this to the namespace so it can be imported. +    ConnectionError = ConnectionError +except NameError:  # Python 2: +    class ConnectionError(Exception): +        pass + + +from .exceptions import ( +    NewConnectionError, +    ConnectTimeoutError, +    SubjectAltNameWarning, +    SystemTimeWarning, +) +from .packages.ssl_match_hostname import match_hostname, CertificateError + +from .util.ssl_ import ( +    resolve_cert_reqs, +    resolve_ssl_version, +    assert_fingerprint, +    create_urllib3_context, +    ssl_wrap_socket +) + + +from .util import connection + +from ._collections import HTTPHeaderDict + +log = logging.getLogger(__name__) + +port_by_scheme = { +    'http': 80, +    'https': 443, +} + +# When updating RECENT_DATE, move it to within two years of the current date, +# and not less than 6 months ago. +# Example: if Today is 2018-01-01, then RECENT_DATE should be any date on or +# after 2016-01-01 (today - 2 years) AND before 2017-07-01 (today - 6 months) +RECENT_DATE = datetime.date(2017, 6, 30) + + +class DummyConnection(object): +    """Used to detect a failed ConnectionCls import.""" +    pass + + +class HTTPConnection(_HTTPConnection, object): +    """ +    Based on httplib.HTTPConnection but provides an extra constructor +    backwards-compatibility layer between older and newer Pythons. + +    Additional keyword parameters are used to configure attributes of the connection. +    Accepted parameters include: + +      - ``strict``: See the documentation on :class:`urllib3.connectionpool.HTTPConnectionPool` +      - ``source_address``: Set the source address for the current connection. +      - ``socket_options``: Set specific options on the underlying socket. If not specified, then +        defaults are loaded from ``HTTPConnection.default_socket_options`` which includes disabling +        Nagle's algorithm (sets TCP_NODELAY to 1) unless the connection is behind a proxy. + +        For example, if you wish to enable TCP Keep Alive in addition to the defaults, +        you might pass:: + +            HTTPConnection.default_socket_options + [ +                (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), +            ] + +        Or you may want to disable the defaults by passing an empty list (e.g., ``[]``). +    """ + +    default_port = port_by_scheme['http'] + +    #: Disable Nagle's algorithm by default. +    #: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]`` +    default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + +    #: Whether this connection verifies the host's certificate. +    is_verified = False + +    def __init__(self, *args, **kw): +        if six.PY3:  # Python 3 +            kw.pop('strict', None) + +        # Pre-set source_address. +        self.source_address = kw.get('source_address') + +        #: The socket options provided by the user. If no options are +        #: provided, we use the default options. +        self.socket_options = kw.pop('socket_options', self.default_socket_options) + +        _HTTPConnection.__init__(self, *args, **kw) + +    @property +    def host(self): +        """ +        Getter method to remove any trailing dots that indicate the hostname is an FQDN. + +        In general, SSL certificates don't include the trailing dot indicating a +        fully-qualified domain name, and thus, they don't validate properly when +        checked against a domain name that includes the dot. In addition, some +        servers may not expect to receive the trailing dot when provided. + +        However, the hostname with trailing dot is critical to DNS resolution; doing a +        lookup with the trailing dot will properly only resolve the appropriate FQDN, +        whereas a lookup without a trailing dot will search the system's search domain +        list. Thus, it's important to keep the original host around for use only in +        those cases where it's appropriate (i.e., when doing DNS lookup to establish the +        actual TCP connection across which we're going to send HTTP requests). +        """ +        return self._dns_host.rstrip('.') + +    @host.setter +    def host(self, value): +        """ +        Setter for the `host` property. + +        We assume that only urllib3 uses the _dns_host attribute; httplib itself +        only uses `host`, and it seems reasonable that other libraries follow suit. +        """ +        self._dns_host = value + +    def _new_conn(self): +        """ Establish a socket connection and set nodelay settings on it. + +        :return: New socket connection. +        """ +        extra_kw = {} +        if self.source_address: +            extra_kw['source_address'] = self.source_address + +        if self.socket_options: +            extra_kw['socket_options'] = self.socket_options + +        try: +            conn = connection.create_connection( +                (self._dns_host, self.port), self.timeout, **extra_kw) + +        except SocketTimeout as e: +            raise ConnectTimeoutError( +                self, "Connection to %s timed out. (connect timeout=%s)" % +                (self.host, self.timeout)) + +        except SocketError as e: +            raise NewConnectionError( +                self, "Failed to establish a new connection: %s" % e) + +        return conn + +    def _prepare_conn(self, conn): +        self.sock = conn +        if self._tunnel_host: +            # TODO: Fix tunnel so it doesn't depend on self.sock state. +            self._tunnel() +            # Mark this connection as not reusable +            self.auto_open = 0 + +    def connect(self): +        conn = self._new_conn() +        self._prepare_conn(conn) + +    def request_chunked(self, method, url, body=None, headers=None): +        """ +        Alternative to the common request method, which sends the +        body with chunked encoding and not as one block +        """ +        headers = HTTPHeaderDict(headers if headers is not None else {}) +        skip_accept_encoding = 'accept-encoding' in headers +        skip_host = 'host' in headers +        self.putrequest( +            method, +            url, +            skip_accept_encoding=skip_accept_encoding, +            skip_host=skip_host +        ) +        for header, value in headers.items(): +            self.putheader(header, value) +        if 'transfer-encoding' not in headers: +            self.putheader('Transfer-Encoding', 'chunked') +        self.endheaders() + +        if body is not None: +            stringish_types = six.string_types + (bytes,) +            if isinstance(body, stringish_types): +                body = (body,) +            for chunk in body: +                if not chunk: +                    continue +                if not isinstance(chunk, bytes): +                    chunk = chunk.encode('utf8') +                len_str = hex(len(chunk))[2:] +                self.send(len_str.encode('utf-8')) +                self.send(b'\r\n') +                self.send(chunk) +                self.send(b'\r\n') + +        # After the if clause, to always have a closed body +        self.send(b'0\r\n\r\n') + + +class HTTPSConnection(HTTPConnection): +    default_port = port_by_scheme['https'] + +    ssl_version = None + +    def __init__(self, host, port=None, key_file=None, cert_file=None, +                 strict=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, +                 ssl_context=None, server_hostname=None, **kw): + +        HTTPConnection.__init__(self, host, port, strict=strict, +                                timeout=timeout, **kw) + +        self.key_file = key_file +        self.cert_file = cert_file +        self.ssl_context = ssl_context +        self.server_hostname = server_hostname + +        # Required property for Google AppEngine 1.9.0 which otherwise causes +        # HTTPS requests to go out as HTTP. (See Issue #356) +        self._protocol = 'https' + +    def connect(self): +        conn = self._new_conn() +        self._prepare_conn(conn) + +        if self.ssl_context is None: +            self.ssl_context = create_urllib3_context( +                ssl_version=resolve_ssl_version(None), +                cert_reqs=resolve_cert_reqs(None), +            ) + +        self.sock = ssl_wrap_socket( +            sock=conn, +            keyfile=self.key_file, +            certfile=self.cert_file, +            ssl_context=self.ssl_context, +            server_hostname=self.server_hostname +        ) + + +class VerifiedHTTPSConnection(HTTPSConnection): +    """ +    Based on httplib.HTTPSConnection but wraps the socket with +    SSL certification. +    """ +    cert_reqs = None +    ca_certs = None +    ca_cert_dir = None +    ssl_version = None +    assert_fingerprint = None + +    def set_cert(self, key_file=None, cert_file=None, +                 cert_reqs=None, ca_certs=None, +                 assert_hostname=None, assert_fingerprint=None, +                 ca_cert_dir=None): +        """ +        This method should only be called once, before the connection is used. +        """ +        # If cert_reqs is not provided, we can try to guess. If the user gave +        # us a cert database, we assume they want to use it: otherwise, if +        # they gave us an SSL Context object we should use whatever is set for +        # it. +        if cert_reqs is None: +            if ca_certs or ca_cert_dir: +                cert_reqs = 'CERT_REQUIRED' +            elif self.ssl_context is not None: +                cert_reqs = self.ssl_context.verify_mode + +        self.key_file = key_file +        self.cert_file = cert_file +        self.cert_reqs = cert_reqs +        self.assert_hostname = assert_hostname +        self.assert_fingerprint = assert_fingerprint +        self.ca_certs = ca_certs and os.path.expanduser(ca_certs) +        self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir) + +    def connect(self): +        # Add certificate verification +        conn = self._new_conn() +        hostname = self.host + +        if self._tunnel_host: +            self.sock = conn +            # Calls self._set_hostport(), so self.host is +            # self._tunnel_host below. +            self._tunnel() +            # Mark this connection as not reusable +            self.auto_open = 0 + +            # Override the host with the one we're requesting data from. +            hostname = self._tunnel_host + +        server_hostname = hostname +        if self.server_hostname is not None: +            server_hostname = self.server_hostname + +        is_time_off = datetime.date.today() < RECENT_DATE +        if is_time_off: +            warnings.warn(( +                'System time is way off (before {0}). This will probably ' +                'lead to SSL verification errors').format(RECENT_DATE), +                SystemTimeWarning +            ) + +        # Wrap socket using verification with the root certs in +        # trusted_root_certs +        if self.ssl_context is None: +            self.ssl_context = create_urllib3_context( +                ssl_version=resolve_ssl_version(self.ssl_version), +                cert_reqs=resolve_cert_reqs(self.cert_reqs), +            ) + +        context = self.ssl_context +        context.verify_mode = resolve_cert_reqs(self.cert_reqs) +        self.sock = ssl_wrap_socket( +            sock=conn, +            keyfile=self.key_file, +            certfile=self.cert_file, +            ca_certs=self.ca_certs, +            ca_cert_dir=self.ca_cert_dir, +            server_hostname=server_hostname, +            ssl_context=context) + +        if self.assert_fingerprint: +            assert_fingerprint(self.sock.getpeercert(binary_form=True), +                               self.assert_fingerprint) +        elif context.verify_mode != ssl.CERT_NONE \ +                and not getattr(context, 'check_hostname', False) \ +                and self.assert_hostname is not False: +            # While urllib3 attempts to always turn off hostname matching from +            # the TLS library, this cannot always be done. So we check whether +            # the TLS Library still thinks it's matching hostnames. +            cert = self.sock.getpeercert() +            if not cert.get('subjectAltName', ()): +                warnings.warn(( +                    'Certificate for {0} has no `subjectAltName`, falling back to check for a ' +                    '`commonName` for now. This feature is being removed by major browsers and ' +                    'deprecated by RFC 2818. (See https://github.com/shazow/urllib3/issues/497 ' +                    'for details.)'.format(hostname)), +                    SubjectAltNameWarning +                ) +            _match_hostname(cert, self.assert_hostname or server_hostname) + +        self.is_verified = ( +            context.verify_mode == ssl.CERT_REQUIRED or +            self.assert_fingerprint is not None +        ) + + +def _match_hostname(cert, asserted_hostname): +    try: +        match_hostname(cert, asserted_hostname) +    except CertificateError as e: +        log.error( +            'Certificate did not match expected hostname: %s. ' +            'Certificate: %s', asserted_hostname, cert +        ) +        # Add cert to exception and reraise so client code can inspect +        # the cert when catching the exception, if they want to +        e._peer_cert = cert +        raise + + +if ssl: +    # Make a copy for testing. +    UnverifiedHTTPSConnection = HTTPSConnection +    HTTPSConnection = VerifiedHTTPSConnection +else: +    HTTPSConnection = DummyConnection diff --git a/python/urllib3/connectionpool.py b/python/urllib3/connectionpool.py new file mode 100644 index 0000000..f7a8f19 --- /dev/null +++ b/python/urllib3/connectionpool.py @@ -0,0 +1,896 @@ +from __future__ import absolute_import +import errno +import logging +import sys +import warnings + +from socket import error as SocketError, timeout as SocketTimeout +import socket + + +from .exceptions import ( +    ClosedPoolError, +    ProtocolError, +    EmptyPoolError, +    HeaderParsingError, +    HostChangedError, +    LocationValueError, +    MaxRetryError, +    ProxyError, +    ReadTimeoutError, +    SSLError, +    TimeoutError, +    InsecureRequestWarning, +    NewConnectionError, +) +from .packages.ssl_match_hostname import CertificateError +from .packages import six +from .packages.six.moves import queue +from .connection import ( +    port_by_scheme, +    DummyConnection, +    HTTPConnection, HTTPSConnection, VerifiedHTTPSConnection, +    HTTPException, BaseSSLError, +) +from .request import RequestMethods +from .response import HTTPResponse + +from .util.connection import is_connection_dropped +from .util.request import set_file_position +from .util.response import assert_header_parsing +from .util.retry import Retry +from .util.timeout import Timeout +from .util.url import get_host, Url, NORMALIZABLE_SCHEMES +from .util.queue import LifoQueue + + +xrange = six.moves.xrange + +log = logging.getLogger(__name__) + +_Default = object() + + +# Pool objects +class ConnectionPool(object): +    """ +    Base class for all connection pools, such as +    :class:`.HTTPConnectionPool` and :class:`.HTTPSConnectionPool`. +    """ + +    scheme = None +    QueueCls = LifoQueue + +    def __init__(self, host, port=None): +        if not host: +            raise LocationValueError("No host specified.") + +        self.host = _ipv6_host(host, self.scheme) +        self._proxy_host = host.lower() +        self.port = port + +    def __str__(self): +        return '%s(host=%r, port=%r)' % (type(self).__name__, +                                         self.host, self.port) + +    def __enter__(self): +        return self + +    def __exit__(self, exc_type, exc_val, exc_tb): +        self.close() +        # Return False to re-raise any potential exceptions +        return False + +    def close(self): +        """ +        Close all pooled connections and disable the pool. +        """ +        pass + + +# This is taken from http://hg.python.org/cpython/file/7aaba721ebc0/Lib/socket.py#l252 +_blocking_errnos = {errno.EAGAIN, errno.EWOULDBLOCK} + + +class HTTPConnectionPool(ConnectionPool, RequestMethods): +    """ +    Thread-safe connection pool for one host. + +    :param host: +        Host used for this HTTP Connection (e.g. "localhost"), passed into +        :class:`httplib.HTTPConnection`. + +    :param port: +        Port used for this HTTP Connection (None is equivalent to 80), passed +        into :class:`httplib.HTTPConnection`. + +    :param strict: +        Causes BadStatusLine to be raised if the status line can't be parsed +        as a valid HTTP/1.0 or 1.1 status line, passed into +        :class:`httplib.HTTPConnection`. + +        .. note:: +           Only works in Python 2. This parameter is ignored in Python 3. + +    :param timeout: +        Socket timeout in seconds for each individual connection. This can +        be a float or integer, which sets the timeout for the HTTP request, +        or an instance of :class:`urllib3.util.Timeout` which gives you more +        fine-grained control over request timeouts. After the constructor has +        been parsed, this is always a `urllib3.util.Timeout` object. + +    :param maxsize: +        Number of connections to save that can be reused. More than 1 is useful +        in multithreaded situations. If ``block`` is set to False, more +        connections will be created but they will not be saved once they've +        been used. + +    :param block: +        If set to True, no more than ``maxsize`` connections will be used at +        a time. When no free connections are available, the call will block +        until a connection has been released. This is a useful side effect for +        particular multithreaded situations where one does not want to use more +        than maxsize connections per host to prevent flooding. + +    :param headers: +        Headers to include with all requests, unless other headers are given +        explicitly. + +    :param retries: +        Retry configuration to use by default with requests in this pool. + +    :param _proxy: +        Parsed proxy URL, should not be used directly, instead, see +        :class:`urllib3.connectionpool.ProxyManager`" + +    :param _proxy_headers: +        A dictionary with proxy headers, should not be used directly, +        instead, see :class:`urllib3.connectionpool.ProxyManager`" + +    :param \\**conn_kw: +        Additional parameters are used to create fresh :class:`urllib3.connection.HTTPConnection`, +        :class:`urllib3.connection.HTTPSConnection` instances. +    """ + +    scheme = 'http' +    ConnectionCls = HTTPConnection +    ResponseCls = HTTPResponse + +    def __init__(self, host, port=None, strict=False, +                 timeout=Timeout.DEFAULT_TIMEOUT, maxsize=1, block=False, +                 headers=None, retries=None, +                 _proxy=None, _proxy_headers=None, +                 **conn_kw): +        ConnectionPool.__init__(self, host, port) +        RequestMethods.__init__(self, headers) + +        self.strict = strict + +        if not isinstance(timeout, Timeout): +            timeout = Timeout.from_float(timeout) + +        if retries is None: +            retries = Retry.DEFAULT + +        self.timeout = timeout +        self.retries = retries + +        self.pool = self.QueueCls(maxsize) +        self.block = block + +        self.proxy = _proxy +        self.proxy_headers = _proxy_headers or {} + +        # Fill the queue up so that doing get() on it will block properly +        for _ in xrange(maxsize): +            self.pool.put(None) + +        # These are mostly for testing and debugging purposes. +        self.num_connections = 0 +        self.num_requests = 0 +        self.conn_kw = conn_kw + +        if self.proxy: +            # Enable Nagle's algorithm for proxies, to avoid packet fragmentation. +            # We cannot know if the user has added default socket options, so we cannot replace the +            # list. +            self.conn_kw.setdefault('socket_options', []) + +    def _new_conn(self): +        """ +        Return a fresh :class:`HTTPConnection`. +        """ +        self.num_connections += 1 +        log.debug("Starting new HTTP connection (%d): %s:%s", +                  self.num_connections, self.host, self.port or "80") + +        conn = self.ConnectionCls(host=self.host, port=self.port, +                                  timeout=self.timeout.connect_timeout, +                                  strict=self.strict, **self.conn_kw) +        return conn + +    def _get_conn(self, timeout=None): +        """ +        Get a connection. Will return a pooled connection if one is available. + +        If no connections are available and :prop:`.block` is ``False``, then a +        fresh connection is returned. + +        :param timeout: +            Seconds to wait before giving up and raising +            :class:`urllib3.exceptions.EmptyPoolError` if the pool is empty and +            :prop:`.block` is ``True``. +        """ +        conn = None +        try: +            conn = self.pool.get(block=self.block, timeout=timeout) + +        except AttributeError:  # self.pool is None +            raise ClosedPoolError(self, "Pool is closed.") + +        except queue.Empty: +            if self.block: +                raise EmptyPoolError(self, +                                     "Pool reached maximum size and no more " +                                     "connections are allowed.") +            pass  # Oh well, we'll create a new connection then + +        # If this is a persistent connection, check if it got disconnected +        if conn and is_connection_dropped(conn): +            log.debug("Resetting dropped connection: %s", self.host) +            conn.close() +            if getattr(conn, 'auto_open', 1) == 0: +                # This is a proxied connection that has been mutated by +                # httplib._tunnel() and cannot be reused (since it would +                # attempt to bypass the proxy) +                conn = None + +        return conn or self._new_conn() + +    def _put_conn(self, conn): +        """ +        Put a connection back into the pool. + +        :param conn: +            Connection object for the current host and port as returned by +            :meth:`._new_conn` or :meth:`._get_conn`. + +        If the pool is already full, the connection is closed and discarded +        because we exceeded maxsize. If connections are discarded frequently, +        then maxsize should be increased. + +        If the pool is closed, then the connection will be closed and discarded. +        """ +        try: +            self.pool.put(conn, block=False) +            return  # Everything is dandy, done. +        except AttributeError: +            # self.pool is None. +            pass +        except queue.Full: +            # This should never happen if self.block == True +            log.warning( +                "Connection pool is full, discarding connection: %s", +                self.host) + +        # Connection never got put back into the pool, close it. +        if conn: +            conn.close() + +    def _validate_conn(self, conn): +        """ +        Called right before a request is made, after the socket is created. +        """ +        pass + +    def _prepare_proxy(self, conn): +        # Nothing to do for HTTP connections. +        pass + +    def _get_timeout(self, timeout): +        """ Helper that always returns a :class:`urllib3.util.Timeout` """ +        if timeout is _Default: +            return self.timeout.clone() + +        if isinstance(timeout, Timeout): +            return timeout.clone() +        else: +            # User passed us an int/float. This is for backwards compatibility, +            # can be removed later +            return Timeout.from_float(timeout) + +    def _raise_timeout(self, err, url, timeout_value): +        """Is the error actually a timeout? Will raise a ReadTimeout or pass""" + +        if isinstance(err, SocketTimeout): +            raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value) + +        # See the above comment about EAGAIN in Python 3. In Python 2 we have +        # to specifically catch it and throw the timeout error +        if hasattr(err, 'errno') and err.errno in _blocking_errnos: +            raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value) + +        # Catch possible read timeouts thrown as SSL errors. If not the +        # case, rethrow the original. We need to do this because of: +        # http://bugs.python.org/issue10272 +        if 'timed out' in str(err) or 'did not complete (read)' in str(err):  # Python < 2.7.4 +            raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value) + +    def _make_request(self, conn, method, url, timeout=_Default, chunked=False, +                      **httplib_request_kw): +        """ +        Perform a request on a given urllib connection object taken from our +        pool. + +        :param conn: +            a connection from one of our connection pools + +        :param timeout: +            Socket timeout in seconds for the request. This can be a +            float or integer, which will set the same timeout value for +            the socket connect and the socket read, or an instance of +            :class:`urllib3.util.Timeout`, which gives you more fine-grained +            control over your timeouts. +        """ +        self.num_requests += 1 + +        timeout_obj = self._get_timeout(timeout) +        timeout_obj.start_connect() +        conn.timeout = timeout_obj.connect_timeout + +        # Trigger any extra validation we need to do. +        try: +            self._validate_conn(conn) +        except (SocketTimeout, BaseSSLError) as e: +            # Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout. +            self._raise_timeout(err=e, url=url, timeout_value=conn.timeout) +            raise + +        # conn.request() calls httplib.*.request, not the method in +        # urllib3.request. It also calls makefile (recv) on the socket. +        if chunked: +            conn.request_chunked(method, url, **httplib_request_kw) +        else: +            conn.request(method, url, **httplib_request_kw) + +        # Reset the timeout for the recv() on the socket +        read_timeout = timeout_obj.read_timeout + +        # App Engine doesn't have a sock attr +        if getattr(conn, 'sock', None): +            # In Python 3 socket.py will catch EAGAIN and return None when you +            # try and read into the file pointer created by http.client, which +            # instead raises a BadStatusLine exception. Instead of catching +            # the exception and assuming all BadStatusLine exceptions are read +            # timeouts, check for a zero timeout before making the request. +            if read_timeout == 0: +                raise ReadTimeoutError( +                    self, url, "Read timed out. (read timeout=%s)" % read_timeout) +            if read_timeout is Timeout.DEFAULT_TIMEOUT: +                conn.sock.settimeout(socket.getdefaulttimeout()) +            else:  # None or a value +                conn.sock.settimeout(read_timeout) + +        # Receive the response from the server +        try: +            try:  # Python 2.7, use buffering of HTTP responses +                httplib_response = conn.getresponse(buffering=True) +            except TypeError:  # Python 3 +                try: +                    httplib_response = conn.getresponse() +                except Exception as e: +                    # Remove the TypeError from the exception chain in Python 3; +                    # otherwise it looks like a programming error was the cause. +                    six.raise_from(e, None) +        except (SocketTimeout, BaseSSLError, SocketError) as e: +            self._raise_timeout(err=e, url=url, timeout_value=read_timeout) +            raise + +        # AppEngine doesn't have a version attr. +        http_version = getattr(conn, '_http_vsn_str', 'HTTP/?') +        log.debug("%s://%s:%s \"%s %s %s\" %s %s", self.scheme, self.host, self.port, +                  method, url, http_version, httplib_response.status, +                  httplib_response.length) + +        try: +            assert_header_parsing(httplib_response.msg) +        except (HeaderParsingError, TypeError) as hpe:  # Platform-specific: Python 3 +            log.warning( +                'Failed to parse headers (url=%s): %s', +                self._absolute_url(url), hpe, exc_info=True) + +        return httplib_response + +    def _absolute_url(self, path): +        return Url(scheme=self.scheme, host=self.host, port=self.port, path=path).url + +    def close(self): +        """ +        Close all pooled connections and disable the pool. +        """ +        if self.pool is None: +            return +        # Disable access to the pool +        old_pool, self.pool = self.pool, None + +        try: +            while True: +                conn = old_pool.get(block=False) +                if conn: +                    conn.close() + +        except queue.Empty: +            pass  # Done. + +    def is_same_host(self, url): +        """ +        Check if the given ``url`` is a member of the same host as this +        connection pool. +        """ +        if url.startswith('/'): +            return True + +        # TODO: Add optional support for socket.gethostbyname checking. +        scheme, host, port = get_host(url) + +        host = _ipv6_host(host, self.scheme) + +        # Use explicit default port for comparison when none is given +        if self.port and not port: +            port = port_by_scheme.get(scheme) +        elif not self.port and port == port_by_scheme.get(scheme): +            port = None + +        return (scheme, host, port) == (self.scheme, self.host, self.port) + +    def urlopen(self, method, url, body=None, headers=None, retries=None, +                redirect=True, assert_same_host=True, timeout=_Default, +                pool_timeout=None, release_conn=None, chunked=False, +                body_pos=None, **response_kw): +        """ +        Get a connection from the pool and perform an HTTP request. This is the +        lowest level call for making a request, so you'll need to specify all +        the raw details. + +        .. note:: + +           More commonly, it's appropriate to use a convenience method provided +           by :class:`.RequestMethods`, such as :meth:`request`. + +        .. note:: + +           `release_conn` will only behave as expected if +           `preload_content=False` because we want to make +           `preload_content=False` the default behaviour someday soon without +           breaking backwards compatibility. + +        :param method: +            HTTP request method (such as GET, POST, PUT, etc.) + +        :param body: +            Data to send in the request body (useful for creating +            POST requests, see HTTPConnectionPool.post_url for +            more convenience). + +        :param headers: +            Dictionary of custom headers to send, such as User-Agent, +            If-None-Match, etc. If None, pool headers are used. If provided, +            these headers completely replace any pool-specific headers. + +        :param retries: +            Configure the number of retries to allow before raising a +            :class:`~urllib3.exceptions.MaxRetryError` exception. + +            Pass ``None`` to retry until you receive a response. Pass a +            :class:`~urllib3.util.retry.Retry` object for fine-grained control +            over different types of retries. +            Pass an integer number to retry connection errors that many times, +            but no other types of errors. Pass zero to never retry. + +            If ``False``, then retries are disabled and any exception is raised +            immediately. Also, instead of raising a MaxRetryError on redirects, +            the redirect response will be returned. + +        :type retries: :class:`~urllib3.util.retry.Retry`, False, or an int. + +        :param redirect: +            If True, automatically handle redirects (status codes 301, 302, +            303, 307, 308). Each redirect counts as a retry. Disabling retries +            will disable redirect, too. + +        :param assert_same_host: +            If ``True``, will make sure that the host of the pool requests is +            consistent else will raise HostChangedError. When False, you can +            use the pool on an HTTP proxy and request foreign hosts. + +        :param timeout: +            If specified, overrides the default timeout for this one +            request. It may be a float (in seconds) or an instance of +            :class:`urllib3.util.Timeout`. + +        :param pool_timeout: +            If set and the pool is set to block=True, then this method will +            block for ``pool_timeout`` seconds and raise EmptyPoolError if no +            connection is available within the time period. + +        :param release_conn: +            If False, then the urlopen call will not release the connection +            back into the pool once a response is received (but will release if +            you read the entire contents of the response such as when +            `preload_content=True`). This is useful if you're not preloading +            the response's content immediately. You will need to call +            ``r.release_conn()`` on the response ``r`` to return the connection +            back into the pool. If None, it takes the value of +            ``response_kw.get('preload_content', True)``. + +        :param chunked: +            If True, urllib3 will send the body using chunked transfer +            encoding. Otherwise, urllib3 will send the body using the standard +            content-length form. Defaults to False. + +        :param int body_pos: +            Position to seek to in file-like body in the event of a retry or +            redirect. Typically this won't need to be set because urllib3 will +            auto-populate the value when needed. + +        :param \\**response_kw: +            Additional parameters are passed to +            :meth:`urllib3.response.HTTPResponse.from_httplib` +        """ +        if headers is None: +            headers = self.headers + +        if not isinstance(retries, Retry): +            retries = Retry.from_int(retries, redirect=redirect, default=self.retries) + +        if release_conn is None: +            release_conn = response_kw.get('preload_content', True) + +        # Check host +        if assert_same_host and not self.is_same_host(url): +            raise HostChangedError(self, url, retries) + +        conn = None + +        # Track whether `conn` needs to be released before +        # returning/raising/recursing. Update this variable if necessary, and +        # leave `release_conn` constant throughout the function. That way, if +        # the function recurses, the original value of `release_conn` will be +        # passed down into the recursive call, and its value will be respected. +        # +        # See issue #651 [1] for details. +        # +        # [1] <https://github.com/shazow/urllib3/issues/651> +        release_this_conn = release_conn + +        # Merge the proxy headers. Only do this in HTTP. We have to copy the +        # headers dict so we can safely change it without those changes being +        # reflected in anyone else's copy. +        if self.scheme == 'http': +            headers = headers.copy() +            headers.update(self.proxy_headers) + +        # Must keep the exception bound to a separate variable or else Python 3 +        # complains about UnboundLocalError. +        err = None + +        # Keep track of whether we cleanly exited the except block. This +        # ensures we do proper cleanup in finally. +        clean_exit = False + +        # Rewind body position, if needed. Record current position +        # for future rewinds in the event of a redirect/retry. +        body_pos = set_file_position(body, body_pos) + +        try: +            # Request a connection from the queue. +            timeout_obj = self._get_timeout(timeout) +            conn = self._get_conn(timeout=pool_timeout) + +            conn.timeout = timeout_obj.connect_timeout + +            is_new_proxy_conn = self.proxy is not None and not getattr(conn, 'sock', None) +            if is_new_proxy_conn: +                self._prepare_proxy(conn) + +            # Make the request on the httplib connection object. +            httplib_response = self._make_request(conn, method, url, +                                                  timeout=timeout_obj, +                                                  body=body, headers=headers, +                                                  chunked=chunked) + +            # If we're going to release the connection in ``finally:``, then +            # the response doesn't need to know about the connection. Otherwise +            # it will also try to release it and we'll have a double-release +            # mess. +            response_conn = conn if not release_conn else None + +            # Pass method to Response for length checking +            response_kw['request_method'] = method + +            # Import httplib's response into our own wrapper object +            response = self.ResponseCls.from_httplib(httplib_response, +                                                     pool=self, +                                                     connection=response_conn, +                                                     retries=retries, +                                                     **response_kw) + +            # Everything went great! +            clean_exit = True + +        except queue.Empty: +            # Timed out by queue. +            raise EmptyPoolError(self, "No pool connections are available.") + +        except (TimeoutError, HTTPException, SocketError, ProtocolError, +                BaseSSLError, SSLError, CertificateError) as e: +            # Discard the connection for these exceptions. It will be +            # replaced during the next _get_conn() call. +            clean_exit = False +            if isinstance(e, (BaseSSLError, CertificateError)): +                e = SSLError(e) +            elif isinstance(e, (SocketError, NewConnectionError)) and self.proxy: +                e = ProxyError('Cannot connect to proxy.', e) +            elif isinstance(e, (SocketError, HTTPException)): +                e = ProtocolError('Connection aborted.', e) + +            retries = retries.increment(method, url, error=e, _pool=self, +                                        _stacktrace=sys.exc_info()[2]) +            retries.sleep() + +            # Keep track of the error for the retry warning. +            err = e + +        finally: +            if not clean_exit: +                # We hit some kind of exception, handled or otherwise. We need +                # to throw the connection away unless explicitly told not to. +                # Close the connection, set the variable to None, and make sure +                # we put the None back in the pool to avoid leaking it. +                conn = conn and conn.close() +                release_this_conn = True + +            if release_this_conn: +                # Put the connection back to be reused. If the connection is +                # expired then it will be None, which will get replaced with a +                # fresh connection during _get_conn. +                self._put_conn(conn) + +        if not conn: +            # Try again +            log.warning("Retrying (%r) after connection " +                        "broken by '%r': %s", retries, err, url) +            return self.urlopen(method, url, body, headers, retries, +                                redirect, assert_same_host, +                                timeout=timeout, pool_timeout=pool_timeout, +                                release_conn=release_conn, body_pos=body_pos, +                                **response_kw) + +        def drain_and_release_conn(response): +            try: +                # discard any remaining response body, the connection will be +                # released back to the pool once the entire response is read +                response.read() +            except (TimeoutError, HTTPException, SocketError, ProtocolError, +                    BaseSSLError, SSLError) as e: +                pass + +        # Handle redirect? +        redirect_location = redirect and response.get_redirect_location() +        if redirect_location: +            if response.status == 303: +                method = 'GET' + +            try: +                retries = retries.increment(method, url, response=response, _pool=self) +            except MaxRetryError: +                if retries.raise_on_redirect: +                    # Drain and release the connection for this response, since +                    # we're not returning it to be released manually. +                    drain_and_release_conn(response) +                    raise +                return response + +            # drain and return the connection to the pool before recursing +            drain_and_release_conn(response) + +            retries.sleep_for_retry(response) +            log.debug("Redirecting %s -> %s", url, redirect_location) +            return self.urlopen( +                method, redirect_location, body, headers, +                retries=retries, redirect=redirect, +                assert_same_host=assert_same_host, +                timeout=timeout, pool_timeout=pool_timeout, +                release_conn=release_conn, body_pos=body_pos, +                **response_kw) + +        # Check if we should retry the HTTP response. +        has_retry_after = bool(response.getheader('Retry-After')) +        if retries.is_retry(method, response.status, has_retry_after): +            try: +                retries = retries.increment(method, url, response=response, _pool=self) +            except MaxRetryError: +                if retries.raise_on_status: +                    # Drain and release the connection for this response, since +                    # we're not returning it to be released manually. +                    drain_and_release_conn(response) +                    raise +                return response + +            # drain and return the connection to the pool before recursing +            drain_and_release_conn(response) + +            retries.sleep(response) +            log.debug("Retry: %s", url) +            return self.urlopen( +                method, url, body, headers, +                retries=retries, redirect=redirect, +                assert_same_host=assert_same_host, +                timeout=timeout, pool_timeout=pool_timeout, +                release_conn=release_conn, +                body_pos=body_pos, **response_kw) + +        return response + + +class HTTPSConnectionPool(HTTPConnectionPool): +    """ +    Same as :class:`.HTTPConnectionPool`, but HTTPS. + +    When Python is compiled with the :mod:`ssl` module, then +    :class:`.VerifiedHTTPSConnection` is used, which *can* verify certificates, +    instead of :class:`.HTTPSConnection`. + +    :class:`.VerifiedHTTPSConnection` uses one of ``assert_fingerprint``, +    ``assert_hostname`` and ``host`` in this order to verify connections. +    If ``assert_hostname`` is False, no verification is done. + +    The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``, +    ``ca_cert_dir``, and ``ssl_version`` are only used if :mod:`ssl` is +    available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade +    the connection socket into an SSL socket. +    """ + +    scheme = 'https' +    ConnectionCls = HTTPSConnection + +    def __init__(self, host, port=None, +                 strict=False, timeout=Timeout.DEFAULT_TIMEOUT, maxsize=1, +                 block=False, headers=None, retries=None, +                 _proxy=None, _proxy_headers=None, +                 key_file=None, cert_file=None, cert_reqs=None, +                 ca_certs=None, ssl_version=None, +                 assert_hostname=None, assert_fingerprint=None, +                 ca_cert_dir=None, **conn_kw): + +        HTTPConnectionPool.__init__(self, host, port, strict, timeout, maxsize, +                                    block, headers, retries, _proxy, _proxy_headers, +                                    **conn_kw) + +        if ca_certs and cert_reqs is None: +            cert_reqs = 'CERT_REQUIRED' + +        self.key_file = key_file +        self.cert_file = cert_file +        self.cert_reqs = cert_reqs +        self.ca_certs = ca_certs +        self.ca_cert_dir = ca_cert_dir +        self.ssl_version = ssl_version +        self.assert_hostname = assert_hostname +        self.assert_fingerprint = assert_fingerprint + +    def _prepare_conn(self, conn): +        """ +        Prepare the ``connection`` for :meth:`urllib3.util.ssl_wrap_socket` +        and establish the tunnel if proxy is used. +        """ + +        if isinstance(conn, VerifiedHTTPSConnection): +            conn.set_cert(key_file=self.key_file, +                          cert_file=self.cert_file, +                          cert_reqs=self.cert_reqs, +                          ca_certs=self.ca_certs, +                          ca_cert_dir=self.ca_cert_dir, +                          assert_hostname=self.assert_hostname, +                          assert_fingerprint=self.assert_fingerprint) +            conn.ssl_version = self.ssl_version +        return conn + +    def _prepare_proxy(self, conn): +        """ +        Establish tunnel connection early, because otherwise httplib +        would improperly set Host: header to proxy's IP:port. +        """ +        conn.set_tunnel(self._proxy_host, self.port, self.proxy_headers) +        conn.connect() + +    def _new_conn(self): +        """ +        Return a fresh :class:`httplib.HTTPSConnection`. +        """ +        self.num_connections += 1 +        log.debug("Starting new HTTPS connection (%d): %s:%s", +                  self.num_connections, self.host, self.port or "443") + +        if not self.ConnectionCls or self.ConnectionCls is DummyConnection: +            raise SSLError("Can't connect to HTTPS URL because the SSL " +                           "module is not available.") + +        actual_host = self.host +        actual_port = self.port +        if self.proxy is not None: +            actual_host = self.proxy.host +            actual_port = self.proxy.port + +        conn = self.ConnectionCls(host=actual_host, port=actual_port, +                                  timeout=self.timeout.connect_timeout, +                                  strict=self.strict, **self.conn_kw) + +        return self._prepare_conn(conn) + +    def _validate_conn(self, conn): +        """ +        Called right before a request is made, after the socket is created. +        """ +        super(HTTPSConnectionPool, self)._validate_conn(conn) + +        # Force connect early to allow us to validate the connection. +        if not getattr(conn, 'sock', None):  # AppEngine might not have  `.sock` +            conn.connect() + +        if not conn.is_verified: +            warnings.warn(( +                'Unverified HTTPS request is being made. ' +                'Adding certificate verification is strongly advised. See: ' +                'https://urllib3.readthedocs.io/en/latest/advanced-usage.html' +                '#ssl-warnings'), +                InsecureRequestWarning) + + +def connection_from_url(url, **kw): +    """ +    Given a url, return an :class:`.ConnectionPool` instance of its host. + +    This is a shortcut for not having to parse out the scheme, host, and port +    of the url before creating an :class:`.ConnectionPool` instance. + +    :param url: +        Absolute URL string that must include the scheme. Port is optional. + +    :param \\**kw: +        Passes additional parameters to the constructor of the appropriate +        :class:`.ConnectionPool`. Useful for specifying things like +        timeout, maxsize, headers, etc. + +    Example:: + +        >>> conn = connection_from_url('http://google.com/') +        >>> r = conn.request('GET', '/') +    """ +    scheme, host, port = get_host(url) +    port = port or port_by_scheme.get(scheme, 80) +    if scheme == 'https': +        return HTTPSConnectionPool(host, port=port, **kw) +    else: +        return HTTPConnectionPool(host, port=port, **kw) + + +def _ipv6_host(host, scheme): +    """ +    Process IPv6 address literals +    """ + +    # httplib doesn't like it when we include brackets in IPv6 addresses +    # Specifically, if we include brackets but also pass the port then +    # httplib crazily doubles up the square brackets on the Host header. +    # Instead, we need to make sure we never pass ``None`` as the port. +    # However, for backward compatibility reasons we can't actually +    # *assert* that.  See http://bugs.python.org/issue28539 +    # +    # Also if an IPv6 address literal has a zone identifier, the +    # percent sign might be URIencoded, convert it back into ASCII +    if host.startswith('[') and host.endswith(']'): +        host = host.replace('%25', '%').strip('[]') +    if scheme in NORMALIZABLE_SCHEMES: +        host = host.lower() +    return host diff --git a/python/urllib3/contrib/__init__.py b/python/urllib3/contrib/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/python/urllib3/contrib/__init__.py diff --git a/python/urllib3/contrib/_appengine_environ.py b/python/urllib3/contrib/_appengine_environ.py new file mode 100644 index 0000000..f3e0094 --- /dev/null +++ b/python/urllib3/contrib/_appengine_environ.py @@ -0,0 +1,30 @@ +""" +This module provides means to detect the App Engine environment. +""" + +import os + + +def is_appengine(): +    return (is_local_appengine() or +            is_prod_appengine() or +            is_prod_appengine_mvms()) + + +def is_appengine_sandbox(): +    return is_appengine() and not is_prod_appengine_mvms() + + +def is_local_appengine(): +    return ('APPENGINE_RUNTIME' in os.environ and +            'Development/' in os.environ['SERVER_SOFTWARE']) + + +def is_prod_appengine(): +    return ('APPENGINE_RUNTIME' in os.environ and +            'Google App Engine/' in os.environ['SERVER_SOFTWARE'] and +            not is_prod_appengine_mvms()) + + +def is_prod_appengine_mvms(): +    return os.environ.get('GAE_VM', False) == 'true' diff --git a/python/urllib3/contrib/_securetransport/__init__.py b/python/urllib3/contrib/_securetransport/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/python/urllib3/contrib/_securetransport/__init__.py diff --git a/python/urllib3/contrib/_securetransport/bindings.py b/python/urllib3/contrib/_securetransport/bindings.py new file mode 100644 index 0000000..bcf41c0 --- /dev/null +++ b/python/urllib3/contrib/_securetransport/bindings.py @@ -0,0 +1,593 @@ +""" +This module uses ctypes to bind a whole bunch of functions and constants from +SecureTransport. The goal here is to provide the low-level API to +SecureTransport. These are essentially the C-level functions and constants, and +they're pretty gross to work with. + +This code is a bastardised version of the code found in Will Bond's oscrypto +library. An enormous debt is owed to him for blazing this trail for us. For +that reason, this code should be considered to be covered both by urllib3's +license and by oscrypto's: + +    Copyright (c) 2015-2016 Will Bond <will@wbond.net> + +    Permission is hereby granted, free of charge, to any person obtaining a +    copy of this software and associated documentation files (the "Software"), +    to deal in the Software without restriction, including without limitation +    the rights to use, copy, modify, merge, publish, distribute, sublicense, +    and/or sell copies of the Software, and to permit persons to whom the +    Software is furnished to do so, subject to the following conditions: + +    The above copyright notice and this permission notice shall be included in +    all copies or substantial portions of the Software. + +    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +    DEALINGS IN THE SOFTWARE. +""" +from __future__ import absolute_import + +import platform +from ctypes.util import find_library +from ctypes import ( +    c_void_p, c_int32, c_char_p, c_size_t, c_byte, c_uint32, c_ulong, c_long, +    c_bool +) +from ctypes import CDLL, POINTER, CFUNCTYPE + + +security_path = find_library('Security') +if not security_path: +    raise ImportError('The library Security could not be found') + + +core_foundation_path = find_library('CoreFoundation') +if not core_foundation_path: +    raise ImportError('The library CoreFoundation could not be found') + + +version = platform.mac_ver()[0] +version_info = tuple(map(int, version.split('.'))) +if version_info < (10, 8): +    raise OSError( +        'Only OS X 10.8 and newer are supported, not %s.%s' % ( +            version_info[0], version_info[1] +        ) +    ) + +Security = CDLL(security_path, use_errno=True) +CoreFoundation = CDLL(core_foundation_path, use_errno=True) + +Boolean = c_bool +CFIndex = c_long +CFStringEncoding = c_uint32 +CFData = c_void_p +CFString = c_void_p +CFArray = c_void_p +CFMutableArray = c_void_p +CFDictionary = c_void_p +CFError = c_void_p +CFType = c_void_p +CFTypeID = c_ulong + +CFTypeRef = POINTER(CFType) +CFAllocatorRef = c_void_p + +OSStatus = c_int32 + +CFDataRef = POINTER(CFData) +CFStringRef = POINTER(CFString) +CFArrayRef = POINTER(CFArray) +CFMutableArrayRef = POINTER(CFMutableArray) +CFDictionaryRef = POINTER(CFDictionary) +CFArrayCallBacks = c_void_p +CFDictionaryKeyCallBacks = c_void_p +CFDictionaryValueCallBacks = c_void_p + +SecCertificateRef = POINTER(c_void_p) +SecExternalFormat = c_uint32 +SecExternalItemType = c_uint32 +SecIdentityRef = POINTER(c_void_p) +SecItemImportExportFlags = c_uint32 +SecItemImportExportKeyParameters = c_void_p +SecKeychainRef = POINTER(c_void_p) +SSLProtocol = c_uint32 +SSLCipherSuite = c_uint32 +SSLContextRef = POINTER(c_void_p) +SecTrustRef = POINTER(c_void_p) +SSLConnectionRef = c_uint32 +SecTrustResultType = c_uint32 +SecTrustOptionFlags = c_uint32 +SSLProtocolSide = c_uint32 +SSLConnectionType = c_uint32 +SSLSessionOption = c_uint32 + + +try: +    Security.SecItemImport.argtypes = [ +        CFDataRef, +        CFStringRef, +        POINTER(SecExternalFormat), +        POINTER(SecExternalItemType), +        SecItemImportExportFlags, +        POINTER(SecItemImportExportKeyParameters), +        SecKeychainRef, +        POINTER(CFArrayRef), +    ] +    Security.SecItemImport.restype = OSStatus + +    Security.SecCertificateGetTypeID.argtypes = [] +    Security.SecCertificateGetTypeID.restype = CFTypeID + +    Security.SecIdentityGetTypeID.argtypes = [] +    Security.SecIdentityGetTypeID.restype = CFTypeID + +    Security.SecKeyGetTypeID.argtypes = [] +    Security.SecKeyGetTypeID.restype = CFTypeID + +    Security.SecCertificateCreateWithData.argtypes = [ +        CFAllocatorRef, +        CFDataRef +    ] +    Security.SecCertificateCreateWithData.restype = SecCertificateRef + +    Security.SecCertificateCopyData.argtypes = [ +        SecCertificateRef +    ] +    Security.SecCertificateCopyData.restype = CFDataRef + +    Security.SecCopyErrorMessageString.argtypes = [ +        OSStatus, +        c_void_p +    ] +    Security.SecCopyErrorMessageString.restype = CFStringRef + +    Security.SecIdentityCreateWithCertificate.argtypes = [ +        CFTypeRef, +        SecCertificateRef, +        POINTER(SecIdentityRef) +    ] +    Security.SecIdentityCreateWithCertificate.restype = OSStatus + +    Security.SecKeychainCreate.argtypes = [ +        c_char_p, +        c_uint32, +        c_void_p, +        Boolean, +        c_void_p, +        POINTER(SecKeychainRef) +    ] +    Security.SecKeychainCreate.restype = OSStatus + +    Security.SecKeychainDelete.argtypes = [ +        SecKeychainRef +    ] +    Security.SecKeychainDelete.restype = OSStatus + +    Security.SecPKCS12Import.argtypes = [ +        CFDataRef, +        CFDictionaryRef, +        POINTER(CFArrayRef) +    ] +    Security.SecPKCS12Import.restype = OSStatus + +    SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t)) +    SSLWriteFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)) + +    Security.SSLSetIOFuncs.argtypes = [ +        SSLContextRef, +        SSLReadFunc, +        SSLWriteFunc +    ] +    Security.SSLSetIOFuncs.restype = OSStatus + +    Security.SSLSetPeerID.argtypes = [ +        SSLContextRef, +        c_char_p, +        c_size_t +    ] +    Security.SSLSetPeerID.restype = OSStatus + +    Security.SSLSetCertificate.argtypes = [ +        SSLContextRef, +        CFArrayRef +    ] +    Security.SSLSetCertificate.restype = OSStatus + +    Security.SSLSetCertificateAuthorities.argtypes = [ +        SSLContextRef, +        CFTypeRef, +        Boolean +    ] +    Security.SSLSetCertificateAuthorities.restype = OSStatus + +    Security.SSLSetConnection.argtypes = [ +        SSLContextRef, +        SSLConnectionRef +    ] +    Security.SSLSetConnection.restype = OSStatus + +    Security.SSLSetPeerDomainName.argtypes = [ +        SSLContextRef, +        c_char_p, +        c_size_t +    ] +    Security.SSLSetPeerDomainName.restype = OSStatus + +    Security.SSLHandshake.argtypes = [ +        SSLContextRef +    ] +    Security.SSLHandshake.restype = OSStatus + +    Security.SSLRead.argtypes = [ +        SSLContextRef, +        c_char_p, +        c_size_t, +        POINTER(c_size_t) +    ] +    Security.SSLRead.restype = OSStatus + +    Security.SSLWrite.argtypes = [ +        SSLContextRef, +        c_char_p, +        c_size_t, +        POINTER(c_size_t) +    ] +    Security.SSLWrite.restype = OSStatus + +    Security.SSLClose.argtypes = [ +        SSLContextRef +    ] +    Security.SSLClose.restype = OSStatus + +    Security.SSLGetNumberSupportedCiphers.argtypes = [ +        SSLContextRef, +        POINTER(c_size_t) +    ] +    Security.SSLGetNumberSupportedCiphers.restype = OSStatus + +    Security.SSLGetSupportedCiphers.argtypes = [ +        SSLContextRef, +        POINTER(SSLCipherSuite), +        POINTER(c_size_t) +    ] +    Security.SSLGetSupportedCiphers.restype = OSStatus + +    Security.SSLSetEnabledCiphers.argtypes = [ +        SSLContextRef, +        POINTER(SSLCipherSuite), +        c_size_t +    ] +    Security.SSLSetEnabledCiphers.restype = OSStatus + +    Security.SSLGetNumberEnabledCiphers.argtype = [ +        SSLContextRef, +        POINTER(c_size_t) +    ] +    Security.SSLGetNumberEnabledCiphers.restype = OSStatus + +    Security.SSLGetEnabledCiphers.argtypes = [ +        SSLContextRef, +        POINTER(SSLCipherSuite), +        POINTER(c_size_t) +    ] +    Security.SSLGetEnabledCiphers.restype = OSStatus + +    Security.SSLGetNegotiatedCipher.argtypes = [ +        SSLContextRef, +        POINTER(SSLCipherSuite) +    ] +    Security.SSLGetNegotiatedCipher.restype = OSStatus + +    Security.SSLGetNegotiatedProtocolVersion.argtypes = [ +        SSLContextRef, +        POINTER(SSLProtocol) +    ] +    Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus + +    Security.SSLCopyPeerTrust.argtypes = [ +        SSLContextRef, +        POINTER(SecTrustRef) +    ] +    Security.SSLCopyPeerTrust.restype = OSStatus + +    Security.SecTrustSetAnchorCertificates.argtypes = [ +        SecTrustRef, +        CFArrayRef +    ] +    Security.SecTrustSetAnchorCertificates.restype = OSStatus + +    Security.SecTrustSetAnchorCertificatesOnly.argstypes = [ +        SecTrustRef, +        Boolean +    ] +    Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus + +    Security.SecTrustEvaluate.argtypes = [ +        SecTrustRef, +        POINTER(SecTrustResultType) +    ] +    Security.SecTrustEvaluate.restype = OSStatus + +    Security.SecTrustGetCertificateCount.argtypes = [ +        SecTrustRef +    ] +    Security.SecTrustGetCertificateCount.restype = CFIndex + +    Security.SecTrustGetCertificateAtIndex.argtypes = [ +        SecTrustRef, +        CFIndex +    ] +    Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef + +    Security.SSLCreateContext.argtypes = [ +        CFAllocatorRef, +        SSLProtocolSide, +        SSLConnectionType +    ] +    Security.SSLCreateContext.restype = SSLContextRef + +    Security.SSLSetSessionOption.argtypes = [ +        SSLContextRef, +        SSLSessionOption, +        Boolean +    ] +    Security.SSLSetSessionOption.restype = OSStatus + +    Security.SSLSetProtocolVersionMin.argtypes = [ +        SSLContextRef, +        SSLProtocol +    ] +    Security.SSLSetProtocolVersionMin.restype = OSStatus + +    Security.SSLSetProtocolVersionMax.argtypes = [ +        SSLContextRef, +        SSLProtocol +    ] +    Security.SSLSetProtocolVersionMax.restype = OSStatus + +    Security.SecCopyErrorMessageString.argtypes = [ +        OSStatus, +        c_void_p +    ] +    Security.SecCopyErrorMessageString.restype = CFStringRef + +    Security.SSLReadFunc = SSLReadFunc +    Security.SSLWriteFunc = SSLWriteFunc +    Security.SSLContextRef = SSLContextRef +    Security.SSLProtocol = SSLProtocol +    Security.SSLCipherSuite = SSLCipherSuite +    Security.SecIdentityRef = SecIdentityRef +    Security.SecKeychainRef = SecKeychainRef +    Security.SecTrustRef = SecTrustRef +    Security.SecTrustResultType = SecTrustResultType +    Security.SecExternalFormat = SecExternalFormat +    Security.OSStatus = OSStatus + +    Security.kSecImportExportPassphrase = CFStringRef.in_dll( +        Security, 'kSecImportExportPassphrase' +    ) +    Security.kSecImportItemIdentity = CFStringRef.in_dll( +        Security, 'kSecImportItemIdentity' +    ) + +    # CoreFoundation time! +    CoreFoundation.CFRetain.argtypes = [ +        CFTypeRef +    ] +    CoreFoundation.CFRetain.restype = CFTypeRef + +    CoreFoundation.CFRelease.argtypes = [ +        CFTypeRef +    ] +    CoreFoundation.CFRelease.restype = None + +    CoreFoundation.CFGetTypeID.argtypes = [ +        CFTypeRef +    ] +    CoreFoundation.CFGetTypeID.restype = CFTypeID + +    CoreFoundation.CFStringCreateWithCString.argtypes = [ +        CFAllocatorRef, +        c_char_p, +        CFStringEncoding +    ] +    CoreFoundation.CFStringCreateWithCString.restype = CFStringRef + +    CoreFoundation.CFStringGetCStringPtr.argtypes = [ +        CFStringRef, +        CFStringEncoding +    ] +    CoreFoundation.CFStringGetCStringPtr.restype = c_char_p + +    CoreFoundation.CFStringGetCString.argtypes = [ +        CFStringRef, +        c_char_p, +        CFIndex, +        CFStringEncoding +    ] +    CoreFoundation.CFStringGetCString.restype = c_bool + +    CoreFoundation.CFDataCreate.argtypes = [ +        CFAllocatorRef, +        c_char_p, +        CFIndex +    ] +    CoreFoundation.CFDataCreate.restype = CFDataRef + +    CoreFoundation.CFDataGetLength.argtypes = [ +        CFDataRef +    ] +    CoreFoundation.CFDataGetLength.restype = CFIndex + +    CoreFoundation.CFDataGetBytePtr.argtypes = [ +        CFDataRef +    ] +    CoreFoundation.CFDataGetBytePtr.restype = c_void_p + +    CoreFoundation.CFDictionaryCreate.argtypes = [ +        CFAllocatorRef, +        POINTER(CFTypeRef), +        POINTER(CFTypeRef), +        CFIndex, +        CFDictionaryKeyCallBacks, +        CFDictionaryValueCallBacks +    ] +    CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef + +    CoreFoundation.CFDictionaryGetValue.argtypes = [ +        CFDictionaryRef, +        CFTypeRef +    ] +    CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef + +    CoreFoundation.CFArrayCreate.argtypes = [ +        CFAllocatorRef, +        POINTER(CFTypeRef), +        CFIndex, +        CFArrayCallBacks, +    ] +    CoreFoundation.CFArrayCreate.restype = CFArrayRef + +    CoreFoundation.CFArrayCreateMutable.argtypes = [ +        CFAllocatorRef, +        CFIndex, +        CFArrayCallBacks +    ] +    CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef + +    CoreFoundation.CFArrayAppendValue.argtypes = [ +        CFMutableArrayRef, +        c_void_p +    ] +    CoreFoundation.CFArrayAppendValue.restype = None + +    CoreFoundation.CFArrayGetCount.argtypes = [ +        CFArrayRef +    ] +    CoreFoundation.CFArrayGetCount.restype = CFIndex + +    CoreFoundation.CFArrayGetValueAtIndex.argtypes = [ +        CFArrayRef, +        CFIndex +    ] +    CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p + +    CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll( +        CoreFoundation, 'kCFAllocatorDefault' +    ) +    CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(CoreFoundation, 'kCFTypeArrayCallBacks') +    CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll( +        CoreFoundation, 'kCFTypeDictionaryKeyCallBacks' +    ) +    CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll( +        CoreFoundation, 'kCFTypeDictionaryValueCallBacks' +    ) + +    CoreFoundation.CFTypeRef = CFTypeRef +    CoreFoundation.CFArrayRef = CFArrayRef +    CoreFoundation.CFStringRef = CFStringRef +    CoreFoundation.CFDictionaryRef = CFDictionaryRef + +except (AttributeError): +    raise ImportError('Error initializing ctypes') + + +class CFConst(object): +    """ +    A class object that acts as essentially a namespace for CoreFoundation +    constants. +    """ +    kCFStringEncodingUTF8 = CFStringEncoding(0x08000100) + + +class SecurityConst(object): +    """ +    A class object that acts as essentially a namespace for Security constants. +    """ +    kSSLSessionOptionBreakOnServerAuth = 0 + +    kSSLProtocol2 = 1 +    kSSLProtocol3 = 2 +    kTLSProtocol1 = 4 +    kTLSProtocol11 = 7 +    kTLSProtocol12 = 8 + +    kSSLClientSide = 1 +    kSSLStreamType = 0 + +    kSecFormatPEMSequence = 10 + +    kSecTrustResultInvalid = 0 +    kSecTrustResultProceed = 1 +    # This gap is present on purpose: this was kSecTrustResultConfirm, which +    # is deprecated. +    kSecTrustResultDeny = 3 +    kSecTrustResultUnspecified = 4 +    kSecTrustResultRecoverableTrustFailure = 5 +    kSecTrustResultFatalTrustFailure = 6 +    kSecTrustResultOtherError = 7 + +    errSSLProtocol = -9800 +    errSSLWouldBlock = -9803 +    errSSLClosedGraceful = -9805 +    errSSLClosedNoNotify = -9816 +    errSSLClosedAbort = -9806 + +    errSSLXCertChainInvalid = -9807 +    errSSLCrypto = -9809 +    errSSLInternal = -9810 +    errSSLCertExpired = -9814 +    errSSLCertNotYetValid = -9815 +    errSSLUnknownRootCert = -9812 +    errSSLNoRootCert = -9813 +    errSSLHostNameMismatch = -9843 +    errSSLPeerHandshakeFail = -9824 +    errSSLPeerUserCancelled = -9839 +    errSSLWeakPeerEphemeralDHKey = -9850 +    errSSLServerAuthCompleted = -9841 +    errSSLRecordOverflow = -9847 + +    errSecVerifyFailed = -67808 +    errSecNoTrustSettings = -25263 +    errSecItemNotFound = -25300 +    errSecInvalidTrustSettings = -25262 + +    # Cipher suites. We only pick the ones our default cipher string allows. +    TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C +    TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030 +    TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B +    TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F +    TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 = 0x00A3 +    TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F +    TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 = 0x00A2 +    TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E +    TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024 +    TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028 +    TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A +    TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014 +    TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B +    TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 = 0x006A +    TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039 +    TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0x0038 +    TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023 +    TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027 +    TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009 +    TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013 +    TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067 +    TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 = 0x0040 +    TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033 +    TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0x0032 +    TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D +    TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C +    TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D +    TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C +    TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035 +    TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F +    TLS_AES_128_GCM_SHA256 = 0x1301 +    TLS_AES_256_GCM_SHA384 = 0x1302 +    TLS_CHACHA20_POLY1305_SHA256 = 0x1303 diff --git a/python/urllib3/contrib/_securetransport/low_level.py b/python/urllib3/contrib/_securetransport/low_level.py new file mode 100644 index 0000000..b13cd9e --- /dev/null +++ b/python/urllib3/contrib/_securetransport/low_level.py @@ -0,0 +1,346 @@ +""" +Low-level helpers for the SecureTransport bindings. + +These are Python functions that are not directly related to the high-level APIs +but are necessary to get them to work. They include a whole bunch of low-level +CoreFoundation messing about and memory management. The concerns in this module +are almost entirely about trying to avoid memory leaks and providing +appropriate and useful assistance to the higher-level code. +""" +import base64 +import ctypes +import itertools +import re +import os +import ssl +import tempfile + +from .bindings import Security, CoreFoundation, CFConst + + +# This regular expression is used to grab PEM data out of a PEM bundle. +_PEM_CERTS_RE = re.compile( +    b"-----BEGIN CERTIFICATE-----\n(.*?)\n-----END CERTIFICATE-----", re.DOTALL +) + + +def _cf_data_from_bytes(bytestring): +    """ +    Given a bytestring, create a CFData object from it. This CFData object must +    be CFReleased by the caller. +    """ +    return CoreFoundation.CFDataCreate( +        CoreFoundation.kCFAllocatorDefault, bytestring, len(bytestring) +    ) + + +def _cf_dictionary_from_tuples(tuples): +    """ +    Given a list of Python tuples, create an associated CFDictionary. +    """ +    dictionary_size = len(tuples) + +    # We need to get the dictionary keys and values out in the same order. +    keys = (t[0] for t in tuples) +    values = (t[1] for t in tuples) +    cf_keys = (CoreFoundation.CFTypeRef * dictionary_size)(*keys) +    cf_values = (CoreFoundation.CFTypeRef * dictionary_size)(*values) + +    return CoreFoundation.CFDictionaryCreate( +        CoreFoundation.kCFAllocatorDefault, +        cf_keys, +        cf_values, +        dictionary_size, +        CoreFoundation.kCFTypeDictionaryKeyCallBacks, +        CoreFoundation.kCFTypeDictionaryValueCallBacks, +    ) + + +def _cf_string_to_unicode(value): +    """ +    Creates a Unicode string from a CFString object. Used entirely for error +    reporting. + +    Yes, it annoys me quite a lot that this function is this complex. +    """ +    value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p)) + +    string = CoreFoundation.CFStringGetCStringPtr( +        value_as_void_p, +        CFConst.kCFStringEncodingUTF8 +    ) +    if string is None: +        buffer = ctypes.create_string_buffer(1024) +        result = CoreFoundation.CFStringGetCString( +            value_as_void_p, +            buffer, +            1024, +            CFConst.kCFStringEncodingUTF8 +        ) +        if not result: +            raise OSError('Error copying C string from CFStringRef') +        string = buffer.value +    if string is not None: +        string = string.decode('utf-8') +    return string + + +def _assert_no_error(error, exception_class=None): +    """ +    Checks the return code and throws an exception if there is an error to +    report +    """ +    if error == 0: +        return + +    cf_error_string = Security.SecCopyErrorMessageString(error, None) +    output = _cf_string_to_unicode(cf_error_string) +    CoreFoundation.CFRelease(cf_error_string) + +    if output is None or output == u'': +        output = u'OSStatus %s' % error + +    if exception_class is None: +        exception_class = ssl.SSLError + +    raise exception_class(output) + + +def _cert_array_from_pem(pem_bundle): +    """ +    Given a bundle of certs in PEM format, turns them into a CFArray of certs +    that can be used to validate a cert chain. +    """ +    # Normalize the PEM bundle's line endings. +    pem_bundle = pem_bundle.replace(b"\r\n", b"\n") + +    der_certs = [ +        base64.b64decode(match.group(1)) +        for match in _PEM_CERTS_RE.finditer(pem_bundle) +    ] +    if not der_certs: +        raise ssl.SSLError("No root certificates specified") + +    cert_array = CoreFoundation.CFArrayCreateMutable( +        CoreFoundation.kCFAllocatorDefault, +        0, +        ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks) +    ) +    if not cert_array: +        raise ssl.SSLError("Unable to allocate memory!") + +    try: +        for der_bytes in der_certs: +            certdata = _cf_data_from_bytes(der_bytes) +            if not certdata: +                raise ssl.SSLError("Unable to allocate memory!") +            cert = Security.SecCertificateCreateWithData( +                CoreFoundation.kCFAllocatorDefault, certdata +            ) +            CoreFoundation.CFRelease(certdata) +            if not cert: +                raise ssl.SSLError("Unable to build cert object!") + +            CoreFoundation.CFArrayAppendValue(cert_array, cert) +            CoreFoundation.CFRelease(cert) +    except Exception: +        # We need to free the array before the exception bubbles further. +        # We only want to do that if an error occurs: otherwise, the caller +        # should free. +        CoreFoundation.CFRelease(cert_array) + +    return cert_array + + +def _is_cert(item): +    """ +    Returns True if a given CFTypeRef is a certificate. +    """ +    expected = Security.SecCertificateGetTypeID() +    return CoreFoundation.CFGetTypeID(item) == expected + + +def _is_identity(item): +    """ +    Returns True if a given CFTypeRef is an identity. +    """ +    expected = Security.SecIdentityGetTypeID() +    return CoreFoundation.CFGetTypeID(item) == expected + + +def _temporary_keychain(): +    """ +    This function creates a temporary Mac keychain that we can use to work with +    credentials. This keychain uses a one-time password and a temporary file to +    store the data. We expect to have one keychain per socket. The returned +    SecKeychainRef must be freed by the caller, including calling +    SecKeychainDelete. + +    Returns a tuple of the SecKeychainRef and the path to the temporary +    directory that contains it. +    """ +    # Unfortunately, SecKeychainCreate requires a path to a keychain. This +    # means we cannot use mkstemp to use a generic temporary file. Instead, +    # we're going to create a temporary directory and a filename to use there. +    # This filename will be 8 random bytes expanded into base64. We also need +    # some random bytes to password-protect the keychain we're creating, so we +    # ask for 40 random bytes. +    random_bytes = os.urandom(40) +    filename = base64.b16encode(random_bytes[:8]).decode('utf-8') +    password = base64.b16encode(random_bytes[8:])  # Must be valid UTF-8 +    tempdirectory = tempfile.mkdtemp() + +    keychain_path = os.path.join(tempdirectory, filename).encode('utf-8') + +    # We now want to create the keychain itself. +    keychain = Security.SecKeychainRef() +    status = Security.SecKeychainCreate( +        keychain_path, +        len(password), +        password, +        False, +        None, +        ctypes.byref(keychain) +    ) +    _assert_no_error(status) + +    # Having created the keychain, we want to pass it off to the caller. +    return keychain, tempdirectory + + +def _load_items_from_file(keychain, path): +    """ +    Given a single file, loads all the trust objects from it into arrays and +    the keychain. +    Returns a tuple of lists: the first list is a list of identities, the +    second a list of certs. +    """ +    certificates = [] +    identities = [] +    result_array = None + +    with open(path, 'rb') as f: +        raw_filedata = f.read() + +    try: +        filedata = CoreFoundation.CFDataCreate( +            CoreFoundation.kCFAllocatorDefault, +            raw_filedata, +            len(raw_filedata) +        ) +        result_array = CoreFoundation.CFArrayRef() +        result = Security.SecItemImport( +            filedata,  # cert data +            None,  # Filename, leaving it out for now +            None,  # What the type of the file is, we don't care +            None,  # what's in the file, we don't care +            0,  # import flags +            None,  # key params, can include passphrase in the future +            keychain,  # The keychain to insert into +            ctypes.byref(result_array)  # Results +        ) +        _assert_no_error(result) + +        # A CFArray is not very useful to us as an intermediary +        # representation, so we are going to extract the objects we want +        # and then free the array. We don't need to keep hold of keys: the +        # keychain already has them! +        result_count = CoreFoundation.CFArrayGetCount(result_array) +        for index in range(result_count): +            item = CoreFoundation.CFArrayGetValueAtIndex( +                result_array, index +            ) +            item = ctypes.cast(item, CoreFoundation.CFTypeRef) + +            if _is_cert(item): +                CoreFoundation.CFRetain(item) +                certificates.append(item) +            elif _is_identity(item): +                CoreFoundation.CFRetain(item) +                identities.append(item) +    finally: +        if result_array: +            CoreFoundation.CFRelease(result_array) + +        CoreFoundation.CFRelease(filedata) + +    return (identities, certificates) + + +def _load_client_cert_chain(keychain, *paths): +    """ +    Load certificates and maybe keys from a number of files. Has the end goal +    of returning a CFArray containing one SecIdentityRef, and then zero or more +    SecCertificateRef objects, suitable for use as a client certificate trust +    chain. +    """ +    # Ok, the strategy. +    # +    # This relies on knowing that macOS will not give you a SecIdentityRef +    # unless you have imported a key into a keychain. This is a somewhat +    # artificial limitation of macOS (for example, it doesn't necessarily +    # affect iOS), but there is nothing inside Security.framework that lets you +    # get a SecIdentityRef without having a key in a keychain. +    # +    # So the policy here is we take all the files and iterate them in order. +    # Each one will use SecItemImport to have one or more objects loaded from +    # it. We will also point at a keychain that macOS can use to work with the +    # private key. +    # +    # Once we have all the objects, we'll check what we actually have. If we +    # already have a SecIdentityRef in hand, fab: we'll use that. Otherwise, +    # we'll take the first certificate (which we assume to be our leaf) and +    # ask the keychain to give us a SecIdentityRef with that cert's associated +    # key. +    # +    # We'll then return a CFArray containing the trust chain: one +    # SecIdentityRef and then zero-or-more SecCertificateRef objects. The +    # responsibility for freeing this CFArray will be with the caller. This +    # CFArray must remain alive for the entire connection, so in practice it +    # will be stored with a single SSLSocket, along with the reference to the +    # keychain. +    certificates = [] +    identities = [] + +    # Filter out bad paths. +    paths = (path for path in paths if path) + +    try: +        for file_path in paths: +            new_identities, new_certs = _load_items_from_file( +                keychain, file_path +            ) +            identities.extend(new_identities) +            certificates.extend(new_certs) + +        # Ok, we have everything. The question is: do we have an identity? If +        # not, we want to grab one from the first cert we have. +        if not identities: +            new_identity = Security.SecIdentityRef() +            status = Security.SecIdentityCreateWithCertificate( +                keychain, +                certificates[0], +                ctypes.byref(new_identity) +            ) +            _assert_no_error(status) +            identities.append(new_identity) + +            # We now want to release the original certificate, as we no longer +            # need it. +            CoreFoundation.CFRelease(certificates.pop(0)) + +        # We now need to build a new CFArray that holds the trust chain. +        trust_chain = CoreFoundation.CFArrayCreateMutable( +            CoreFoundation.kCFAllocatorDefault, +            0, +            ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks), +        ) +        for item in itertools.chain(identities, certificates): +            # ArrayAppendValue does a CFRetain on the item. That's fine, +            # because the finally block will release our other refs to them. +            CoreFoundation.CFArrayAppendValue(trust_chain, item) + +        return trust_chain +    finally: +        for obj in itertools.chain(identities, certificates): +            CoreFoundation.CFRelease(obj) diff --git a/python/urllib3/contrib/appengine.py b/python/urllib3/contrib/appengine.py new file mode 100644 index 0000000..2952f11 --- /dev/null +++ b/python/urllib3/contrib/appengine.py @@ -0,0 +1,289 @@ +""" +This module provides a pool manager that uses Google App Engine's +`URLFetch Service <https://cloud.google.com/appengine/docs/python/urlfetch>`_. + +Example usage:: + +    from urllib3 import PoolManager +    from urllib3.contrib.appengine import AppEngineManager, is_appengine_sandbox + +    if is_appengine_sandbox(): +        # AppEngineManager uses AppEngine's URLFetch API behind the scenes +        http = AppEngineManager() +    else: +        # PoolManager uses a socket-level API behind the scenes +        http = PoolManager() + +    r = http.request('GET', 'https://google.com/') + +There are `limitations <https://cloud.google.com/appengine/docs/python/\ +urlfetch/#Python_Quotas_and_limits>`_ to the URLFetch service and it may not be +the best choice for your application. There are three options for using +urllib3 on Google App Engine: + +1. You can use :class:`AppEngineManager` with URLFetch. URLFetch is +   cost-effective in many circumstances as long as your usage is within the +   limitations. +2. You can use a normal :class:`~urllib3.PoolManager` by enabling sockets. +   Sockets also have `limitations and restrictions +   <https://cloud.google.com/appengine/docs/python/sockets/\ +   #limitations-and-restrictions>`_ and have a lower free quota than URLFetch. +   To use sockets, be sure to specify the following in your ``app.yaml``:: + +        env_variables: +            GAE_USE_SOCKETS_HTTPLIB : 'true' + +3. If you are using `App Engine Flexible +<https://cloud.google.com/appengine/docs/flexible/>`_, you can use the standard +:class:`PoolManager` without any configuration or special environment variables. +""" + +from __future__ import absolute_import +import io +import logging +import warnings +from ..packages.six.moves.urllib.parse import urljoin + +from ..exceptions import ( +    HTTPError, +    HTTPWarning, +    MaxRetryError, +    ProtocolError, +    TimeoutError, +    SSLError +) + +from ..request import RequestMethods +from ..response import HTTPResponse +from ..util.timeout import Timeout +from ..util.retry import Retry +from . import _appengine_environ + +try: +    from google.appengine.api import urlfetch +except ImportError: +    urlfetch = None + + +log = logging.getLogger(__name__) + + +class AppEnginePlatformWarning(HTTPWarning): +    pass + + +class AppEnginePlatformError(HTTPError): +    pass + + +class AppEngineManager(RequestMethods): +    """ +    Connection manager for Google App Engine sandbox applications. + +    This manager uses the URLFetch service directly instead of using the +    emulated httplib, and is subject to URLFetch limitations as described in +    the App Engine documentation `here +    <https://cloud.google.com/appengine/docs/python/urlfetch>`_. + +    Notably it will raise an :class:`AppEnginePlatformError` if: +        * URLFetch is not available. +        * If you attempt to use this on App Engine Flexible, as full socket +          support is available. +        * If a request size is more than 10 megabytes. +        * If a response size is more than 32 megabtyes. +        * If you use an unsupported request method such as OPTIONS. + +    Beyond those cases, it will raise normal urllib3 errors. +    """ + +    def __init__(self, headers=None, retries=None, validate_certificate=True, +                 urlfetch_retries=True): +        if not urlfetch: +            raise AppEnginePlatformError( +                "URLFetch is not available in this environment.") + +        if is_prod_appengine_mvms(): +            raise AppEnginePlatformError( +                "Use normal urllib3.PoolManager instead of AppEngineManager" +                "on Managed VMs, as using URLFetch is not necessary in " +                "this environment.") + +        warnings.warn( +            "urllib3 is using URLFetch on Google App Engine sandbox instead " +            "of sockets. To use sockets directly instead of URLFetch see " +            "https://urllib3.readthedocs.io/en/latest/reference/urllib3.contrib.html.", +            AppEnginePlatformWarning) + +        RequestMethods.__init__(self, headers) +        self.validate_certificate = validate_certificate +        self.urlfetch_retries = urlfetch_retries + +        self.retries = retries or Retry.DEFAULT + +    def __enter__(self): +        return self + +    def __exit__(self, exc_type, exc_val, exc_tb): +        # Return False to re-raise any potential exceptions +        return False + +    def urlopen(self, method, url, body=None, headers=None, +                retries=None, redirect=True, timeout=Timeout.DEFAULT_TIMEOUT, +                **response_kw): + +        retries = self._get_retries(retries, redirect) + +        try: +            follow_redirects = ( +                    redirect and +                    retries.redirect != 0 and +                    retries.total) +            response = urlfetch.fetch( +                url, +                payload=body, +                method=method, +                headers=headers or {}, +                allow_truncated=False, +                follow_redirects=self.urlfetch_retries and follow_redirects, +                deadline=self._get_absolute_timeout(timeout), +                validate_certificate=self.validate_certificate, +            ) +        except urlfetch.DeadlineExceededError as e: +            raise TimeoutError(self, e) + +        except urlfetch.InvalidURLError as e: +            if 'too large' in str(e): +                raise AppEnginePlatformError( +                    "URLFetch request too large, URLFetch only " +                    "supports requests up to 10mb in size.", e) +            raise ProtocolError(e) + +        except urlfetch.DownloadError as e: +            if 'Too many redirects' in str(e): +                raise MaxRetryError(self, url, reason=e) +            raise ProtocolError(e) + +        except urlfetch.ResponseTooLargeError as e: +            raise AppEnginePlatformError( +                "URLFetch response too large, URLFetch only supports" +                "responses up to 32mb in size.", e) + +        except urlfetch.SSLCertificateError as e: +            raise SSLError(e) + +        except urlfetch.InvalidMethodError as e: +            raise AppEnginePlatformError( +                "URLFetch does not support method: %s" % method, e) + +        http_response = self._urlfetch_response_to_http_response( +            response, retries=retries, **response_kw) + +        # Handle redirect? +        redirect_location = redirect and http_response.get_redirect_location() +        if redirect_location: +            # Check for redirect response +            if (self.urlfetch_retries and retries.raise_on_redirect): +                raise MaxRetryError(self, url, "too many redirects") +            else: +                if http_response.status == 303: +                    method = 'GET' + +                try: +                    retries = retries.increment(method, url, response=http_response, _pool=self) +                except MaxRetryError: +                    if retries.raise_on_redirect: +                        raise MaxRetryError(self, url, "too many redirects") +                    return http_response + +                retries.sleep_for_retry(http_response) +                log.debug("Redirecting %s -> %s", url, redirect_location) +                redirect_url = urljoin(url, redirect_location) +                return self.urlopen( +                    method, redirect_url, body, headers, +                    retries=retries, redirect=redirect, +                    timeout=timeout, **response_kw) + +        # Check if we should retry the HTTP response. +        has_retry_after = bool(http_response.getheader('Retry-After')) +        if retries.is_retry(method, http_response.status, has_retry_after): +            retries = retries.increment( +                method, url, response=http_response, _pool=self) +            log.debug("Retry: %s", url) +            retries.sleep(http_response) +            return self.urlopen( +                method, url, +                body=body, headers=headers, +                retries=retries, redirect=redirect, +                timeout=timeout, **response_kw) + +        return http_response + +    def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw): + +        if is_prod_appengine(): +            # Production GAE handles deflate encoding automatically, but does +            # not remove the encoding header. +            content_encoding = urlfetch_resp.headers.get('content-encoding') + +            if content_encoding == 'deflate': +                del urlfetch_resp.headers['content-encoding'] + +        transfer_encoding = urlfetch_resp.headers.get('transfer-encoding') +        # We have a full response's content, +        # so let's make sure we don't report ourselves as chunked data. +        if transfer_encoding == 'chunked': +            encodings = transfer_encoding.split(",") +            encodings.remove('chunked') +            urlfetch_resp.headers['transfer-encoding'] = ','.join(encodings) + +        original_response = HTTPResponse( +            # In order for decoding to work, we must present the content as +            # a file-like object. +            body=io.BytesIO(urlfetch_resp.content), +            msg=urlfetch_resp.header_msg, +            headers=urlfetch_resp.headers, +            status=urlfetch_resp.status_code, +            **response_kw +        ) + +        return HTTPResponse( +            body=io.BytesIO(urlfetch_resp.content), +            headers=urlfetch_resp.headers, +            status=urlfetch_resp.status_code, +            original_response=original_response, +            **response_kw +        ) + +    def _get_absolute_timeout(self, timeout): +        if timeout is Timeout.DEFAULT_TIMEOUT: +            return None  # Defer to URLFetch's default. +        if isinstance(timeout, Timeout): +            if timeout._read is not None or timeout._connect is not None: +                warnings.warn( +                    "URLFetch does not support granular timeout settings, " +                    "reverting to total or default URLFetch timeout.", +                    AppEnginePlatformWarning) +            return timeout.total +        return timeout + +    def _get_retries(self, retries, redirect): +        if not isinstance(retries, Retry): +            retries = Retry.from_int( +                retries, redirect=redirect, default=self.retries) + +        if retries.connect or retries.read or retries.redirect: +            warnings.warn( +                "URLFetch only supports total retries and does not " +                "recognize connect, read, or redirect retry parameters.", +                AppEnginePlatformWarning) + +        return retries + + +# Alias methods from _appengine_environ to maintain public API interface. + +is_appengine = _appengine_environ.is_appengine +is_appengine_sandbox = _appengine_environ.is_appengine_sandbox +is_local_appengine = _appengine_environ.is_local_appengine +is_prod_appengine = _appengine_environ.is_prod_appengine +is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms diff --git a/python/urllib3/contrib/ntlmpool.py b/python/urllib3/contrib/ntlmpool.py new file mode 100644 index 0000000..8ea127c --- /dev/null +++ b/python/urllib3/contrib/ntlmpool.py @@ -0,0 +1,111 @@ +""" +NTLM authenticating pool, contributed by erikcederstran + +Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10 +""" +from __future__ import absolute_import + +from logging import getLogger +from ntlm import ntlm + +from .. import HTTPSConnectionPool +from ..packages.six.moves.http_client import HTTPSConnection + + +log = getLogger(__name__) + + +class NTLMConnectionPool(HTTPSConnectionPool): +    """ +    Implements an NTLM authentication version of an urllib3 connection pool +    """ + +    scheme = 'https' + +    def __init__(self, user, pw, authurl, *args, **kwargs): +        """ +        authurl is a random URL on the server that is protected by NTLM. +        user is the Windows user, probably in the DOMAIN\\username format. +        pw is the password for the user. +        """ +        super(NTLMConnectionPool, self).__init__(*args, **kwargs) +        self.authurl = authurl +        self.rawuser = user +        user_parts = user.split('\\', 1) +        self.domain = user_parts[0].upper() +        self.user = user_parts[1] +        self.pw = pw + +    def _new_conn(self): +        # Performs the NTLM handshake that secures the connection. The socket +        # must be kept open while requests are performed. +        self.num_connections += 1 +        log.debug('Starting NTLM HTTPS connection no. %d: https://%s%s', +                  self.num_connections, self.host, self.authurl) + +        headers = {'Connection': 'Keep-Alive'} +        req_header = 'Authorization' +        resp_header = 'www-authenticate' + +        conn = HTTPSConnection(host=self.host, port=self.port) + +        # Send negotiation message +        headers[req_header] = ( +            'NTLM %s' % ntlm.create_NTLM_NEGOTIATE_MESSAGE(self.rawuser)) +        log.debug('Request headers: %s', headers) +        conn.request('GET', self.authurl, None, headers) +        res = conn.getresponse() +        reshdr = dict(res.getheaders()) +        log.debug('Response status: %s %s', res.status, res.reason) +        log.debug('Response headers: %s', reshdr) +        log.debug('Response data: %s [...]', res.read(100)) + +        # Remove the reference to the socket, so that it can not be closed by +        # the response object (we want to keep the socket open) +        res.fp = None + +        # Server should respond with a challenge message +        auth_header_values = reshdr[resp_header].split(', ') +        auth_header_value = None +        for s in auth_header_values: +            if s[:5] == 'NTLM ': +                auth_header_value = s[5:] +        if auth_header_value is None: +            raise Exception('Unexpected %s response header: %s' % +                            (resp_header, reshdr[resp_header])) + +        # Send authentication message +        ServerChallenge, NegotiateFlags = \ +            ntlm.parse_NTLM_CHALLENGE_MESSAGE(auth_header_value) +        auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(ServerChallenge, +                                                         self.user, +                                                         self.domain, +                                                         self.pw, +                                                         NegotiateFlags) +        headers[req_header] = 'NTLM %s' % auth_msg +        log.debug('Request headers: %s', headers) +        conn.request('GET', self.authurl, None, headers) +        res = conn.getresponse() +        log.debug('Response status: %s %s', res.status, res.reason) +        log.debug('Response headers: %s', dict(res.getheaders())) +        log.debug('Response data: %s [...]', res.read()[:100]) +        if res.status != 200: +            if res.status == 401: +                raise Exception('Server rejected request: wrong ' +                                'username or password') +            raise Exception('Wrong server response: %s %s' % +                            (res.status, res.reason)) + +        res.fp = None +        log.debug('Connection established') +        return conn + +    def urlopen(self, method, url, body=None, headers=None, retries=3, +                redirect=True, assert_same_host=True): +        if headers is None: +            headers = {} +        headers['Connection'] = 'Keep-Alive' +        return super(NTLMConnectionPool, self).urlopen(method, url, body, +                                                       headers, retries, +                                                       redirect, +                                                       assert_same_host) diff --git a/python/urllib3/contrib/pyopenssl.py b/python/urllib3/contrib/pyopenssl.py new file mode 100644 index 0000000..7c0e946 --- /dev/null +++ b/python/urllib3/contrib/pyopenssl.py @@ -0,0 +1,466 @@ +""" +SSL with SNI_-support for Python 2. Follow these instructions if you would +like to verify SSL certificates in Python 2. Note, the default libraries do +*not* do certificate checking; you need to do additional work to validate +certificates yourself. + +This needs the following packages installed: + +* pyOpenSSL (tested with 16.0.0) +* cryptography (minimum 1.3.4, from pyopenssl) +* idna (minimum 2.0, from cryptography) + +However, pyopenssl depends on cryptography, which depends on idna, so while we +use all three directly here we end up having relatively few packages required. + +You can install them with the following command: + +    pip install pyopenssl cryptography idna + +To activate certificate checking, call +:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code +before you begin making HTTP requests. This can be done in a ``sitecustomize`` +module, or at any other time before your application begins using ``urllib3``, +like this:: + +    try: +        import urllib3.contrib.pyopenssl +        urllib3.contrib.pyopenssl.inject_into_urllib3() +    except ImportError: +        pass + +Now you can use :mod:`urllib3` as you normally would, and it will support SNI +when the required modules are installed. + +Activating this module also has the positive side effect of disabling SSL/TLS +compression in Python 2 (see `CRIME attack`_). + +If you want to configure the default list of supported cipher suites, you can +set the ``urllib3.contrib.pyopenssl.DEFAULT_SSL_CIPHER_LIST`` variable. + +.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication +.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit) +""" +from __future__ import absolute_import + +import OpenSSL.SSL +from cryptography import x509 +from cryptography.hazmat.backends.openssl import backend as openssl_backend +from cryptography.hazmat.backends.openssl.x509 import _Certificate +try: +    from cryptography.x509 import UnsupportedExtension +except ImportError: +    # UnsupportedExtension is gone in cryptography >= 2.1.0 +    class UnsupportedExtension(Exception): +        pass + +from socket import timeout, error as SocketError +from io import BytesIO + +try:  # Platform-specific: Python 2 +    from socket import _fileobject +except ImportError:  # Platform-specific: Python 3 +    _fileobject = None +    from ..packages.backports.makefile import backport_makefile + +import logging +import ssl +from ..packages import six +import sys + +from .. import util + +__all__ = ['inject_into_urllib3', 'extract_from_urllib3'] + +# SNI always works. +HAS_SNI = True + +# Map from urllib3 to PyOpenSSL compatible parameter-values. +_openssl_versions = { +    ssl.PROTOCOL_SSLv23: OpenSSL.SSL.SSLv23_METHOD, +    ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, +} + +if hasattr(ssl, 'PROTOCOL_TLSv1_1') and hasattr(OpenSSL.SSL, 'TLSv1_1_METHOD'): +    _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD + +if hasattr(ssl, 'PROTOCOL_TLSv1_2') and hasattr(OpenSSL.SSL, 'TLSv1_2_METHOD'): +    _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD + +try: +    _openssl_versions.update({ssl.PROTOCOL_SSLv3: OpenSSL.SSL.SSLv3_METHOD}) +except AttributeError: +    pass + +_stdlib_to_openssl_verify = { +    ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, +    ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, +    ssl.CERT_REQUIRED: +        OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, +} +_openssl_to_stdlib_verify = dict( +    (v, k) for k, v in _stdlib_to_openssl_verify.items() +) + +# OpenSSL will only write 16K at a time +SSL_WRITE_BLOCKSIZE = 16384 + +orig_util_HAS_SNI = util.HAS_SNI +orig_util_SSLContext = util.ssl_.SSLContext + + +log = logging.getLogger(__name__) + + +def inject_into_urllib3(): +    'Monkey-patch urllib3 with PyOpenSSL-backed SSL-support.' + +    _validate_dependencies_met() + +    util.ssl_.SSLContext = PyOpenSSLContext +    util.HAS_SNI = HAS_SNI +    util.ssl_.HAS_SNI = HAS_SNI +    util.IS_PYOPENSSL = True +    util.ssl_.IS_PYOPENSSL = True + + +def extract_from_urllib3(): +    'Undo monkey-patching by :func:`inject_into_urllib3`.' + +    util.ssl_.SSLContext = orig_util_SSLContext +    util.HAS_SNI = orig_util_HAS_SNI +    util.ssl_.HAS_SNI = orig_util_HAS_SNI +    util.IS_PYOPENSSL = False +    util.ssl_.IS_PYOPENSSL = False + + +def _validate_dependencies_met(): +    """ +    Verifies that PyOpenSSL's package-level dependencies have been met. +    Throws `ImportError` if they are not met. +    """ +    # Method added in `cryptography==1.1`; not available in older versions +    from cryptography.x509.extensions import Extensions +    if getattr(Extensions, "get_extension_for_class", None) is None: +        raise ImportError("'cryptography' module missing required functionality.  " +                          "Try upgrading to v1.3.4 or newer.") + +    # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509 +    # attribute is only present on those versions. +    from OpenSSL.crypto import X509 +    x509 = X509() +    if getattr(x509, "_x509", None) is None: +        raise ImportError("'pyOpenSSL' module missing required functionality. " +                          "Try upgrading to v0.14 or newer.") + + +def _dnsname_to_stdlib(name): +    """ +    Converts a dNSName SubjectAlternativeName field to the form used by the +    standard library on the given Python version. + +    Cryptography produces a dNSName as a unicode string that was idna-decoded +    from ASCII bytes. We need to idna-encode that string to get it back, and +    then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib +    uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8). + +    If the name cannot be idna-encoded then we return None signalling that +    the name given should be skipped. +    """ +    def idna_encode(name): +        """ +        Borrowed wholesale from the Python Cryptography Project. It turns out +        that we can't just safely call `idna.encode`: it can explode for +        wildcard names. This avoids that problem. +        """ +        import idna + +        try: +            for prefix in [u'*.', u'.']: +                if name.startswith(prefix): +                    name = name[len(prefix):] +                    return prefix.encode('ascii') + idna.encode(name) +            return idna.encode(name) +        except idna.core.IDNAError: +            return None + +    name = idna_encode(name) +    if name is None: +        return None +    elif sys.version_info >= (3, 0): +        name = name.decode('utf-8') +    return name + + +def get_subj_alt_name(peer_cert): +    """ +    Given an PyOpenSSL certificate, provides all the subject alternative names. +    """ +    # Pass the cert to cryptography, which has much better APIs for this. +    if hasattr(peer_cert, "to_cryptography"): +        cert = peer_cert.to_cryptography() +    else: +        # This is technically using private APIs, but should work across all +        # relevant versions before PyOpenSSL got a proper API for this. +        cert = _Certificate(openssl_backend, peer_cert._x509) + +    # We want to find the SAN extension. Ask Cryptography to locate it (it's +    # faster than looping in Python) +    try: +        ext = cert.extensions.get_extension_for_class( +            x509.SubjectAlternativeName +        ).value +    except x509.ExtensionNotFound: +        # No such extension, return the empty list. +        return [] +    except (x509.DuplicateExtension, UnsupportedExtension, +            x509.UnsupportedGeneralNameType, UnicodeError) as e: +        # A problem has been found with the quality of the certificate. Assume +        # no SAN field is present. +        log.warning( +            "A problem was encountered with the certificate that prevented " +            "urllib3 from finding the SubjectAlternativeName field. This can " +            "affect certificate validation. The error was %s", +            e, +        ) +        return [] + +    # We want to return dNSName and iPAddress fields. We need to cast the IPs +    # back to strings because the match_hostname function wants them as +    # strings. +    # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8 +    # decoded. This is pretty frustrating, but that's what the standard library +    # does with certificates, and so we need to attempt to do the same. +    # We also want to skip over names which cannot be idna encoded. +    names = [ +        ('DNS', name) for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName)) +        if name is not None +    ] +    names.extend( +        ('IP Address', str(name)) +        for name in ext.get_values_for_type(x509.IPAddress) +    ) + +    return names + + +class WrappedSocket(object): +    '''API-compatibility wrapper for Python OpenSSL's Connection-class. + +    Note: _makefile_refs, _drop() and _reuse() are needed for the garbage +    collector of pypy. +    ''' + +    def __init__(self, connection, socket, suppress_ragged_eofs=True): +        self.connection = connection +        self.socket = socket +        self.suppress_ragged_eofs = suppress_ragged_eofs +        self._makefile_refs = 0 +        self._closed = False + +    def fileno(self): +        return self.socket.fileno() + +    # Copy-pasted from Python 3.5 source code +    def _decref_socketios(self): +        if self._makefile_refs > 0: +            self._makefile_refs -= 1 +        if self._closed: +            self.close() + +    def recv(self, *args, **kwargs): +        try: +            data = self.connection.recv(*args, **kwargs) +        except OpenSSL.SSL.SysCallError as e: +            if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'): +                return b'' +            else: +                raise SocketError(str(e)) +        except OpenSSL.SSL.ZeroReturnError as e: +            if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: +                return b'' +            else: +                raise +        except OpenSSL.SSL.WantReadError: +            if not util.wait_for_read(self.socket, self.socket.gettimeout()): +                raise timeout('The read operation timed out') +            else: +                return self.recv(*args, **kwargs) +        else: +            return data + +    def recv_into(self, *args, **kwargs): +        try: +            return self.connection.recv_into(*args, **kwargs) +        except OpenSSL.SSL.SysCallError as e: +            if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'): +                return 0 +            else: +                raise SocketError(str(e)) +        except OpenSSL.SSL.ZeroReturnError as e: +            if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: +                return 0 +            else: +                raise +        except OpenSSL.SSL.WantReadError: +            if not util.wait_for_read(self.socket, self.socket.gettimeout()): +                raise timeout('The read operation timed out') +            else: +                return self.recv_into(*args, **kwargs) + +    def settimeout(self, timeout): +        return self.socket.settimeout(timeout) + +    def _send_until_done(self, data): +        while True: +            try: +                return self.connection.send(data) +            except OpenSSL.SSL.WantWriteError: +                if not util.wait_for_write(self.socket, self.socket.gettimeout()): +                    raise timeout() +                continue +            except OpenSSL.SSL.SysCallError as e: +                raise SocketError(str(e)) + +    def sendall(self, data): +        total_sent = 0 +        while total_sent < len(data): +            sent = self._send_until_done(data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE]) +            total_sent += sent + +    def shutdown(self): +        # FIXME rethrow compatible exceptions should we ever use this +        self.connection.shutdown() + +    def close(self): +        if self._makefile_refs < 1: +            try: +                self._closed = True +                return self.connection.close() +            except OpenSSL.SSL.Error: +                return +        else: +            self._makefile_refs -= 1 + +    def getpeercert(self, binary_form=False): +        x509 = self.connection.get_peer_certificate() + +        if not x509: +            return x509 + +        if binary_form: +            return OpenSSL.crypto.dump_certificate( +                OpenSSL.crypto.FILETYPE_ASN1, +                x509) + +        return { +            'subject': ( +                (('commonName', x509.get_subject().CN),), +            ), +            'subjectAltName': get_subj_alt_name(x509) +        } + +    def _reuse(self): +        self._makefile_refs += 1 + +    def _drop(self): +        if self._makefile_refs < 1: +            self.close() +        else: +            self._makefile_refs -= 1 + + +if _fileobject:  # Platform-specific: Python 2 +    def makefile(self, mode, bufsize=-1): +        self._makefile_refs += 1 +        return _fileobject(self, mode, bufsize, close=True) +else:  # Platform-specific: Python 3 +    makefile = backport_makefile + +WrappedSocket.makefile = makefile + + +class PyOpenSSLContext(object): +    """ +    I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible +    for translating the interface of the standard library ``SSLContext`` object +    to calls into PyOpenSSL. +    """ +    def __init__(self, protocol): +        self.protocol = _openssl_versions[protocol] +        self._ctx = OpenSSL.SSL.Context(self.protocol) +        self._options = 0 +        self.check_hostname = False + +    @property +    def options(self): +        return self._options + +    @options.setter +    def options(self, value): +        self._options = value +        self._ctx.set_options(value) + +    @property +    def verify_mode(self): +        return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()] + +    @verify_mode.setter +    def verify_mode(self, value): +        self._ctx.set_verify( +            _stdlib_to_openssl_verify[value], +            _verify_callback +        ) + +    def set_default_verify_paths(self): +        self._ctx.set_default_verify_paths() + +    def set_ciphers(self, ciphers): +        if isinstance(ciphers, six.text_type): +            ciphers = ciphers.encode('utf-8') +        self._ctx.set_cipher_list(ciphers) + +    def load_verify_locations(self, cafile=None, capath=None, cadata=None): +        if cafile is not None: +            cafile = cafile.encode('utf-8') +        if capath is not None: +            capath = capath.encode('utf-8') +        self._ctx.load_verify_locations(cafile, capath) +        if cadata is not None: +            self._ctx.load_verify_locations(BytesIO(cadata)) + +    def load_cert_chain(self, certfile, keyfile=None, password=None): +        self._ctx.use_certificate_chain_file(certfile) +        if password is not None: +            self._ctx.set_passwd_cb(lambda max_length, prompt_twice, userdata: password) +        self._ctx.use_privatekey_file(keyfile or certfile) + +    def wrap_socket(self, sock, server_side=False, +                    do_handshake_on_connect=True, suppress_ragged_eofs=True, +                    server_hostname=None): +        cnx = OpenSSL.SSL.Connection(self._ctx, sock) + +        if isinstance(server_hostname, six.text_type):  # Platform-specific: Python 3 +            server_hostname = server_hostname.encode('utf-8') + +        if server_hostname is not None: +            cnx.set_tlsext_host_name(server_hostname) + +        cnx.set_connect_state() + +        while True: +            try: +                cnx.do_handshake() +            except OpenSSL.SSL.WantReadError: +                if not util.wait_for_read(sock, sock.gettimeout()): +                    raise timeout('select timed out') +                continue +            except OpenSSL.SSL.Error as e: +                raise ssl.SSLError('bad handshake: %r' % e) +            break + +        return WrappedSocket(cnx, sock) + + +def _verify_callback(cnx, x509, err_no, err_depth, return_code): +    return err_no == 0 diff --git a/python/urllib3/contrib/securetransport.py b/python/urllib3/contrib/securetransport.py new file mode 100644 index 0000000..77cb59e --- /dev/null +++ b/python/urllib3/contrib/securetransport.py @@ -0,0 +1,804 @@ +""" +SecureTranport support for urllib3 via ctypes. + +This makes platform-native TLS available to urllib3 users on macOS without the +use of a compiler. This is an important feature because the Python Package +Index is moving to become a TLSv1.2-or-higher server, and the default OpenSSL +that ships with macOS is not capable of doing TLSv1.2. The only way to resolve +this is to give macOS users an alternative solution to the problem, and that +solution is to use SecureTransport. + +We use ctypes here because this solution must not require a compiler. That's +because pip is not allowed to require a compiler either. + +This is not intended to be a seriously long-term solution to this problem. +The hope is that PEP 543 will eventually solve this issue for us, at which +point we can retire this contrib module. But in the short term, we need to +solve the impending tire fire that is Python on Mac without this kind of +contrib module. So...here we are. + +To use this module, simply import and inject it:: + +    import urllib3.contrib.securetransport +    urllib3.contrib.securetransport.inject_into_urllib3() + +Happy TLSing! +""" +from __future__ import absolute_import + +import contextlib +import ctypes +import errno +import os.path +import shutil +import socket +import ssl +import threading +import weakref + +from .. import util +from ._securetransport.bindings import ( +    Security, SecurityConst, CoreFoundation +) +from ._securetransport.low_level import ( +    _assert_no_error, _cert_array_from_pem, _temporary_keychain, +    _load_client_cert_chain +) + +try:  # Platform-specific: Python 2 +    from socket import _fileobject +except ImportError:  # Platform-specific: Python 3 +    _fileobject = None +    from ..packages.backports.makefile import backport_makefile + +__all__ = ['inject_into_urllib3', 'extract_from_urllib3'] + +# SNI always works +HAS_SNI = True + +orig_util_HAS_SNI = util.HAS_SNI +orig_util_SSLContext = util.ssl_.SSLContext + +# This dictionary is used by the read callback to obtain a handle to the +# calling wrapped socket. This is a pretty silly approach, but for now it'll +# do. I feel like I should be able to smuggle a handle to the wrapped socket +# directly in the SSLConnectionRef, but for now this approach will work I +# guess. +# +# We need to lock around this structure for inserts, but we don't do it for +# reads/writes in the callbacks. The reasoning here goes as follows: +# +#    1. It is not possible to call into the callbacks before the dictionary is +#       populated, so once in the callback the id must be in the dictionary. +#    2. The callbacks don't mutate the dictionary, they only read from it, and +#       so cannot conflict with any of the insertions. +# +# This is good: if we had to lock in the callbacks we'd drastically slow down +# the performance of this code. +_connection_refs = weakref.WeakValueDictionary() +_connection_ref_lock = threading.Lock() + +# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over +# for no better reason than we need *a* limit, and this one is right there. +SSL_WRITE_BLOCKSIZE = 16384 + +# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to +# individual cipher suites. We need to do this because this is how +# SecureTransport wants them. +CIPHER_SUITES = [ +    SecurityConst.TLS_AES_256_GCM_SHA384, +    SecurityConst.TLS_CHACHA20_POLY1305_SHA256, +    SecurityConst.TLS_AES_128_GCM_SHA256, +    SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, +    SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, +    SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, +    SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, +    SecurityConst.TLS_DHE_DSS_WITH_AES_256_GCM_SHA384, +    SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384, +    SecurityConst.TLS_DHE_DSS_WITH_AES_128_GCM_SHA256, +    SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +    SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, +    SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, +    SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, +    SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, +    SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, +    SecurityConst.TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, +    SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA, +    SecurityConst.TLS_DHE_DSS_WITH_AES_256_CBC_SHA, +    SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, +    SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, +    SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, +    SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, +    SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, +    SecurityConst.TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, +    SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, +    SecurityConst.TLS_DHE_DSS_WITH_AES_128_CBC_SHA, +    SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384, +    SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256, +    SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256, +    SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256, +    SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA, +    SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA, +] + +# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of +# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version. +_protocol_to_min_max = { +    ssl.PROTOCOL_SSLv23: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), +} + +if hasattr(ssl, "PROTOCOL_SSLv2"): +    _protocol_to_min_max[ssl.PROTOCOL_SSLv2] = ( +        SecurityConst.kSSLProtocol2, SecurityConst.kSSLProtocol2 +    ) +if hasattr(ssl, "PROTOCOL_SSLv3"): +    _protocol_to_min_max[ssl.PROTOCOL_SSLv3] = ( +        SecurityConst.kSSLProtocol3, SecurityConst.kSSLProtocol3 +    ) +if hasattr(ssl, "PROTOCOL_TLSv1"): +    _protocol_to_min_max[ssl.PROTOCOL_TLSv1] = ( +        SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol1 +    ) +if hasattr(ssl, "PROTOCOL_TLSv1_1"): +    _protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = ( +        SecurityConst.kTLSProtocol11, SecurityConst.kTLSProtocol11 +    ) +if hasattr(ssl, "PROTOCOL_TLSv1_2"): +    _protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = ( +        SecurityConst.kTLSProtocol12, SecurityConst.kTLSProtocol12 +    ) +if hasattr(ssl, "PROTOCOL_TLS"): +    _protocol_to_min_max[ssl.PROTOCOL_TLS] = _protocol_to_min_max[ssl.PROTOCOL_SSLv23] + + +def inject_into_urllib3(): +    """ +    Monkey-patch urllib3 with SecureTransport-backed SSL-support. +    """ +    util.ssl_.SSLContext = SecureTransportContext +    util.HAS_SNI = HAS_SNI +    util.ssl_.HAS_SNI = HAS_SNI +    util.IS_SECURETRANSPORT = True +    util.ssl_.IS_SECURETRANSPORT = True + + +def extract_from_urllib3(): +    """ +    Undo monkey-patching by :func:`inject_into_urllib3`. +    """ +    util.ssl_.SSLContext = orig_util_SSLContext +    util.HAS_SNI = orig_util_HAS_SNI +    util.ssl_.HAS_SNI = orig_util_HAS_SNI +    util.IS_SECURETRANSPORT = False +    util.ssl_.IS_SECURETRANSPORT = False + + +def _read_callback(connection_id, data_buffer, data_length_pointer): +    """ +    SecureTransport read callback. This is called by ST to request that data +    be returned from the socket. +    """ +    wrapped_socket = None +    try: +        wrapped_socket = _connection_refs.get(connection_id) +        if wrapped_socket is None: +            return SecurityConst.errSSLInternal +        base_socket = wrapped_socket.socket + +        requested_length = data_length_pointer[0] + +        timeout = wrapped_socket.gettimeout() +        error = None +        read_count = 0 + +        try: +            while read_count < requested_length: +                if timeout is None or timeout >= 0: +                    if not util.wait_for_read(base_socket, timeout): +                        raise socket.error(errno.EAGAIN, 'timed out') + +                remaining = requested_length - read_count +                buffer = (ctypes.c_char * remaining).from_address( +                    data_buffer + read_count +                ) +                chunk_size = base_socket.recv_into(buffer, remaining) +                read_count += chunk_size +                if not chunk_size: +                    if not read_count: +                        return SecurityConst.errSSLClosedGraceful +                    break +        except (socket.error) as e: +            error = e.errno + +            if error is not None and error != errno.EAGAIN: +                data_length_pointer[0] = read_count +                if error == errno.ECONNRESET or error == errno.EPIPE: +                    return SecurityConst.errSSLClosedAbort +                raise + +        data_length_pointer[0] = read_count + +        if read_count != requested_length: +            return SecurityConst.errSSLWouldBlock + +        return 0 +    except Exception as e: +        if wrapped_socket is not None: +            wrapped_socket._exception = e +        return SecurityConst.errSSLInternal + + +def _write_callback(connection_id, data_buffer, data_length_pointer): +    """ +    SecureTransport write callback. This is called by ST to request that data +    actually be sent on the network. +    """ +    wrapped_socket = None +    try: +        wrapped_socket = _connection_refs.get(connection_id) +        if wrapped_socket is None: +            return SecurityConst.errSSLInternal +        base_socket = wrapped_socket.socket + +        bytes_to_write = data_length_pointer[0] +        data = ctypes.string_at(data_buffer, bytes_to_write) + +        timeout = wrapped_socket.gettimeout() +        error = None +        sent = 0 + +        try: +            while sent < bytes_to_write: +                if timeout is None or timeout >= 0: +                    if not util.wait_for_write(base_socket, timeout): +                        raise socket.error(errno.EAGAIN, 'timed out') +                chunk_sent = base_socket.send(data) +                sent += chunk_sent + +                # This has some needless copying here, but I'm not sure there's +                # much value in optimising this data path. +                data = data[chunk_sent:] +        except (socket.error) as e: +            error = e.errno + +            if error is not None and error != errno.EAGAIN: +                data_length_pointer[0] = sent +                if error == errno.ECONNRESET or error == errno.EPIPE: +                    return SecurityConst.errSSLClosedAbort +                raise + +        data_length_pointer[0] = sent + +        if sent != bytes_to_write: +            return SecurityConst.errSSLWouldBlock + +        return 0 +    except Exception as e: +        if wrapped_socket is not None: +            wrapped_socket._exception = e +        return SecurityConst.errSSLInternal + + +# We need to keep these two objects references alive: if they get GC'd while +# in use then SecureTransport could attempt to call a function that is in freed +# memory. That would be...uh...bad. Yeah, that's the word. Bad. +_read_callback_pointer = Security.SSLReadFunc(_read_callback) +_write_callback_pointer = Security.SSLWriteFunc(_write_callback) + + +class WrappedSocket(object): +    """ +    API-compatibility wrapper for Python's OpenSSL wrapped socket object. + +    Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage +    collector of PyPy. +    """ +    def __init__(self, socket): +        self.socket = socket +        self.context = None +        self._makefile_refs = 0 +        self._closed = False +        self._exception = None +        self._keychain = None +        self._keychain_dir = None +        self._client_cert_chain = None + +        # We save off the previously-configured timeout and then set it to +        # zero. This is done because we use select and friends to handle the +        # timeouts, but if we leave the timeout set on the lower socket then +        # Python will "kindly" call select on that socket again for us. Avoid +        # that by forcing the timeout to zero. +        self._timeout = self.socket.gettimeout() +        self.socket.settimeout(0) + +    @contextlib.contextmanager +    def _raise_on_error(self): +        """ +        A context manager that can be used to wrap calls that do I/O from +        SecureTransport. If any of the I/O callbacks hit an exception, this +        context manager will correctly propagate the exception after the fact. +        This avoids silently swallowing those exceptions. + +        It also correctly forces the socket closed. +        """ +        self._exception = None + +        # We explicitly don't catch around this yield because in the unlikely +        # event that an exception was hit in the block we don't want to swallow +        # it. +        yield +        if self._exception is not None: +            exception, self._exception = self._exception, None +            self.close() +            raise exception + +    def _set_ciphers(self): +        """ +        Sets up the allowed ciphers. By default this matches the set in +        util.ssl_.DEFAULT_CIPHERS, at least as supported by macOS. This is done +        custom and doesn't allow changing at this time, mostly because parsing +        OpenSSL cipher strings is going to be a freaking nightmare. +        """ +        ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES) +        result = Security.SSLSetEnabledCiphers( +            self.context, ciphers, len(CIPHER_SUITES) +        ) +        _assert_no_error(result) + +    def _custom_validate(self, verify, trust_bundle): +        """ +        Called when we have set custom validation. We do this in two cases: +        first, when cert validation is entirely disabled; and second, when +        using a custom trust DB. +        """ +        # If we disabled cert validation, just say: cool. +        if not verify: +            return + +        # We want data in memory, so load it up. +        if os.path.isfile(trust_bundle): +            with open(trust_bundle, 'rb') as f: +                trust_bundle = f.read() + +        cert_array = None +        trust = Security.SecTrustRef() + +        try: +            # Get a CFArray that contains the certs we want. +            cert_array = _cert_array_from_pem(trust_bundle) + +            # Ok, now the hard part. We want to get the SecTrustRef that ST has +            # created for this connection, shove our CAs into it, tell ST to +            # ignore everything else it knows, and then ask if it can build a +            # chain. This is a buuuunch of code. +            result = Security.SSLCopyPeerTrust( +                self.context, ctypes.byref(trust) +            ) +            _assert_no_error(result) +            if not trust: +                raise ssl.SSLError("Failed to copy trust reference") + +            result = Security.SecTrustSetAnchorCertificates(trust, cert_array) +            _assert_no_error(result) + +            result = Security.SecTrustSetAnchorCertificatesOnly(trust, True) +            _assert_no_error(result) + +            trust_result = Security.SecTrustResultType() +            result = Security.SecTrustEvaluate( +                trust, ctypes.byref(trust_result) +            ) +            _assert_no_error(result) +        finally: +            if trust: +                CoreFoundation.CFRelease(trust) + +            if cert_array is not None: +                CoreFoundation.CFRelease(cert_array) + +        # Ok, now we can look at what the result was. +        successes = ( +            SecurityConst.kSecTrustResultUnspecified, +            SecurityConst.kSecTrustResultProceed +        ) +        if trust_result.value not in successes: +            raise ssl.SSLError( +                "certificate verify failed, error code: %d" % +                trust_result.value +            ) + +    def handshake(self, +                  server_hostname, +                  verify, +                  trust_bundle, +                  min_version, +                  max_version, +                  client_cert, +                  client_key, +                  client_key_passphrase): +        """ +        Actually performs the TLS handshake. This is run automatically by +        wrapped socket, and shouldn't be needed in user code. +        """ +        # First, we do the initial bits of connection setup. We need to create +        # a context, set its I/O funcs, and set the connection reference. +        self.context = Security.SSLCreateContext( +            None, SecurityConst.kSSLClientSide, SecurityConst.kSSLStreamType +        ) +        result = Security.SSLSetIOFuncs( +            self.context, _read_callback_pointer, _write_callback_pointer +        ) +        _assert_no_error(result) + +        # Here we need to compute the handle to use. We do this by taking the +        # id of self modulo 2**31 - 1. If this is already in the dictionary, we +        # just keep incrementing by one until we find a free space. +        with _connection_ref_lock: +            handle = id(self) % 2147483647 +            while handle in _connection_refs: +                handle = (handle + 1) % 2147483647 +            _connection_refs[handle] = self + +        result = Security.SSLSetConnection(self.context, handle) +        _assert_no_error(result) + +        # If we have a server hostname, we should set that too. +        if server_hostname: +            if not isinstance(server_hostname, bytes): +                server_hostname = server_hostname.encode('utf-8') + +            result = Security.SSLSetPeerDomainName( +                self.context, server_hostname, len(server_hostname) +            ) +            _assert_no_error(result) + +        # Setup the ciphers. +        self._set_ciphers() + +        # Set the minimum and maximum TLS versions. +        result = Security.SSLSetProtocolVersionMin(self.context, min_version) +        _assert_no_error(result) +        result = Security.SSLSetProtocolVersionMax(self.context, max_version) +        _assert_no_error(result) + +        # If there's a trust DB, we need to use it. We do that by telling +        # SecureTransport to break on server auth. We also do that if we don't +        # want to validate the certs at all: we just won't actually do any +        # authing in that case. +        if not verify or trust_bundle is not None: +            result = Security.SSLSetSessionOption( +                self.context, +                SecurityConst.kSSLSessionOptionBreakOnServerAuth, +                True +            ) +            _assert_no_error(result) + +        # If there's a client cert, we need to use it. +        if client_cert: +            self._keychain, self._keychain_dir = _temporary_keychain() +            self._client_cert_chain = _load_client_cert_chain( +                self._keychain, client_cert, client_key +            ) +            result = Security.SSLSetCertificate( +                self.context, self._client_cert_chain +            ) +            _assert_no_error(result) + +        while True: +            with self._raise_on_error(): +                result = Security.SSLHandshake(self.context) + +                if result == SecurityConst.errSSLWouldBlock: +                    raise socket.timeout("handshake timed out") +                elif result == SecurityConst.errSSLServerAuthCompleted: +                    self._custom_validate(verify, trust_bundle) +                    continue +                else: +                    _assert_no_error(result) +                    break + +    def fileno(self): +        return self.socket.fileno() + +    # Copy-pasted from Python 3.5 source code +    def _decref_socketios(self): +        if self._makefile_refs > 0: +            self._makefile_refs -= 1 +        if self._closed: +            self.close() + +    def recv(self, bufsiz): +        buffer = ctypes.create_string_buffer(bufsiz) +        bytes_read = self.recv_into(buffer, bufsiz) +        data = buffer[:bytes_read] +        return data + +    def recv_into(self, buffer, nbytes=None): +        # Read short on EOF. +        if self._closed: +            return 0 + +        if nbytes is None: +            nbytes = len(buffer) + +        buffer = (ctypes.c_char * nbytes).from_buffer(buffer) +        processed_bytes = ctypes.c_size_t(0) + +        with self._raise_on_error(): +            result = Security.SSLRead( +                self.context, buffer, nbytes, ctypes.byref(processed_bytes) +            ) + +        # There are some result codes that we want to treat as "not always +        # errors". Specifically, those are errSSLWouldBlock, +        # errSSLClosedGraceful, and errSSLClosedNoNotify. +        if (result == SecurityConst.errSSLWouldBlock): +            # If we didn't process any bytes, then this was just a time out. +            # However, we can get errSSLWouldBlock in situations when we *did* +            # read some data, and in those cases we should just read "short" +            # and return. +            if processed_bytes.value == 0: +                # Timed out, no data read. +                raise socket.timeout("recv timed out") +        elif result in (SecurityConst.errSSLClosedGraceful, SecurityConst.errSSLClosedNoNotify): +            # The remote peer has closed this connection. We should do so as +            # well. Note that we don't actually return here because in +            # principle this could actually be fired along with return data. +            # It's unlikely though. +            self.close() +        else: +            _assert_no_error(result) + +        # Ok, we read and probably succeeded. We should return whatever data +        # was actually read. +        return processed_bytes.value + +    def settimeout(self, timeout): +        self._timeout = timeout + +    def gettimeout(self): +        return self._timeout + +    def send(self, data): +        processed_bytes = ctypes.c_size_t(0) + +        with self._raise_on_error(): +            result = Security.SSLWrite( +                self.context, data, len(data), ctypes.byref(processed_bytes) +            ) + +        if result == SecurityConst.errSSLWouldBlock and processed_bytes.value == 0: +            # Timed out +            raise socket.timeout("send timed out") +        else: +            _assert_no_error(result) + +        # We sent, and probably succeeded. Tell them how much we sent. +        return processed_bytes.value + +    def sendall(self, data): +        total_sent = 0 +        while total_sent < len(data): +            sent = self.send(data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE]) +            total_sent += sent + +    def shutdown(self): +        with self._raise_on_error(): +            Security.SSLClose(self.context) + +    def close(self): +        # TODO: should I do clean shutdown here? Do I have to? +        if self._makefile_refs < 1: +            self._closed = True +            if self.context: +                CoreFoundation.CFRelease(self.context) +                self.context = None +            if self._client_cert_chain: +                CoreFoundation.CFRelease(self._client_cert_chain) +                self._client_cert_chain = None +            if self._keychain: +                Security.SecKeychainDelete(self._keychain) +                CoreFoundation.CFRelease(self._keychain) +                shutil.rmtree(self._keychain_dir) +                self._keychain = self._keychain_dir = None +            return self.socket.close() +        else: +            self._makefile_refs -= 1 + +    def getpeercert(self, binary_form=False): +        # Urgh, annoying. +        # +        # Here's how we do this: +        # +        # 1. Call SSLCopyPeerTrust to get hold of the trust object for this +        #    connection. +        # 2. Call SecTrustGetCertificateAtIndex for index 0 to get the leaf. +        # 3. To get the CN, call SecCertificateCopyCommonName and process that +        #    string so that it's of the appropriate type. +        # 4. To get the SAN, we need to do something a bit more complex: +        #    a. Call SecCertificateCopyValues to get the data, requesting +        #       kSecOIDSubjectAltName. +        #    b. Mess about with this dictionary to try to get the SANs out. +        # +        # This is gross. Really gross. It's going to be a few hundred LoC extra +        # just to repeat something that SecureTransport can *already do*. So my +        # operating assumption at this time is that what we want to do is +        # instead to just flag to urllib3 that it shouldn't do its own hostname +        # validation when using SecureTransport. +        if not binary_form: +            raise ValueError( +                "SecureTransport only supports dumping binary certs" +            ) +        trust = Security.SecTrustRef() +        certdata = None +        der_bytes = None + +        try: +            # Grab the trust store. +            result = Security.SSLCopyPeerTrust( +                self.context, ctypes.byref(trust) +            ) +            _assert_no_error(result) +            if not trust: +                # Probably we haven't done the handshake yet. No biggie. +                return None + +            cert_count = Security.SecTrustGetCertificateCount(trust) +            if not cert_count: +                # Also a case that might happen if we haven't handshaked. +                # Handshook? Handshaken? +                return None + +            leaf = Security.SecTrustGetCertificateAtIndex(trust, 0) +            assert leaf + +            # Ok, now we want the DER bytes. +            certdata = Security.SecCertificateCopyData(leaf) +            assert certdata + +            data_length = CoreFoundation.CFDataGetLength(certdata) +            data_buffer = CoreFoundation.CFDataGetBytePtr(certdata) +            der_bytes = ctypes.string_at(data_buffer, data_length) +        finally: +            if certdata: +                CoreFoundation.CFRelease(certdata) +            if trust: +                CoreFoundation.CFRelease(trust) + +        return der_bytes + +    def _reuse(self): +        self._makefile_refs += 1 + +    def _drop(self): +        if self._makefile_refs < 1: +            self.close() +        else: +            self._makefile_refs -= 1 + + +if _fileobject:  # Platform-specific: Python 2 +    def makefile(self, mode, bufsize=-1): +        self._makefile_refs += 1 +        return _fileobject(self, mode, bufsize, close=True) +else:  # Platform-specific: Python 3 +    def makefile(self, mode="r", buffering=None, *args, **kwargs): +        # We disable buffering with SecureTransport because it conflicts with +        # the buffering that ST does internally (see issue #1153 for more). +        buffering = 0 +        return backport_makefile(self, mode, buffering, *args, **kwargs) + +WrappedSocket.makefile = makefile + + +class SecureTransportContext(object): +    """ +    I am a wrapper class for the SecureTransport library, to translate the +    interface of the standard library ``SSLContext`` object to calls into +    SecureTransport. +    """ +    def __init__(self, protocol): +        self._min_version, self._max_version = _protocol_to_min_max[protocol] +        self._options = 0 +        self._verify = False +        self._trust_bundle = None +        self._client_cert = None +        self._client_key = None +        self._client_key_passphrase = None + +    @property +    def check_hostname(self): +        """ +        SecureTransport cannot have its hostname checking disabled. For more, +        see the comment on getpeercert() in this file. +        """ +        return True + +    @check_hostname.setter +    def check_hostname(self, value): +        """ +        SecureTransport cannot have its hostname checking disabled. For more, +        see the comment on getpeercert() in this file. +        """ +        pass + +    @property +    def options(self): +        # TODO: Well, crap. +        # +        # So this is the bit of the code that is the most likely to cause us +        # trouble. Essentially we need to enumerate all of the SSL options that +        # users might want to use and try to see if we can sensibly translate +        # them, or whether we should just ignore them. +        return self._options + +    @options.setter +    def options(self, value): +        # TODO: Update in line with above. +        self._options = value + +    @property +    def verify_mode(self): +        return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE + +    @verify_mode.setter +    def verify_mode(self, value): +        self._verify = True if value == ssl.CERT_REQUIRED else False + +    def set_default_verify_paths(self): +        # So, this has to do something a bit weird. Specifically, what it does +        # is nothing. +        # +        # This means that, if we had previously had load_verify_locations +        # called, this does not undo that. We need to do that because it turns +        # out that the rest of the urllib3 code will attempt to load the +        # default verify paths if it hasn't been told about any paths, even if +        # the context itself was sometime earlier. We resolve that by just +        # ignoring it. +        pass + +    def load_default_certs(self): +        return self.set_default_verify_paths() + +    def set_ciphers(self, ciphers): +        # For now, we just require the default cipher string. +        if ciphers != util.ssl_.DEFAULT_CIPHERS: +            raise ValueError( +                "SecureTransport doesn't support custom cipher strings" +            ) + +    def load_verify_locations(self, cafile=None, capath=None, cadata=None): +        # OK, we only really support cadata and cafile. +        if capath is not None: +            raise ValueError( +                "SecureTransport does not support cert directories" +            ) + +        self._trust_bundle = cafile or cadata + +    def load_cert_chain(self, certfile, keyfile=None, password=None): +        self._client_cert = certfile +        self._client_key = keyfile +        self._client_cert_passphrase = password + +    def wrap_socket(self, sock, server_side=False, +                    do_handshake_on_connect=True, suppress_ragged_eofs=True, +                    server_hostname=None): +        # So, what do we do here? Firstly, we assert some properties. This is a +        # stripped down shim, so there is some functionality we don't support. +        # See PEP 543 for the real deal. +        assert not server_side +        assert do_handshake_on_connect +        assert suppress_ragged_eofs + +        # Ok, we're good to go. Now we want to create the wrapped socket object +        # and store it in the appropriate place. +        wrapped_socket = WrappedSocket(sock) + +        # Now we can handshake +        wrapped_socket.handshake( +            server_hostname, self._verify, self._trust_bundle, +            self._min_version, self._max_version, self._client_cert, +            self._client_key, self._client_key_passphrase +        ) +        return wrapped_socket diff --git a/python/urllib3/contrib/socks.py b/python/urllib3/contrib/socks.py new file mode 100644 index 0000000..811e312 --- /dev/null +++ b/python/urllib3/contrib/socks.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +""" +This module contains provisional support for SOCKS proxies from within +urllib3. This module supports SOCKS4 (specifically the SOCKS4A variant) and +SOCKS5. To enable its functionality, either install PySocks or install this +module with the ``socks`` extra. + +The SOCKS implementation supports the full range of urllib3 features. It also +supports the following SOCKS features: + +- SOCKS4 +- SOCKS4a +- SOCKS5 +- Usernames and passwords for the SOCKS proxy + +Known Limitations: + +- Currently PySocks does not support contacting remote websites via literal +  IPv6 addresses. Any such connection attempt will fail. You must use a domain +  name. +- Currently PySocks does not support IPv6 connections to the SOCKS proxy. Any +  such connection attempt will fail. +""" +from __future__ import absolute_import + +try: +    import socks +except ImportError: +    import warnings +    from ..exceptions import DependencyWarning + +    warnings.warn(( +        'SOCKS support in urllib3 requires the installation of optional ' +        'dependencies: specifically, PySocks.  For more information, see ' +        'https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies' +        ), +        DependencyWarning +    ) +    raise + +from socket import error as SocketError, timeout as SocketTimeout + +from ..connection import ( +    HTTPConnection, HTTPSConnection +) +from ..connectionpool import ( +    HTTPConnectionPool, HTTPSConnectionPool +) +from ..exceptions import ConnectTimeoutError, NewConnectionError +from ..poolmanager import PoolManager +from ..util.url import parse_url + +try: +    import ssl +except ImportError: +    ssl = None + + +class SOCKSConnection(HTTPConnection): +    """ +    A plain-text HTTP connection that connects via a SOCKS proxy. +    """ +    def __init__(self, *args, **kwargs): +        self._socks_options = kwargs.pop('_socks_options') +        super(SOCKSConnection, self).__init__(*args, **kwargs) + +    def _new_conn(self): +        """ +        Establish a new connection via the SOCKS proxy. +        """ +        extra_kw = {} +        if self.source_address: +            extra_kw['source_address'] = self.source_address + +        if self.socket_options: +            extra_kw['socket_options'] = self.socket_options + +        try: +            conn = socks.create_connection( +                (self.host, self.port), +                proxy_type=self._socks_options['socks_version'], +                proxy_addr=self._socks_options['proxy_host'], +                proxy_port=self._socks_options['proxy_port'], +                proxy_username=self._socks_options['username'], +                proxy_password=self._socks_options['password'], +                proxy_rdns=self._socks_options['rdns'], +                timeout=self.timeout, +                **extra_kw +            ) + +        except SocketTimeout as e: +            raise ConnectTimeoutError( +                self, "Connection to %s timed out. (connect timeout=%s)" % +                (self.host, self.timeout)) + +        except socks.ProxyError as e: +            # This is fragile as hell, but it seems to be the only way to raise +            # useful errors here. +            if e.socket_err: +                error = e.socket_err +                if isinstance(error, SocketTimeout): +                    raise ConnectTimeoutError( +                        self, +                        "Connection to %s timed out. (connect timeout=%s)" % +                        (self.host, self.timeout) +                    ) +                else: +                    raise NewConnectionError( +                        self, +                        "Failed to establish a new connection: %s" % error +                    ) +            else: +                raise NewConnectionError( +                    self, +                    "Failed to establish a new connection: %s" % e +                ) + +        except SocketError as e:  # Defensive: PySocks should catch all these. +            raise NewConnectionError( +                self, "Failed to establish a new connection: %s" % e) + +        return conn + + +# We don't need to duplicate the Verified/Unverified distinction from +# urllib3/connection.py here because the HTTPSConnection will already have been +# correctly set to either the Verified or Unverified form by that module. This +# means the SOCKSHTTPSConnection will automatically be the correct type. +class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection): +    pass + + +class SOCKSHTTPConnectionPool(HTTPConnectionPool): +    ConnectionCls = SOCKSConnection + + +class SOCKSHTTPSConnectionPool(HTTPSConnectionPool): +    ConnectionCls = SOCKSHTTPSConnection + + +class SOCKSProxyManager(PoolManager): +    """ +    A version of the urllib3 ProxyManager that routes connections via the +    defined SOCKS proxy. +    """ +    pool_classes_by_scheme = { +        'http': SOCKSHTTPConnectionPool, +        'https': SOCKSHTTPSConnectionPool, +    } + +    def __init__(self, proxy_url, username=None, password=None, +                 num_pools=10, headers=None, **connection_pool_kw): +        parsed = parse_url(proxy_url) + +        if username is None and password is None and parsed.auth is not None: +            split = parsed.auth.split(':') +            if len(split) == 2: +                username, password = split +        if parsed.scheme == 'socks5': +            socks_version = socks.PROXY_TYPE_SOCKS5 +            rdns = False +        elif parsed.scheme == 'socks5h': +            socks_version = socks.PROXY_TYPE_SOCKS5 +            rdns = True +        elif parsed.scheme == 'socks4': +            socks_version = socks.PROXY_TYPE_SOCKS4 +            rdns = False +        elif parsed.scheme == 'socks4a': +            socks_version = socks.PROXY_TYPE_SOCKS4 +            rdns = True +        else: +            raise ValueError( +                "Unable to determine SOCKS version from %s" % proxy_url +            ) + +        self.proxy_url = proxy_url + +        socks_options = { +            'socks_version': socks_version, +            'proxy_host': parsed.host, +            'proxy_port': parsed.port, +            'username': username, +            'password': password, +            'rdns': rdns +        } +        connection_pool_kw['_socks_options'] = socks_options + +        super(SOCKSProxyManager, self).__init__( +            num_pools, headers, **connection_pool_kw +        ) + +        self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme diff --git a/python/urllib3/exceptions.py b/python/urllib3/exceptions.py new file mode 100644 index 0000000..7bbaa98 --- /dev/null +++ b/python/urllib3/exceptions.py @@ -0,0 +1,246 @@ +from __future__ import absolute_import +from .packages.six.moves.http_client import ( +    IncompleteRead as httplib_IncompleteRead +) +# Base Exceptions + + +class HTTPError(Exception): +    "Base exception used by this module." +    pass + + +class HTTPWarning(Warning): +    "Base warning used by this module." +    pass + + +class PoolError(HTTPError): +    "Base exception for errors caused within a pool." +    def __init__(self, pool, message): +        self.pool = pool +        HTTPError.__init__(self, "%s: %s" % (pool, message)) + +    def __reduce__(self): +        # For pickling purposes. +        return self.__class__, (None, None) + + +class RequestError(PoolError): +    "Base exception for PoolErrors that have associated URLs." +    def __init__(self, pool, url, message): +        self.url = url +        PoolError.__init__(self, pool, message) + +    def __reduce__(self): +        # For pickling purposes. +        return self.__class__, (None, self.url, None) + + +class SSLError(HTTPError): +    "Raised when SSL certificate fails in an HTTPS connection." +    pass + + +class ProxyError(HTTPError): +    "Raised when the connection to a proxy fails." +    pass + + +class DecodeError(HTTPError): +    "Raised when automatic decoding based on Content-Type fails." +    pass + + +class ProtocolError(HTTPError): +    "Raised when something unexpected happens mid-request/response." +    pass + + +#: Renamed to ProtocolError but aliased for backwards compatibility. +ConnectionError = ProtocolError + + +# Leaf Exceptions + +class MaxRetryError(RequestError): +    """Raised when the maximum number of retries is exceeded. + +    :param pool: The connection pool +    :type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool` +    :param string url: The requested Url +    :param exceptions.Exception reason: The underlying error + +    """ + +    def __init__(self, pool, url, reason=None): +        self.reason = reason + +        message = "Max retries exceeded with url: %s (Caused by %r)" % ( +            url, reason) + +        RequestError.__init__(self, pool, url, message) + + +class HostChangedError(RequestError): +    "Raised when an existing pool gets a request for a foreign host." + +    def __init__(self, pool, url, retries=3): +        message = "Tried to open a foreign host with url: %s" % url +        RequestError.__init__(self, pool, url, message) +        self.retries = retries + + +class TimeoutStateError(HTTPError): +    """ Raised when passing an invalid state to a timeout """ +    pass + + +class TimeoutError(HTTPError): +    """ Raised when a socket timeout error occurs. + +    Catching this error will catch both :exc:`ReadTimeoutErrors +    <ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`. +    """ +    pass + + +class ReadTimeoutError(TimeoutError, RequestError): +    "Raised when a socket timeout occurs while receiving data from a server" +    pass + + +# This timeout error does not have a URL attached and needs to inherit from the +# base HTTPError +class ConnectTimeoutError(TimeoutError): +    "Raised when a socket timeout occurs while connecting to a server" +    pass + + +class NewConnectionError(ConnectTimeoutError, PoolError): +    "Raised when we fail to establish a new connection. Usually ECONNREFUSED." +    pass + + +class EmptyPoolError(PoolError): +    "Raised when a pool runs out of connections and no more are allowed." +    pass + + +class ClosedPoolError(PoolError): +    "Raised when a request enters a pool after the pool has been closed." +    pass + + +class LocationValueError(ValueError, HTTPError): +    "Raised when there is something wrong with a given URL input." +    pass + + +class LocationParseError(LocationValueError): +    "Raised when get_host or similar fails to parse the URL input." + +    def __init__(self, location): +        message = "Failed to parse: %s" % location +        HTTPError.__init__(self, message) + +        self.location = location + + +class ResponseError(HTTPError): +    "Used as a container for an error reason supplied in a MaxRetryError." +    GENERIC_ERROR = 'too many error responses' +    SPECIFIC_ERROR = 'too many {status_code} error responses' + + +class SecurityWarning(HTTPWarning): +    "Warned when performing security reducing actions" +    pass + + +class SubjectAltNameWarning(SecurityWarning): +    "Warned when connecting to a host with a certificate missing a SAN." +    pass + + +class InsecureRequestWarning(SecurityWarning): +    "Warned when making an unverified HTTPS request." +    pass + + +class SystemTimeWarning(SecurityWarning): +    "Warned when system time is suspected to be wrong" +    pass + + +class InsecurePlatformWarning(SecurityWarning): +    "Warned when certain SSL configuration is not available on a platform." +    pass + + +class SNIMissingWarning(HTTPWarning): +    "Warned when making a HTTPS request without SNI available." +    pass + + +class DependencyWarning(HTTPWarning): +    """ +    Warned when an attempt is made to import a module with missing optional +    dependencies. +    """ +    pass + + +class ResponseNotChunked(ProtocolError, ValueError): +    "Response needs to be chunked in order to read it as chunks." +    pass + + +class BodyNotHttplibCompatible(HTTPError): +    """ +    Body should be httplib.HTTPResponse like (have an fp attribute which +    returns raw chunks) for read_chunked(). +    """ +    pass + + +class IncompleteRead(HTTPError, httplib_IncompleteRead): +    """ +    Response length doesn't match expected Content-Length + +    Subclass of http_client.IncompleteRead to allow int value +    for `partial` to avoid creating large objects on streamed +    reads. +    """ +    def __init__(self, partial, expected): +        super(IncompleteRead, self).__init__(partial, expected) + +    def __repr__(self): +        return ('IncompleteRead(%i bytes read, ' +                '%i more expected)' % (self.partial, self.expected)) + + +class InvalidHeader(HTTPError): +    "The header provided was somehow invalid." +    pass + + +class ProxySchemeUnknown(AssertionError, ValueError): +    "ProxyManager does not support the supplied scheme" +    # TODO(t-8ch): Stop inheriting from AssertionError in v2.0. + +    def __init__(self, scheme): +        message = "Not supported proxy scheme %s" % scheme +        super(ProxySchemeUnknown, self).__init__(message) + + +class HeaderParsingError(HTTPError): +    "Raised by assert_header_parsing, but we convert it to a log.warning statement." +    def __init__(self, defects, unparsed_data): +        message = '%s, unparsed data: %r' % (defects or 'Unknown', unparsed_data) +        super(HeaderParsingError, self).__init__(message) + + +class UnrewindableBodyError(HTTPError): +    "urllib3 encountered an error when trying to rewind a body" +    pass diff --git a/python/urllib3/fields.py b/python/urllib3/fields.py new file mode 100644 index 0000000..37fe64a --- /dev/null +++ b/python/urllib3/fields.py @@ -0,0 +1,178 @@ +from __future__ import absolute_import +import email.utils +import mimetypes + +from .packages import six + + +def guess_content_type(filename, default='application/octet-stream'): +    """ +    Guess the "Content-Type" of a file. + +    :param filename: +        The filename to guess the "Content-Type" of using :mod:`mimetypes`. +    :param default: +        If no "Content-Type" can be guessed, default to `default`. +    """ +    if filename: +        return mimetypes.guess_type(filename)[0] or default +    return default + + +def format_header_param(name, value): +    """ +    Helper function to format and quote a single header parameter. + +    Particularly useful for header parameters which might contain +    non-ASCII values, like file names. This follows RFC 2231, as +    suggested by RFC 2388 Section 4.4. + +    :param name: +        The name of the parameter, a string expected to be ASCII only. +    :param value: +        The value of the parameter, provided as a unicode string. +    """ +    if not any(ch in value for ch in '"\\\r\n'): +        result = '%s="%s"' % (name, value) +        try: +            result.encode('ascii') +        except (UnicodeEncodeError, UnicodeDecodeError): +            pass +        else: +            return result +    if not six.PY3 and isinstance(value, six.text_type):  # Python 2: +        value = value.encode('utf-8') +    value = email.utils.encode_rfc2231(value, 'utf-8') +    value = '%s*=%s' % (name, value) +    return value + + +class RequestField(object): +    """ +    A data container for request body parameters. + +    :param name: +        The name of this request field. +    :param data: +        The data/value body. +    :param filename: +        An optional filename of the request field. +    :param headers: +        An optional dict-like object of headers to initially use for the field. +    """ +    def __init__(self, name, data, filename=None, headers=None): +        self._name = name +        self._filename = filename +        self.data = data +        self.headers = {} +        if headers: +            self.headers = dict(headers) + +    @classmethod +    def from_tuples(cls, fieldname, value): +        """ +        A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters. + +        Supports constructing :class:`~urllib3.fields.RequestField` from +        parameter of key/value strings AND key/filetuple. A filetuple is a +        (filename, data, MIME type) tuple where the MIME type is optional. +        For example:: + +            'foo': 'bar', +            'fakefile': ('foofile.txt', 'contents of foofile'), +            'realfile': ('barfile.txt', open('realfile').read()), +            'typedfile': ('bazfile.bin', open('bazfile').read(), 'image/jpeg'), +            'nonamefile': 'contents of nonamefile field', + +        Field names and filenames must be unicode. +        """ +        if isinstance(value, tuple): +            if len(value) == 3: +                filename, data, content_type = value +            else: +                filename, data = value +                content_type = guess_content_type(filename) +        else: +            filename = None +            content_type = None +            data = value + +        request_param = cls(fieldname, data, filename=filename) +        request_param.make_multipart(content_type=content_type) + +        return request_param + +    def _render_part(self, name, value): +        """ +        Overridable helper function to format a single header parameter. + +        :param name: +            The name of the parameter, a string expected to be ASCII only. +        :param value: +            The value of the parameter, provided as a unicode string. +        """ +        return format_header_param(name, value) + +    def _render_parts(self, header_parts): +        """ +        Helper function to format and quote a single header. + +        Useful for single headers that are composed of multiple items. E.g., +        'Content-Disposition' fields. + +        :param header_parts: +            A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format +            as `k1="v1"; k2="v2"; ...`. +        """ +        parts = [] +        iterable = header_parts +        if isinstance(header_parts, dict): +            iterable = header_parts.items() + +        for name, value in iterable: +            if value is not None: +                parts.append(self._render_part(name, value)) + +        return '; '.join(parts) + +    def render_headers(self): +        """ +        Renders the headers for this request field. +        """ +        lines = [] + +        sort_keys = ['Content-Disposition', 'Content-Type', 'Content-Location'] +        for sort_key in sort_keys: +            if self.headers.get(sort_key, False): +                lines.append('%s: %s' % (sort_key, self.headers[sort_key])) + +        for header_name, header_value in self.headers.items(): +            if header_name not in sort_keys: +                if header_value: +                    lines.append('%s: %s' % (header_name, header_value)) + +        lines.append('\r\n') +        return '\r\n'.join(lines) + +    def make_multipart(self, content_disposition=None, content_type=None, +                       content_location=None): +        """ +        Makes this request field into a multipart request field. + +        This method overrides "Content-Disposition", "Content-Type" and +        "Content-Location" headers to the request parameter. + +        :param content_type: +            The 'Content-Type' of the request body. +        :param content_location: +            The 'Content-Location' of the request body. + +        """ +        self.headers['Content-Disposition'] = content_disposition or 'form-data' +        self.headers['Content-Disposition'] += '; '.join([ +            '', self._render_parts( +                (('name', self._name), ('filename', self._filename)) +            ) +        ]) +        self.headers['Content-Type'] = content_type +        self.headers['Content-Location'] = content_location diff --git a/python/urllib3/filepost.py b/python/urllib3/filepost.py new file mode 100644 index 0000000..78f1e19 --- /dev/null +++ b/python/urllib3/filepost.py @@ -0,0 +1,98 @@ +from __future__ import absolute_import +import binascii +import codecs +import os + +from io import BytesIO + +from .packages import six +from .packages.six import b +from .fields import RequestField + +writer = codecs.lookup('utf-8')[3] + + +def choose_boundary(): +    """ +    Our embarrassingly-simple replacement for mimetools.choose_boundary. +    """ +    boundary = binascii.hexlify(os.urandom(16)) +    if six.PY3: +        boundary = boundary.decode('ascii') +    return boundary + + +def iter_field_objects(fields): +    """ +    Iterate over fields. + +    Supports list of (k, v) tuples and dicts, and lists of +    :class:`~urllib3.fields.RequestField`. + +    """ +    if isinstance(fields, dict): +        i = six.iteritems(fields) +    else: +        i = iter(fields) + +    for field in i: +        if isinstance(field, RequestField): +            yield field +        else: +            yield RequestField.from_tuples(*field) + + +def iter_fields(fields): +    """ +    .. deprecated:: 1.6 + +    Iterate over fields. + +    The addition of :class:`~urllib3.fields.RequestField` makes this function +    obsolete. Instead, use :func:`iter_field_objects`, which returns +    :class:`~urllib3.fields.RequestField` objects. + +    Supports list of (k, v) tuples and dicts. +    """ +    if isinstance(fields, dict): +        return ((k, v) for k, v in six.iteritems(fields)) + +    return ((k, v) for k, v in fields) + + +def encode_multipart_formdata(fields, boundary=None): +    """ +    Encode a dictionary of ``fields`` using the multipart/form-data MIME format. + +    :param fields: +        Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`). + +    :param boundary: +        If not specified, then a random boundary will be generated using +        :func:`urllib3.filepost.choose_boundary`. +    """ +    body = BytesIO() +    if boundary is None: +        boundary = choose_boundary() + +    for field in iter_field_objects(fields): +        body.write(b('--%s\r\n' % (boundary))) + +        writer(body).write(field.render_headers()) +        data = field.data + +        if isinstance(data, int): +            data = str(data)  # Backwards compatibility + +        if isinstance(data, six.text_type): +            writer(body).write(data) +        else: +            body.write(data) + +        body.write(b'\r\n') + +    body.write(b('--%s--\r\n' % (boundary))) + +    content_type = str('multipart/form-data; boundary=%s' % boundary) + +    return body.getvalue(), content_type diff --git a/python/urllib3/packages/__init__.py b/python/urllib3/packages/__init__.py new file mode 100644 index 0000000..170e974 --- /dev/null +++ b/python/urllib3/packages/__init__.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import + +from . import ssl_match_hostname + +__all__ = ('ssl_match_hostname', ) diff --git a/python/urllib3/packages/backports/__init__.py b/python/urllib3/packages/backports/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/python/urllib3/packages/backports/__init__.py diff --git a/python/urllib3/packages/backports/makefile.py b/python/urllib3/packages/backports/makefile.py new file mode 100644 index 0000000..740db37 --- /dev/null +++ b/python/urllib3/packages/backports/makefile.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +""" +backports.makefile +~~~~~~~~~~~~~~~~~~ + +Backports the Python 3 ``socket.makefile`` method for use with anything that +wants to create a "fake" socket object. +""" +import io + +from socket import SocketIO + + +def backport_makefile(self, mode="r", buffering=None, encoding=None, +                      errors=None, newline=None): +    """ +    Backport of ``socket.makefile`` from Python 3.5. +    """ +    if not set(mode) <= {"r", "w", "b"}: +        raise ValueError( +            "invalid mode %r (only r, w, b allowed)" % (mode,) +        ) +    writing = "w" in mode +    reading = "r" in mode or not writing +    assert reading or writing +    binary = "b" in mode +    rawmode = "" +    if reading: +        rawmode += "r" +    if writing: +        rawmode += "w" +    raw = SocketIO(self, rawmode) +    self._makefile_refs += 1 +    if buffering is None: +        buffering = -1 +    if buffering < 0: +        buffering = io.DEFAULT_BUFFER_SIZE +    if buffering == 0: +        if not binary: +            raise ValueError("unbuffered streams must be binary") +        return raw +    if reading and writing: +        buffer = io.BufferedRWPair(raw, raw, buffering) +    elif reading: +        buffer = io.BufferedReader(raw, buffering) +    else: +        assert writing +        buffer = io.BufferedWriter(raw, buffering) +    if binary: +        return buffer +    text = io.TextIOWrapper(buffer, encoding, errors, newline) +    text.mode = mode +    return text diff --git a/python/urllib3/packages/six.py b/python/urllib3/packages/six.py new file mode 100644 index 0000000..190c023 --- /dev/null +++ b/python/urllib3/packages/six.py @@ -0,0 +1,868 @@ +"""Utilities for writing code that runs on Python 2 and 3""" + +# Copyright (c) 2010-2015 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import absolute_import + +import functools +import itertools +import operator +import sys +import types + +__author__ = "Benjamin Peterson <benjamin@python.org>" +__version__ = "1.10.0" + + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +PY34 = sys.version_info[0:2] >= (3, 4) + +if PY3: +    string_types = str, +    integer_types = int, +    class_types = type, +    text_type = str +    binary_type = bytes + +    MAXSIZE = sys.maxsize +else: +    string_types = basestring, +    integer_types = (int, long) +    class_types = (type, types.ClassType) +    text_type = unicode +    binary_type = str + +    if sys.platform.startswith("java"): +        # Jython always uses 32 bits. +        MAXSIZE = int((1 << 31) - 1) +    else: +        # It's possible to have sizeof(long) != sizeof(Py_ssize_t). +        class X(object): + +            def __len__(self): +                return 1 << 31 +        try: +            len(X()) +        except OverflowError: +            # 32-bit +            MAXSIZE = int((1 << 31) - 1) +        else: +            # 64-bit +            MAXSIZE = int((1 << 63) - 1) +        del X + + +def _add_doc(func, doc): +    """Add documentation to a function.""" +    func.__doc__ = doc + + +def _import_module(name): +    """Import module, returning the module after the last dot.""" +    __import__(name) +    return sys.modules[name] + + +class _LazyDescr(object): + +    def __init__(self, name): +        self.name = name + +    def __get__(self, obj, tp): +        result = self._resolve() +        setattr(obj, self.name, result)  # Invokes __set__. +        try: +            # This is a bit ugly, but it avoids running this again by +            # removing this descriptor. +            delattr(obj.__class__, self.name) +        except AttributeError: +            pass +        return result + + +class MovedModule(_LazyDescr): + +    def __init__(self, name, old, new=None): +        super(MovedModule, self).__init__(name) +        if PY3: +            if new is None: +                new = name +            self.mod = new +        else: +            self.mod = old + +    def _resolve(self): +        return _import_module(self.mod) + +    def __getattr__(self, attr): +        _module = self._resolve() +        value = getattr(_module, attr) +        setattr(self, attr, value) +        return value + + +class _LazyModule(types.ModuleType): + +    def __init__(self, name): +        super(_LazyModule, self).__init__(name) +        self.__doc__ = self.__class__.__doc__ + +    def __dir__(self): +        attrs = ["__doc__", "__name__"] +        attrs += [attr.name for attr in self._moved_attributes] +        return attrs + +    # Subclasses should override this +    _moved_attributes = [] + + +class MovedAttribute(_LazyDescr): + +    def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): +        super(MovedAttribute, self).__init__(name) +        if PY3: +            if new_mod is None: +                new_mod = name +            self.mod = new_mod +            if new_attr is None: +                if old_attr is None: +                    new_attr = name +                else: +                    new_attr = old_attr +            self.attr = new_attr +        else: +            self.mod = old_mod +            if old_attr is None: +                old_attr = name +            self.attr = old_attr + +    def _resolve(self): +        module = _import_module(self.mod) +        return getattr(module, self.attr) + + +class _SixMetaPathImporter(object): + +    """ +    A meta path importer to import six.moves and its submodules. + +    This class implements a PEP302 finder and loader. It should be compatible +    with Python 2.5 and all existing versions of Python3 +    """ + +    def __init__(self, six_module_name): +        self.name = six_module_name +        self.known_modules = {} + +    def _add_module(self, mod, *fullnames): +        for fullname in fullnames: +            self.known_modules[self.name + "." + fullname] = mod + +    def _get_module(self, fullname): +        return self.known_modules[self.name + "." + fullname] + +    def find_module(self, fullname, path=None): +        if fullname in self.known_modules: +            return self +        return None + +    def __get_module(self, fullname): +        try: +            return self.known_modules[fullname] +        except KeyError: +            raise ImportError("This loader does not know module " + fullname) + +    def load_module(self, fullname): +        try: +            # in case of a reload +            return sys.modules[fullname] +        except KeyError: +            pass +        mod = self.__get_module(fullname) +        if isinstance(mod, MovedModule): +            mod = mod._resolve() +        else: +            mod.__loader__ = self +        sys.modules[fullname] = mod +        return mod + +    def is_package(self, fullname): +        """ +        Return true, if the named module is a package. + +        We need this method to get correct spec objects with +        Python 3.4 (see PEP451) +        """ +        return hasattr(self.__get_module(fullname), "__path__") + +    def get_code(self, fullname): +        """Return None + +        Required, if is_package is implemented""" +        self.__get_module(fullname)  # eventually raises ImportError +        return None +    get_source = get_code  # same as get_code + +_importer = _SixMetaPathImporter(__name__) + + +class _MovedItems(_LazyModule): + +    """Lazy loading of moved objects""" +    __path__ = []  # mark as package + + +_moved_attributes = [ +    MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), +    MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), +    MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), +    MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), +    MovedAttribute("intern", "__builtin__", "sys"), +    MovedAttribute("map", "itertools", "builtins", "imap", "map"), +    MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), +    MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), +    MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), +    MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), +    MovedAttribute("reduce", "__builtin__", "functools"), +    MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), +    MovedAttribute("StringIO", "StringIO", "io"), +    MovedAttribute("UserDict", "UserDict", "collections"), +    MovedAttribute("UserList", "UserList", "collections"), +    MovedAttribute("UserString", "UserString", "collections"), +    MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), +    MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), +    MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), +    MovedModule("builtins", "__builtin__"), +    MovedModule("configparser", "ConfigParser"), +    MovedModule("copyreg", "copy_reg"), +    MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), +    MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), +    MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), +    MovedModule("http_cookies", "Cookie", "http.cookies"), +    MovedModule("html_entities", "htmlentitydefs", "html.entities"), +    MovedModule("html_parser", "HTMLParser", "html.parser"), +    MovedModule("http_client", "httplib", "http.client"), +    MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), +    MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), +    MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), +    MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), +    MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), +    MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), +    MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), +    MovedModule("cPickle", "cPickle", "pickle"), +    MovedModule("queue", "Queue"), +    MovedModule("reprlib", "repr"), +    MovedModule("socketserver", "SocketServer"), +    MovedModule("_thread", "thread", "_thread"), +    MovedModule("tkinter", "Tkinter"), +    MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), +    MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), +    MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), +    MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), +    MovedModule("tkinter_tix", "Tix", "tkinter.tix"), +    MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), +    MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), +    MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), +    MovedModule("tkinter_colorchooser", "tkColorChooser", +                "tkinter.colorchooser"), +    MovedModule("tkinter_commondialog", "tkCommonDialog", +                "tkinter.commondialog"), +    MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), +    MovedModule("tkinter_font", "tkFont", "tkinter.font"), +    MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), +    MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", +                "tkinter.simpledialog"), +    MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), +    MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), +    MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), +    MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), +    MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), +    MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), +] +# Add windows specific modules. +if sys.platform == "win32": +    _moved_attributes += [ +        MovedModule("winreg", "_winreg"), +    ] + +for attr in _moved_attributes: +    setattr(_MovedItems, attr.name, attr) +    if isinstance(attr, MovedModule): +        _importer._add_module(attr, "moves." + attr.name) +del attr + +_MovedItems._moved_attributes = _moved_attributes + +moves = _MovedItems(__name__ + ".moves") +_importer._add_module(moves, "moves") + + +class Module_six_moves_urllib_parse(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_parse""" + + +_urllib_parse_moved_attributes = [ +    MovedAttribute("ParseResult", "urlparse", "urllib.parse"), +    MovedAttribute("SplitResult", "urlparse", "urllib.parse"), +    MovedAttribute("parse_qs", "urlparse", "urllib.parse"), +    MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), +    MovedAttribute("urldefrag", "urlparse", "urllib.parse"), +    MovedAttribute("urljoin", "urlparse", "urllib.parse"), +    MovedAttribute("urlparse", "urlparse", "urllib.parse"), +    MovedAttribute("urlsplit", "urlparse", "urllib.parse"), +    MovedAttribute("urlunparse", "urlparse", "urllib.parse"), +    MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), +    MovedAttribute("quote", "urllib", "urllib.parse"), +    MovedAttribute("quote_plus", "urllib", "urllib.parse"), +    MovedAttribute("unquote", "urllib", "urllib.parse"), +    MovedAttribute("unquote_plus", "urllib", "urllib.parse"), +    MovedAttribute("urlencode", "urllib", "urllib.parse"), +    MovedAttribute("splitquery", "urllib", "urllib.parse"), +    MovedAttribute("splittag", "urllib", "urllib.parse"), +    MovedAttribute("splituser", "urllib", "urllib.parse"), +    MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), +    MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), +    MovedAttribute("uses_params", "urlparse", "urllib.parse"), +    MovedAttribute("uses_query", "urlparse", "urllib.parse"), +    MovedAttribute("uses_relative", "urlparse", "urllib.parse"), +] +for attr in _urllib_parse_moved_attributes: +    setattr(Module_six_moves_urllib_parse, attr.name, attr) +del attr + +Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes + +_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), +                      "moves.urllib_parse", "moves.urllib.parse") + + +class Module_six_moves_urllib_error(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_error""" + + +_urllib_error_moved_attributes = [ +    MovedAttribute("URLError", "urllib2", "urllib.error"), +    MovedAttribute("HTTPError", "urllib2", "urllib.error"), +    MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), +] +for attr in _urllib_error_moved_attributes: +    setattr(Module_six_moves_urllib_error, attr.name, attr) +del attr + +Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes + +_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), +                      "moves.urllib_error", "moves.urllib.error") + + +class Module_six_moves_urllib_request(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_request""" + + +_urllib_request_moved_attributes = [ +    MovedAttribute("urlopen", "urllib2", "urllib.request"), +    MovedAttribute("install_opener", "urllib2", "urllib.request"), +    MovedAttribute("build_opener", "urllib2", "urllib.request"), +    MovedAttribute("pathname2url", "urllib", "urllib.request"), +    MovedAttribute("url2pathname", "urllib", "urllib.request"), +    MovedAttribute("getproxies", "urllib", "urllib.request"), +    MovedAttribute("Request", "urllib2", "urllib.request"), +    MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), +    MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), +    MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), +    MovedAttribute("BaseHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), +    MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), +    MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), +    MovedAttribute("FileHandler", "urllib2", "urllib.request"), +    MovedAttribute("FTPHandler", "urllib2", "urllib.request"), +    MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), +    MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), +    MovedAttribute("urlretrieve", "urllib", "urllib.request"), +    MovedAttribute("urlcleanup", "urllib", "urllib.request"), +    MovedAttribute("URLopener", "urllib", "urllib.request"), +    MovedAttribute("FancyURLopener", "urllib", "urllib.request"), +    MovedAttribute("proxy_bypass", "urllib", "urllib.request"), +] +for attr in _urllib_request_moved_attributes: +    setattr(Module_six_moves_urllib_request, attr.name, attr) +del attr + +Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes + +_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), +                      "moves.urllib_request", "moves.urllib.request") + + +class Module_six_moves_urllib_response(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_response""" + + +_urllib_response_moved_attributes = [ +    MovedAttribute("addbase", "urllib", "urllib.response"), +    MovedAttribute("addclosehook", "urllib", "urllib.response"), +    MovedAttribute("addinfo", "urllib", "urllib.response"), +    MovedAttribute("addinfourl", "urllib", "urllib.response"), +] +for attr in _urllib_response_moved_attributes: +    setattr(Module_six_moves_urllib_response, attr.name, attr) +del attr + +Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes + +_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), +                      "moves.urllib_response", "moves.urllib.response") + + +class Module_six_moves_urllib_robotparser(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_robotparser""" + + +_urllib_robotparser_moved_attributes = [ +    MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), +] +for attr in _urllib_robotparser_moved_attributes: +    setattr(Module_six_moves_urllib_robotparser, attr.name, attr) +del attr + +Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes + +_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), +                      "moves.urllib_robotparser", "moves.urllib.robotparser") + + +class Module_six_moves_urllib(types.ModuleType): + +    """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" +    __path__ = []  # mark as package +    parse = _importer._get_module("moves.urllib_parse") +    error = _importer._get_module("moves.urllib_error") +    request = _importer._get_module("moves.urllib_request") +    response = _importer._get_module("moves.urllib_response") +    robotparser = _importer._get_module("moves.urllib_robotparser") + +    def __dir__(self): +        return ['parse', 'error', 'request', 'response', 'robotparser'] + +_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), +                      "moves.urllib") + + +def add_move(move): +    """Add an item to six.moves.""" +    setattr(_MovedItems, move.name, move) + + +def remove_move(name): +    """Remove item from six.moves.""" +    try: +        delattr(_MovedItems, name) +    except AttributeError: +        try: +            del moves.__dict__[name] +        except KeyError: +            raise AttributeError("no such move, %r" % (name,)) + + +if PY3: +    _meth_func = "__func__" +    _meth_self = "__self__" + +    _func_closure = "__closure__" +    _func_code = "__code__" +    _func_defaults = "__defaults__" +    _func_globals = "__globals__" +else: +    _meth_func = "im_func" +    _meth_self = "im_self" + +    _func_closure = "func_closure" +    _func_code = "func_code" +    _func_defaults = "func_defaults" +    _func_globals = "func_globals" + + +try: +    advance_iterator = next +except NameError: +    def advance_iterator(it): +        return it.next() +next = advance_iterator + + +try: +    callable = callable +except NameError: +    def callable(obj): +        return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) + + +if PY3: +    def get_unbound_function(unbound): +        return unbound + +    create_bound_method = types.MethodType + +    def create_unbound_method(func, cls): +        return func + +    Iterator = object +else: +    def get_unbound_function(unbound): +        return unbound.im_func + +    def create_bound_method(func, obj): +        return types.MethodType(func, obj, obj.__class__) + +    def create_unbound_method(func, cls): +        return types.MethodType(func, None, cls) + +    class Iterator(object): + +        def next(self): +            return type(self).__next__(self) + +    callable = callable +_add_doc(get_unbound_function, +         """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_closure = operator.attrgetter(_func_closure) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) +get_function_globals = operator.attrgetter(_func_globals) + + +if PY3: +    def iterkeys(d, **kw): +        return iter(d.keys(**kw)) + +    def itervalues(d, **kw): +        return iter(d.values(**kw)) + +    def iteritems(d, **kw): +        return iter(d.items(**kw)) + +    def iterlists(d, **kw): +        return iter(d.lists(**kw)) + +    viewkeys = operator.methodcaller("keys") + +    viewvalues = operator.methodcaller("values") + +    viewitems = operator.methodcaller("items") +else: +    def iterkeys(d, **kw): +        return d.iterkeys(**kw) + +    def itervalues(d, **kw): +        return d.itervalues(**kw) + +    def iteritems(d, **kw): +        return d.iteritems(**kw) + +    def iterlists(d, **kw): +        return d.iterlists(**kw) + +    viewkeys = operator.methodcaller("viewkeys") + +    viewvalues = operator.methodcaller("viewvalues") + +    viewitems = operator.methodcaller("viewitems") + +_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") +_add_doc(itervalues, "Return an iterator over the values of a dictionary.") +_add_doc(iteritems, +         "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc(iterlists, +         "Return an iterator over the (key, [values]) pairs of a dictionary.") + + +if PY3: +    def b(s): +        return s.encode("latin-1") + +    def u(s): +        return s +    unichr = chr +    import struct +    int2byte = struct.Struct(">B").pack +    del struct +    byte2int = operator.itemgetter(0) +    indexbytes = operator.getitem +    iterbytes = iter +    import io +    StringIO = io.StringIO +    BytesIO = io.BytesIO +    _assertCountEqual = "assertCountEqual" +    if sys.version_info[1] <= 1: +        _assertRaisesRegex = "assertRaisesRegexp" +        _assertRegex = "assertRegexpMatches" +    else: +        _assertRaisesRegex = "assertRaisesRegex" +        _assertRegex = "assertRegex" +else: +    def b(s): +        return s +    # Workaround for standalone backslash + +    def u(s): +        return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") +    unichr = unichr +    int2byte = chr + +    def byte2int(bs): +        return ord(bs[0]) + +    def indexbytes(buf, i): +        return ord(buf[i]) +    iterbytes = functools.partial(itertools.imap, ord) +    import StringIO +    StringIO = BytesIO = StringIO.StringIO +    _assertCountEqual = "assertItemsEqual" +    _assertRaisesRegex = "assertRaisesRegexp" +    _assertRegex = "assertRegexpMatches" +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +def assertCountEqual(self, *args, **kwargs): +    return getattr(self, _assertCountEqual)(*args, **kwargs) + + +def assertRaisesRegex(self, *args, **kwargs): +    return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +def assertRegex(self, *args, **kwargs): +    return getattr(self, _assertRegex)(*args, **kwargs) + + +if PY3: +    exec_ = getattr(moves.builtins, "exec") + +    def reraise(tp, value, tb=None): +        if value is None: +            value = tp() +        if value.__traceback__ is not tb: +            raise value.with_traceback(tb) +        raise value + +else: +    def exec_(_code_, _globs_=None, _locs_=None): +        """Execute code in a namespace.""" +        if _globs_ is None: +            frame = sys._getframe(1) +            _globs_ = frame.f_globals +            if _locs_ is None: +                _locs_ = frame.f_locals +            del frame +        elif _locs_ is None: +            _locs_ = _globs_ +        exec("""exec _code_ in _globs_, _locs_""") + +    exec_("""def reraise(tp, value, tb=None): +    raise tp, value, tb +""") + + +if sys.version_info[:2] == (3, 2): +    exec_("""def raise_from(value, from_value): +    if from_value is None: +        raise value +    raise value from from_value +""") +elif sys.version_info[:2] > (3, 2): +    exec_("""def raise_from(value, from_value): +    raise value from from_value +""") +else: +    def raise_from(value, from_value): +        raise value + + +print_ = getattr(moves.builtins, "print", None) +if print_ is None: +    def print_(*args, **kwargs): +        """The new-style print function for Python 2.4 and 2.5.""" +        fp = kwargs.pop("file", sys.stdout) +        if fp is None: +            return + +        def write(data): +            if not isinstance(data, basestring): +                data = str(data) +            # If the file has an encoding, encode unicode with it. +            if (isinstance(fp, file) and +                    isinstance(data, unicode) and +                    fp.encoding is not None): +                errors = getattr(fp, "errors", None) +                if errors is None: +                    errors = "strict" +                data = data.encode(fp.encoding, errors) +            fp.write(data) +        want_unicode = False +        sep = kwargs.pop("sep", None) +        if sep is not None: +            if isinstance(sep, unicode): +                want_unicode = True +            elif not isinstance(sep, str): +                raise TypeError("sep must be None or a string") +        end = kwargs.pop("end", None) +        if end is not None: +            if isinstance(end, unicode): +                want_unicode = True +            elif not isinstance(end, str): +                raise TypeError("end must be None or a string") +        if kwargs: +            raise TypeError("invalid keyword arguments to print()") +        if not want_unicode: +            for arg in args: +                if isinstance(arg, unicode): +                    want_unicode = True +                    break +        if want_unicode: +            newline = unicode("\n") +            space = unicode(" ") +        else: +            newline = "\n" +            space = " " +        if sep is None: +            sep = space +        if end is None: +            end = newline +        for i, arg in enumerate(args): +            if i: +                write(sep) +            write(arg) +        write(end) +if sys.version_info[:2] < (3, 3): +    _print = print_ + +    def print_(*args, **kwargs): +        fp = kwargs.get("file", sys.stdout) +        flush = kwargs.pop("flush", False) +        _print(*args, **kwargs) +        if flush and fp is not None: +            fp.flush() + +_add_doc(reraise, """Reraise an exception.""") + +if sys.version_info[0:2] < (3, 4): +    def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, +              updated=functools.WRAPPER_UPDATES): +        def wrapper(f): +            f = functools.wraps(wrapped, assigned, updated)(f) +            f.__wrapped__ = wrapped +            return f +        return wrapper +else: +    wraps = functools.wraps + + +def with_metaclass(meta, *bases): +    """Create a base class with a metaclass.""" +    # This requires a bit of explanation: the basic idea is to make a dummy +    # metaclass for one level of class instantiation that replaces itself with +    # the actual metaclass. +    class metaclass(meta): + +        def __new__(cls, name, this_bases, d): +            return meta(name, bases, d) +    return type.__new__(metaclass, 'temporary_class', (), {}) + + +def add_metaclass(metaclass): +    """Class decorator for creating a class with a metaclass.""" +    def wrapper(cls): +        orig_vars = cls.__dict__.copy() +        slots = orig_vars.get('__slots__') +        if slots is not None: +            if isinstance(slots, str): +                slots = [slots] +            for slots_var in slots: +                orig_vars.pop(slots_var) +        orig_vars.pop('__dict__', None) +        orig_vars.pop('__weakref__', None) +        return metaclass(cls.__name__, cls.__bases__, orig_vars) +    return wrapper + + +def python_2_unicode_compatible(klass): +    """ +    A decorator that defines __unicode__ and __str__ methods under Python 2. +    Under Python 3 it does nothing. + +    To support Python 2 and 3 with a single code base, define a __str__ method +    returning text and apply this decorator to the class. +    """ +    if PY2: +        if '__str__' not in klass.__dict__: +            raise ValueError("@python_2_unicode_compatible cannot be applied " +                             "to %s because it doesn't define __str__()." % +                             klass.__name__) +        klass.__unicode__ = klass.__str__ +        klass.__str__ = lambda self: self.__unicode__().encode('utf-8') +    return klass + + +# Complete the moves implementation. +# This code is at the end of this module to speed up module loading. +# Turn this module into a package. +__path__ = []  # required for PEP 302 and PEP 451 +__package__ = __name__  # see PEP 366 @ReservedAssignment +if globals().get("__spec__") is not None: +    __spec__.submodule_search_locations = []  # PEP 451 @UndefinedVariable +# Remove other six meta path importers, since they cause problems. This can +# happen if six is removed from sys.modules and then reloaded. (Setuptools does +# this for some reason.) +if sys.meta_path: +    for i, importer in enumerate(sys.meta_path): +        # Here's some real nastiness: Another "instance" of the six module might +        # be floating around. Therefore, we can't use isinstance() to check for +        # the six meta path importer, since the other six instance will have +        # inserted an importer with different class. +        if (type(importer).__name__ == "_SixMetaPathImporter" and +                importer.name == __name__): +            del sys.meta_path[i] +            break +    del i, importer +# Finally, add the importer to the meta path import hook. +sys.meta_path.append(_importer) diff --git a/python/urllib3/packages/ssl_match_hostname/__init__.py b/python/urllib3/packages/ssl_match_hostname/__init__.py new file mode 100644 index 0000000..d6594eb --- /dev/null +++ b/python/urllib3/packages/ssl_match_hostname/__init__.py @@ -0,0 +1,19 @@ +import sys + +try: +    # Our match_hostname function is the same as 3.5's, so we only want to +    # import the match_hostname function if it's at least that good. +    if sys.version_info < (3, 5): +        raise ImportError("Fallback to vendored code") + +    from ssl import CertificateError, match_hostname +except ImportError: +    try: +        # Backport of the function from a pypi module +        from backports.ssl_match_hostname import CertificateError, match_hostname +    except ImportError: +        # Our vendored copy +        from ._implementation import CertificateError, match_hostname + +# Not needed, but documenting what we provide. +__all__ = ('CertificateError', 'match_hostname') diff --git a/python/urllib3/packages/ssl_match_hostname/_implementation.py b/python/urllib3/packages/ssl_match_hostname/_implementation.py new file mode 100644 index 0000000..d6e66c0 --- /dev/null +++ b/python/urllib3/packages/ssl_match_hostname/_implementation.py @@ -0,0 +1,156 @@ +"""The match_hostname() function from Python 3.3.3, essential when using SSL.""" + +# Note: This file is under the PSF license as the code comes from the python +# stdlib.   http://docs.python.org/3/license.html + +import re +import sys + +# ipaddress has been backported to 2.6+ in pypi.  If it is installed on the +# system, use it to handle IPAddress ServerAltnames (this was added in +# python-3.5) otherwise only do DNS matching.  This allows +# backports.ssl_match_hostname to continue to be used in Python 2.7. +try: +    import ipaddress +except ImportError: +    ipaddress = None + +__version__ = '3.5.0.1' + + +class CertificateError(ValueError): +    pass + + +def _dnsname_match(dn, hostname, max_wildcards=1): +    """Matching according to RFC 6125, section 6.4.3 + +    http://tools.ietf.org/html/rfc6125#section-6.4.3 +    """ +    pats = [] +    if not dn: +        return False + +    # Ported from python3-syntax: +    # leftmost, *remainder = dn.split(r'.') +    parts = dn.split(r'.') +    leftmost = parts[0] +    remainder = parts[1:] + +    wildcards = leftmost.count('*') +    if wildcards > max_wildcards: +        # Issue #17980: avoid denials of service by refusing more +        # than one wildcard per fragment.  A survey of established +        # policy among SSL implementations showed it to be a +        # reasonable choice. +        raise CertificateError( +            "too many wildcards in certificate DNS name: " + repr(dn)) + +    # speed up common case w/o wildcards +    if not wildcards: +        return dn.lower() == hostname.lower() + +    # RFC 6125, section 6.4.3, subitem 1. +    # The client SHOULD NOT attempt to match a presented identifier in which +    # the wildcard character comprises a label other than the left-most label. +    if leftmost == '*': +        # When '*' is a fragment by itself, it matches a non-empty dotless +        # fragment. +        pats.append('[^.]+') +    elif leftmost.startswith('xn--') or hostname.startswith('xn--'): +        # RFC 6125, section 6.4.3, subitem 3. +        # The client SHOULD NOT attempt to match a presented identifier +        # where the wildcard character is embedded within an A-label or +        # U-label of an internationalized domain name. +        pats.append(re.escape(leftmost)) +    else: +        # Otherwise, '*' matches any dotless string, e.g. www* +        pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) + +    # add the remaining fragments, ignore any wildcards +    for frag in remainder: +        pats.append(re.escape(frag)) + +    pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) +    return pat.match(hostname) + + +def _to_unicode(obj): +    if isinstance(obj, str) and sys.version_info < (3,): +        obj = unicode(obj, encoding='ascii', errors='strict') +    return obj + +def _ipaddress_match(ipname, host_ip): +    """Exact matching of IP addresses. + +    RFC 6125 explicitly doesn't define an algorithm for this +    (section 1.7.2 - "Out of Scope"). +    """ +    # OpenSSL may add a trailing newline to a subjectAltName's IP address +    # Divergence from upstream: ipaddress can't handle byte str +    ip = ipaddress.ip_address(_to_unicode(ipname).rstrip()) +    return ip == host_ip + + +def match_hostname(cert, hostname): +    """Verify that *cert* (in decoded format as returned by +    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125 +    rules are followed, but IP addresses are not accepted for *hostname*. + +    CertificateError is raised on failure. On success, the function +    returns nothing. +    """ +    if not cert: +        raise ValueError("empty or no certificate, match_hostname needs a " +                         "SSL socket or SSL context with either " +                         "CERT_OPTIONAL or CERT_REQUIRED") +    try: +        # Divergence from upstream: ipaddress can't handle byte str +        host_ip = ipaddress.ip_address(_to_unicode(hostname)) +    except ValueError: +        # Not an IP address (common case) +        host_ip = None +    except UnicodeError: +        # Divergence from upstream: Have to deal with ipaddress not taking +        # byte strings.  addresses should be all ascii, so we consider it not +        # an ipaddress in this case +        host_ip = None +    except AttributeError: +        # Divergence from upstream: Make ipaddress library optional +        if ipaddress is None: +            host_ip = None +        else: +            raise +    dnsnames = [] +    san = cert.get('subjectAltName', ()) +    for key, value in san: +        if key == 'DNS': +            if host_ip is None and _dnsname_match(value, hostname): +                return +            dnsnames.append(value) +        elif key == 'IP Address': +            if host_ip is not None and _ipaddress_match(value, host_ip): +                return +            dnsnames.append(value) +    if not dnsnames: +        # The subject is only checked when there is no dNSName entry +        # in subjectAltName +        for sub in cert.get('subject', ()): +            for key, value in sub: +                # XXX according to RFC 2818, the most specific Common Name +                # must be used. +                if key == 'commonName': +                    if _dnsname_match(value, hostname): +                        return +                    dnsnames.append(value) +    if len(dnsnames) > 1: +        raise CertificateError("hostname %r " +            "doesn't match either of %s" +            % (hostname, ', '.join(map(repr, dnsnames)))) +    elif len(dnsnames) == 1: +        raise CertificateError("hostname %r " +            "doesn't match %r" +            % (hostname, dnsnames[0])) +    else: +        raise CertificateError("no appropriate commonName or " +            "subjectAltName fields were found") diff --git a/python/urllib3/poolmanager.py b/python/urllib3/poolmanager.py new file mode 100644 index 0000000..fe5491c --- /dev/null +++ b/python/urllib3/poolmanager.py @@ -0,0 +1,450 @@ +from __future__ import absolute_import +import collections +import functools +import logging + +from ._collections import RecentlyUsedContainer +from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool +from .connectionpool import port_by_scheme +from .exceptions import LocationValueError, MaxRetryError, ProxySchemeUnknown +from .packages.six.moves.urllib.parse import urljoin +from .request import RequestMethods +from .util.url import parse_url +from .util.retry import Retry + + +__all__ = ['PoolManager', 'ProxyManager', 'proxy_from_url'] + + +log = logging.getLogger(__name__) + +SSL_KEYWORDS = ('key_file', 'cert_file', 'cert_reqs', 'ca_certs', +                'ssl_version', 'ca_cert_dir', 'ssl_context') + +# All known keyword arguments that could be provided to the pool manager, its +# pools, or the underlying connections. This is used to construct a pool key. +_key_fields = ( +    'key_scheme',  # str +    'key_host',  # str +    'key_port',  # int +    'key_timeout',  # int or float or Timeout +    'key_retries',  # int or Retry +    'key_strict',  # bool +    'key_block',  # bool +    'key_source_address',  # str +    'key_key_file',  # str +    'key_cert_file',  # str +    'key_cert_reqs',  # str +    'key_ca_certs',  # str +    'key_ssl_version',  # str +    'key_ca_cert_dir',  # str +    'key_ssl_context',  # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext +    'key_maxsize',  # int +    'key_headers',  # dict +    'key__proxy',  # parsed proxy url +    'key__proxy_headers',  # dict +    'key_socket_options',  # list of (level (int), optname (int), value (int or str)) tuples +    'key__socks_options',  # dict +    'key_assert_hostname',  # bool or string +    'key_assert_fingerprint',  # str +    'key_server_hostname', #str +) + +#: The namedtuple class used to construct keys for the connection pool. +#: All custom key schemes should include the fields in this key at a minimum. +PoolKey = collections.namedtuple('PoolKey', _key_fields) + + +def _default_key_normalizer(key_class, request_context): +    """ +    Create a pool key out of a request context dictionary. + +    According to RFC 3986, both the scheme and host are case-insensitive. +    Therefore, this function normalizes both before constructing the pool +    key for an HTTPS request. If you wish to change this behaviour, provide +    alternate callables to ``key_fn_by_scheme``. + +    :param key_class: +        The class to use when constructing the key. This should be a namedtuple +        with the ``scheme`` and ``host`` keys at a minimum. +    :type  key_class: namedtuple +    :param request_context: +        A dictionary-like object that contain the context for a request. +    :type  request_context: dict + +    :return: A namedtuple that can be used as a connection pool key. +    :rtype:  PoolKey +    """ +    # Since we mutate the dictionary, make a copy first +    context = request_context.copy() +    context['scheme'] = context['scheme'].lower() +    context['host'] = context['host'].lower() + +    # These are both dictionaries and need to be transformed into frozensets +    for key in ('headers', '_proxy_headers', '_socks_options'): +        if key in context and context[key] is not None: +            context[key] = frozenset(context[key].items()) + +    # The socket_options key may be a list and needs to be transformed into a +    # tuple. +    socket_opts = context.get('socket_options') +    if socket_opts is not None: +        context['socket_options'] = tuple(socket_opts) + +    # Map the kwargs to the names in the namedtuple - this is necessary since +    # namedtuples can't have fields starting with '_'. +    for key in list(context.keys()): +        context['key_' + key] = context.pop(key) + +    # Default to ``None`` for keys missing from the context +    for field in key_class._fields: +        if field not in context: +            context[field] = None + +    return key_class(**context) + + +#: A dictionary that maps a scheme to a callable that creates a pool key. +#: This can be used to alter the way pool keys are constructed, if desired. +#: Each PoolManager makes a copy of this dictionary so they can be configured +#: globally here, or individually on the instance. +key_fn_by_scheme = { +    'http': functools.partial(_default_key_normalizer, PoolKey), +    'https': functools.partial(_default_key_normalizer, PoolKey), +} + +pool_classes_by_scheme = { +    'http': HTTPConnectionPool, +    'https': HTTPSConnectionPool, +} + + +class PoolManager(RequestMethods): +    """ +    Allows for arbitrary requests while transparently keeping track of +    necessary connection pools for you. + +    :param num_pools: +        Number of connection pools to cache before discarding the least +        recently used pool. + +    :param headers: +        Headers to include with all requests, unless other headers are given +        explicitly. + +    :param \\**connection_pool_kw: +        Additional parameters are used to create fresh +        :class:`urllib3.connectionpool.ConnectionPool` instances. + +    Example:: + +        >>> manager = PoolManager(num_pools=2) +        >>> r = manager.request('GET', 'http://google.com/') +        >>> r = manager.request('GET', 'http://google.com/mail') +        >>> r = manager.request('GET', 'http://yahoo.com/') +        >>> len(manager.pools) +        2 + +    """ + +    proxy = None + +    def __init__(self, num_pools=10, headers=None, **connection_pool_kw): +        RequestMethods.__init__(self, headers) +        self.connection_pool_kw = connection_pool_kw +        self.pools = RecentlyUsedContainer(num_pools, +                                           dispose_func=lambda p: p.close()) + +        # Locally set the pool classes and keys so other PoolManagers can +        # override them. +        self.pool_classes_by_scheme = pool_classes_by_scheme +        self.key_fn_by_scheme = key_fn_by_scheme.copy() + +    def __enter__(self): +        return self + +    def __exit__(self, exc_type, exc_val, exc_tb): +        self.clear() +        # Return False to re-raise any potential exceptions +        return False + +    def _new_pool(self, scheme, host, port, request_context=None): +        """ +        Create a new :class:`ConnectionPool` based on host, port, scheme, and +        any additional pool keyword arguments. + +        If ``request_context`` is provided, it is provided as keyword arguments +        to the pool class used. This method is used to actually create the +        connection pools handed out by :meth:`connection_from_url` and +        companion methods. It is intended to be overridden for customization. +        """ +        pool_cls = self.pool_classes_by_scheme[scheme] +        if request_context is None: +            request_context = self.connection_pool_kw.copy() + +        # Although the context has everything necessary to create the pool, +        # this function has historically only used the scheme, host, and port +        # in the positional args. When an API change is acceptable these can +        # be removed. +        for key in ('scheme', 'host', 'port'): +            request_context.pop(key, None) + +        if scheme == 'http': +            for kw in SSL_KEYWORDS: +                request_context.pop(kw, None) + +        return pool_cls(host, port, **request_context) + +    def clear(self): +        """ +        Empty our store of pools and direct them all to close. + +        This will not affect in-flight connections, but they will not be +        re-used after completion. +        """ +        self.pools.clear() + +    def connection_from_host(self, host, port=None, scheme='http', pool_kwargs=None): +        """ +        Get a :class:`ConnectionPool` based on the host, port, and scheme. + +        If ``port`` isn't given, it will be derived from the ``scheme`` using +        ``urllib3.connectionpool.port_by_scheme``. If ``pool_kwargs`` is +        provided, it is merged with the instance's ``connection_pool_kw`` +        variable and used to create the new connection pool, if one is +        needed. +        """ + +        if not host: +            raise LocationValueError("No host specified.") + +        request_context = self._merge_pool_kwargs(pool_kwargs) +        request_context['scheme'] = scheme or 'http' +        if not port: +            port = port_by_scheme.get(request_context['scheme'].lower(), 80) +        request_context['port'] = port +        request_context['host'] = host + +        return self.connection_from_context(request_context) + +    def connection_from_context(self, request_context): +        """ +        Get a :class:`ConnectionPool` based on the request context. + +        ``request_context`` must at least contain the ``scheme`` key and its +        value must be a key in ``key_fn_by_scheme`` instance variable. +        """ +        scheme = request_context['scheme'].lower() +        pool_key_constructor = self.key_fn_by_scheme[scheme] +        pool_key = pool_key_constructor(request_context) + +        return self.connection_from_pool_key(pool_key, request_context=request_context) + +    def connection_from_pool_key(self, pool_key, request_context=None): +        """ +        Get a :class:`ConnectionPool` based on the provided pool key. + +        ``pool_key`` should be a namedtuple that only contains immutable +        objects. At a minimum it must have the ``scheme``, ``host``, and +        ``port`` fields. +        """ +        with self.pools.lock: +            # If the scheme, host, or port doesn't match existing open +            # connections, open a new ConnectionPool. +            pool = self.pools.get(pool_key) +            if pool: +                return pool + +            # Make a fresh ConnectionPool of the desired type +            scheme = request_context['scheme'] +            host = request_context['host'] +            port = request_context['port'] +            pool = self._new_pool(scheme, host, port, request_context=request_context) +            self.pools[pool_key] = pool + +        return pool + +    def connection_from_url(self, url, pool_kwargs=None): +        """ +        Similar to :func:`urllib3.connectionpool.connection_from_url`. + +        If ``pool_kwargs`` is not provided and a new pool needs to be +        constructed, ``self.connection_pool_kw`` is used to initialize +        the :class:`urllib3.connectionpool.ConnectionPool`. If ``pool_kwargs`` +        is provided, it is used instead. Note that if a new pool does not +        need to be created for the request, the provided ``pool_kwargs`` are +        not used. +        """ +        u = parse_url(url) +        return self.connection_from_host(u.host, port=u.port, scheme=u.scheme, +                                         pool_kwargs=pool_kwargs) + +    def _merge_pool_kwargs(self, override): +        """ +        Merge a dictionary of override values for self.connection_pool_kw. + +        This does not modify self.connection_pool_kw and returns a new dict. +        Any keys in the override dictionary with a value of ``None`` are +        removed from the merged dictionary. +        """ +        base_pool_kwargs = self.connection_pool_kw.copy() +        if override: +            for key, value in override.items(): +                if value is None: +                    try: +                        del base_pool_kwargs[key] +                    except KeyError: +                        pass +                else: +                    base_pool_kwargs[key] = value +        return base_pool_kwargs + +    def urlopen(self, method, url, redirect=True, **kw): +        """ +        Same as :meth:`urllib3.connectionpool.HTTPConnectionPool.urlopen` +        with custom cross-host redirect logic and only sends the request-uri +        portion of the ``url``. + +        The given ``url`` parameter must be absolute, such that an appropriate +        :class:`urllib3.connectionpool.ConnectionPool` can be chosen for it. +        """ +        u = parse_url(url) +        conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme) + +        kw['assert_same_host'] = False +        kw['redirect'] = False + +        if 'headers' not in kw: +            kw['headers'] = self.headers.copy() + +        if self.proxy is not None and u.scheme == "http": +            response = conn.urlopen(method, url, **kw) +        else: +            response = conn.urlopen(method, u.request_uri, **kw) + +        redirect_location = redirect and response.get_redirect_location() +        if not redirect_location: +            return response + +        # Support relative URLs for redirecting. +        redirect_location = urljoin(url, redirect_location) + +        # RFC 7231, Section 6.4.4 +        if response.status == 303: +            method = 'GET' + +        retries = kw.get('retries') +        if not isinstance(retries, Retry): +            retries = Retry.from_int(retries, redirect=redirect) + +        # Strip headers marked as unsafe to forward to the redirected location. +        # Check remove_headers_on_redirect to avoid a potential network call within +        # conn.is_same_host() which may use socket.gethostbyname() in the future. +        if (retries.remove_headers_on_redirect +                and not conn.is_same_host(redirect_location)): +            for header in retries.remove_headers_on_redirect: +                kw['headers'].pop(header, None) + +        try: +            retries = retries.increment(method, url, response=response, _pool=conn) +        except MaxRetryError: +            if retries.raise_on_redirect: +                raise +            return response + +        kw['retries'] = retries +        kw['redirect'] = redirect + +        log.info("Redirecting %s -> %s", url, redirect_location) +        return self.urlopen(method, redirect_location, **kw) + + +class ProxyManager(PoolManager): +    """ +    Behaves just like :class:`PoolManager`, but sends all requests through +    the defined proxy, using the CONNECT method for HTTPS URLs. + +    :param proxy_url: +        The URL of the proxy to be used. + +    :param proxy_headers: +        A dictionary containing headers that will be sent to the proxy. In case +        of HTTP they are being sent with each request, while in the +        HTTPS/CONNECT case they are sent only once. Could be used for proxy +        authentication. + +    Example: +        >>> proxy = urllib3.ProxyManager('http://localhost:3128/') +        >>> r1 = proxy.request('GET', 'http://google.com/') +        >>> r2 = proxy.request('GET', 'http://httpbin.org/') +        >>> len(proxy.pools) +        1 +        >>> r3 = proxy.request('GET', 'https://httpbin.org/') +        >>> r4 = proxy.request('GET', 'https://twitter.com/') +        >>> len(proxy.pools) +        3 + +    """ + +    def __init__(self, proxy_url, num_pools=10, headers=None, +                 proxy_headers=None, **connection_pool_kw): + +        if isinstance(proxy_url, HTTPConnectionPool): +            proxy_url = '%s://%s:%i' % (proxy_url.scheme, proxy_url.host, +                                        proxy_url.port) +        proxy = parse_url(proxy_url) +        if not proxy.port: +            port = port_by_scheme.get(proxy.scheme, 80) +            proxy = proxy._replace(port=port) + +        if proxy.scheme not in ("http", "https"): +            raise ProxySchemeUnknown(proxy.scheme) + +        self.proxy = proxy +        self.proxy_headers = proxy_headers or {} + +        connection_pool_kw['_proxy'] = self.proxy +        connection_pool_kw['_proxy_headers'] = self.proxy_headers + +        super(ProxyManager, self).__init__( +            num_pools, headers, **connection_pool_kw) + +    def connection_from_host(self, host, port=None, scheme='http', pool_kwargs=None): +        if scheme == "https": +            return super(ProxyManager, self).connection_from_host( +                host, port, scheme, pool_kwargs=pool_kwargs) + +        return super(ProxyManager, self).connection_from_host( +            self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs) + +    def _set_proxy_headers(self, url, headers=None): +        """ +        Sets headers needed by proxies: specifically, the Accept and Host +        headers. Only sets headers not provided by the user. +        """ +        headers_ = {'Accept': '*/*'} + +        netloc = parse_url(url).netloc +        if netloc: +            headers_['Host'] = netloc + +        if headers: +            headers_.update(headers) +        return headers_ + +    def urlopen(self, method, url, redirect=True, **kw): +        "Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute." +        u = parse_url(url) + +        if u.scheme == "http": +            # For proxied HTTPS requests, httplib sets the necessary headers +            # on the CONNECT to the proxy. For HTTP, we'll definitely +            # need to set 'Host' at the very least. +            headers = kw.get('headers', self.headers) +            kw['headers'] = self._set_proxy_headers(url, headers) + +        return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw) + + +def proxy_from_url(url, **kw): +    return ProxyManager(proxy_url=url, **kw) diff --git a/python/urllib3/request.py b/python/urllib3/request.py new file mode 100644 index 0000000..8f2f44b --- /dev/null +++ b/python/urllib3/request.py @@ -0,0 +1,150 @@ +from __future__ import absolute_import + +from .filepost import encode_multipart_formdata +from .packages.six.moves.urllib.parse import urlencode + + +__all__ = ['RequestMethods'] + + +class RequestMethods(object): +    """ +    Convenience mixin for classes who implement a :meth:`urlopen` method, such +    as :class:`~urllib3.connectionpool.HTTPConnectionPool` and +    :class:`~urllib3.poolmanager.PoolManager`. + +    Provides behavior for making common types of HTTP request methods and +    decides which type of request field encoding to use. + +    Specifically, + +    :meth:`.request_encode_url` is for sending requests whose fields are +    encoded in the URL (such as GET, HEAD, DELETE). + +    :meth:`.request_encode_body` is for sending requests whose fields are +    encoded in the *body* of the request using multipart or www-form-urlencoded +    (such as for POST, PUT, PATCH). + +    :meth:`.request` is for making any kind of request, it will look up the +    appropriate encoding format and use one of the above two methods to make +    the request. + +    Initializer parameters: + +    :param headers: +        Headers to include with all requests, unless other headers are given +        explicitly. +    """ + +    _encode_url_methods = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + +    def __init__(self, headers=None): +        self.headers = headers or {} + +    def urlopen(self, method, url, body=None, headers=None, +                encode_multipart=True, multipart_boundary=None, +                **kw):  # Abstract +        raise NotImplementedError("Classes extending RequestMethods must implement " +                                  "their own ``urlopen`` method.") + +    def request(self, method, url, fields=None, headers=None, **urlopen_kw): +        """ +        Make a request using :meth:`urlopen` with the appropriate encoding of +        ``fields`` based on the ``method`` used. + +        This is a convenience method that requires the least amount of manual +        effort. It can be used in most situations, while still having the +        option to drop down to more specific methods when necessary, such as +        :meth:`request_encode_url`, :meth:`request_encode_body`, +        or even the lowest level :meth:`urlopen`. +        """ +        method = method.upper() + +        urlopen_kw['request_url'] = url + +        if method in self._encode_url_methods: +            return self.request_encode_url(method, url, fields=fields, +                                           headers=headers, +                                           **urlopen_kw) +        else: +            return self.request_encode_body(method, url, fields=fields, +                                            headers=headers, +                                            **urlopen_kw) + +    def request_encode_url(self, method, url, fields=None, headers=None, +                           **urlopen_kw): +        """ +        Make a request using :meth:`urlopen` with the ``fields`` encoded in +        the url. This is useful for request methods like GET, HEAD, DELETE, etc. +        """ +        if headers is None: +            headers = self.headers + +        extra_kw = {'headers': headers} +        extra_kw.update(urlopen_kw) + +        if fields: +            url += '?' + urlencode(fields) + +        return self.urlopen(method, url, **extra_kw) + +    def request_encode_body(self, method, url, fields=None, headers=None, +                            encode_multipart=True, multipart_boundary=None, +                            **urlopen_kw): +        """ +        Make a request using :meth:`urlopen` with the ``fields`` encoded in +        the body. This is useful for request methods like POST, PUT, PATCH, etc. + +        When ``encode_multipart=True`` (default), then +        :meth:`urllib3.filepost.encode_multipart_formdata` is used to encode +        the payload with the appropriate content type. Otherwise +        :meth:`urllib.urlencode` is used with the +        'application/x-www-form-urlencoded' content type. + +        Multipart encoding must be used when posting files, and it's reasonably +        safe to use it in other times too. However, it may break request +        signing, such as with OAuth. + +        Supports an optional ``fields`` parameter of key/value strings AND +        key/filetuple. A filetuple is a (filename, data, MIME type) tuple where +        the MIME type is optional. For example:: + +            fields = { +                'foo': 'bar', +                'fakefile': ('foofile.txt', 'contents of foofile'), +                'realfile': ('barfile.txt', open('realfile').read()), +                'typedfile': ('bazfile.bin', open('bazfile').read(), +                              'image/jpeg'), +                'nonamefile': 'contents of nonamefile field', +            } + +        When uploading a file, providing a filename (the first parameter of the +        tuple) is optional but recommended to best mimic behavior of browsers. + +        Note that if ``headers`` are supplied, the 'Content-Type' header will +        be overwritten because it depends on the dynamic random boundary string +        which is used to compose the body of the request. The random boundary +        string can be explicitly set with the ``multipart_boundary`` parameter. +        """ +        if headers is None: +            headers = self.headers + +        extra_kw = {'headers': {}} + +        if fields: +            if 'body' in urlopen_kw: +                raise TypeError( +                    "request got values for both 'fields' and 'body', can only specify one.") + +            if encode_multipart: +                body, content_type = encode_multipart_formdata(fields, boundary=multipart_boundary) +            else: +                body, content_type = urlencode(fields), 'application/x-www-form-urlencoded' + +            extra_kw['body'] = body +            extra_kw['headers'] = {'Content-Type': content_type} + +        extra_kw['headers'].update(headers) +        extra_kw.update(urlopen_kw) + +        return self.urlopen(method, url, **extra_kw) diff --git a/python/urllib3/response.py b/python/urllib3/response.py new file mode 100644 index 0000000..c112690 --- /dev/null +++ b/python/urllib3/response.py @@ -0,0 +1,705 @@ +from __future__ import absolute_import +from contextlib import contextmanager +import zlib +import io +import logging +from socket import timeout as SocketTimeout +from socket import error as SocketError + +from ._collections import HTTPHeaderDict +from .exceptions import ( +    BodyNotHttplibCompatible, ProtocolError, DecodeError, ReadTimeoutError, +    ResponseNotChunked, IncompleteRead, InvalidHeader +) +from .packages.six import string_types as basestring, PY3 +from .packages.six.moves import http_client as httplib +from .connection import HTTPException, BaseSSLError +from .util.response import is_fp_closed, is_response_to_head + +log = logging.getLogger(__name__) + + +class DeflateDecoder(object): + +    def __init__(self): +        self._first_try = True +        self._data = b'' +        self._obj = zlib.decompressobj() + +    def __getattr__(self, name): +        return getattr(self._obj, name) + +    def decompress(self, data): +        if not data: +            return data + +        if not self._first_try: +            return self._obj.decompress(data) + +        self._data += data +        try: +            decompressed = self._obj.decompress(data) +            if decompressed: +                self._first_try = False +                self._data = None +            return decompressed +        except zlib.error: +            self._first_try = False +            self._obj = zlib.decompressobj(-zlib.MAX_WBITS) +            try: +                return self.decompress(self._data) +            finally: +                self._data = None + + +class GzipDecoderState(object): + +    FIRST_MEMBER = 0 +    OTHER_MEMBERS = 1 +    SWALLOW_DATA = 2 + + +class GzipDecoder(object): + +    def __init__(self): +        self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) +        self._state = GzipDecoderState.FIRST_MEMBER + +    def __getattr__(self, name): +        return getattr(self._obj, name) + +    def decompress(self, data): +        ret = bytearray() +        if self._state == GzipDecoderState.SWALLOW_DATA or not data: +            return bytes(ret) +        while True: +            try: +                ret += self._obj.decompress(data) +            except zlib.error: +                previous_state = self._state +                # Ignore data after the first error +                self._state = GzipDecoderState.SWALLOW_DATA +                if previous_state == GzipDecoderState.OTHER_MEMBERS: +                    # Allow trailing garbage acceptable in other gzip clients +                    return bytes(ret) +                raise +            data = self._obj.unused_data +            if not data: +                return bytes(ret) +            self._state = GzipDecoderState.OTHER_MEMBERS +            self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) + + +class MultiDecoder(object): +    """ +    From RFC7231: +        If one or more encodings have been applied to a representation, the +        sender that applied the encodings MUST generate a Content-Encoding +        header field that lists the content codings in the order in which +        they were applied. +    """ + +    def __init__(self, modes): +        self._decoders = [_get_decoder(m.strip()) for m in modes.split(',')] + +    def flush(self): +        return self._decoders[0].flush() + +    def decompress(self, data): +        for d in reversed(self._decoders): +            data = d.decompress(data) +        return data + + +def _get_decoder(mode): +    if ',' in mode: +        return MultiDecoder(mode) + +    if mode == 'gzip': +        return GzipDecoder() + +    return DeflateDecoder() + + +class HTTPResponse(io.IOBase): +    """ +    HTTP Response container. + +    Backwards-compatible to httplib's HTTPResponse but the response ``body`` is +    loaded and decoded on-demand when the ``data`` property is accessed.  This +    class is also compatible with the Python standard library's :mod:`io` +    module, and can hence be treated as a readable object in the context of that +    framework. + +    Extra parameters for behaviour not present in httplib.HTTPResponse: + +    :param preload_content: +        If True, the response's body will be preloaded during construction. + +    :param decode_content: +        If True, will attempt to decode the body based on the +        'content-encoding' header. + +    :param original_response: +        When this HTTPResponse wrapper is generated from an httplib.HTTPResponse +        object, it's convenient to include the original for debug purposes. It's +        otherwise unused. + +    :param retries: +        The retries contains the last :class:`~urllib3.util.retry.Retry` that +        was used during the request. + +    :param enforce_content_length: +        Enforce content length checking. Body returned by server must match +        value of Content-Length header, if present. Otherwise, raise error. +    """ + +    CONTENT_DECODERS = ['gzip', 'deflate'] +    REDIRECT_STATUSES = [301, 302, 303, 307, 308] + +    def __init__(self, body='', headers=None, status=0, version=0, reason=None, +                 strict=0, preload_content=True, decode_content=True, +                 original_response=None, pool=None, connection=None, msg=None, +                 retries=None, enforce_content_length=False, +                 request_method=None, request_url=None): + +        if isinstance(headers, HTTPHeaderDict): +            self.headers = headers +        else: +            self.headers = HTTPHeaderDict(headers) +        self.status = status +        self.version = version +        self.reason = reason +        self.strict = strict +        self.decode_content = decode_content +        self.retries = retries +        self.enforce_content_length = enforce_content_length + +        self._decoder = None +        self._body = None +        self._fp = None +        self._original_response = original_response +        self._fp_bytes_read = 0 +        self.msg = msg +        self._request_url = request_url + +        if body and isinstance(body, (basestring, bytes)): +            self._body = body + +        self._pool = pool +        self._connection = connection + +        if hasattr(body, 'read'): +            self._fp = body + +        # Are we using the chunked-style of transfer encoding? +        self.chunked = False +        self.chunk_left = None +        tr_enc = self.headers.get('transfer-encoding', '').lower() +        # Don't incur the penalty of creating a list and then discarding it +        encodings = (enc.strip() for enc in tr_enc.split(",")) +        if "chunked" in encodings: +            self.chunked = True + +        # Determine length of response +        self.length_remaining = self._init_length(request_method) + +        # If requested, preload the body. +        if preload_content and not self._body: +            self._body = self.read(decode_content=decode_content) + +    def get_redirect_location(self): +        """ +        Should we redirect and where to? + +        :returns: Truthy redirect location string if we got a redirect status +            code and valid location. ``None`` if redirect status and no +            location. ``False`` if not a redirect status code. +        """ +        if self.status in self.REDIRECT_STATUSES: +            return self.headers.get('location') + +        return False + +    def release_conn(self): +        if not self._pool or not self._connection: +            return + +        self._pool._put_conn(self._connection) +        self._connection = None + +    @property +    def data(self): +        # For backwords-compat with earlier urllib3 0.4 and earlier. +        if self._body: +            return self._body + +        if self._fp: +            return self.read(cache_content=True) + +    @property +    def connection(self): +        return self._connection + +    def isclosed(self): +        return is_fp_closed(self._fp) + +    def tell(self): +        """ +        Obtain the number of bytes pulled over the wire so far. May differ from +        the amount of content returned by :meth:``HTTPResponse.read`` if bytes +        are encoded on the wire (e.g, compressed). +        """ +        return self._fp_bytes_read + +    def _init_length(self, request_method): +        """ +        Set initial length value for Response content if available. +        """ +        length = self.headers.get('content-length') + +        if length is not None: +            if self.chunked: +                # This Response will fail with an IncompleteRead if it can't be +                # received as chunked. This method falls back to attempt reading +                # the response before raising an exception. +                log.warning("Received response with both Content-Length and " +                            "Transfer-Encoding set. This is expressly forbidden " +                            "by RFC 7230 sec 3.3.2. Ignoring Content-Length and " +                            "attempting to process response as Transfer-Encoding: " +                            "chunked.") +                return None + +            try: +                # RFC 7230 section 3.3.2 specifies multiple content lengths can +                # be sent in a single Content-Length header +                # (e.g. Content-Length: 42, 42). This line ensures the values +                # are all valid ints and that as long as the `set` length is 1, +                # all values are the same. Otherwise, the header is invalid. +                lengths = set([int(val) for val in length.split(',')]) +                if len(lengths) > 1: +                    raise InvalidHeader("Content-Length contained multiple " +                                        "unmatching values (%s)" % length) +                length = lengths.pop() +            except ValueError: +                length = None +            else: +                if length < 0: +                    length = None + +        # Convert status to int for comparison +        # In some cases, httplib returns a status of "_UNKNOWN" +        try: +            status = int(self.status) +        except ValueError: +            status = 0 + +        # Check for responses that shouldn't include a body +        if status in (204, 304) or 100 <= status < 200 or request_method == 'HEAD': +            length = 0 + +        return length + +    def _init_decoder(self): +        """ +        Set-up the _decoder attribute if necessary. +        """ +        # Note: content-encoding value should be case-insensitive, per RFC 7230 +        # Section 3.2 +        content_encoding = self.headers.get('content-encoding', '').lower() +        if self._decoder is None: +            if content_encoding in self.CONTENT_DECODERS: +                self._decoder = _get_decoder(content_encoding) +            elif ',' in content_encoding: +                encodings = [e.strip() for e in content_encoding.split(',') if e.strip() in self.CONTENT_DECODERS] +                if len(encodings): +                    self._decoder = _get_decoder(content_encoding) + +    def _decode(self, data, decode_content, flush_decoder): +        """ +        Decode the data passed in and potentially flush the decoder. +        """ +        try: +            if decode_content and self._decoder: +                data = self._decoder.decompress(data) +        except (IOError, zlib.error) as e: +            content_encoding = self.headers.get('content-encoding', '').lower() +            raise DecodeError( +                "Received response with content-encoding: %s, but " +                "failed to decode it." % content_encoding, e) + +        if flush_decoder and decode_content: +            data += self._flush_decoder() + +        return data + +    def _flush_decoder(self): +        """ +        Flushes the decoder. Should only be called if the decoder is actually +        being used. +        """ +        if self._decoder: +            buf = self._decoder.decompress(b'') +            return buf + self._decoder.flush() + +        return b'' + +    @contextmanager +    def _error_catcher(self): +        """ +        Catch low-level python exceptions, instead re-raising urllib3 +        variants, so that low-level exceptions are not leaked in the +        high-level api. + +        On exit, release the connection back to the pool. +        """ +        clean_exit = False + +        try: +            try: +                yield + +            except SocketTimeout: +                # FIXME: Ideally we'd like to include the url in the ReadTimeoutError but +                # there is yet no clean way to get at it from this context. +                raise ReadTimeoutError(self._pool, None, 'Read timed out.') + +            except BaseSSLError as e: +                # FIXME: Is there a better way to differentiate between SSLErrors? +                if 'read operation timed out' not in str(e):  # Defensive: +                    # This shouldn't happen but just in case we're missing an edge +                    # case, let's avoid swallowing SSL errors. +                    raise + +                raise ReadTimeoutError(self._pool, None, 'Read timed out.') + +            except (HTTPException, SocketError) as e: +                # This includes IncompleteRead. +                raise ProtocolError('Connection broken: %r' % e, e) + +            # If no exception is thrown, we should avoid cleaning up +            # unnecessarily. +            clean_exit = True +        finally: +            # If we didn't terminate cleanly, we need to throw away our +            # connection. +            if not clean_exit: +                # The response may not be closed but we're not going to use it +                # anymore so close it now to ensure that the connection is +                # released back to the pool. +                if self._original_response: +                    self._original_response.close() + +                # Closing the response may not actually be sufficient to close +                # everything, so if we have a hold of the connection close that +                # too. +                if self._connection: +                    self._connection.close() + +            # If we hold the original response but it's closed now, we should +            # return the connection back to the pool. +            if self._original_response and self._original_response.isclosed(): +                self.release_conn() + +    def read(self, amt=None, decode_content=None, cache_content=False): +        """ +        Similar to :meth:`httplib.HTTPResponse.read`, but with two additional +        parameters: ``decode_content`` and ``cache_content``. + +        :param amt: +            How much of the content to read. If specified, caching is skipped +            because it doesn't make sense to cache partial content as the full +            response. + +        :param decode_content: +            If True, will attempt to decode the body based on the +            'content-encoding' header. + +        :param cache_content: +            If True, will save the returned data such that the same result is +            returned despite of the state of the underlying file object. This +            is useful if you want the ``.data`` property to continue working +            after having ``.read()`` the file object. (Overridden if ``amt`` is +            set.) +        """ +        self._init_decoder() +        if decode_content is None: +            decode_content = self.decode_content + +        if self._fp is None: +            return + +        flush_decoder = False +        data = None + +        with self._error_catcher(): +            if amt is None: +                # cStringIO doesn't like amt=None +                data = self._fp.read() +                flush_decoder = True +            else: +                cache_content = False +                data = self._fp.read(amt) +                if amt != 0 and not data:  # Platform-specific: Buggy versions of Python. +                    # Close the connection when no data is returned +                    # +                    # This is redundant to what httplib/http.client _should_ +                    # already do.  However, versions of python released before +                    # December 15, 2012 (http://bugs.python.org/issue16298) do +                    # not properly close the connection in all cases. There is +                    # no harm in redundantly calling close. +                    self._fp.close() +                    flush_decoder = True +                    if self.enforce_content_length and self.length_remaining not in (0, None): +                        # This is an edge case that httplib failed to cover due +                        # to concerns of backward compatibility. We're +                        # addressing it here to make sure IncompleteRead is +                        # raised during streaming, so all calls with incorrect +                        # Content-Length are caught. +                        raise IncompleteRead(self._fp_bytes_read, self.length_remaining) + +        if data: +            self._fp_bytes_read += len(data) +            if self.length_remaining is not None: +                self.length_remaining -= len(data) + +            data = self._decode(data, decode_content, flush_decoder) + +            if cache_content: +                self._body = data + +        return data + +    def stream(self, amt=2**16, decode_content=None): +        """ +        A generator wrapper for the read() method. A call will block until +        ``amt`` bytes have been read from the connection or until the +        connection is closed. + +        :param amt: +            How much of the content to read. The generator will return up to +            much data per iteration, but may return less. This is particularly +            likely when using compressed data. However, the empty string will +            never be returned. + +        :param decode_content: +            If True, will attempt to decode the body based on the +            'content-encoding' header. +        """ +        if self.chunked and self.supports_chunked_reads(): +            for line in self.read_chunked(amt, decode_content=decode_content): +                yield line +        else: +            while not is_fp_closed(self._fp): +                data = self.read(amt=amt, decode_content=decode_content) + +                if data: +                    yield data + +    @classmethod +    def from_httplib(ResponseCls, r, **response_kw): +        """ +        Given an :class:`httplib.HTTPResponse` instance ``r``, return a +        corresponding :class:`urllib3.response.HTTPResponse` object. + +        Remaining parameters are passed to the HTTPResponse constructor, along +        with ``original_response=r``. +        """ +        headers = r.msg + +        if not isinstance(headers, HTTPHeaderDict): +            if PY3:  # Python 3 +                headers = HTTPHeaderDict(headers.items()) +            else:  # Python 2 +                headers = HTTPHeaderDict.from_httplib(headers) + +        # HTTPResponse objects in Python 3 don't have a .strict attribute +        strict = getattr(r, 'strict', 0) +        resp = ResponseCls(body=r, +                           headers=headers, +                           status=r.status, +                           version=r.version, +                           reason=r.reason, +                           strict=strict, +                           original_response=r, +                           **response_kw) +        return resp + +    # Backwards-compatibility methods for httplib.HTTPResponse +    def getheaders(self): +        return self.headers + +    def getheader(self, name, default=None): +        return self.headers.get(name, default) + +    # Backwards compatibility for http.cookiejar +    def info(self): +        return self.headers + +    # Overrides from io.IOBase +    def close(self): +        if not self.closed: +            self._fp.close() + +        if self._connection: +            self._connection.close() + +    @property +    def closed(self): +        if self._fp is None: +            return True +        elif hasattr(self._fp, 'isclosed'): +            return self._fp.isclosed() +        elif hasattr(self._fp, 'closed'): +            return self._fp.closed +        else: +            return True + +    def fileno(self): +        if self._fp is None: +            raise IOError("HTTPResponse has no file to get a fileno from") +        elif hasattr(self._fp, "fileno"): +            return self._fp.fileno() +        else: +            raise IOError("The file-like object this HTTPResponse is wrapped " +                          "around has no file descriptor") + +    def flush(self): +        if self._fp is not None and hasattr(self._fp, 'flush'): +            return self._fp.flush() + +    def readable(self): +        # This method is required for `io` module compatibility. +        return True + +    def readinto(self, b): +        # This method is required for `io` module compatibility. +        temp = self.read(len(b)) +        if len(temp) == 0: +            return 0 +        else: +            b[:len(temp)] = temp +            return len(temp) + +    def supports_chunked_reads(self): +        """ +        Checks if the underlying file-like object looks like a +        httplib.HTTPResponse object. We do this by testing for the fp +        attribute. If it is present we assume it returns raw chunks as +        processed by read_chunked(). +        """ +        return hasattr(self._fp, 'fp') + +    def _update_chunk_length(self): +        # First, we'll figure out length of a chunk and then +        # we'll try to read it from socket. +        if self.chunk_left is not None: +            return +        line = self._fp.fp.readline() +        line = line.split(b';', 1)[0] +        try: +            self.chunk_left = int(line, 16) +        except ValueError: +            # Invalid chunked protocol response, abort. +            self.close() +            raise httplib.IncompleteRead(line) + +    def _handle_chunk(self, amt): +        returned_chunk = None +        if amt is None: +            chunk = self._fp._safe_read(self.chunk_left) +            returned_chunk = chunk +            self._fp._safe_read(2)  # Toss the CRLF at the end of the chunk. +            self.chunk_left = None +        elif amt < self.chunk_left: +            value = self._fp._safe_read(amt) +            self.chunk_left = self.chunk_left - amt +            returned_chunk = value +        elif amt == self.chunk_left: +            value = self._fp._safe_read(amt) +            self._fp._safe_read(2)  # Toss the CRLF at the end of the chunk. +            self.chunk_left = None +            returned_chunk = value +        else:  # amt > self.chunk_left +            returned_chunk = self._fp._safe_read(self.chunk_left) +            self._fp._safe_read(2)  # Toss the CRLF at the end of the chunk. +            self.chunk_left = None +        return returned_chunk + +    def read_chunked(self, amt=None, decode_content=None): +        """ +        Similar to :meth:`HTTPResponse.read`, but with an additional +        parameter: ``decode_content``. + +        :param amt: +            How much of the content to read. If specified, caching is skipped +            because it doesn't make sense to cache partial content as the full +            response. + +        :param decode_content: +            If True, will attempt to decode the body based on the +            'content-encoding' header. +        """ +        self._init_decoder() +        # FIXME: Rewrite this method and make it a class with a better structured logic. +        if not self.chunked: +            raise ResponseNotChunked( +                "Response is not chunked. " +                "Header 'transfer-encoding: chunked' is missing.") +        if not self.supports_chunked_reads(): +            raise BodyNotHttplibCompatible( +                "Body should be httplib.HTTPResponse like. " +                "It should have have an fp attribute which returns raw chunks.") + +        with self._error_catcher(): +            # Don't bother reading the body of a HEAD request. +            if self._original_response and is_response_to_head(self._original_response): +                self._original_response.close() +                return + +            # If a response is already read and closed +            # then return immediately. +            if self._fp.fp is None: +                return + +            while True: +                self._update_chunk_length() +                if self.chunk_left == 0: +                    break +                chunk = self._handle_chunk(amt) +                decoded = self._decode(chunk, decode_content=decode_content, +                                       flush_decoder=False) +                if decoded: +                    yield decoded + +            if decode_content: +                # On CPython and PyPy, we should never need to flush the +                # decoder. However, on Jython we *might* need to, so +                # lets defensively do it anyway. +                decoded = self._flush_decoder() +                if decoded:  # Platform-specific: Jython. +                    yield decoded + +            # Chunk content ends with \r\n: discard it. +            while True: +                line = self._fp.fp.readline() +                if not line: +                    # Some sites may not end with '\r\n'. +                    break +                if line == b'\r\n': +                    break + +            # We read everything; close the "file". +            if self._original_response: +                self._original_response.close() + +    def geturl(self): +        """ +        Returns the URL that was the source of this response. +        If the request that generated this response redirected, this method +        will return the final redirect location. +        """ +        if self.retries is not None and len(self.retries.history): +            return self.retries.history[-1].redirect_location +        else: +            return self._request_url diff --git a/python/urllib3/util/__init__.py b/python/urllib3/util/__init__.py new file mode 100644 index 0000000..2f2770b --- /dev/null +++ b/python/urllib3/util/__init__.py @@ -0,0 +1,54 @@ +from __future__ import absolute_import +# For backwards compatibility, provide imports that used to be here. +from .connection import is_connection_dropped +from .request import make_headers +from .response import is_fp_closed +from .ssl_ import ( +    SSLContext, +    HAS_SNI, +    IS_PYOPENSSL, +    IS_SECURETRANSPORT, +    assert_fingerprint, +    resolve_cert_reqs, +    resolve_ssl_version, +    ssl_wrap_socket, +) +from .timeout import ( +    current_time, +    Timeout, +) + +from .retry import Retry +from .url import ( +    get_host, +    parse_url, +    split_first, +    Url, +) +from .wait import ( +    wait_for_read, +    wait_for_write +) + +__all__ = ( +    'HAS_SNI', +    'IS_PYOPENSSL', +    'IS_SECURETRANSPORT', +    'SSLContext', +    'Retry', +    'Timeout', +    'Url', +    'assert_fingerprint', +    'current_time', +    'is_connection_dropped', +    'is_fp_closed', +    'get_host', +    'parse_url', +    'make_headers', +    'resolve_cert_reqs', +    'resolve_ssl_version', +    'split_first', +    'ssl_wrap_socket', +    'wait_for_read', +    'wait_for_write' +) diff --git a/python/urllib3/util/connection.py b/python/urllib3/util/connection.py new file mode 100644 index 0000000..5ad70b2 --- /dev/null +++ b/python/urllib3/util/connection.py @@ -0,0 +1,134 @@ +from __future__ import absolute_import +import socket +from .wait import NoWayToWaitForSocketError, wait_for_read +from ..contrib import _appengine_environ + + +def is_connection_dropped(conn):  # Platform-specific +    """ +    Returns True if the connection is dropped and should be closed. + +    :param conn: +        :class:`httplib.HTTPConnection` object. + +    Note: For platforms like AppEngine, this will always return ``False`` to +    let the platform handle connection recycling transparently for us. +    """ +    sock = getattr(conn, 'sock', False) +    if sock is False:  # Platform-specific: AppEngine +        return False +    if sock is None:  # Connection already closed (such as by httplib). +        return True +    try: +        # Returns True if readable, which here means it's been dropped +        return wait_for_read(sock, timeout=0.0) +    except NoWayToWaitForSocketError:  # Platform-specific: AppEngine +        return False + + +# This function is copied from socket.py in the Python 2.7 standard +# library test suite. Added to its signature is only `socket_options`. +# One additional modification is that we avoid binding to IPv6 servers +# discovered in DNS if the system doesn't have IPv6 functionality. +def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, +                      source_address=None, socket_options=None): +    """Connect to *address* and return the socket object. + +    Convenience function.  Connect to *address* (a 2-tuple ``(host, +    port)``) and return the socket object.  Passing the optional +    *timeout* parameter will set the timeout on the socket instance +    before attempting to connect.  If no *timeout* is supplied, the +    global default timeout setting returned by :func:`getdefaulttimeout` +    is used.  If *source_address* is set it must be a tuple of (host, port) +    for the socket to bind as a source address before making the connection. +    An host of '' or port 0 tells the OS to use the default. +    """ + +    host, port = address +    if host.startswith('['): +        host = host.strip('[]') +    err = None + +    # Using the value from allowed_gai_family() in the context of getaddrinfo lets +    # us select whether to work with IPv4 DNS records, IPv6 records, or both. +    # The original create_connection function always returns all records. +    family = allowed_gai_family() + +    for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): +        af, socktype, proto, canonname, sa = res +        sock = None +        try: +            sock = socket.socket(af, socktype, proto) + +            # If provided, set socket level options before connecting. +            _set_socket_options(sock, socket_options) + +            if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: +                sock.settimeout(timeout) +            if source_address: +                sock.bind(source_address) +            sock.connect(sa) +            return sock + +        except socket.error as e: +            err = e +            if sock is not None: +                sock.close() +                sock = None + +    if err is not None: +        raise err + +    raise socket.error("getaddrinfo returns an empty list") + + +def _set_socket_options(sock, options): +    if options is None: +        return + +    for opt in options: +        sock.setsockopt(*opt) + + +def allowed_gai_family(): +    """This function is designed to work in the context of +    getaddrinfo, where family=socket.AF_UNSPEC is the default and +    will perform a DNS search for both IPv6 and IPv4 records.""" + +    family = socket.AF_INET +    if HAS_IPV6: +        family = socket.AF_UNSPEC +    return family + + +def _has_ipv6(host): +    """ Returns True if the system can bind an IPv6 address. """ +    sock = None +    has_ipv6 = False + +    # App Engine doesn't support IPV6 sockets and actually has a quota on the +    # number of sockets that can be used, so just early out here instead of +    # creating a socket needlessly. +    # See https://github.com/urllib3/urllib3/issues/1446 +    if _appengine_environ.is_appengine_sandbox(): +        return False + +    if socket.has_ipv6: +        # has_ipv6 returns true if cPython was compiled with IPv6 support. +        # It does not tell us if the system has IPv6 support enabled. To +        # determine that we must bind to an IPv6 address. +        # https://github.com/shazow/urllib3/pull/611 +        # https://bugs.python.org/issue658327 +        try: +            sock = socket.socket(socket.AF_INET6) +            sock.bind((host, 0)) +            has_ipv6 = True +        except Exception: +            pass + +    if sock: +        sock.close() +    return has_ipv6 + + +HAS_IPV6 = _has_ipv6('::1') diff --git a/python/urllib3/util/queue.py b/python/urllib3/util/queue.py new file mode 100644 index 0000000..d3d379a --- /dev/null +++ b/python/urllib3/util/queue.py @@ -0,0 +1,21 @@ +import collections +from ..packages import six +from ..packages.six.moves import queue + +if six.PY2: +    # Queue is imported for side effects on MS Windows. See issue #229. +    import Queue as _unused_module_Queue  # noqa: F401 + + +class LifoQueue(queue.Queue): +    def _init(self, _): +        self.queue = collections.deque() + +    def _qsize(self, len=len): +        return len(self.queue) + +    def _put(self, item): +        self.queue.append(item) + +    def _get(self): +        return self.queue.pop() diff --git a/python/urllib3/util/request.py b/python/urllib3/util/request.py new file mode 100644 index 0000000..3ddfcd5 --- /dev/null +++ b/python/urllib3/util/request.py @@ -0,0 +1,118 @@ +from __future__ import absolute_import +from base64 import b64encode + +from ..packages.six import b, integer_types +from ..exceptions import UnrewindableBodyError + +ACCEPT_ENCODING = 'gzip,deflate' +_FAILEDTELL = object() + + +def make_headers(keep_alive=None, accept_encoding=None, user_agent=None, +                 basic_auth=None, proxy_basic_auth=None, disable_cache=None): +    """ +    Shortcuts for generating request headers. + +    :param keep_alive: +        If ``True``, adds 'connection: keep-alive' header. + +    :param accept_encoding: +        Can be a boolean, list, or string. +        ``True`` translates to 'gzip,deflate'. +        List will get joined by comma. +        String will be used as provided. + +    :param user_agent: +        String representing the user-agent you want, such as +        "python-urllib3/0.6" + +    :param basic_auth: +        Colon-separated username:password string for 'authorization: basic ...' +        auth header. + +    :param proxy_basic_auth: +        Colon-separated username:password string for 'proxy-authorization: basic ...' +        auth header. + +    :param disable_cache: +        If ``True``, adds 'cache-control: no-cache' header. + +    Example:: + +        >>> make_headers(keep_alive=True, user_agent="Batman/1.0") +        {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'} +        >>> make_headers(accept_encoding=True) +        {'accept-encoding': 'gzip,deflate'} +    """ +    headers = {} +    if accept_encoding: +        if isinstance(accept_encoding, str): +            pass +        elif isinstance(accept_encoding, list): +            accept_encoding = ','.join(accept_encoding) +        else: +            accept_encoding = ACCEPT_ENCODING +        headers['accept-encoding'] = accept_encoding + +    if user_agent: +        headers['user-agent'] = user_agent + +    if keep_alive: +        headers['connection'] = 'keep-alive' + +    if basic_auth: +        headers['authorization'] = 'Basic ' + \ +            b64encode(b(basic_auth)).decode('utf-8') + +    if proxy_basic_auth: +        headers['proxy-authorization'] = 'Basic ' + \ +            b64encode(b(proxy_basic_auth)).decode('utf-8') + +    if disable_cache: +        headers['cache-control'] = 'no-cache' + +    return headers + + +def set_file_position(body, pos): +    """ +    If a position is provided, move file to that point. +    Otherwise, we'll attempt to record a position for future use. +    """ +    if pos is not None: +        rewind_body(body, pos) +    elif getattr(body, 'tell', None) is not None: +        try: +            pos = body.tell() +        except (IOError, OSError): +            # This differentiates from None, allowing us to catch +            # a failed `tell()` later when trying to rewind the body. +            pos = _FAILEDTELL + +    return pos + + +def rewind_body(body, body_pos): +    """ +    Attempt to rewind body to a certain position. +    Primarily used for request redirects and retries. + +    :param body: +        File-like object that supports seek. + +    :param int pos: +        Position to seek to in file. +    """ +    body_seek = getattr(body, 'seek', None) +    if body_seek is not None and isinstance(body_pos, integer_types): +        try: +            body_seek(body_pos) +        except (IOError, OSError): +            raise UnrewindableBodyError("An error occurred when rewinding request " +                                        "body for redirect/retry.") +    elif body_pos is _FAILEDTELL: +        raise UnrewindableBodyError("Unable to record file position for rewinding " +                                    "request body during a redirect/retry.") +    else: +        raise ValueError("body_pos must be of type integer, " +                         "instead it was %s." % type(body_pos)) diff --git a/python/urllib3/util/response.py b/python/urllib3/util/response.py new file mode 100644 index 0000000..3d54864 --- /dev/null +++ b/python/urllib3/util/response.py @@ -0,0 +1,87 @@ +from __future__ import absolute_import +from ..packages.six.moves import http_client as httplib + +from ..exceptions import HeaderParsingError + + +def is_fp_closed(obj): +    """ +    Checks whether a given file-like object is closed. + +    :param obj: +        The file-like object to check. +    """ + +    try: +        # Check `isclosed()` first, in case Python3 doesn't set `closed`. +        # GH Issue #928 +        return obj.isclosed() +    except AttributeError: +        pass + +    try: +        # Check via the official file-like-object way. +        return obj.closed +    except AttributeError: +        pass + +    try: +        # Check if the object is a container for another file-like object that +        # gets released on exhaustion (e.g. HTTPResponse). +        return obj.fp is None +    except AttributeError: +        pass + +    raise ValueError("Unable to determine whether fp is closed.") + + +def assert_header_parsing(headers): +    """ +    Asserts whether all headers have been successfully parsed. +    Extracts encountered errors from the result of parsing headers. + +    Only works on Python 3. + +    :param headers: Headers to verify. +    :type headers: `httplib.HTTPMessage`. + +    :raises urllib3.exceptions.HeaderParsingError: +        If parsing errors are found. +    """ + +    # This will fail silently if we pass in the wrong kind of parameter. +    # To make debugging easier add an explicit check. +    if not isinstance(headers, httplib.HTTPMessage): +        raise TypeError('expected httplib.Message, got {0}.'.format( +            type(headers))) + +    defects = getattr(headers, 'defects', None) +    get_payload = getattr(headers, 'get_payload', None) + +    unparsed_data = None +    if get_payload: +        # get_payload is actually email.message.Message.get_payload; +        # we're only interested in the result if it's not a multipart message +        if not headers.is_multipart(): +            payload = get_payload() + +            if isinstance(payload, (bytes, str)): +                unparsed_data = payload + +    if defects or unparsed_data: +        raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data) + + +def is_response_to_head(response): +    """ +    Checks whether the request of a response has been a HEAD-request. +    Handles the quirks of AppEngine. + +    :param conn: +    :type conn: :class:`httplib.HTTPResponse` +    """ +    # FIXME: Can we do this somehow without accessing private httplib _method? +    method = response._method +    if isinstance(method, int):  # Platform-specific: Appengine +        return method == 3 +    return method.upper() == 'HEAD' diff --git a/python/urllib3/util/retry.py b/python/urllib3/util/retry.py new file mode 100644 index 0000000..e7d0abd --- /dev/null +++ b/python/urllib3/util/retry.py @@ -0,0 +1,411 @@ +from __future__ import absolute_import +import time +import logging +from collections import namedtuple +from itertools import takewhile +import email +import re + +from ..exceptions import ( +    ConnectTimeoutError, +    MaxRetryError, +    ProtocolError, +    ReadTimeoutError, +    ResponseError, +    InvalidHeader, +) +from ..packages import six + + +log = logging.getLogger(__name__) + + +# Data structure for representing the metadata of requests that result in a retry. +RequestHistory = namedtuple('RequestHistory', ["method", "url", "error", +                                               "status", "redirect_location"]) + + +class Retry(object): +    """ Retry configuration. + +    Each retry attempt will create a new Retry object with updated values, so +    they can be safely reused. + +    Retries can be defined as a default for a pool:: + +        retries = Retry(connect=5, read=2, redirect=5) +        http = PoolManager(retries=retries) +        response = http.request('GET', 'http://example.com/') + +    Or per-request (which overrides the default for the pool):: + +        response = http.request('GET', 'http://example.com/', retries=Retry(10)) + +    Retries can be disabled by passing ``False``:: + +        response = http.request('GET', 'http://example.com/', retries=False) + +    Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless +    retries are disabled, in which case the causing exception will be raised. + +    :param int total: +        Total number of retries to allow. Takes precedence over other counts. + +        Set to ``None`` to remove this constraint and fall back on other +        counts. It's a good idea to set this to some sensibly-high value to +        account for unexpected edge cases and avoid infinite retry loops. + +        Set to ``0`` to fail on the first retry. + +        Set to ``False`` to disable and imply ``raise_on_redirect=False``. + +    :param int connect: +        How many connection-related errors to retry on. + +        These are errors raised before the request is sent to the remote server, +        which we assume has not triggered the server to process the request. + +        Set to ``0`` to fail on the first retry of this type. + +    :param int read: +        How many times to retry on read errors. + +        These errors are raised after the request was sent to the server, so the +        request may have side-effects. + +        Set to ``0`` to fail on the first retry of this type. + +    :param int redirect: +        How many redirects to perform. Limit this to avoid infinite redirect +        loops. + +        A redirect is a HTTP response with a status code 301, 302, 303, 307 or +        308. + +        Set to ``0`` to fail on the first retry of this type. + +        Set to ``False`` to disable and imply ``raise_on_redirect=False``. + +    :param int status: +        How many times to retry on bad status codes. + +        These are retries made on responses, where status code matches +        ``status_forcelist``. + +        Set to ``0`` to fail on the first retry of this type. + +    :param iterable method_whitelist: +        Set of uppercased HTTP method verbs that we should retry on. + +        By default, we only retry on methods which are considered to be +        idempotent (multiple requests with the same parameters end with the +        same state). See :attr:`Retry.DEFAULT_METHOD_WHITELIST`. + +        Set to a ``False`` value to retry on any verb. + +    :param iterable status_forcelist: +        A set of integer HTTP status codes that we should force a retry on. +        A retry is initiated if the request method is in ``method_whitelist`` +        and the response status code is in ``status_forcelist``. + +        By default, this is disabled with ``None``. + +    :param float backoff_factor: +        A backoff factor to apply between attempts after the second try +        (most errors are resolved immediately by a second try without a +        delay). urllib3 will sleep for:: + +            {backoff factor} * (2 ** ({number of total retries} - 1)) + +        seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep +        for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer +        than :attr:`Retry.BACKOFF_MAX`. + +        By default, backoff is disabled (set to 0). + +    :param bool raise_on_redirect: Whether, if the number of redirects is +        exhausted, to raise a MaxRetryError, or to return a response with a +        response code in the 3xx range. + +    :param bool raise_on_status: Similar meaning to ``raise_on_redirect``: +        whether we should raise an exception, or return a response, +        if status falls in ``status_forcelist`` range and retries have +        been exhausted. + +    :param tuple history: The history of the request encountered during +        each call to :meth:`~Retry.increment`. The list is in the order +        the requests occurred. Each list item is of class :class:`RequestHistory`. + +    :param bool respect_retry_after_header: +        Whether to respect Retry-After header on status codes defined as +        :attr:`Retry.RETRY_AFTER_STATUS_CODES` or not. + +    :param iterable remove_headers_on_redirect: +        Sequence of headers to remove from the request when a response +        indicating a redirect is returned before firing off the redirected +        request. +    """ + +    DEFAULT_METHOD_WHITELIST = frozenset([ +        'HEAD', 'GET', 'PUT', 'DELETE', 'OPTIONS', 'TRACE']) + +    RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503]) + +    DEFAULT_REDIRECT_HEADERS_BLACKLIST = frozenset(['Authorization']) + +    #: Maximum backoff time. +    BACKOFF_MAX = 120 + +    def __init__(self, total=10, connect=None, read=None, redirect=None, status=None, +                 method_whitelist=DEFAULT_METHOD_WHITELIST, status_forcelist=None, +                 backoff_factor=0, raise_on_redirect=True, raise_on_status=True, +                 history=None, respect_retry_after_header=True, +                 remove_headers_on_redirect=DEFAULT_REDIRECT_HEADERS_BLACKLIST): + +        self.total = total +        self.connect = connect +        self.read = read +        self.status = status + +        if redirect is False or total is False: +            redirect = 0 +            raise_on_redirect = False + +        self.redirect = redirect +        self.status_forcelist = status_forcelist or set() +        self.method_whitelist = method_whitelist +        self.backoff_factor = backoff_factor +        self.raise_on_redirect = raise_on_redirect +        self.raise_on_status = raise_on_status +        self.history = history or tuple() +        self.respect_retry_after_header = respect_retry_after_header +        self.remove_headers_on_redirect = remove_headers_on_redirect + +    def new(self, **kw): +        params = dict( +            total=self.total, +            connect=self.connect, read=self.read, redirect=self.redirect, status=self.status, +            method_whitelist=self.method_whitelist, +            status_forcelist=self.status_forcelist, +            backoff_factor=self.backoff_factor, +            raise_on_redirect=self.raise_on_redirect, +            raise_on_status=self.raise_on_status, +            history=self.history, +            remove_headers_on_redirect=self.remove_headers_on_redirect +        ) +        params.update(kw) +        return type(self)(**params) + +    @classmethod +    def from_int(cls, retries, redirect=True, default=None): +        """ Backwards-compatibility for the old retries format.""" +        if retries is None: +            retries = default if default is not None else cls.DEFAULT + +        if isinstance(retries, Retry): +            return retries + +        redirect = bool(redirect) and None +        new_retries = cls(retries, redirect=redirect) +        log.debug("Converted retries value: %r -> %r", retries, new_retries) +        return new_retries + +    def get_backoff_time(self): +        """ Formula for computing the current backoff + +        :rtype: float +        """ +        # We want to consider only the last consecutive errors sequence (Ignore redirects). +        consecutive_errors_len = len(list(takewhile(lambda x: x.redirect_location is None, +                                                    reversed(self.history)))) +        if consecutive_errors_len <= 1: +            return 0 + +        backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1)) +        return min(self.BACKOFF_MAX, backoff_value) + +    def parse_retry_after(self, retry_after): +        # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 +        if re.match(r"^\s*[0-9]+\s*$", retry_after): +            seconds = int(retry_after) +        else: +            retry_date_tuple = email.utils.parsedate(retry_after) +            if retry_date_tuple is None: +                raise InvalidHeader("Invalid Retry-After header: %s" % retry_after) +            retry_date = time.mktime(retry_date_tuple) +            seconds = retry_date - time.time() + +        if seconds < 0: +            seconds = 0 + +        return seconds + +    def get_retry_after(self, response): +        """ Get the value of Retry-After in seconds. """ + +        retry_after = response.getheader("Retry-After") + +        if retry_after is None: +            return None + +        return self.parse_retry_after(retry_after) + +    def sleep_for_retry(self, response=None): +        retry_after = self.get_retry_after(response) +        if retry_after: +            time.sleep(retry_after) +            return True + +        return False + +    def _sleep_backoff(self): +        backoff = self.get_backoff_time() +        if backoff <= 0: +            return +        time.sleep(backoff) + +    def sleep(self, response=None): +        """ Sleep between retry attempts. + +        This method will respect a server's ``Retry-After`` response header +        and sleep the duration of the time requested. If that is not present, it +        will use an exponential backoff. By default, the backoff factor is 0 and +        this method will return immediately. +        """ + +        if response: +            slept = self.sleep_for_retry(response) +            if slept: +                return + +        self._sleep_backoff() + +    def _is_connection_error(self, err): +        """ Errors when we're fairly sure that the server did not receive the +        request, so it should be safe to retry. +        """ +        return isinstance(err, ConnectTimeoutError) + +    def _is_read_error(self, err): +        """ Errors that occur after the request has been started, so we should +        assume that the server began processing it. +        """ +        return isinstance(err, (ReadTimeoutError, ProtocolError)) + +    def _is_method_retryable(self, method): +        """ Checks if a given HTTP method should be retried upon, depending if +        it is included on the method whitelist. +        """ +        if self.method_whitelist and method.upper() not in self.method_whitelist: +            return False + +        return True + +    def is_retry(self, method, status_code, has_retry_after=False): +        """ Is this method/status code retryable? (Based on whitelists and control +        variables such as the number of total retries to allow, whether to +        respect the Retry-After header, whether this header is present, and +        whether the returned status code is on the list of status codes to +        be retried upon on the presence of the aforementioned header) +        """ +        if not self._is_method_retryable(method): +            return False + +        if self.status_forcelist and status_code in self.status_forcelist: +            return True + +        return (self.total and self.respect_retry_after_header and +                has_retry_after and (status_code in self.RETRY_AFTER_STATUS_CODES)) + +    def is_exhausted(self): +        """ Are we out of retries? """ +        retry_counts = (self.total, self.connect, self.read, self.redirect, self.status) +        retry_counts = list(filter(None, retry_counts)) +        if not retry_counts: +            return False + +        return min(retry_counts) < 0 + +    def increment(self, method=None, url=None, response=None, error=None, +                  _pool=None, _stacktrace=None): +        """ Return a new Retry object with incremented retry counters. + +        :param response: A response object, or None, if the server did not +            return a response. +        :type response: :class:`~urllib3.response.HTTPResponse` +        :param Exception error: An error encountered during the request, or +            None if the response was received successfully. + +        :return: A new ``Retry`` object. +        """ +        if self.total is False and error: +            # Disabled, indicate to re-raise the error. +            raise six.reraise(type(error), error, _stacktrace) + +        total = self.total +        if total is not None: +            total -= 1 + +        connect = self.connect +        read = self.read +        redirect = self.redirect +        status_count = self.status +        cause = 'unknown' +        status = None +        redirect_location = None + +        if error and self._is_connection_error(error): +            # Connect retry? +            if connect is False: +                raise six.reraise(type(error), error, _stacktrace) +            elif connect is not None: +                connect -= 1 + +        elif error and self._is_read_error(error): +            # Read retry? +            if read is False or not self._is_method_retryable(method): +                raise six.reraise(type(error), error, _stacktrace) +            elif read is not None: +                read -= 1 + +        elif response and response.get_redirect_location(): +            # Redirect retry? +            if redirect is not None: +                redirect -= 1 +            cause = 'too many redirects' +            redirect_location = response.get_redirect_location() +            status = response.status + +        else: +            # Incrementing because of a server error like a 500 in +            # status_forcelist and a the given method is in the whitelist +            cause = ResponseError.GENERIC_ERROR +            if response and response.status: +                if status_count is not None: +                    status_count -= 1 +                cause = ResponseError.SPECIFIC_ERROR.format( +                    status_code=response.status) +                status = response.status + +        history = self.history + (RequestHistory(method, url, error, status, redirect_location),) + +        new_retry = self.new( +            total=total, +            connect=connect, read=read, redirect=redirect, status=status_count, +            history=history) + +        if new_retry.is_exhausted(): +            raise MaxRetryError(_pool, url, error or ResponseError(cause)) + +        log.debug("Incremented Retry for (url='%s'): %r", url, new_retry) + +        return new_retry + +    def __repr__(self): +        return ('{cls.__name__}(total={self.total}, connect={self.connect}, ' +                'read={self.read}, redirect={self.redirect}, status={self.status})').format( +                    cls=type(self), self=self) + + +# For backwards compatibility (equivalent to pre-v1.9): +Retry.DEFAULT = Retry(3) diff --git a/python/urllib3/util/ssl_.py b/python/urllib3/util/ssl_.py new file mode 100644 index 0000000..64ea192 --- /dev/null +++ b/python/urllib3/util/ssl_.py @@ -0,0 +1,381 @@ +from __future__ import absolute_import +import errno +import warnings +import hmac +import socket + +from binascii import hexlify, unhexlify +from hashlib import md5, sha1, sha256 + +from ..exceptions import SSLError, InsecurePlatformWarning, SNIMissingWarning +from ..packages import six + + +SSLContext = None +HAS_SNI = False +IS_PYOPENSSL = False +IS_SECURETRANSPORT = False + +# Maps the length of a digest to a possible hash function producing this digest +HASHFUNC_MAP = { +    32: md5, +    40: sha1, +    64: sha256, +} + + +def _const_compare_digest_backport(a, b): +    """ +    Compare two digests of equal length in constant time. + +    The digests must be of type str/bytes. +    Returns True if the digests match, and False otherwise. +    """ +    result = abs(len(a) - len(b)) +    for l, r in zip(bytearray(a), bytearray(b)): +        result |= l ^ r +    return result == 0 + + +_const_compare_digest = getattr(hmac, 'compare_digest', +                                _const_compare_digest_backport) + + +try:  # Test for SSL features +    import ssl +    from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23 +    from ssl import HAS_SNI  # Has SNI? +except ImportError: +    pass + + +try: +    from ssl import OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION +except ImportError: +    OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000 +    OP_NO_COMPRESSION = 0x20000 + + +# Python 2.7 doesn't have inet_pton on non-Linux so we fallback on inet_aton in +# those cases. This means that we can only detect IPv4 addresses in this case. +if hasattr(socket, 'inet_pton'): +    inet_pton = socket.inet_pton +else: +    # Maybe we can use ipaddress if the user has urllib3[secure]? +    try: +        import ipaddress + +        def inet_pton(_, host): +            if isinstance(host, bytes): +                host = host.decode('ascii') +            return ipaddress.ip_address(host) + +    except ImportError:  # Platform-specific: Non-Linux +        def inet_pton(_, host): +            return socket.inet_aton(host) + + +# A secure default. +# Sources for more information on TLS ciphers: +# +# - https://wiki.mozilla.org/Security/Server_Side_TLS +# - https://www.ssllabs.com/projects/best-practices/index.html +# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ +# +# The general intent is: +# - Prefer TLS 1.3 cipher suites +# - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE), +# - prefer ECDHE over DHE for better performance, +# - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and +#   security, +# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common, +# - disable NULL authentication, MD5 MACs and DSS for security reasons. +DEFAULT_CIPHERS = ':'.join([ +    'TLS13-AES-256-GCM-SHA384', +    'TLS13-CHACHA20-POLY1305-SHA256', +    'TLS13-AES-128-GCM-SHA256', +    'ECDH+AESGCM', +    'ECDH+CHACHA20', +    'DH+AESGCM', +    'DH+CHACHA20', +    'ECDH+AES256', +    'DH+AES256', +    'ECDH+AES128', +    'DH+AES', +    'RSA+AESGCM', +    'RSA+AES', +    '!aNULL', +    '!eNULL', +    '!MD5', +]) + +try: +    from ssl import SSLContext  # Modern SSL? +except ImportError: +    import sys + +    class SSLContext(object):  # Platform-specific: Python 2 +        def __init__(self, protocol_version): +            self.protocol = protocol_version +            # Use default values from a real SSLContext +            self.check_hostname = False +            self.verify_mode = ssl.CERT_NONE +            self.ca_certs = None +            self.options = 0 +            self.certfile = None +            self.keyfile = None +            self.ciphers = None + +        def load_cert_chain(self, certfile, keyfile): +            self.certfile = certfile +            self.keyfile = keyfile + +        def load_verify_locations(self, cafile=None, capath=None): +            self.ca_certs = cafile + +            if capath is not None: +                raise SSLError("CA directories not supported in older Pythons") + +        def set_ciphers(self, cipher_suite): +            self.ciphers = cipher_suite + +        def wrap_socket(self, socket, server_hostname=None, server_side=False): +            warnings.warn( +                'A true SSLContext object is not available. This prevents ' +                'urllib3 from configuring SSL appropriately and may cause ' +                'certain SSL connections to fail. You can upgrade to a newer ' +                'version of Python to solve this. For more information, see ' +                'https://urllib3.readthedocs.io/en/latest/advanced-usage.html' +                '#ssl-warnings', +                InsecurePlatformWarning +            ) +            kwargs = { +                'keyfile': self.keyfile, +                'certfile': self.certfile, +                'ca_certs': self.ca_certs, +                'cert_reqs': self.verify_mode, +                'ssl_version': self.protocol, +                'server_side': server_side, +            } +            return wrap_socket(socket, ciphers=self.ciphers, **kwargs) + + +def assert_fingerprint(cert, fingerprint): +    """ +    Checks if given fingerprint matches the supplied certificate. + +    :param cert: +        Certificate as bytes object. +    :param fingerprint: +        Fingerprint as string of hexdigits, can be interspersed by colons. +    """ + +    fingerprint = fingerprint.replace(':', '').lower() +    digest_length = len(fingerprint) +    hashfunc = HASHFUNC_MAP.get(digest_length) +    if not hashfunc: +        raise SSLError( +            'Fingerprint of invalid length: {0}'.format(fingerprint)) + +    # We need encode() here for py32; works on py2 and p33. +    fingerprint_bytes = unhexlify(fingerprint.encode()) + +    cert_digest = hashfunc(cert).digest() + +    if not _const_compare_digest(cert_digest, fingerprint_bytes): +        raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".' +                       .format(fingerprint, hexlify(cert_digest))) + + +def resolve_cert_reqs(candidate): +    """ +    Resolves the argument to a numeric constant, which can be passed to +    the wrap_socket function/method from the ssl module. +    Defaults to :data:`ssl.CERT_NONE`. +    If given a string it is assumed to be the name of the constant in the +    :mod:`ssl` module or its abbreviation. +    (So you can specify `REQUIRED` instead of `CERT_REQUIRED`. +    If it's neither `None` nor a string we assume it is already the numeric +    constant which can directly be passed to wrap_socket. +    """ +    if candidate is None: +        return CERT_NONE + +    if isinstance(candidate, str): +        res = getattr(ssl, candidate, None) +        if res is None: +            res = getattr(ssl, 'CERT_' + candidate) +        return res + +    return candidate + + +def resolve_ssl_version(candidate): +    """ +    like resolve_cert_reqs +    """ +    if candidate is None: +        return PROTOCOL_SSLv23 + +    if isinstance(candidate, str): +        res = getattr(ssl, candidate, None) +        if res is None: +            res = getattr(ssl, 'PROTOCOL_' + candidate) +        return res + +    return candidate + + +def create_urllib3_context(ssl_version=None, cert_reqs=None, +                           options=None, ciphers=None): +    """All arguments have the same meaning as ``ssl_wrap_socket``. + +    By default, this function does a lot of the same work that +    ``ssl.create_default_context`` does on Python 3.4+. It: + +    - Disables SSLv2, SSLv3, and compression +    - Sets a restricted set of server ciphers + +    If you wish to enable SSLv3, you can do:: + +        from urllib3.util import ssl_ +        context = ssl_.create_urllib3_context() +        context.options &= ~ssl_.OP_NO_SSLv3 + +    You can do the same to enable compression (substituting ``COMPRESSION`` +    for ``SSLv3`` in the last line above). + +    :param ssl_version: +        The desired protocol version to use. This will default to +        PROTOCOL_SSLv23 which will negotiate the highest protocol that both +        the server and your installation of OpenSSL support. +    :param cert_reqs: +        Whether to require the certificate verification. This defaults to +        ``ssl.CERT_REQUIRED``. +    :param options: +        Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``, +        ``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``. +    :param ciphers: +        Which cipher suites to allow the server to select. +    :returns: +        Constructed SSLContext object with specified options +    :rtype: SSLContext +    """ +    context = SSLContext(ssl_version or ssl.PROTOCOL_SSLv23) + +    context.set_ciphers(ciphers or DEFAULT_CIPHERS) + +    # Setting the default here, as we may have no ssl module on import +    cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs + +    if options is None: +        options = 0 +        # SSLv2 is easily broken and is considered harmful and dangerous +        options |= OP_NO_SSLv2 +        # SSLv3 has several problems and is now dangerous +        options |= OP_NO_SSLv3 +        # Disable compression to prevent CRIME attacks for OpenSSL 1.0+ +        # (issue #309) +        options |= OP_NO_COMPRESSION + +    context.options |= options + +    context.verify_mode = cert_reqs +    if getattr(context, 'check_hostname', None) is not None:  # Platform-specific: Python 3.2 +        # We do our own verification, including fingerprints and alternative +        # hostnames. So disable it here +        context.check_hostname = False +    return context + + +def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, +                    ca_certs=None, server_hostname=None, +                    ssl_version=None, ciphers=None, ssl_context=None, +                    ca_cert_dir=None): +    """ +    All arguments except for server_hostname, ssl_context, and ca_cert_dir have +    the same meaning as they do when using :func:`ssl.wrap_socket`. + +    :param server_hostname: +        When SNI is supported, the expected hostname of the certificate +    :param ssl_context: +        A pre-made :class:`SSLContext` object. If none is provided, one will +        be created using :func:`create_urllib3_context`. +    :param ciphers: +        A string of ciphers we wish the client to support. +    :param ca_cert_dir: +        A directory containing CA certificates in multiple separate files, as +        supported by OpenSSL's -CApath flag or the capath argument to +        SSLContext.load_verify_locations(). +    """ +    context = ssl_context +    if context is None: +        # Note: This branch of code and all the variables in it are no longer +        # used by urllib3 itself. We should consider deprecating and removing +        # this code. +        context = create_urllib3_context(ssl_version, cert_reqs, +                                         ciphers=ciphers) + +    if ca_certs or ca_cert_dir: +        try: +            context.load_verify_locations(ca_certs, ca_cert_dir) +        except IOError as e:  # Platform-specific: Python 2.7 +            raise SSLError(e) +        # Py33 raises FileNotFoundError which subclasses OSError +        # These are not equivalent unless we check the errno attribute +        except OSError as e:  # Platform-specific: Python 3.3 and beyond +            if e.errno == errno.ENOENT: +                raise SSLError(e) +            raise +    elif getattr(context, 'load_default_certs', None) is not None: +        # try to load OS default certs; works well on Windows (require Python3.4+) +        context.load_default_certs() + +    if certfile: +        context.load_cert_chain(certfile, keyfile) + +    # If we detect server_hostname is an IP address then the SNI +    # extension should not be used according to RFC3546 Section 3.1 +    # We shouldn't warn the user if SNI isn't available but we would +    # not be using SNI anyways due to IP address for server_hostname. +    if ((server_hostname is not None and not is_ipaddress(server_hostname)) +            or IS_SECURETRANSPORT): +        if HAS_SNI and server_hostname is not None: +            return context.wrap_socket(sock, server_hostname=server_hostname) + +        warnings.warn( +            'An HTTPS request has been made, but the SNI (Server Name ' +            'Indication) extension to TLS is not available on this platform. ' +            'This may cause the server to present an incorrect TLS ' +            'certificate, which can cause validation failures. You can upgrade to ' +            'a newer version of Python to solve this. For more information, see ' +            'https://urllib3.readthedocs.io/en/latest/advanced-usage.html' +            '#ssl-warnings', +            SNIMissingWarning +        ) + +    return context.wrap_socket(sock) + + +def is_ipaddress(hostname): +    """Detects whether the hostname given is an IP address. + +    :param str hostname: Hostname to examine. +    :return: True if the hostname is an IP address, False otherwise. +    """ +    if six.PY3 and isinstance(hostname, bytes): +        # IDN A-label bytes are ASCII compatible. +        hostname = hostname.decode('ascii') + +    families = [socket.AF_INET] +    if hasattr(socket, 'AF_INET6'): +        families.append(socket.AF_INET6) + +    for af in families: +        try: +            inet_pton(af, hostname) +        except (socket.error, ValueError, OSError): +            pass +        else: +            return True +    return False diff --git a/python/urllib3/util/timeout.py b/python/urllib3/util/timeout.py new file mode 100644 index 0000000..cec817e --- /dev/null +++ b/python/urllib3/util/timeout.py @@ -0,0 +1,242 @@ +from __future__ import absolute_import +# The default socket timeout, used by httplib to indicate that no timeout was +# specified by the user +from socket import _GLOBAL_DEFAULT_TIMEOUT +import time + +from ..exceptions import TimeoutStateError + +# A sentinel value to indicate that no timeout was specified by the user in +# urllib3 +_Default = object() + + +# Use time.monotonic if available. +current_time = getattr(time, "monotonic", time.time) + + +class Timeout(object): +    """ Timeout configuration. + +    Timeouts can be defined as a default for a pool:: + +        timeout = Timeout(connect=2.0, read=7.0) +        http = PoolManager(timeout=timeout) +        response = http.request('GET', 'http://example.com/') + +    Or per-request (which overrides the default for the pool):: + +        response = http.request('GET', 'http://example.com/', timeout=Timeout(10)) + +    Timeouts can be disabled by setting all the parameters to ``None``:: + +        no_timeout = Timeout(connect=None, read=None) +        response = http.request('GET', 'http://example.com/, timeout=no_timeout) + + +    :param total: +        This combines the connect and read timeouts into one; the read timeout +        will be set to the time leftover from the connect attempt. In the +        event that both a connect timeout and a total are specified, or a read +        timeout and a total are specified, the shorter timeout will be applied. + +        Defaults to None. + +    :type total: integer, float, or None + +    :param connect: +        The maximum amount of time to wait for a connection attempt to a server +        to succeed. Omitting the parameter will default the connect timeout to +        the system default, probably `the global default timeout in socket.py +        <http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_. +        None will set an infinite timeout for connection attempts. + +    :type connect: integer, float, or None + +    :param read: +        The maximum amount of time to wait between consecutive +        read operations for a response from the server. Omitting +        the parameter will default the read timeout to the system +        default, probably `the global default timeout in socket.py +        <http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_. +        None will set an infinite timeout. + +    :type read: integer, float, or None + +    .. note:: + +        Many factors can affect the total amount of time for urllib3 to return +        an HTTP response. + +        For example, Python's DNS resolver does not obey the timeout specified +        on the socket. Other factors that can affect total request time include +        high CPU load, high swap, the program running at a low priority level, +        or other behaviors. + +        In addition, the read and total timeouts only measure the time between +        read operations on the socket connecting the client and the server, +        not the total amount of time for the request to return a complete +        response. For most requests, the timeout is raised because the server +        has not sent the first byte in the specified time. This is not always +        the case; if a server streams one byte every fifteen seconds, a timeout +        of 20 seconds will not trigger, even though the request will take +        several minutes to complete. + +        If your goal is to cut off any request after a set amount of wall clock +        time, consider having a second "watcher" thread to cut off a slow +        request. +    """ + +    #: A sentinel object representing the default timeout value +    DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT + +    def __init__(self, total=None, connect=_Default, read=_Default): +        self._connect = self._validate_timeout(connect, 'connect') +        self._read = self._validate_timeout(read, 'read') +        self.total = self._validate_timeout(total, 'total') +        self._start_connect = None + +    def __str__(self): +        return '%s(connect=%r, read=%r, total=%r)' % ( +            type(self).__name__, self._connect, self._read, self.total) + +    @classmethod +    def _validate_timeout(cls, value, name): +        """ Check that a timeout attribute is valid. + +        :param value: The timeout value to validate +        :param name: The name of the timeout attribute to validate. This is +            used to specify in error messages. +        :return: The validated and casted version of the given value. +        :raises ValueError: If it is a numeric value less than or equal to +            zero, or the type is not an integer, float, or None. +        """ +        if value is _Default: +            return cls.DEFAULT_TIMEOUT + +        if value is None or value is cls.DEFAULT_TIMEOUT: +            return value + +        if isinstance(value, bool): +            raise ValueError("Timeout cannot be a boolean value. It must " +                             "be an int, float or None.") +        try: +            float(value) +        except (TypeError, ValueError): +            raise ValueError("Timeout value %s was %s, but it must be an " +                             "int, float or None." % (name, value)) + +        try: +            if value <= 0: +                raise ValueError("Attempted to set %s timeout to %s, but the " +                                 "timeout cannot be set to a value less " +                                 "than or equal to 0." % (name, value)) +        except TypeError:  # Python 3 +            raise ValueError("Timeout value %s was %s, but it must be an " +                             "int, float or None." % (name, value)) + +        return value + +    @classmethod +    def from_float(cls, timeout): +        """ Create a new Timeout from a legacy timeout value. + +        The timeout value used by httplib.py sets the same timeout on the +        connect(), and recv() socket requests. This creates a :class:`Timeout` +        object that sets the individual timeouts to the ``timeout`` value +        passed to this function. + +        :param timeout: The legacy timeout value. +        :type timeout: integer, float, sentinel default object, or None +        :return: Timeout object +        :rtype: :class:`Timeout` +        """ +        return Timeout(read=timeout, connect=timeout) + +    def clone(self): +        """ Create a copy of the timeout object + +        Timeout properties are stored per-pool but each request needs a fresh +        Timeout object to ensure each one has its own start/stop configured. + +        :return: a copy of the timeout object +        :rtype: :class:`Timeout` +        """ +        # We can't use copy.deepcopy because that will also create a new object +        # for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to +        # detect the user default. +        return Timeout(connect=self._connect, read=self._read, +                       total=self.total) + +    def start_connect(self): +        """ Start the timeout clock, used during a connect() attempt + +        :raises urllib3.exceptions.TimeoutStateError: if you attempt +            to start a timer that has been started already. +        """ +        if self._start_connect is not None: +            raise TimeoutStateError("Timeout timer has already been started.") +        self._start_connect = current_time() +        return self._start_connect + +    def get_connect_duration(self): +        """ Gets the time elapsed since the call to :meth:`start_connect`. + +        :return: Elapsed time. +        :rtype: float +        :raises urllib3.exceptions.TimeoutStateError: if you attempt +            to get duration for a timer that hasn't been started. +        """ +        if self._start_connect is None: +            raise TimeoutStateError("Can't get connect duration for timer " +                                    "that has not started.") +        return current_time() - self._start_connect + +    @property +    def connect_timeout(self): +        """ Get the value to use when setting a connection timeout. + +        This will be a positive float or integer, the value None +        (never timeout), or the default system timeout. + +        :return: Connect timeout. +        :rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None +        """ +        if self.total is None: +            return self._connect + +        if self._connect is None or self._connect is self.DEFAULT_TIMEOUT: +            return self.total + +        return min(self._connect, self.total) + +    @property +    def read_timeout(self): +        """ Get the value for the read timeout. + +        This assumes some time has elapsed in the connection timeout and +        computes the read timeout appropriately. + +        If self.total is set, the read timeout is dependent on the amount of +        time taken by the connect timeout. If the connection time has not been +        established, a :exc:`~urllib3.exceptions.TimeoutStateError` will be +        raised. + +        :return: Value to use for the read timeout. +        :rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None +        :raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect` +            has not yet been called on this object. +        """ +        if (self.total is not None and +                self.total is not self.DEFAULT_TIMEOUT and +                self._read is not None and +                self._read is not self.DEFAULT_TIMEOUT): +            # In case the connect timeout has not yet been established. +            if self._start_connect is None: +                return self._read +            return max(0, min(self.total - self.get_connect_duration(), +                              self._read)) +        elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT: +            return max(0, self.total - self.get_connect_duration()) +        else: +            return self._read diff --git a/python/urllib3/util/url.py b/python/urllib3/util/url.py new file mode 100644 index 0000000..6b6f996 --- /dev/null +++ b/python/urllib3/util/url.py @@ -0,0 +1,230 @@ +from __future__ import absolute_import +from collections import namedtuple + +from ..exceptions import LocationParseError + + +url_attrs = ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment'] + +# We only want to normalize urls with an HTTP(S) scheme. +# urllib3 infers URLs without a scheme (None) to be http. +NORMALIZABLE_SCHEMES = ('http', 'https', None) + + +class Url(namedtuple('Url', url_attrs)): +    """ +    Datastructure for representing an HTTP URL. Used as a return value for +    :func:`parse_url`. Both the scheme and host are normalized as they are +    both case-insensitive according to RFC 3986. +    """ +    __slots__ = () + +    def __new__(cls, scheme=None, auth=None, host=None, port=None, path=None, +                query=None, fragment=None): +        if path and not path.startswith('/'): +            path = '/' + path +        if scheme: +            scheme = scheme.lower() +        if host and scheme in NORMALIZABLE_SCHEMES: +            host = host.lower() +        return super(Url, cls).__new__(cls, scheme, auth, host, port, path, +                                       query, fragment) + +    @property +    def hostname(self): +        """For backwards-compatibility with urlparse. We're nice like that.""" +        return self.host + +    @property +    def request_uri(self): +        """Absolute path including the query string.""" +        uri = self.path or '/' + +        if self.query is not None: +            uri += '?' + self.query + +        return uri + +    @property +    def netloc(self): +        """Network location including host and port""" +        if self.port: +            return '%s:%d' % (self.host, self.port) +        return self.host + +    @property +    def url(self): +        """ +        Convert self into a url + +        This function should more or less round-trip with :func:`.parse_url`. The +        returned url may not be exactly the same as the url inputted to +        :func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls +        with a blank port will have : removed). + +        Example: :: + +            >>> U = parse_url('http://google.com/mail/') +            >>> U.url +            'http://google.com/mail/' +            >>> Url('http', 'username:password', 'host.com', 80, +            ... '/path', 'query', 'fragment').url +            'http://username:password@host.com:80/path?query#fragment' +        """ +        scheme, auth, host, port, path, query, fragment = self +        url = '' + +        # We use "is not None" we want things to happen with empty strings (or 0 port) +        if scheme is not None: +            url += scheme + '://' +        if auth is not None: +            url += auth + '@' +        if host is not None: +            url += host +        if port is not None: +            url += ':' + str(port) +        if path is not None: +            url += path +        if query is not None: +            url += '?' + query +        if fragment is not None: +            url += '#' + fragment + +        return url + +    def __str__(self): +        return self.url + + +def split_first(s, delims): +    """ +    Given a string and an iterable of delimiters, split on the first found +    delimiter. Return two split parts and the matched delimiter. + +    If not found, then the first part is the full input string. + +    Example:: + +        >>> split_first('foo/bar?baz', '?/=') +        ('foo', 'bar?baz', '/') +        >>> split_first('foo/bar?baz', '123') +        ('foo/bar?baz', '', None) + +    Scales linearly with number of delims. Not ideal for large number of delims. +    """ +    min_idx = None +    min_delim = None +    for d in delims: +        idx = s.find(d) +        if idx < 0: +            continue + +        if min_idx is None or idx < min_idx: +            min_idx = idx +            min_delim = d + +    if min_idx is None or min_idx < 0: +        return s, '', None + +    return s[:min_idx], s[min_idx + 1:], min_delim + + +def parse_url(url): +    """ +    Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is +    performed to parse incomplete urls. Fields not provided will be None. + +    Partly backwards-compatible with :mod:`urlparse`. + +    Example:: + +        >>> parse_url('http://google.com/mail/') +        Url(scheme='http', host='google.com', port=None, path='/mail/', ...) +        >>> parse_url('google.com:80') +        Url(scheme=None, host='google.com', port=80, path=None, ...) +        >>> parse_url('/foo?bar') +        Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) +    """ + +    # While this code has overlap with stdlib's urlparse, it is much +    # simplified for our needs and less annoying. +    # Additionally, this implementations does silly things to be optimal +    # on CPython. + +    if not url: +        # Empty +        return Url() + +    scheme = None +    auth = None +    host = None +    port = None +    path = None +    fragment = None +    query = None + +    # Scheme +    if '://' in url: +        scheme, url = url.split('://', 1) + +    # Find the earliest Authority Terminator +    # (http://tools.ietf.org/html/rfc3986#section-3.2) +    url, path_, delim = split_first(url, ['/', '?', '#']) + +    if delim: +        # Reassemble the path +        path = delim + path_ + +    # Auth +    if '@' in url: +        # Last '@' denotes end of auth part +        auth, url = url.rsplit('@', 1) + +    # IPv6 +    if url and url[0] == '[': +        host, url = url.split(']', 1) +        host += ']' + +    # Port +    if ':' in url: +        _host, port = url.split(':', 1) + +        if not host: +            host = _host + +        if port: +            # If given, ports must be integers. No whitespace, no plus or +            # minus prefixes, no non-integer digits such as ^2 (superscript). +            if not port.isdigit(): +                raise LocationParseError(url) +            try: +                port = int(port) +            except ValueError: +                raise LocationParseError(url) +        else: +            # Blank ports are cool, too. (rfc3986#section-3.2.3) +            port = None + +    elif not host and url: +        host = url + +    if not path: +        return Url(scheme, auth, host, port, path, query, fragment) + +    # Fragment +    if '#' in path: +        path, fragment = path.split('#', 1) + +    # Query +    if '?' in path: +        path, query = path.split('?', 1) + +    return Url(scheme, auth, host, port, path, query, fragment) + + +def get_host(url): +    """ +    Deprecated. Use :func:`parse_url` instead. +    """ +    p = parse_url(url) +    return p.scheme or 'http', p.hostname, p.port diff --git a/python/urllib3/util/wait.py b/python/urllib3/util/wait.py new file mode 100644 index 0000000..4db71ba --- /dev/null +++ b/python/urllib3/util/wait.py @@ -0,0 +1,150 @@ +import errno +from functools import partial +import select +import sys +try: +    from time import monotonic +except ImportError: +    from time import time as monotonic + +__all__ = ["NoWayToWaitForSocketError", "wait_for_read", "wait_for_write"] + + +class NoWayToWaitForSocketError(Exception): +    pass + + +# How should we wait on sockets? +# +# There are two types of APIs you can use for waiting on sockets: the fancy +# modern stateful APIs like epoll/kqueue, and the older stateless APIs like +# select/poll. The stateful APIs are more efficient when you have a lots of +# sockets to keep track of, because you can set them up once and then use them +# lots of times. But we only ever want to wait on a single socket at a time +# and don't want to keep track of state, so the stateless APIs are actually +# more efficient. So we want to use select() or poll(). +# +# Now, how do we choose between select() and poll()? On traditional Unixes, +# select() has a strange calling convention that makes it slow, or fail +# altogether, for high-numbered file descriptors. The point of poll() is to fix +# that, so on Unixes, we prefer poll(). +# +# On Windows, there is no poll() (or at least Python doesn't provide a wrapper +# for it), but that's OK, because on Windows, select() doesn't have this +# strange calling convention; plain select() works fine. +# +# So: on Windows we use select(), and everywhere else we use poll(). We also +# fall back to select() in case poll() is somehow broken or missing. + +if sys.version_info >= (3, 5): +    # Modern Python, that retries syscalls by default +    def _retry_on_intr(fn, timeout): +        return fn(timeout) +else: +    # Old and broken Pythons. +    def _retry_on_intr(fn, timeout): +        if timeout is None: +            deadline = float("inf") +        else: +            deadline = monotonic() + timeout + +        while True: +            try: +                return fn(timeout) +            # OSError for 3 <= pyver < 3.5, select.error for pyver <= 2.7 +            except (OSError, select.error) as e: +                # 'e.args[0]' incantation works for both OSError and select.error +                if e.args[0] != errno.EINTR: +                    raise +                else: +                    timeout = deadline - monotonic() +                    if timeout < 0: +                        timeout = 0 +                    if timeout == float("inf"): +                        timeout = None +                    continue + + +def select_wait_for_socket(sock, read=False, write=False, timeout=None): +    if not read and not write: +        raise RuntimeError("must specify at least one of read=True, write=True") +    rcheck = [] +    wcheck = [] +    if read: +        rcheck.append(sock) +    if write: +        wcheck.append(sock) +    # When doing a non-blocking connect, most systems signal success by +    # marking the socket writable. Windows, though, signals success by marked +    # it as "exceptional". We paper over the difference by checking the write +    # sockets for both conditions. (The stdlib selectors module does the same +    # thing.) +    fn = partial(select.select, rcheck, wcheck, wcheck) +    rready, wready, xready = _retry_on_intr(fn, timeout) +    return bool(rready or wready or xready) + + +def poll_wait_for_socket(sock, read=False, write=False, timeout=None): +    if not read and not write: +        raise RuntimeError("must specify at least one of read=True, write=True") +    mask = 0 +    if read: +        mask |= select.POLLIN +    if write: +        mask |= select.POLLOUT +    poll_obj = select.poll() +    poll_obj.register(sock, mask) + +    # For some reason, poll() takes timeout in milliseconds +    def do_poll(t): +        if t is not None: +            t *= 1000 +        return poll_obj.poll(t) + +    return bool(_retry_on_intr(do_poll, timeout)) + + +def null_wait_for_socket(*args, **kwargs): +    raise NoWayToWaitForSocketError("no select-equivalent available") + + +def _have_working_poll(): +    # Apparently some systems have a select.poll that fails as soon as you try +    # to use it, either due to strange configuration or broken monkeypatching +    # from libraries like eventlet/greenlet. +    try: +        poll_obj = select.poll() +        _retry_on_intr(poll_obj.poll, 0) +    except (AttributeError, OSError): +        return False +    else: +        return True + + +def wait_for_socket(*args, **kwargs): +    # We delay choosing which implementation to use until the first time we're +    # called. We could do it at import time, but then we might make the wrong +    # decision if someone goes wild with monkeypatching select.poll after +    # we're imported. +    global wait_for_socket +    if _have_working_poll(): +        wait_for_socket = poll_wait_for_socket +    elif hasattr(select, "select"): +        wait_for_socket = select_wait_for_socket +    else:  # Platform-specific: Appengine. +        wait_for_socket = null_wait_for_socket +    return wait_for_socket(*args, **kwargs) + + +def wait_for_read(sock, timeout=None): +    """ Waits for reading to be available on a given socket. +    Returns True if the socket is readable, or False if the timeout expired. +    """ +    return wait_for_socket(sock, read=True, timeout=timeout) + + +def wait_for_write(sock, timeout=None): +    """ Waits for writing to be available on a given socket. +    Returns True if the socket is readable, or False if the timeout expired. +    """ +    return wait_for_socket(sock, write=True, timeout=timeout) @@ -2,16 +2,19 @@ from gevent import monkey  monkey.patch_all()  import gevent.socket -from gevent.pywsgi import WSGIServer  from youtube.youtube import youtube +from youtube import util  import http_errors +import settings + +from gevent.pywsgi import WSGIServer  import urllib +import urllib3  import socket  import socks, sockshandler  import subprocess  import re -import settings @@ -31,15 +34,14 @@ def proxy_site(env, start_response):          url += '?' + env['QUERY_STRING'] -    req = urllib.request.Request(url, headers=headers) -    if settings.route_tor: -        opener = urllib.request.build_opener(sockshandler.SocksiPyHandler(socks.PROXY_TYPE_SOCKS5, "127.0.0.1", 9150)) -        response = opener.open(req, timeout=10) -    else: -        response = urllib.request.urlopen(req, timeout=10) +    content, response = util.fetch_url(url, headers, return_response=True) + +    headers = response.getheaders() +    if isinstance(headers, urllib3._collections.HTTPHeaderDict): +        headers = headers.items() -    start_response('200 OK', response.getheaders() ) -    return response.read() +    start_response('200 OK', headers ) +    return content  site_handlers = {      'youtube.com':youtube, diff --git a/youtube/accounts.py b/youtube/accounts.py index bde9852..375bf2a 100644 --- a/youtube/accounts.py +++ b/youtube/accounts.py @@ -1,10 +1,10 @@  # Contains functions having to do with logging in +from youtube import util, html_common +import settings  import urllib  import json -from youtube import common  import re -import settings  import http.cookiejar  import io  import os @@ -106,7 +106,7 @@ def get_account_login_page(env, start_response):      '''      page = ''' -    <form action="''' + common.URL_ORIGIN + '''/login" method="POST"> +    <form action="''' + util.URL_ORIGIN + '''/login" method="POST">          <div class="form-field">              <label for="username">Username:</label>              <input type="text" id="username" name="username"> @@ -130,10 +130,10 @@ Using Tor to log in should only be done if the account was created using a proxy      </div>      ''' -    return common.yt_basic_template.substitute( +    return html_common.yt_basic_template.substitute(          page_title = "Login",          style = style, -        header = common.get_header(), +        header = html_common.get_header(),          page = page,      ).encode('utf-8') @@ -229,7 +229,7 @@ def _login(username, password, cookiejar, use_tor):      Taken from youtube-dl      """ -    login_page = common.fetch_url(_LOGIN_URL, yt_dl_headers, report_text='Downloaded login page', cookiejar_receive=cookiejar, use_tor=use_tor).decode('utf-8') +    login_page = util.fetch_url(_LOGIN_URL, yt_dl_headers, report_text='Downloaded login page', cookiejar_receive=cookiejar, use_tor=use_tor).decode('utf-8')      '''with open('debug/login_page', 'w', encoding='utf-8') as f:          f.write(login_page)'''      #print(cookiejar.as_lwp_str()) @@ -255,7 +255,7 @@ def _login(username, password, cookiejar, use_tor):              'Google-Accounts-XSRF': 1,          }          headers.update(yt_dl_headers) -        result = common.fetch_url(url, headers, report_text=note, data=data, cookiejar_send=cookiejar, cookiejar_receive=cookiejar, use_tor=use_tor).decode('utf-8') +        result = util.fetch_url(url, headers, report_text=note, data=data, cookiejar_send=cookiejar, cookiejar_receive=cookiejar, use_tor=use_tor).decode('utf-8')          #print(cookiejar.as_lwp_str())          '''with open('debug/' + note, 'w', encoding='utf-8') as f:              f.write(result)''' @@ -387,7 +387,7 @@ def _login(username, password, cookiejar, use_tor):          return False      try: -        check_cookie_results = common.fetch_url(check_cookie_url, headers=yt_dl_headers, report_text="Checked cookie", cookiejar_send=cookiejar, cookiejar_receive=cookiejar, use_tor=use_tor).decode('utf-8') +        check_cookie_results = util.fetch_url(check_cookie_url, headers=yt_dl_headers, report_text="Checked cookie", cookiejar_send=cookiejar, cookiejar_receive=cookiejar, use_tor=use_tor).decode('utf-8')      except (urllib.error.URLError, compat_http_client.HTTPException, socket.error) as err:          return False @@ -398,7 +398,7 @@ def _login(username, password, cookiejar, use_tor):          warn('Unable to log in')          return False -    select_site_page = common.fetch_url('https://m.youtube.com/select_site', headers=common.mobile_ua, report_text="Retrieved page for channel id", cookiejar_send=cookiejar, use_tor=use_tor).decode('utf-8') +    select_site_page = util.fetch_url('https://m.youtube.com/select_site', headers=util.mobile_ua, report_text="Retrieved page for channel id", cookiejar_send=cookiejar, use_tor=use_tor).decode('utf-8')      match = _CHANNEL_ID_RE.search(select_site_page)      if match is None:          warn('Failed to find channel id') diff --git a/youtube/channel.py b/youtube/channel.py index c83d7d1..55316e2 100644 --- a/youtube/channel.py +++ b/youtube/channel.py @@ -1,6 +1,6 @@  import base64 -import youtube.common as common -from youtube.common import default_multi_get, URL_ORIGIN, get_thumbnail_url, video_id +from youtube import util, yt_data_extract, html_common +  import http_errors  import urllib  import json @@ -91,7 +91,7 @@ def get_channel_tab(channel_id, page="1", sort=3, tab='videos', view=1):      url = "https://www.youtube.com/browse_ajax?ctoken=" + ctoken      print("Sending channel tab ajax request") -    content = common.fetch_url(url, common.desktop_ua + headers_1) +    content = util.fetch_url(url, util.desktop_ua + headers_1)      print("Finished recieving channel tab response")      '''with open('debug/channel_debug', 'wb') as f: @@ -110,7 +110,7 @@ def get_number_of_videos(channel_id):      # Sometimes retrieving playlist info fails with 403 for no discernable reason      try: -        response = common.fetch_url(url, common.mobile_ua + headers_pbj) +        response = util.fetch_url(url, util.mobile_ua + headers_pbj)      except urllib.error.HTTPError as e:          if e.code != 403:              raise @@ -133,20 +133,20 @@ def get_channel_id(username):      # method that gives the smallest possible response at ~10 kb      # needs to be as fast as possible      url = 'https://m.youtube.com/user/' + username + '/about?ajax=1&disable_polymer=true' -    response = common.fetch_url(url, common.mobile_ua + headers_1).decode('utf-8') +    response = util.fetch_url(url, util.mobile_ua + headers_1).decode('utf-8')      return re.search(r'"channel_id":\s*"([a-zA-Z0-9_-]*)"', response).group(1)  def grid_items_html(items, additional_info={}):      result = '''            <nav class="item-grid">\n'''      for item in items: -        result += common.renderer_html(item, additional_info) +        result += html_common.renderer_html(item, additional_info)      result += '''\n</nav>'''      return result  def list_items_html(items, additional_info={}):      result = '''                <nav class="item-list">'''      for item in items: -        result += common.renderer_html(item, additional_info) +        result += html_common.renderer_html(item, additional_info)      result += '''\n</nav>'''      return result @@ -168,11 +168,11 @@ def channel_tabs_html(channel_id, current_tab, search_box_value=''):              )          else:              result += channel_tab_template.substitute( -                href_attribute = ' href="' + URL_ORIGIN + '/channel/' + channel_id + '/' + tab_name.lower() + '"', +                href_attribute = ' href="' + util.URL_ORIGIN + '/channel/' + channel_id + '/' + tab_name.lower() + '"',                  tab_name = tab_name,              )      result += channel_search_template.substitute( -        action = URL_ORIGIN + "/channel/" + channel_id + "/search", +        action = util.URL_ORIGIN + "/channel/" + channel_id + "/search",          search_box_value = html.escape(search_box_value),      )      return result @@ -192,7 +192,7 @@ def channel_sort_buttons_html(channel_id, tab, current_sort):              )          else:              result += channel_sort_button_template.substitute( -                href_attribute=' href="' + URL_ORIGIN + '/channel/' + channel_id + '/' + tab + '?sort=' + sort_number + '"', +                href_attribute=' href="' + util.URL_ORIGIN + '/channel/' + channel_id + '/' + tab + '?sort=' + sort_number + '"',                  text = 'Sort by ' + sort_name              )      return result @@ -246,7 +246,7 @@ def channel_videos_html(polymer_json, current_page=1, current_sort=3, number_of_      items_html = grid_items_html(items, {'author': microformat['title']})      return yt_channel_items_template.substitute( -        header              = common.get_header(), +        header              = html_common.get_header(),          channel_title       = microformat['title'],          channel_id          = channel_id,          channel_tabs        = channel_tabs_html(channel_id, 'Videos'), @@ -254,7 +254,7 @@ def channel_videos_html(polymer_json, current_page=1, current_sort=3, number_of_          avatar              = '/' + microformat['thumbnail']['thumbnails'][0]['url'],          page_title          = microformat['title'] + ' - Channel',          items               = items_html, -        page_buttons        = common.page_buttons_html(current_page, math.ceil(number_of_videos/30), URL_ORIGIN + "/channel/" + channel_id + "/videos", current_query_string), +        page_buttons        = html_common.page_buttons_html(current_page, math.ceil(number_of_videos/30), util.URL_ORIGIN + "/channel/" + channel_id + "/videos", current_query_string),          number_of_results   = '{:,}'.format(number_of_videos) + " videos",      ) @@ -268,7 +268,7 @@ def channel_playlists_html(polymer_json, current_sort=3):      items_html = grid_items_html(items, {'author': microformat['title']})      return yt_channel_items_template.substitute( -        header              = common.get_header(), +        header              = html_common.get_header(),          channel_title       = microformat['title'],          channel_id          = channel_id,          channel_tabs        = channel_tabs_html(channel_id, 'Playlists'), @@ -312,25 +312,25 @@ def channel_about_page(polymer_json):          channel_links += channel_link_template.substitute(              url     = html.escape(url), -            text    = common.get_plain_text(link_json['title']), +            text    = yt_data_extract.get_plain_text(link_json['title']),          )      stats = ''      for stat_name in ('subscriberCountText', 'joinedDateText', 'viewCountText', 'country'):          try: -            stat_value = common.get_plain_text(channel_metadata[stat_name]) +            stat_value = yt_data_extract.get_plain_text(channel_metadata[stat_name])          except KeyError:              continue          else:              stats += stat_template.substitute(stat_value=stat_value)      try: -        description = common.format_text_runs(common.get_formatted_text(channel_metadata['description'])) +        description = yt_data_extract.format_text_runs(yt_data_extract.get_formatted_text(channel_metadata['description']))      except KeyError:          description = ''      return yt_channel_about_template.substitute( -        header              = common.get_header(), -        page_title          = common.get_plain_text(channel_metadata['title']) + ' - About', -        channel_title       = common.get_plain_text(channel_metadata['title']), +        header              = html_common.get_header(), +        page_title          = yt_data_extract.get_plain_text(channel_metadata['title']) + ' - About', +        channel_title       = yt_data_extract.get_plain_text(channel_metadata['title']),          avatar              = html.escape(avatar),          description         = description,          links               = channel_links, @@ -354,14 +354,14 @@ def channel_search_page(polymer_json, query, current_page=1, number_of_videos =      items_html = list_items_html(items)      return yt_channel_items_template.substitute( -        header              = common.get_header(), +        header              = html_common.get_header(),          channel_title       = html.escape(microformat['title']),          channel_id          = channel_id,          channel_tabs        = channel_tabs_html(channel_id, '', query),          avatar              = '/' + microformat['thumbnail']['thumbnails'][0]['url'],          page_title          = html.escape(query + ' - Channel search'),          items               = items_html, -        page_buttons        = common.page_buttons_html(current_page, math.ceil(number_of_videos/29), URL_ORIGIN + "/channel/" + channel_id + "/search", current_query_string), +        page_buttons        = html_common.page_buttons_html(current_page, math.ceil(number_of_videos/29), util.URL_ORIGIN + "/channel/" + channel_id + "/search", current_query_string),          number_of_results   = '',          sort_buttons        = '',      ) @@ -371,7 +371,7 @@ def get_channel_search_json(channel_id, query, page):      ctoken = proto.string(2, channel_id) + proto.string(3, params) + proto.string(11, query)      ctoken = base64.urlsafe_b64encode(proto.nested(80226972, ctoken)).decode('ascii') -    polymer_json = common.fetch_url("https://www.youtube.com/browse_ajax?ctoken=" + ctoken, common.desktop_ua + headers_1) +    polymer_json = util.fetch_url("https://www.youtube.com/browse_ajax?ctoken=" + ctoken, util.desktop_ua + headers_1)      '''with open('debug/channel_search_debug', 'wb') as f:          f.write(polymer_json)'''      polymer_json = json.loads(polymer_json) @@ -388,10 +388,10 @@ def get_channel_page(env, start_response):          tab = 'videos'      parameters = env['parameters'] -    page_number = int(common.default_multi_get(parameters, 'page', 0, default='1')) -    sort = common.default_multi_get(parameters, 'sort', 0, default='3') -    view = common.default_multi_get(parameters, 'view', 0, default='1') -    query = common.default_multi_get(parameters, 'query', 0, default='') +    page_number = int(util.default_multi_get(parameters, 'page', 0, default='1')) +    sort = util.default_multi_get(parameters, 'sort', 0, default='3') +    view = util.default_multi_get(parameters, 'view', 0, default='1') +    query = util.default_multi_get(parameters, 'query', 0, default='')      if tab == 'videos':          tasks = ( @@ -403,11 +403,11 @@ def get_channel_page(env, start_response):          result = channel_videos_html(polymer_json, page_number, sort, number_of_videos, env['QUERY_STRING'])      elif tab == 'about': -        polymer_json = common.fetch_url('https://www.youtube.com/channel/' + channel_id + '/about?pbj=1', common.desktop_ua + headers_1) +        polymer_json = util.fetch_url('https://www.youtube.com/channel/' + channel_id + '/about?pbj=1', util.desktop_ua + headers_1)          polymer_json = json.loads(polymer_json)          result = channel_about_page(polymer_json)      elif tab == 'playlists': -        polymer_json = common.fetch_url('https://www.youtube.com/channel/' + channel_id + '/playlists?pbj=1&view=1&sort=' + playlist_sort_codes[sort], common.desktop_ua + headers_1) +        polymer_json = util.fetch_url('https://www.youtube.com/channel/' + channel_id + '/playlists?pbj=1&view=1&sort=' + playlist_sort_codes[sort], util.desktop_ua + headers_1)          '''with open('debug/channel_playlists_debug', 'wb') as f:              f.write(polymer_json)'''          polymer_json = json.loads(polymer_json) @@ -447,22 +447,22 @@ def get_channel_page_general_url(env, start_response):          return b'Invalid channel url'      if page == 'videos': -        polymer_json = common.fetch_url(base_url + '/videos?pbj=1&view=0', common.desktop_ua + headers_1) +        polymer_json = util.fetch_url(base_url + '/videos?pbj=1&view=0', util.desktop_ua + headers_1)          '''with open('debug/user_page_videos', 'wb') as f:              f.write(polymer_json)'''          polymer_json = json.loads(polymer_json)          result = channel_videos_html(polymer_json)      elif page == 'about': -        polymer_json = common.fetch_url(base_url + '/about?pbj=1', common.desktop_ua + headers_1) +        polymer_json = util.fetch_url(base_url + '/about?pbj=1', util.desktop_ua + headers_1)          polymer_json = json.loads(polymer_json)          result = channel_about_page(polymer_json)      elif page == 'playlists': -        polymer_json = common.fetch_url(base_url+ '/playlists?pbj=1&view=1', common.desktop_ua + headers_1) +        polymer_json = util.fetch_url(base_url+ '/playlists?pbj=1&view=1', util.desktop_ua + headers_1)          polymer_json = json.loads(polymer_json)          result = channel_playlists_html(polymer_json)      elif page == 'search':          raise NotImplementedError() -        '''polymer_json = common.fetch_url('https://www.youtube.com/user' + username +  '/search?pbj=1&' + query_string, common.desktop_ua + headers_1) +        '''polymer_json = util.fetch_url('https://www.youtube.com/user' + username +  '/search?pbj=1&' + query_string, util.desktop_ua + headers_1)          polymer_json = json.loads(polymer_json)          return channel_search_page('''      else: diff --git a/youtube/comments.py b/youtube/comments.py index 10209e7..94b086e 100644 --- a/youtube/comments.py +++ b/youtube/comments.py @@ -1,13 +1,14 @@ +from youtube import proto, util, html_common, yt_data_extract, accounts +import settings +  import json -from youtube import proto, common, accounts  import base64 -from youtube.common import uppercase_escape, default_multi_get, format_text_runs, URL_ORIGIN, fetch_url  from string import Template  import urllib.request  import urllib  import html -import settings  import re +  comment_area_template = Template('''  <section class="comment-area">  $video-metadata @@ -130,7 +131,7 @@ def request_comments(ctoken, replies=False):      url = base_url + ctoken.replace("=", "%3D") + "&pbj=1"      for i in range(0,8):    # don't retry more than 8 times -        content = fetch_url(url, headers=mobile_headers, report_text="Retrieved comments") +        content = util.fetch_url(url, headers=mobile_headers, report_text="Retrieved comments")          if content[0:4] == b")]}'":             # random closing characters included at beginning of response for some reason              content = content[4:]          elif content[0:10] == b'\n<!DOCTYPE':   # occasionally returns html instead of json for no reason @@ -151,10 +152,10 @@ def single_comment_ctoken(video_id, comment_id):  def parse_comments_ajax(content, replies=False):      try: -        content = json.loads(uppercase_escape(content.decode('utf-8'))) +        content = json.loads(util.uppercase_escape(content.decode('utf-8')))          #print(content)          comments_raw = content['content']['continuation_contents']['contents'] -        ctoken = default_multi_get(content, 'content', 'continuation_contents', 'continuations', 0, 'continuation', default='') +        ctoken = util.default_multi_get(content, 'content', 'continuation_contents', 'continuations', 0, 'continuation', default='')          comments = []          for comment_raw in comments_raw: @@ -163,7 +164,7 @@ def parse_comments_ajax(content, replies=False):                  if comment_raw['replies'] is not None:                      reply_ctoken = comment_raw['replies']['continuations'][0]['continuation']                      comment_id, video_id = get_ids(reply_ctoken) -                    replies_url = URL_ORIGIN + '/comments?parent_id=' + comment_id + "&video_id=" + video_id +                    replies_url = util.URL_ORIGIN + '/comments?parent_id=' + comment_id + "&video_id=" + video_id                  comment_raw = comment_raw['comment']              comment = {              'author': comment_raw['author']['runs'][0]['text'], @@ -189,7 +190,7 @@ reply_count_regex = re.compile(r'(\d+)')  def parse_comments_polymer(content, replies=False):      try:          video_title = '' -        content = json.loads(uppercase_escape(content.decode('utf-8'))) +        content = json.loads(util.uppercase_escape(content.decode('utf-8')))          url = content[1]['url']          ctoken = urllib.parse.parse_qs(url[url.find('?')+1:])['ctoken'][0]          video_id = ctoken_metadata(ctoken)['video_id'] @@ -200,7 +201,7 @@ def parse_comments_polymer(content, replies=False):              comments_raw = content[1]['response']['continuationContents']['commentRepliesContinuation']['contents']              replies = True -        ctoken = default_multi_get(content, 1, 'response', 'continuationContents', 'commentSectionContinuation', 'continuations', 0, 'nextContinuationData', 'continuation', default='') +        ctoken = util.default_multi_get(content, 1, 'response', 'continuationContents', 'commentSectionContinuation', 'continuations', 0, 'nextContinuationData', 'continuation', default='')          comments = []          for comment_raw in comments_raw: @@ -219,8 +220,8 @@ def parse_comments_polymer(content, replies=False):                  if 'replies' in comment_raw:                      #reply_ctoken = comment_raw['replies']['commentRepliesRenderer']['continuations'][0]['nextContinuationData']['continuation']                      #comment_id, video_id = get_ids(reply_ctoken) -                    replies_url = URL_ORIGIN + '/comments?parent_id=' + parent_id + "&video_id=" + video_id -                    view_replies_text = common.get_plain_text(comment_raw['replies']['commentRepliesRenderer']['moreText']) +                    replies_url = util.URL_ORIGIN + '/comments?parent_id=' + parent_id + "&video_id=" + video_id +                    view_replies_text = yt_data_extract.get_plain_text(comment_raw['replies']['commentRepliesRenderer']['moreText'])                      match = reply_count_regex.search(view_replies_text)                      if match is None:                          view_replies_text = '1 reply' @@ -228,24 +229,31 @@ def parse_comments_polymer(content, replies=False):                          view_replies_text = match.group(1) + " replies"                  elif not replies:                      view_replies_text = "Reply" -                    replies_url = URL_ORIGIN + '/post_comment?parent_id=' + parent_id + "&video_id=" + video_id +                    replies_url = util.URL_ORIGIN + '/post_comment?parent_id=' + parent_id + "&video_id=" + video_id                  comment_raw = comment_raw['comment']              comment_raw = comment_raw['commentRenderer']              comment = { -            'author': common.get_plain_text(comment_raw['authorText']), -            'author_url': comment_raw['authorEndpoint']['commandMetadata']['webCommandMetadata']['url'], -            'author_channel_id': comment_raw['authorEndpoint']['browseEndpoint']['browseId'], -            'author_id': comment_raw['authorId'], +            'author_id': comment_raw.get('authorId', ''),              'author_avatar': comment_raw['authorThumbnail']['thumbnails'][0]['url'],              'likes': comment_raw['likeCount'], -            'published': common.get_plain_text(comment_raw['publishedTimeText']), +            'published': yt_data_extract.get_plain_text(comment_raw['publishedTimeText']),              'text': comment_raw['contentText'].get('runs', ''),              'view_replies_text': view_replies_text,              'replies_url': replies_url,              'video_id': video_id,              'comment_id': comment_raw['commentId'],              } + +            if 'authorText' in comment_raw:     # deleted channels have no name or channel link +                comment['author'] = yt_data_extract.get_plain_text(comment_raw['authorText']) +                comment['author_url'] = comment_raw['authorEndpoint']['commandMetadata']['webCommandMetadata']['url'] +                comment['author_channel_id'] = comment_raw['authorEndpoint']['browseEndpoint']['browseId'] +            else: +                comment['author'] = '' +                comment['author_url'] = '' +                comment['author_channel_id'] = '' +              comments.append(comment)      except Exception as e:          print('Error parsing comments: ' + str(e)) @@ -264,13 +272,13 @@ def get_comments_html(comments):              replies = reply_link_template.substitute(url=comment['replies_url'], view_replies_text=html.escape(comment['view_replies_text']))          if settings.enable_comment_avatars:              avatar = comment_avatar_template.substitute( -                author_url = URL_ORIGIN + comment['author_url'], +                author_url = util.URL_ORIGIN + comment['author_url'],                  author_avatar = '/' + comment['author_avatar'],              )          else:              avatar = ''          if comment['author_channel_id'] in accounts.accounts: -            delete_url = (URL_ORIGIN + '/delete_comment?video_id=' +            delete_url = (util.URL_ORIGIN + '/delete_comment?video_id='                  + comment['video_id']                  + '&channel_id='+ comment['author_channel_id']                  + '&author_id=' + comment['author_id'] @@ -280,14 +288,14 @@ def get_comments_html(comments):          else:              action_buttons = '' -        permalink = URL_ORIGIN + '/watch?v=' + comment['video_id'] + '&lc=' + comment['comment_id'] +        permalink = util.URL_ORIGIN + '/watch?v=' + comment['video_id'] + '&lc=' + comment['comment_id']          html_result += comment_template.substitute(              author=comment['author'], -            author_url = URL_ORIGIN + comment['author_url'], +            author_url = util.URL_ORIGIN + comment['author_url'],              avatar = avatar,              likes = str(comment['likes']) + ' likes' if str(comment['likes']) != '0' else '',              published = comment['published'], -            text = format_text_runs(comment['text']), +            text = yt_data_extract.format_text_runs(comment['text']),              datetime = '',  #TODO              replies = replies,              action_buttons = action_buttons, @@ -297,10 +305,10 @@ def get_comments_html(comments):  def video_comments(video_id, sort=0, offset=0, lc='', secret_key=''):      if settings.enable_comments: -        post_comment_url = common.URL_ORIGIN + "/post_comment?video_id=" + video_id +        post_comment_url = util.URL_ORIGIN + "/post_comment?video_id=" + video_id          post_comment_link = '''<a class="sort-button" href="''' + post_comment_url + '''">Post comment</a>''' -        other_sort_url = common.URL_ORIGIN + '/comments?ctoken=' + make_comment_ctoken(video_id, sort=1 - sort, lc=lc) +        other_sort_url = util.URL_ORIGIN + '/comments?ctoken=' + make_comment_ctoken(video_id, sort=1 - sort, lc=lc)          other_sort_name = 'newest' if sort == 0 else 'top'          other_sort_link = '''<a class="sort-button" href="''' + other_sort_url + '''">Sort by ''' + other_sort_name + '''</a>''' @@ -314,7 +322,7 @@ def video_comments(video_id, sort=0, offset=0, lc='', secret_key=''):          if ctoken == '':              more_comments_button = ''          else: -            more_comments_button = more_comments_template.substitute(url = common.URL_ORIGIN + '/comments?ctoken=' + ctoken) +            more_comments_button = more_comments_template.substitute(url = util.URL_ORIGIN + '/comments?ctoken=' + ctoken)          result = '''<section class="comments-area">\n'''          result += comment_links + '\n' @@ -350,7 +358,7 @@ comment_box_template = Template('''          <select id="account-selection" name="channel_id">  $options          </select> -        <a href="''' + common.URL_ORIGIN + '''/login" target="_blank">Add account</a> +        <a href="''' + util.URL_ORIGIN + '''/login" target="_blank">Add account</a>      </div>      <textarea name="comment_text"></textarea>      $video_id_input @@ -359,7 +367,7 @@ $options  def get_comments_page(env, start_response):      start_response('200 OK',  [('Content-type','text/html'),] )      parameters = env['parameters'] -    ctoken = default_multi_get(parameters, 'ctoken', 0, default='') +    ctoken = util.default_multi_get(parameters, 'ctoken', 0, default='')      replies = False      if not ctoken:          video_id = parameters['video_id'][0] @@ -384,17 +392,17 @@ def get_comments_page(env, start_response):              page_number = page_number,              sort = 'top' if metadata['sort'] == 0 else 'newest',              title = html.escape(comment_info['video_title']), -            url = common.URL_ORIGIN + '/watch?v=' + metadata['video_id'], +            url = util.URL_ORIGIN + '/watch?v=' + metadata['video_id'],              thumbnail = '/i.ytimg.com/vi/'+ metadata['video_id'] + '/mqdefault.jpg',          )          comment_box = comment_box_template.substitute( -            form_action= common.URL_ORIGIN + '/post_comment', +            form_action= util.URL_ORIGIN + '/post_comment',              video_id_input='''<input type="hidden" name="video_id" value="''' + metadata['video_id'] + '''">''',              post_text='Post comment',              options=comment_box_account_options(),          ) -        other_sort_url = common.URL_ORIGIN + '/comments?ctoken=' + make_comment_ctoken(metadata['video_id'], sort=1 - metadata['sort']) +        other_sort_url = util.URL_ORIGIN + '/comments?ctoken=' + make_comment_ctoken(metadata['video_id'], sort=1 - metadata['sort'])          other_sort_name = 'newest' if metadata['sort'] == 0 else 'top'          other_sort_link = '''<a class="sort-button" href="''' + other_sort_url + '''">Sort by ''' + other_sort_name + '''</a>''' @@ -408,7 +416,7 @@ def get_comments_page(env, start_response):      if ctoken == '':          more_comments_button = ''      else: -        more_comments_button = more_comments_template.substitute(url = URL_ORIGIN + '/comments?ctoken=' + ctoken) +        more_comments_button = more_comments_template.substitute(url = util.URL_ORIGIN + '/comments?ctoken=' + ctoken)      comments_area = '<section class="comments-area">\n'      comments_area += video_metadata + comment_box + comment_links + '\n'      comments_area += '<div class="comments">\n' @@ -417,7 +425,7 @@ def get_comments_page(env, start_response):      comments_area += more_comments_button + '\n'      comments_area += '</section>\n'      return yt_comments_template.substitute( -        header = common.get_header(), +        header = html_common.get_header(),          comments_area = comments_area,          page_title = page_title,      ).encode('utf-8') diff --git a/youtube/common.py b/youtube/html_common.py index cb963ce..8e65a1f 100644 --- a/youtube/common.py +++ b/youtube/html_common.py @@ -1,46 +1,8 @@  from youtube.template import Template -from youtube import local_playlist -import settings -import html +from youtube import local_playlist, yt_data_extract, util +  import json -import re -import urllib.parse -import gzip -import brotli -import time -import socks, sockshandler - -URL_ORIGIN = "/https://www.youtube.com" - - -# videos (all of type str): - -# id -# title -# url -# author -# author_url -# thumbnail -# description -# published -# duration -# likes -# dislikes -# views -# playlist_index - -# playlists: - -# id -# title -# url -# author -# author_url -# thumbnail -# description -# updated -# size -# first_video_id +import html  with open('yt_basic_template.html', 'r', encoding='utf-8') as file: @@ -139,205 +101,8 @@ medium_channel_item_template = Template('''  ''') -class HTTPAsymmetricCookieProcessor(urllib.request.BaseHandler): -    '''Separate cookiejars for receiving and sending''' -    def __init__(self, cookiejar_send=None, cookiejar_receive=None): -        import http.cookiejar -        self.cookiejar_send = cookiejar_send -        self.cookiejar_receive = cookiejar_receive - -    def http_request(self, request): -        if self.cookiejar_send is not None: -            self.cookiejar_send.add_cookie_header(request) -        return request - -    def http_response(self, request, response): -        if self.cookiejar_receive is not None: -            self.cookiejar_receive.extract_cookies(response, request) -        return response - -    https_request = http_request -    https_response = http_response - - -def decode_content(content, encoding_header): -    encodings = encoding_header.replace(' ', '').split(',') -    for encoding in reversed(encodings): -        if encoding == 'identity': -            continue -        if encoding == 'br': -            content = brotli.decompress(content) -        elif encoding == 'gzip': -            content = gzip.decompress(content) -    return content - -def fetch_url(url, headers=(), timeout=15, report_text=None, data=None, cookiejar_send=None, cookiejar_receive=None, use_tor=True): -    ''' -    When cookiejar_send is set to a CookieJar object, -     those cookies will be sent in the request (but cookies in response will not be merged into it) -    When cookiejar_receive is set to a CookieJar object, -     cookies received in the response will be merged into the object (nothing will be sent from it) -    When both are set to the same object, cookies will be sent from the object, -     and response cookies will be merged into it. -    ''' -    headers = dict(headers)     # Note: Calling dict() on a dict will make a copy -    headers['Accept-Encoding'] = 'gzip, br' - -    # prevent python version being leaked by urllib if User-Agent isn't provided -    #  (urllib will use ex. Python-urllib/3.6 otherwise) -    if 'User-Agent' not in headers and 'user-agent' not in headers and 'User-agent' not in headers: -        headers['User-Agent'] = 'Python-urllib' - -    if data is not None: -        if isinstance(data, str): -            data = data.encode('ascii') -        elif not isinstance(data, bytes): -            data = urllib.parse.urlencode(data).encode('ascii') - -    start_time = time.time() - - -    req = urllib.request.Request(url, data=data, headers=headers) - -    cookie_processor = HTTPAsymmetricCookieProcessor(cookiejar_send=cookiejar_send, cookiejar_receive=cookiejar_receive) - -    if use_tor and settings.route_tor: -        opener = urllib.request.build_opener(sockshandler.SocksiPyHandler(socks.PROXY_TYPE_SOCKS5, "127.0.0.1", 9150), cookie_processor) -    else: -        opener = urllib.request.build_opener(cookie_processor) - -    response = opener.open(req, timeout=timeout) -    response_time = time.time() - - -    content = response.read() -    read_finish = time.time() -    if report_text: -        print(report_text, '    Latency:', round(response_time - start_time,3), '    Read time:', round(read_finish - response_time,3)) -    content = decode_content(content, response.getheader('Content-Encoding', default='identity')) -    return content - -mobile_user_agent = 'Mozilla/5.0 (iPhone; CPU iPhone OS 10_3_1 like Mac OS X) AppleWebKit/603.1.30 (KHTML, like Gecko) Version/10.0 Mobile/14E304 Safari/602.1' -mobile_ua = (('User-Agent', mobile_user_agent),) -desktop_user_agent = 'Mozilla/5.0 (Windows NT 6.1; rv:52.0) Gecko/20100101 Firefox/52.0' -desktop_ua = (('User-Agent', desktop_user_agent),) -def dict_add(*dicts): -    for dictionary in dicts[1:]: -        dicts[0].update(dictionary) -    return dicts[0] -def video_id(url): -    url_parts = urllib.parse.urlparse(url) -    return urllib.parse.parse_qs(url_parts.query)['v'][0] - -def uppercase_escape(s): -     return re.sub( -         r'\\U([0-9a-fA-F]{8})', -         lambda m: chr(int(m.group(1), base=16)), s) - -def default_multi_get(object, *keys, default): -    ''' Like dict.get(), but for nested dictionaries/sequences, supporting keys or indices. Last argument is the default value to use in case of any IndexErrors or KeyErrors ''' -    try: -        for key in keys: -            object = object[key] -        return object -    except (IndexError, KeyError): -        return default - -def get_plain_text(node): -    try: -        return html.escape(node['simpleText']) -    except KeyError: -        return unformmated_text_runs(node['runs']) -         -def unformmated_text_runs(runs): -    result = '' -    for text_run in runs: -        result += html.escape(text_run["text"]) -    return result - -def format_text_runs(runs): -    if isinstance(runs, str): -        return runs -    result = '' -    for text_run in runs: -        if text_run.get("bold", False): -            result += "<b>" + html.escape(text_run["text"]) + "</b>" -        elif text_run.get('italics', False): -            result += "<i>" + html.escape(text_run["text"]) + "</i>" -        else: -            result += html.escape(text_run["text"]) -    return result - -# default, sddefault, mqdefault, hqdefault, hq720 -def get_thumbnail_url(video_id): -    return "/i.ytimg.com/vi/" + video_id + "/mqdefault.jpg" -     -def seconds_to_timestamp(seconds): -    seconds = int(seconds) -    hours, seconds = divmod(seconds,3600) -    minutes, seconds = divmod(seconds,60) -    if hours != 0: -        timestamp = str(hours) + ":" -        timestamp += str(minutes).zfill(2)  # zfill pads with zeros -    else: -        timestamp = str(minutes) - -    timestamp += ":" + str(seconds).zfill(2) -    return timestamp - - -# ----- -# HTML -# ----- - -def small_video_item_html(item): -    video_info = json.dumps({key: item[key] for key in ('id', 'title', 'author', 'duration')}) -    return small_video_item_template.substitute( -        title       = html.escape(item["title"]), -        views       = item["views"], -        author      = html.escape(item["author"]), -        duration    = item["duration"], -        url         = URL_ORIGIN + "/watch?v=" + item["id"], -        thumbnail   = get_thumbnail_url(item['id']), -        video_info  = html.escape(video_info), -    ) - -def small_playlist_item_html(item): -    return small_playlist_item_template.substitute( -        title=html.escape(item["title"]), -        size = item['size'], -        author="", -        url = URL_ORIGIN + "/playlist?list=" + item["id"], -        thumbnail= get_thumbnail_url(item['first_video_id']), -    ) - -def medium_playlist_item_html(item): -    return medium_playlist_item_template.substitute( -        title=html.escape(item["title"]), -        size = item['size'], -        author=item['author'], -        author_url= URL_ORIGIN + item['author_url'], -        url = URL_ORIGIN + "/playlist?list=" + item["id"], -        thumbnail= item['thumbnail'], -    ) - -def medium_video_item_html(medium_video_info): -    info = medium_video_info -        -    return medium_video_item_template.substitute( -            title=html.escape(info["title"]), -            views=info["views"], -            published = info["published"], -            description = format_text_runs(info["description"]), -            author=html.escape(info["author"]), -            author_url=info["author_url"], -            duration=info["duration"], -            url = URL_ORIGIN + "/watch?v=" + info["id"], -            thumbnail=info['thumbnail'], -            datetime='', # TODO -        )  header_template = Template(''' @@ -440,158 +205,28 @@ def get_header(search_box_value=""): -def get_url(node): -    try: -        return node['runs'][0]['navigationEndpoint']['commandMetadata']['webCommandMetadata']['url'] -    except KeyError: -        return node['navigationEndpoint']['commandMetadata']['webCommandMetadata']['url'] -def get_text(node): -    try: -        return node['simpleText'] -    except KeyError: -            pass -    try: -        return node['runs'][0]['text'] -    except IndexError: # empty text runs -        return '' -def get_formatted_text(node): -    try: -        return node['runs'] -    except KeyError: -        return node['simpleText'] -def get_badges(node): -    badges = [] -    for badge_node in node: -        badge = badge_node['metadataBadgeRenderer']['label'] -        if badge.lower() != 'new': -            badges.append(badge) -    return badges -def get_thumbnail(node): -    try: -        return node['thumbnails'][0]['url']     # polymer format -    except KeyError: -        return node['url']     # ajax format - -dispatch = { - -# polymer format     -    'title':                ('title',       get_text), -    'publishedTimeText':    ('published',   get_text), -    'videoId':              ('id',          lambda node: node), -    'descriptionSnippet':   ('description', get_formatted_text), -    'lengthText':           ('duration',    get_text), -    'thumbnail':            ('thumbnail',   get_thumbnail), -    'thumbnails':           ('thumbnail',   lambda node: node[0]['thumbnails'][0]['url']), - -    'viewCountText':        ('views',       get_text), -    'numVideosText':        ('size',        lambda node: get_text(node).split(' ')[0]),     # the format is "324 videos" -    'videoCountText':       ('size',        get_text), -    'playlistId':           ('id',          lambda node: node), -    'descriptionText':      ('description', get_formatted_text), - -    'subscriberCountText':  ('subscriber_count',    get_text), -    'channelId':            ('id',          lambda node: node), -    'badges':               ('badges',      get_badges), - -# ajax format -    'view_count_text':  ('views',       get_text), -    'num_videos_text':  ('size',        lambda node: get_text(node).split(' ')[0]), -    'owner_text':       ('author',      get_text), -    'owner_endpoint':   ('author_url',  lambda node: node['url']), -    'description':      ('description', get_formatted_text), -    'index':            ('playlist_index', get_text), -    'short_byline':     ('author',      get_text), -    'length':           ('duration',    get_text), -    'video_id':         ('id',          lambda node: node), -} -def renderer_info(renderer): -    try: -        info = {} -        if 'viewCountText' in renderer:     # prefer this one as it contains all the digits -            info['views'] = get_text(renderer['viewCountText']) -        elif 'shortViewCountText' in renderer: -            info['views'] = get_text(renderer['shortViewCountText']) - -        if 'ownerText' in renderer: -            info['author'] = renderer['ownerText']['runs'][0]['text'] -            info['author_url'] = renderer['ownerText']['runs'][0]['navigationEndpoint']['commandMetadata']['webCommandMetadata']['url'] -        try: -            overlays = renderer['thumbnailOverlays'] -        except KeyError: -            pass -        else: -            for overlay in overlays: -                if 'thumbnailOverlayTimeStatusRenderer' in overlay: -                    info['duration'] = get_text(overlay['thumbnailOverlayTimeStatusRenderer']['text']) -                # show renderers don't have videoCountText -                elif 'thumbnailOverlayBottomPanelRenderer' in overlay: -                    info['size'] = get_text(overlay['thumbnailOverlayBottomPanelRenderer']['text']) - -        # show renderers don't have playlistId, have to dig into the url to get it -        try: -            info['id'] = renderer['navigationEndpoint']['watchEndpoint']['playlistId'] -        except KeyError: -            pass -        for key, node in renderer.items(): -            if key in ('longBylineText', 'shortBylineText'): -                info['author'] = get_text(node) -                try: -                    info['author_url'] = get_url(node) -                except KeyError: -                    pass - -            # show renderers don't have thumbnail key at top level, dig into thumbnailRenderer -            elif key == 'thumbnailRenderer' and 'showCustomThumbnailRenderer' in node: -                info['thumbnail'] = node['showCustomThumbnailRenderer']['thumbnail']['thumbnails'][0]['url'] -            else: -                try: -                    simple_key, function = dispatch[key] -                except KeyError: -                    continue -                info[simple_key] = function(node) -        return info -    except KeyError: -        print(renderer) -        raise -     -def ajax_info(item_json): -    try: -        info = {}           -        for key, node in item_json.items(): -            try: -                simple_key, function = dispatch[key] -            except KeyError: -                continue -            info[simple_key] = function(node) -        return info -    except KeyError: -        print(item_json) -        raise -     +  def badges_html(badges):      return ' | '.join(map(html.escape, badges)) - - -  html_transform_dispatch = {      'title':        html.escape,      'published':    html.escape,      'id':           html.escape, -    'description':  format_text_runs, +    'description':  yt_data_extract.format_text_runs,      'duration':     html.escape,      'thumbnail':    lambda url: html.escape('/' + url.lstrip('/')),      'size':         html.escape,      'author':       html.escape, -    'author_url':   lambda url: html.escape(URL_ORIGIN + url), +    'author_url':   lambda url: html.escape(util.URL_ORIGIN + url),      'views':        html.escape,      'subscriber_count': html.escape,      'badges':       badges_html, @@ -645,7 +280,7 @@ def video_item_html(item, template, html_exclude=set()):      html_ready = get_html_ready(item)      html_ready['video_info'] = html.escape(json.dumps(video_info) ) -    html_ready['url'] = URL_ORIGIN + "/watch?v=" + html_ready['id'] +    html_ready['url'] = util.URL_ORIGIN + "/watch?v=" + html_ready['id']      html_ready['datetime'] = '' #TODO      for key in html_exclude: @@ -658,7 +293,7 @@ def video_item_html(item, template, html_exclude=set()):  def playlist_item_html(item, template, html_exclude=set()):      html_ready = get_html_ready(item) -    html_ready['url'] = URL_ORIGIN + "/playlist?list=" + html_ready['id'] +    html_ready['url'] = util.URL_ORIGIN + "/playlist?list=" + html_ready['id']      html_ready['datetime'] = '' #TODO      for key in html_exclude: @@ -672,10 +307,6 @@ def playlist_item_html(item, template, html_exclude=set()): -def update_query_string(query_string, items): -    parameters = urllib.parse.parse_qs(query_string) -    parameters.update(items) -    return urllib.parse.urlencode(parameters, doseq=True)  page_button_template = Template('''<a class="page-button" href="$href">$page</a>''')  current_page_button_template = Template('''<div class="page-button">$page</div>''') @@ -694,7 +325,7 @@ def page_buttons_html(current_page, estimated_pages, url, current_query_string):              template = current_page_button_template          else:              template = page_button_template -        result += template.substitute(page=page, href = url + "?" + update_query_string(current_query_string, {'page': [str(page)]}) ) +        result += template.substitute(page=page, href = url + "?" + util.update_query_string(current_query_string, {'page': [str(page)]}) )      return result @@ -723,15 +354,15 @@ def renderer_html(renderer, additional_info={}, current_query_string=''):          return renderer_html(renderer['contents'][0], additional_info, current_query_string)      if type == 'channelRenderer': -        info = renderer_info(renderer) +        info = yt_data_extract.renderer_info(renderer)          html_ready = get_html_ready(info) -        html_ready['url'] = URL_ORIGIN + "/channel/" + html_ready['id'] +        html_ready['url'] = util.URL_ORIGIN + "/channel/" + html_ready['id']          return medium_channel_item_template.substitute(html_ready)      if type in ('movieRenderer', 'clarificationRenderer'):          return '' -    info = renderer_info(renderer) +    info = yt_data_extract.renderer_info(renderer)      info.update(additional_info)      html_exclude = set(additional_info.keys())      if type == 'compactVideoRenderer': @@ -745,4 +376,4 @@ def renderer_html(renderer, additional_info={}, current_query_string=''):      #print(renderer)      #raise NotImplementedError('Unknown renderer type: ' + type) -    return '' +    return ''
\ No newline at end of file diff --git a/youtube/local_playlist.py b/youtube/local_playlist.py index 0375040..e354013 100644 --- a/youtube/local_playlist.py +++ b/youtube/local_playlist.py @@ -1,11 +1,12 @@ +from youtube.template import Template +from youtube import util, html_common +import settings +  import os  import json -from youtube.template import Template -from youtube import common  import html  import gevent  import urllib -import settings  playlists_directory = os.path.join(settings.data_dir, "playlists")  thumbnails_directory = os.path.join(settings.data_dir, "playlist_thumbnails") @@ -38,7 +39,7 @@ def download_thumbnail(playlist_name, video_id):      url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg"      save_location = os.path.join(thumbnails_directory, playlist_name, video_id + ".jpg")      try: -        thumbnail = common.fetch_url(url, report_text="Saved local playlist thumbnail: " + video_id) +        thumbnail = util.fetch_url(url, report_text="Saved local playlist thumbnail: " + video_id)      except urllib.error.HTTPError as e:          print("Failed to download thumbnail for " + video_id + ": " + str(e))          return @@ -78,15 +79,15 @@ def get_local_playlist_page(name):              if info['id'] + ".jpg" in thumbnails:                  info['thumbnail'] = "/youtube.com/data/playlist_thumbnails/" + name + "/" + info['id'] + ".jpg"              else: -                info['thumbnail'] = common.get_thumbnail_url(info['id']) +                info['thumbnail'] = util.get_thumbnail_url(info['id'])                  missing_thumbnails.append(info['id']) -            videos_html += common.video_item_html(info, common.small_video_item_template) +            videos_html += html_common.video_item_html(info, html_common.small_video_item_template)          except json.decoder.JSONDecodeError:              pass      gevent.spawn(download_thumbnails, name, missing_thumbnails)      return local_playlist_template.substitute(          page_title = name + ' - Local playlist', -        header = common.get_header(), +        header = html_common.get_header(),          videos = videos_html,          title = name,          page_buttons = '' @@ -127,11 +128,11 @@ def get_playlists_list_page():      page = '''<ul>\n'''      list_item_template = Template('''    <li><a href="$url">$name</a></li>\n''')      for name in get_playlist_names(): -        page += list_item_template.substitute(url = html.escape(common.URL_ORIGIN + '/playlists/' + name), name = html.escape(name)) +        page += list_item_template.substitute(url = html.escape(util.URL_ORIGIN + '/playlists/' + name), name = html.escape(name))      page += '''</ul>\n''' -    return common.yt_basic_template.substitute( +    return html_common.yt_basic_template.substitute(          page_title = "Local playlists", -        header = common.get_header(), +        header = html_common.get_header(),          style = '',          page = page,      ) @@ -151,7 +152,7 @@ def path_edit_playlist(env, start_response):      if parameters['action'][0] == 'remove':          playlist_name = env['path_parts'][1]          remove_from_playlist(playlist_name, parameters['video_info_list']) -        start_response('303 See Other', [('Location', common.URL_ORIGIN + env['PATH_INFO']),] ) +        start_response('303 See Other', [('Location', util.URL_ORIGIN + env['PATH_INFO']),] )          return b''      else: diff --git a/youtube/playlist.py b/youtube/playlist.py index cc0da33..fbe6448 100644 --- a/youtube/playlist.py +++ b/youtube/playlist.py @@ -1,14 +1,14 @@ +from youtube import util, yt_data_extract, html_common, template, proto +  import base64 -import youtube.common as common  import urllib  import json -from string import Template -import youtube.proto as proto +import string  import gevent  import math  with open("yt_playlist_template.html", "r") as file: -    yt_playlist_template = Template(file.read()) +    yt_playlist_template = template.Template(file.read()) @@ -48,10 +48,10 @@ headers_1 = (  def playlist_first_page(playlist_id, report_text = "Retrieved playlist"):      url = 'https://m.youtube.com/playlist?list=' + playlist_id + '&pbj=1' -    content = common.fetch_url(url, common.mobile_ua + headers_1, report_text=report_text) +    content = util.fetch_url(url, util.mobile_ua + headers_1, report_text=report_text)      '''with open('debug/playlist_debug', 'wb') as f:          f.write(content)''' -    content = json.loads(common.uppercase_escape(content.decode('utf-8'))) +    content = json.loads(util.uppercase_escape(content.decode('utf-8')))      return content @@ -68,15 +68,15 @@ def get_videos(playlist_id, page):          'X-YouTube-Client-Version': '2.20180508',      } -    content = common.fetch_url(url, headers, report_text="Retrieved playlist") +    content = util.fetch_url(url, headers, report_text="Retrieved playlist")      '''with open('debug/playlist_debug', 'wb') as f:          f.write(content)''' -    info = json.loads(common.uppercase_escape(content.decode('utf-8'))) +    info = json.loads(util.uppercase_escape(content.decode('utf-8')))      return info -playlist_stat_template = Template(''' +playlist_stat_template = string.Template('''  <div>$stat</div>''')  def get_playlist_page(env, start_response):      start_response('200 OK', [('Content-type','text/html'),]) @@ -100,22 +100,22 @@ def get_playlist_page(env, start_response):          video_list = this_page_json['response']['continuationContents']['playlistVideoListContinuation']['contents']      videos_html = ''      for video_json in video_list: -        info = common.renderer_info(video_json['playlistVideoRenderer']) -        videos_html += common.video_item_html(info, common.small_video_item_template) +        info = yt_data_extract.renderer_info(video_json['playlistVideoRenderer']) +        videos_html += html_common.video_item_html(info, html_common.small_video_item_template) -    metadata = common.renderer_info(first_page_json['response']['header']['playlistHeaderRenderer']) +    metadata = yt_data_extract.renderer_info(first_page_json['response']['header']['playlistHeaderRenderer'])      video_count = int(metadata['size'].replace(',', '')) -    page_buttons = common.page_buttons_html(int(page), math.ceil(video_count/20), common.URL_ORIGIN + "/playlist", env['QUERY_STRING']) +    page_buttons = html_common.page_buttons_html(int(page), math.ceil(video_count/20), util.URL_ORIGIN + "/playlist", env['QUERY_STRING']) -    html_ready = common.get_html_ready(metadata) +    html_ready = html_common.get_html_ready(metadata)      html_ready['page_title'] = html_ready['title'] + ' - Page ' + str(page)      stats = ''      stats += playlist_stat_template.substitute(stat=html_ready['size'] + ' videos')      stats += playlist_stat_template.substitute(stat=html_ready['views'])      return yt_playlist_template.substitute( -        header          = common.get_header(), +        header          = html_common.get_header(),          videos          = videos_html,          page_buttons    = page_buttons,          stats = stats, diff --git a/youtube/post_comment.py b/youtube/post_comment.py index 92c45e1..876a1c0 100644 --- a/youtube/post_comment.py +++ b/youtube/post_comment.py @@ -1,11 +1,11 @@  # Contains functions having to do with posting/editing/deleting comments +from youtube import util, html_common, proto, comments, accounts +import settings  import urllib  import json -from youtube import common, proto, comments, accounts  import re  import traceback -import settings  import os  def _post_comment(text, video_id, session_token, cookiejar): @@ -31,7 +31,7 @@ def _post_comment(text, video_id, session_token, cookiejar):      data = urllib.parse.urlencode(data_dict).encode() -    content = common.fetch_url("https://m.youtube.com/service_ajax?name=createCommentEndpoint", headers=headers, data=data, cookiejar_send=cookiejar) +    content = util.fetch_url("https://m.youtube.com/service_ajax?name=createCommentEndpoint", headers=headers, data=data, cookiejar_send=cookiejar)      code = json.loads(content)['code']      print("Comment posting code: " + code) @@ -62,7 +62,7 @@ def _post_comment_reply(text, video_id, parent_comment_id, session_token, cookie      }      data = urllib.parse.urlencode(data_dict).encode() -    content = common.fetch_url("https://m.youtube.com/service_ajax?name=createCommentReplyEndpoint", headers=headers, data=data, cookiejar_send=cookiejar) +    content = util.fetch_url("https://m.youtube.com/service_ajax?name=createCommentReplyEndpoint", headers=headers, data=data, cookiejar_send=cookiejar)      code = json.loads(content)['code']      print("Comment posting code: " + code) @@ -90,7 +90,7 @@ def _delete_comment(video_id, comment_id, author_id, session_token, cookiejar):      }      data = urllib.parse.urlencode(data_dict).encode() -    content = common.fetch_url("https://m.youtube.com/service_ajax?name=performCommentActionEndpoint", headers=headers, data=data, cookiejar_send=cookiejar) +    content = util.fetch_url("https://m.youtube.com/service_ajax?name=performCommentActionEndpoint", headers=headers, data=data, cookiejar_send=cookiejar)      code = json.loads(content)['code']      print("Comment deletion code: " + code)      return code @@ -101,8 +101,8 @@ def get_session_token(video_id, cookiejar):      # youtube-dl uses disable_polymer=1 which uses a different request format which has an obfuscated javascript algorithm to generate a parameter called "bgr"      # Tokens retrieved from disable_polymer pages only work with that format. Tokens retrieved on mobile only work using mobile requests      # Additionally, tokens retrieved without sending the same cookie won't work. So this is necessary even if the bgr and stuff was reverse engineered. -    headers = {'User-Agent': common.mobile_user_agent} -    mobile_page = common.fetch_url('https://m.youtube.com/watch?v=' + video_id, headers, report_text="Retrieved session token for comment", cookiejar_send=cookiejar, cookiejar_receive=cookiejar).decode() +    headers = {'User-Agent': util.mobile_user_agent} +    mobile_page = util.fetch_url('https://m.youtube.com/watch?v=' + video_id, headers, report_text="Retrieved session token for comment", cookiejar_send=cookiejar, cookiejar_receive=cookiejar).decode()      match = xsrf_token_regex.search(mobile_page)      if match:          return match.group(1).replace("%3D", "=") @@ -118,9 +118,9 @@ def delete_comment(env, start_response):      code = _delete_comment(video_id, parameters['comment_id'][0], parameters['author_id'][0], token, cookiejar)      if code == "SUCCESS": -        start_response('303 See Other',  [('Location', common.URL_ORIGIN + '/comment_delete_success'),] ) +        start_response('303 See Other',  [('Location', util.URL_ORIGIN + '/comment_delete_success'),] )      else: -        start_response('303 See Other',  [('Location', common.URL_ORIGIN + '/comment_delete_fail'),] ) +        start_response('303 See Other',  [('Location', util.URL_ORIGIN + '/comment_delete_fail'),] )  def post_comment(env, start_response):      parameters = env['parameters'] @@ -131,11 +131,11 @@ def post_comment(env, start_response):      if 'parent_id' in parameters:          code = _post_comment_reply(parameters['comment_text'][0], parameters['video_id'][0], parameters['parent_id'][0], token, cookiejar) -        start_response('303 See Other',  (('Location', common.URL_ORIGIN + '/comments?' + env['QUERY_STRING']),) ) +        start_response('303 See Other',  (('Location', util.URL_ORIGIN + '/comments?' + env['QUERY_STRING']),) )      else:          code = _post_comment(parameters['comment_text'][0], parameters['video_id'][0], token, cookiejar) -        start_response('303 See Other',  (('Location', common.URL_ORIGIN + '/comments?ctoken=' + comments.make_comment_ctoken(video_id, sort=1)),) ) +        start_response('303 See Other',  (('Location', util.URL_ORIGIN + '/comments?ctoken=' + comments.make_comment_ctoken(video_id, sort=1)),) )      return b'' @@ -163,10 +163,10 @@ def get_delete_comment_page(env, start_response):      page += '''          <input type="submit" value="Yes, delete it">      </form>''' -    return common.yt_basic_template.substitute( +    return html_common.yt_basic_template.substitute(          page_title = "Delete comment?",          style = style, -        header = common.get_header(), +        header = html_common.get_header(),          page = page,      ).encode('utf-8') @@ -174,7 +174,7 @@ def get_post_comment_page(env, start_response):      start_response('200 OK', [('Content-type','text/html'),])      parameters = env['parameters']      video_id = parameters['video_id'][0] -    parent_id = common.default_multi_get(parameters, 'parent_id', 0, default='') +    parent_id = util.default_multi_get(parameters, 'parent_id', 0, default='')      style = ''' main{      display: grid; @@ -194,23 +194,23 @@ textarea{  }'''      if parent_id:   # comment reply          comment_box = comments.comment_box_template.substitute( -            form_action = common.URL_ORIGIN + '/comments?parent_id=' + parent_id + "&video_id=" + video_id, +            form_action = util.URL_ORIGIN + '/comments?parent_id=' + parent_id + "&video_id=" + video_id,              video_id_input = '',              post_text = "Post reply",              options=comments.comment_box_account_options(),          )      else:          comment_box = comments.comment_box_template.substitute( -            form_action = common.URL_ORIGIN + '/post_comment', +            form_action = util.URL_ORIGIN + '/post_comment',              video_id_input = '''<input type="hidden" name="video_id" value="''' + video_id + '''">''',              post_text = "Post comment",              options=comments.comment_box_account_options(),          )      page = '''<div class="left">\n''' + comment_box + '''</div>\n''' -    return common.yt_basic_template.substitute( +    return html_common.yt_basic_template.substitute(          page_title = "Post comment reply" if parent_id else "Post a comment",          style = style, -        header = common.get_header(), +        header = html_common.get_header(),          page = page,      ).encode('utf-8') diff --git a/youtube/proto.py b/youtube/proto.py index 004375a..d966455 100644 --- a/youtube/proto.py +++ b/youtube/proto.py @@ -60,7 +60,7 @@ def unpadded_b64encode(data):  def as_bytes(value):      if isinstance(value, str): -        return value.encode('ascii') +        return value.encode('utf-8')      return value diff --git a/youtube/search.py b/youtube/search.py index db65eaa..0cef0f3 100644 --- a/youtube/search.py +++ b/youtube/search.py @@ -1,11 +1,12 @@ +from youtube import util, html_common, yt_data_extract, proto +  import json  import urllib  import html  from string import Template  import base64  from math import ceil -from youtube.common import default_multi_get, get_thumbnail_url, URL_ORIGIN -from youtube import common, proto +  with open("yt_search_results_template.html", "r") as file:      yt_search_results_template = file.read() @@ -54,7 +55,7 @@ def get_search_json(query, page, autocorrect, sort, filters):          'X-YouTube-Client-Version': '2.20180418',      }      url += "&pbj=1&sp=" + page_number_to_sp_parameter(page, autocorrect, sort, filters).replace("=", "%3D") -    content = common.fetch_url(url, headers=headers, report_text="Got search results") +    content = util.fetch_url(url, headers=headers, report_text="Got search results")      info = json.loads(content)      return info @@ -70,9 +71,9 @@ def get_search_page(env, start_response):      start_response('200 OK', [('Content-type','text/html'),])      parameters = env['parameters']      if len(parameters) == 0: -        return common.yt_basic_template.substitute( +        return html_common.yt_basic_template.substitute(              page_title = "Search", -            header = common.get_header(), +            header = html_common.get_header(),              style = '',              page = '',          ).encode('utf-8') @@ -100,24 +101,24 @@ def get_search_page(env, start_response):              renderer = renderer[type]              corrected_query_string = parameters.copy()              corrected_query_string['query'] = [renderer['correctedQueryEndpoint']['searchEndpoint']['query']] -            corrected_query_url = URL_ORIGIN + '/search?' + urllib.parse.urlencode(corrected_query_string, doseq=True) +            corrected_query_url = util.URL_ORIGIN + '/search?' + urllib.parse.urlencode(corrected_query_string, doseq=True)              corrections = did_you_mean.substitute(                  corrected_query_url = corrected_query_url, -                corrected_query = common.format_text_runs(renderer['correctedQuery']['runs']), +                corrected_query = yt_data_extract.format_text_runs(renderer['correctedQuery']['runs']),              )              continue          if type == 'showingResultsForRenderer':              renderer = renderer[type]              no_autocorrect_query_string = parameters.copy()              no_autocorrect_query_string['autocorrect'] = ['0'] -            no_autocorrect_query_url = URL_ORIGIN + '/search?' + urllib.parse.urlencode(no_autocorrect_query_string, doseq=True) +            no_autocorrect_query_url = util.URL_ORIGIN + '/search?' + urllib.parse.urlencode(no_autocorrect_query_string, doseq=True)              corrections = showing_results_for.substitute( -                corrected_query = common.format_text_runs(renderer['correctedQuery']['runs']), +                corrected_query = yt_data_extract.format_text_runs(renderer['correctedQuery']['runs']),                  original_query_url = no_autocorrect_query_url,                  original_query = html.escape(renderer['originalQuery']['simpleText']),              )              continue -        result_list_html += common.renderer_html(renderer, current_query_string=env['QUERY_STRING']) +        result_list_html += html_common.renderer_html(renderer, current_query_string=env['QUERY_STRING'])      page = int(page)      if page <= 5: @@ -129,13 +130,13 @@ def get_search_page(env, start_response):      result = Template(yt_search_results_template).substitute( -        header              = common.get_header(query), +        header              = html_common.get_header(query),          results             = result_list_html,           page_title          = query + " - Search",           search_box_value    = html.escape(query),          number_of_results   = '{:,}'.format(estimated_results),          number_of_pages     = '{:,}'.format(estimated_pages), -        page_buttons        = common.page_buttons_html(page, estimated_pages, URL_ORIGIN + "/search", env['QUERY_STRING']), +        page_buttons        = html_common.page_buttons_html(page, estimated_pages, util.URL_ORIGIN + "/search", env['QUERY_STRING']),          corrections         = corrections          )      return result.encode('utf-8') diff --git a/youtube/subscriptions.py b/youtube/subscriptions.py index ff7d0df..0c7e8a5 100644 --- a/youtube/subscriptions.py +++ b/youtube/subscriptions.py @@ -1,4 +1,4 @@ -from youtube import common, channel +from youtube import util, yt_data_extract, html_common, channel  import settings  from string import Template  import sqlite3 @@ -169,7 +169,7 @@ def _get_upstream_videos(channel_id, time_last_checked):      content = response.read()      print('Retrieved videos for ' + channel_id) -    content = common.decode_content(content, response.getheader('Content-Encoding', default='identity')) +    content = util.decode_content(content, response.getheader('Content-Encoding', default='identity'))      feed = atoma.parse_atom_bytes(content) @@ -191,7 +191,7 @@ def _get_upstream_videos(channel_id, time_last_checked):      # Now check channel page to retrieve missing information for videos      json_channel_videos = channel.get_grid_items(channel.get_channel_tab(channel_id)[1]['response'])      for json_video in json_channel_videos: -        info = common.renderer_info(json_video['gridVideoRenderer']) +        info = yt_data_extract.renderer_info(json_video['gridVideoRenderer'])          if 'description' not in info:              info['description'] = ''          if info['id'] in atom_videos: @@ -205,12 +205,12 @@ def get_subscriptions_page(env, start_response):      items_html = '''<nav class="item-grid">\n'''      for item in _get_videos(30, 0): -        items_html += common.video_item_html(item, common.small_video_item_template) +        items_html += html_common.video_item_html(item, html_common.small_video_item_template)      items_html += '''\n</nav>'''      start_response('200 OK', [('Content-type','text/html'),])      return subscriptions_template.substitute( -        header = common.get_header(), +        header = html_common.get_header(),          items = items_html,          page_buttons = '',      ).encode('utf-8') @@ -243,7 +243,7 @@ def post_subscriptions_page(env, start_response):          finally:              connection.close() -        start_response('303 See Other', [('Location', common.URL_ORIGIN + '/subscriptions'),] ) +        start_response('303 See Other', [('Location', util.URL_ORIGIN + '/subscriptions'),] )          return b''      else:          start_response('400 Bad Request', ()) diff --git a/youtube/util.py b/youtube/util.py new file mode 100644 index 0000000..9950815 --- /dev/null +++ b/youtube/util.py @@ -0,0 +1,229 @@ +import settings +import socks, sockshandler +import gzip +import brotli +import urllib.parse +import re +import time + +# The trouble with the requests library: It ships its own certificate bundle via certifi +#  instead of using the system certificate store, meaning self-signed certificates +#  configured by the user will not work. Some draconian networks block TLS unless a corporate +#  certificate is installed on the system. Additionally, some users install a self signed cert +#  in order to use programs to modify or monitor requests made by programs on the system. + +# Finally, certificates expire and need to be updated, or are sometimes revoked. Sometimes +#  certificate authorites go rogue and need to be untrusted. Since we are going through Tor exit nodes, +#  this becomes all the more important. A rogue CA could issue a fake certificate for accounts.google.com, and a +#  malicious exit node could use this to decrypt traffic when logging in and retrieve passwords. Examples: +#   https://www.engadget.com/2015/10/29/google-warns-symantec-over-certificates/ +#   https://nakedsecurity.sophos.com/2013/12/09/serious-security-google-finds-fake-but-trusted-ssl-certificates-for-its-domains-made-in-france/ + +# In the requests documentation it says: +#    "Before version 2.16, Requests bundled a set of root CAs that it trusted, sourced from the Mozilla trust store. +#     The certificates were only updated once for each Requests version. When certifi was not installed, +#     this led to extremely out-of-date certificate bundles when using significantly older versions of Requests. +#     For the sake of security we recommend upgrading certifi frequently!" +#   (http://docs.python-requests.org/en/master/user/advanced/#ca-certificates) + +# Expecting users to remember to manually update certifi on Linux isn't reasonable in my view. +#  On windows, this is even worse since I am distributing all dependencies. This program is not +#  updated frequently, and using requests would lead to outdated certificates. Certificates +#  should be updated with OS updates, instead of thousands of developers of different programs +#  being expected to do this correctly 100% of the time. + +# There is hope that this might be fixed eventually: +#   https://github.com/kennethreitz/requests/issues/2966 + +# Until then, I will use a mix of urllib3 and urllib. +import urllib3 +import urllib3.contrib.socks + +URL_ORIGIN = "/https://www.youtube.com" + +connection_pool = urllib3.PoolManager(cert_reqs = 'CERT_REQUIRED') + +old_tor_connection_pool = None +tor_connection_pool = urllib3.contrib.socks.SOCKSProxyManager('socks5://127.0.0.1:9150/', cert_reqs = 'CERT_REQUIRED') + +tor_pool_refresh_time = time.monotonic()   # prevent problems due to clock changes + +def get_pool(use_tor): +    global old_tor_connection_pool +    global tor_connection_pool +    global tor_pool_refresh_time + +    if not use_tor: +        return connection_pool + +    # Tor changes circuits after 10 minutes: https://tor.stackexchange.com/questions/262/for-how-long-does-a-circuit-stay-alive +    current_time = time.monotonic() +    if current_time - tor_pool_refresh_time > 300:   # close pool after 5 minutes +        tor_connection_pool.clear() + +        # Keep a reference for 5 min to avoid it getting garbage collected while sockets still in use +        old_tor_connection_pool = tor_connection_pool + +        tor_connection_pool = urllib3.contrib.socks.SOCKSProxyManager('socks5://127.0.0.1:9150/', cert_reqs = 'CERT_REQUIRED') +        tor_pool_refresh_time = current_time + +    return tor_connection_pool + + + +class HTTPAsymmetricCookieProcessor(urllib.request.BaseHandler): +    '''Separate cookiejars for receiving and sending''' +    def __init__(self, cookiejar_send=None, cookiejar_receive=None): +        import http.cookiejar +        self.cookiejar_send = cookiejar_send +        self.cookiejar_receive = cookiejar_receive + +    def http_request(self, request): +        if self.cookiejar_send is not None: +            self.cookiejar_send.add_cookie_header(request) +        return request + +    def http_response(self, request, response): +        if self.cookiejar_receive is not None: +            self.cookiejar_receive.extract_cookies(response, request) +        return response + +    https_request = http_request +    https_response = http_response + + +def decode_content(content, encoding_header): +    encodings = encoding_header.replace(' ', '').split(',') +    for encoding in reversed(encodings): +        if encoding == 'identity': +            continue +        if encoding == 'br': +            content = brotli.decompress(content) +        elif encoding == 'gzip': +            content = gzip.decompress(content) +    return content + +def fetch_url(url, headers=(), timeout=15, report_text=None, data=None, cookiejar_send=None, cookiejar_receive=None, use_tor=True, return_response=False): +    ''' +    When cookiejar_send is set to a CookieJar object, +     those cookies will be sent in the request (but cookies in response will not be merged into it) +    When cookiejar_receive is set to a CookieJar object, +     cookies received in the response will be merged into the object (nothing will be sent from it) +    When both are set to the same object, cookies will be sent from the object, +     and response cookies will be merged into it. +    ''' +    headers = dict(headers)     # Note: Calling dict() on a dict will make a copy +    headers['Accept-Encoding'] = 'gzip, br' + +    # prevent python version being leaked by urllib if User-Agent isn't provided +    #  (urllib will use ex. Python-urllib/3.6 otherwise) +    if 'User-Agent' not in headers and 'user-agent' not in headers and 'User-agent' not in headers: +        headers['User-Agent'] = 'Python-urllib' + +    method = "GET" +    if data is not None: +        method = "POST" +        if isinstance(data, str): +            data = data.encode('ascii') +        elif not isinstance(data, bytes): +            data = urllib.parse.urlencode(data).encode('ascii') + +    start_time = time.time() + +    if cookiejar_send is not None or cookiejar_receive is not None:     # Use urllib +        req = urllib.request.Request(url, data=data, headers=headers) + +        cookie_processor = HTTPAsymmetricCookieProcessor(cookiejar_send=cookiejar_send, cookiejar_receive=cookiejar_receive) + +        if use_tor and settings.route_tor: +            opener = urllib.request.build_opener(sockshandler.SocksiPyHandler(socks.PROXY_TYPE_SOCKS5, "127.0.0.1", 9150), cookie_processor) +        else: +            opener = urllib.request.build_opener(cookie_processor) + +        response = opener.open(req, timeout=timeout) +        response_time = time.time() + + +        content = response.read() + +    else:           # Use a urllib3 pool. Cookies can't be used since urllib3 doesn't have easy support for them. +        pool = get_pool(use_tor and settings.route_tor) + +        response = pool.request(method, url, headers=headers, timeout=timeout, preload_content=False, decode_content=False) +        response_time = time.time() + +        content = response.read() +        response.release_conn() + +    read_finish = time.time() +    if report_text: +        print(report_text, '    Latency:', round(response_time - start_time,3), '    Read time:', round(read_finish - response_time,3)) +    content = decode_content(content, response.getheader('Content-Encoding', default='identity')) + +    if return_response: +        return content, response +    return content + +mobile_user_agent = 'Mozilla/5.0 (iPhone; CPU iPhone OS 10_3_1 like Mac OS X) AppleWebKit/603.1.30 (KHTML, like Gecko) Version/10.0 Mobile/14E304 Safari/602.1' +mobile_ua = (('User-Agent', mobile_user_agent),) +desktop_user_agent = 'Mozilla/5.0 (Windows NT 6.1; rv:52.0) Gecko/20100101 Firefox/52.0' +desktop_ua = (('User-Agent', desktop_user_agent),) + + + + + + + + + + +def dict_add(*dicts): +    for dictionary in dicts[1:]: +        dicts[0].update(dictionary) +    return dicts[0] + +def video_id(url): +    url_parts = urllib.parse.urlparse(url) +    return urllib.parse.parse_qs(url_parts.query)['v'][0] + +def default_multi_get(object, *keys, default): +    ''' Like dict.get(), but for nested dictionaries/sequences, supporting keys or indices. Last argument is the default value to use in case of any IndexErrors or KeyErrors ''' +    try: +        for key in keys: +            object = object[key] +        return object +    except (IndexError, KeyError): +        return default + + +# default, sddefault, mqdefault, hqdefault, hq720 +def get_thumbnail_url(video_id): +    return "/i.ytimg.com/vi/" + video_id + "/mqdefault.jpg" +     +def seconds_to_timestamp(seconds): +    seconds = int(seconds) +    hours, seconds = divmod(seconds,3600) +    minutes, seconds = divmod(seconds,60) +    if hours != 0: +        timestamp = str(hours) + ":" +        timestamp += str(minutes).zfill(2)  # zfill pads with zeros +    else: +        timestamp = str(minutes) + +    timestamp += ":" + str(seconds).zfill(2) +    return timestamp + + + +def update_query_string(query_string, items): +    parameters = urllib.parse.parse_qs(query_string) +    parameters.update(items) +    return urllib.parse.urlencode(parameters, doseq=True) + + + +def uppercase_escape(s): +     return re.sub( +         r'\\U([0-9a-fA-F]{8})', +         lambda m: chr(int(m.group(1), base=16)), s)
\ No newline at end of file diff --git a/youtube/watch.py b/youtube/watch.py index 04a5b5d..06b525a 100644 --- a/youtube/watch.py +++ b/youtube/watch.py @@ -1,12 +1,12 @@ +from youtube import util, html_common, comments +  from youtube_dl.YoutubeDL import YoutubeDL  from youtube_dl.extractor.youtube import YoutubeError  import json  import urllib  from string import Template  import html -import youtube.common as common -from youtube.common import default_multi_get, get_thumbnail_url, video_id, URL_ORIGIN -import youtube.comments as comments +  import gevent  import settings  import os @@ -127,9 +127,11 @@ def get_related_items_html(info):      result = ""      for item in info['related_vids']:          if 'list' in item:  # playlist: -            result += common.small_playlist_item_html(watch_page_related_playlist_info(item)) +            item = watch_page_related_playlist_info(item) +            result += html_common.playlist_item_html(item, html_common.small_playlist_item_template)          else: -            result += common.small_video_item_html(watch_page_related_video_info(item)) +            item = watch_page_related_video_info(item) +            result += html_common.video_item_html(item, html_common.small_video_item_template)      return result @@ -137,11 +139,12 @@ def get_related_items_html(info):  # converts these to standard names  def watch_page_related_video_info(item):      result = {key: item[key] for key in ('id', 'title', 'author')} -    result['duration'] = common.seconds_to_timestamp(item['length_seconds']) +    result['duration'] = util.seconds_to_timestamp(item['length_seconds'])      try:          result['views'] = item['short_view_count_text']      except KeyError:          result['views'] = '' +    result['thumbnail'] = util.get_thumbnail_url(item['id'])      return result  def watch_page_related_playlist_info(item): @@ -150,14 +153,15 @@ def watch_page_related_playlist_info(item):          'title': item['playlist_title'],          'id': item['list'],          'first_video_id': item['video_id'], +        'thumbnail': util.get_thumbnail_url(item['video_id']),      }  def sort_formats(info):      sorted_formats = info['formats'].copy() -    sorted_formats.sort(key=lambda x: default_multi_get(_formats, x['format_id'], 'height', default=0)) +    sorted_formats.sort(key=lambda x: util.default_multi_get(_formats, x['format_id'], 'height', default=0))      for index, format in enumerate(sorted_formats): -        if default_multi_get(_formats, format['format_id'], 'height', default=0) >= 360: +        if util.default_multi_get(_formats, format['format_id'], 'height', default=0) >= 360:              break      sorted_formats = sorted_formats[index:] + sorted_formats[0:index]      sorted_formats = [format for format in info['formats'] if format['acodec'] != 'none' and format['vcodec'] != 'none'] @@ -236,7 +240,7 @@ def get_watch_page(env, start_response):          start_response('200 OK', [('Content-type','text/html'),]) -        lc = common.default_multi_get(env['parameters'], 'lc', 0, default='') +        lc = util.default_multi_get(env['parameters'], 'lc', 0, default='')          if settings.route_tor:              proxy = 'socks5://127.0.0.1:9150/'          else: @@ -256,17 +260,17 @@ def get_watch_page(env, start_response):          #chosen_format = choose_format(info)          if isinstance(info, str): # youtube error -            return common.yt_basic_template.substitute( +            return html_common.yt_basic_template.substitute(                  page_title = "Error",                  style = "", -                header = common.get_header(), +                header = html_common.get_header(),                  page = html.escape(info),              ).encode('utf-8')          sorted_formats = sort_formats(info)          video_info = { -            "duration": common.seconds_to_timestamp(info["duration"]), +            "duration": util.seconds_to_timestamp(info["duration"]),              "id":       info['id'],              "title":    info['title'],              "author":   info['uploader'], @@ -338,7 +342,7 @@ def get_watch_page(env, start_response):          page = yt_watch_template.substitute(              video_title             = html.escape(info["title"]),              page_title              = html.escape(info["title"]), -            header                  = common.get_header(), +            header                  = html_common.get_header(),              uploader                = html.escape(info["uploader"]),              uploader_channel_url    = '/' + info["uploader_url"],              upload_date             = upload_date, diff --git a/youtube/youtube.py b/youtube/youtube.py index 288f68b..4ec7962 100644 --- a/youtube/youtube.py +++ b/youtube/youtube.py @@ -1,7 +1,7 @@  import mimetypes  import urllib.parse  import os -from youtube import local_playlist, watch, search, playlist, channel, comments, common, post_comment, accounts, subscriptions +from youtube import local_playlist, watch, search, playlist, channel, comments, post_comment, accounts, util, subscriptions  import settings  YOUTUBE_FILES = (      "/shared.css", @@ -68,7 +68,7 @@ def youtube(env, start_response):          elif path.startswith("/api/"):              start_response('200 OK',  [('Content-type', 'text/vtt'),] ) -            result = common.fetch_url('https://www.youtube.com' + path + ('?' + query_string if query_string else '')) +            result = util.fetch_url('https://www.youtube.com' + path + ('?' + query_string if query_string else ''))              result = result.replace(b"align:start position:0%", b"")              return result diff --git a/youtube/yt_data_extract.py b/youtube/yt_data_extract.py new file mode 100644 index 0000000..5483911 --- /dev/null +++ b/youtube/yt_data_extract.py @@ -0,0 +1,205 @@ +import html + +# videos (all of type str): + +# id +# title +# url +# author +# author_url +# thumbnail +# description +# published +# duration +# likes +# dislikes +# views +# playlist_index + +# playlists: + +# id +# title +# url +# author +# author_url +# thumbnail +# description +# updated +# size +# first_video_id + + + + + + + +def get_plain_text(node): +    try: +        return html.escape(node['simpleText']) +    except KeyError: +        return unformmated_text_runs(node['runs']) +         +def unformmated_text_runs(runs): +    result = '' +    for text_run in runs: +        result += html.escape(text_run["text"]) +    return result + +def format_text_runs(runs): +    if isinstance(runs, str): +        return runs +    result = '' +    for text_run in runs: +        if text_run.get("bold", False): +            result += "<b>" + html.escape(text_run["text"]) + "</b>" +        elif text_run.get('italics', False): +            result += "<i>" + html.escape(text_run["text"]) + "</i>" +        else: +            result += html.escape(text_run["text"]) +    return result + + + + + + + + +def get_url(node): +    try: +        return node['runs'][0]['navigationEndpoint']['commandMetadata']['webCommandMetadata']['url'] +    except KeyError: +        return node['navigationEndpoint']['commandMetadata']['webCommandMetadata']['url'] + + +def get_text(node): +    try: +        return node['simpleText'] +    except KeyError: +            pass +    try: +        return node['runs'][0]['text'] +    except IndexError: # empty text runs +        return '' + +def get_formatted_text(node): +    try: +        return node['runs'] +    except KeyError: +        return node['simpleText'] + +def get_badges(node): +    badges = [] +    for badge_node in node: +        badge = badge_node['metadataBadgeRenderer']['label'] +        if badge.lower() != 'new': +            badges.append(badge) +    return badges + +def get_thumbnail(node): +    try: +        return node['thumbnails'][0]['url']     # polymer format +    except KeyError: +        return node['url']     # ajax format + +dispatch = { + +# polymer format     +    'title':                ('title',       get_text), +    'publishedTimeText':    ('published',   get_text), +    'videoId':              ('id',          lambda node: node), +    'descriptionSnippet':   ('description', get_formatted_text), +    'lengthText':           ('duration',    get_text), +    'thumbnail':            ('thumbnail',   get_thumbnail), +    'thumbnails':           ('thumbnail',   lambda node: node[0]['thumbnails'][0]['url']), + +    'viewCountText':        ('views',       get_text), +    'numVideosText':        ('size',        lambda node: get_text(node).split(' ')[0]),     # the format is "324 videos" +    'videoCountText':       ('size',        get_text), +    'playlistId':           ('id',          lambda node: node), +    'descriptionText':      ('description', get_formatted_text), + +    'subscriberCountText':  ('subscriber_count',    get_text), +    'channelId':            ('id',          lambda node: node), +    'badges':               ('badges',      get_badges), + +# ajax format +    'view_count_text':  ('views',       get_text), +    'num_videos_text':  ('size',        lambda node: get_text(node).split(' ')[0]), +    'owner_text':       ('author',      get_text), +    'owner_endpoint':   ('author_url',  lambda node: node['url']), +    'description':      ('description', get_formatted_text), +    'index':            ('playlist_index', get_text), +    'short_byline':     ('author',      get_text), +    'length':           ('duration',    get_text), +    'video_id':         ('id',          lambda node: node), + +} + +def renderer_info(renderer): +    try: +        info = {} +        if 'viewCountText' in renderer:     # prefer this one as it contains all the digits +            info['views'] = get_text(renderer['viewCountText']) +        elif 'shortViewCountText' in renderer: +            info['views'] = get_text(renderer['shortViewCountText']) + +        if 'ownerText' in renderer: +            info['author'] = renderer['ownerText']['runs'][0]['text'] +            info['author_url'] = renderer['ownerText']['runs'][0]['navigationEndpoint']['commandMetadata']['webCommandMetadata']['url'] +        try: +            overlays = renderer['thumbnailOverlays'] +        except KeyError: +            pass +        else: +            for overlay in overlays: +                if 'thumbnailOverlayTimeStatusRenderer' in overlay: +                    info['duration'] = get_text(overlay['thumbnailOverlayTimeStatusRenderer']['text']) +                # show renderers don't have videoCountText +                elif 'thumbnailOverlayBottomPanelRenderer' in overlay: +                    info['size'] = get_text(overlay['thumbnailOverlayBottomPanelRenderer']['text']) + +        # show renderers don't have playlistId, have to dig into the url to get it +        try: +            info['id'] = renderer['navigationEndpoint']['watchEndpoint']['playlistId'] +        except KeyError: +            pass +        for key, node in renderer.items(): +            if key in ('longBylineText', 'shortBylineText'): +                info['author'] = get_text(node) +                try: +                    info['author_url'] = get_url(node) +                except KeyError: +                    pass + +            # show renderers don't have thumbnail key at top level, dig into thumbnailRenderer +            elif key == 'thumbnailRenderer' and 'showCustomThumbnailRenderer' in node: +                info['thumbnail'] = node['showCustomThumbnailRenderer']['thumbnail']['thumbnails'][0]['url'] +            else: +                try: +                    simple_key, function = dispatch[key] +                except KeyError: +                    continue +                info[simple_key] = function(node) +        return info +    except KeyError: +        print(renderer) +        raise +     +def ajax_info(item_json): +    try: +        info = {}           +        for key, node in item_json.items(): +            try: +                simple_key, function = dispatch[key] +            except KeyError: +                continue +            info[simple_key] = function(node) +        return info +    except KeyError: +        print(item_json) +        raise +     + diff --git a/youtube_dl/extractor/youtube.py b/youtube_dl/extractor/youtube.py index 4fab4e0..52c8731 100644 --- a/youtube_dl/extractor/youtube.py +++ b/youtube_dl/extractor/youtube.py @@ -1712,7 +1712,8 @@ class YoutubeIE(YoutubeBaseInfoExtractor):                          view_count = extract_view_count(get_video_info)                      if not video_info:                          video_info = get_video_info -                    if 'token' in get_video_info: +                    get_token = get_video_info.get('token') or get_video_info.get('account_playback_token') +                    if get_token:                          # Different get_video_info requests may report different results, e.g.                          # some may report video unavailability, but some may serve it without                          # any complaint (see https://github.com/rg3/youtube-dl/issues/7362, @@ -1722,7 +1723,8 @@ class YoutubeIE(YoutubeBaseInfoExtractor):                          # due to YouTube measures against IP ranges of hosting providers.                          # Working around by preferring the first succeeded video_info containing                          # the token if no such video_info yet was found. -                        if 'token' not in video_info: +                        token = video_info.get('token') or video_info.get('account_playback_token') +                        if not token:                              video_info = get_video_info                          break @@ -1731,7 +1733,8 @@ class YoutubeIE(YoutubeBaseInfoExtractor):                  r'(?s)<h1[^>]+id="unavailable-message"[^>]*>(.+?)</h1>',                  video_webpage, 'unavailable message', default=None) -        if 'token' not in video_info: +        token = video_info.get('token') or video_info.get('account_playback_token') +        if not token:              if 'reason' in video_info:                  if 'The uploader has not made this video available in your country.' in video_info['reason']:                      regions_allowed = self._html_search_meta( diff --git a/youtube_dl/extractor/youtube_unmodified_reference.py b/youtube_dl/extractor/youtube_unmodified_reference.py index c8bf98b..c12c417 100644 --- a/youtube_dl/extractor/youtube_unmodified_reference.py +++ b/youtube_dl/extractor/youtube_unmodified_reference.py @@ -1648,7 +1648,8 @@ class YoutubeIE(YoutubeBaseInfoExtractor):                          view_count = extract_view_count(get_video_info)                      if not video_info:                          video_info = get_video_info -                    if 'token' in get_video_info: +                    get_token = get_video_info.get('token') or get_video_info.get('account_playback_token') +                    if get_token:                          # Different get_video_info requests may report different results, e.g.                          # some may report video unavailability, but some may serve it without                          # any complaint (see https://github.com/rg3/youtube-dl/issues/7362, @@ -1658,7 +1659,8 @@ class YoutubeIE(YoutubeBaseInfoExtractor):                          # due to YouTube measures against IP ranges of hosting providers.                          # Working around by preferring the first succeeded video_info containing                          # the token if no such video_info yet was found. -                        if 'token' not in video_info: +                        token = video_info.get('token') or video_info.get('account_playback_token') +                        if not token:                              video_info = get_video_info                          break @@ -1667,7 +1669,8 @@ class YoutubeIE(YoutubeBaseInfoExtractor):                  r'(?s)<h1[^>]+id="unavailable-message"[^>]*>(.+?)</h1>',                  video_webpage, 'unavailable message', default=None) -        if 'token' not in video_info: +        token = video_info.get('token') or video_info.get('account_playback_token') +        if not token:              if 'reason' in video_info:                  if 'The uploader has not made this video available in your country.' in video_info['reason']:                      regions_allowed = self._html_search_meta( | 
