diff options
author | Jesús <heckyel@hyperbola.info> | 2021-10-18 15:24:21 -0500 |
---|---|---|
committer | Jesús <heckyel@hyperbola.info> | 2021-10-18 15:24:21 -0500 |
commit | 5122028a4bcac4ae577ef7fbd55ccad5cb34ef5e (patch) | |
tree | 65209bc739db35e31f1c9b5b868eb5df4fe12ae3 /hypervideo_dl/utils.py | |
parent | 27fe903c511691c078942bef5ee9a05a43b15c8f (diff) | |
download | hypervideo-5122028a4bcac4ae577ef7fbd55ccad5cb34ef5e.tar.lz hypervideo-5122028a4bcac4ae577ef7fbd55ccad5cb34ef5e.tar.xz hypervideo-5122028a4bcac4ae577ef7fbd55ccad5cb34ef5e.zip |
update from upstream
Diffstat (limited to 'hypervideo_dl/utils.py')
-rw-r--r-- | hypervideo_dl/utils.py | 1053 |
1 files changed, 880 insertions, 173 deletions
diff --git a/hypervideo_dl/utils.py b/hypervideo_dl/utils.py index fc62f09..0199f4c 100644 --- a/hypervideo_dl/utils.py +++ b/hypervideo_dl/utils.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # coding: utf-8 from __future__ import unicode_literals @@ -16,6 +16,9 @@ import email.header import errno import functools import gzip +import hashlib +import hmac +import importlib.util import io import itertools import json @@ -50,6 +53,7 @@ from .compat import ( compat_html_entities_html5, compat_http_client, compat_integer_types, + compat_numeric_types, compat_kwargs, compat_os_name, compat_parse_qs, @@ -61,6 +65,9 @@ from .compat import ( compat_urllib_parse, compat_urllib_parse_urlencode, compat_urllib_parse_urlparse, + compat_urllib_parse_urlunparse, + compat_urllib_parse_quote, + compat_urllib_parse_quote_plus, compat_urllib_parse_unquote_plus, compat_urllib_request, compat_urlparse, @@ -1735,12 +1742,16 @@ DATE_FORMATS = ( '%b %dth %Y %I:%M', '%Y %m %d', '%Y-%m-%d', + '%Y.%m.%d.', '%Y/%m/%d', '%Y/%m/%d %H:%M', '%Y/%m/%d %H:%M:%S', + '%Y%m%d%H%M', + '%Y%m%d%H%M%S', '%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S:%f', '%d.%m.%Y %H:%M', '%d.%m.%Y %H.%M', '%Y-%m-%dT%H:%M:%SZ', @@ -1753,6 +1764,7 @@ DATE_FORMATS = ( '%b %d %Y at %H:%M:%S', '%B %d %Y at %H:%M', '%B %d %Y at %H:%M:%S', + '%H:%M %d-%b-%Y', ) DATE_FORMATS_DAY_FIRST = list(DATE_FORMATS) @@ -1985,6 +1997,7 @@ def get_elements_by_attribute(attribute, value, html, escape_value=True): class HTMLAttributeParser(compat_HTMLParser): """Trivial HTML parser to gather the attributes for a single element""" + def __init__(self): self.attrs = {} compat_HTMLParser.__init__(self) @@ -2086,7 +2099,9 @@ def sanitize_filename(s, restricted=False, is_id=False): def replace_insane(char): if restricted and char in ACCENT_CHARS: return ACCENT_CHARS[char] - if char == '?' or ord(char) < 32 or ord(char) == 127: + elif not restricted and char == '\n': + return ' ' + elif char == '?' or ord(char) < 32 or ord(char) == 127: return '' elif char == '"': return '' if restricted else '\'' @@ -2100,6 +2115,8 @@ def sanitize_filename(s, restricted=False, is_id=False): return '_' return char + if s == '': + return '' # Handle timestamps s = re.sub(r'[0-9]+(?::[0-9]+)+', lambda m: m.group(0).replace(':', '_'), s) result = ''.join(map(replace_insane, s)) @@ -2118,13 +2135,18 @@ def sanitize_filename(s, restricted=False, is_id=False): return result -def sanitize_path(s): +def sanitize_path(s, force=False): """Sanitizes and normalizes path on Windows""" - if sys.platform != 'win32': + if sys.platform == 'win32': + force = False + drive_or_unc, _ = os.path.splitdrive(s) + if sys.version_info < (2, 7) and not drive_or_unc: + drive_or_unc, _ = os.path.splitunc(s) + elif force: + drive_or_unc = '' + else: return s - drive_or_unc, _ = os.path.splitdrive(s) - if sys.version_info < (2, 7) and not drive_or_unc: - drive_or_unc, _ = os.path.splitunc(s) + norm_path = os.path.normpath(remove_start(s, drive_or_unc)).split(os.path.sep) if drive_or_unc: norm_path.pop(0) @@ -2133,6 +2155,8 @@ def sanitize_path(s): for path_part in norm_path] if drive_or_unc: sanitized_path.insert(0, drive_or_unc + os.path.sep) + elif force and s[0] == os.path.sep: + sanitized_path.insert(0, os.path.sep) return os.path.join(*sanitized_path) @@ -2154,8 +2178,24 @@ def sanitize_url(url): return url +def extract_basic_auth(url): + parts = compat_urlparse.urlsplit(url) + if parts.username is None: + return url, None + url = compat_urlparse.urlunsplit(parts._replace(netloc=( + parts.hostname if parts.port is None + else '%s:%d' % (parts.hostname, parts.port)))) + auth_payload = base64.b64encode( + ('%s:%s' % (parts.username, parts.password or '')).encode('utf-8')) + return url, 'Basic ' + auth_payload.decode('utf-8') + + def sanitized_Request(url, *args, **kwargs): - return compat_urllib_request.Request(sanitize_url(url), *args, **kwargs) + url, auth_header = extract_basic_auth(escape_url(sanitize_url(url))) + if auth_header is not None: + headers = args[1] if len(args) >= 2 else kwargs.setdefault('headers', {}) + headers['Authorization'] = auth_header + return compat_urllib_request.Request(url, *args, **kwargs) def expand_path(s): @@ -2212,6 +2252,26 @@ def unescapeHTML(s): r'&([^&;]+;)', lambda m: _htmlentity_transform(m.group(1)), s) +def escapeHTML(text): + return ( + text + .replace('&', '&') + .replace('<', '<') + .replace('>', '>') + .replace('"', '"') + .replace("'", ''') + ) + + +def process_communicate_or_kill(p, *args, **kwargs): + try: + return p.communicate(*args, **kwargs) + except BaseException: # Including KeyboardInterrupt + p.kill() + p.wait() + raise + + def get_subprocess_encoding(): if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5: # For subprocess calls, encode with locale encoding @@ -2282,49 +2342,68 @@ def decodeOption(optval): return optval -def formatSeconds(secs): +def formatSeconds(secs, delim=':', msec=False): if secs > 3600: - return '%d:%02d:%02d' % (secs // 3600, (secs % 3600) // 60, secs % 60) + ret = '%d%s%02d%s%02d' % (secs // 3600, delim, (secs % 3600) // 60, delim, secs % 60) elif secs > 60: - return '%d:%02d' % (secs // 60, secs % 60) + ret = '%d%s%02d' % (secs // 60, delim, secs % 60) else: - return '%d' % secs + ret = '%d' % secs + return '%s.%03d' % (ret, secs % 1) if msec else ret -def make_HTTPS_handler(params, **kwargs): - opts_no_check_certificate = params.get('nocheckcertificate', False) - if hasattr(ssl, 'create_default_context'): # Python >= 3.4 or 2.7.9 - context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - if opts_no_check_certificate: - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE +def _ssl_load_windows_store_certs(ssl_context, storename): + # Code adapted from _load_windows_store_certs in https://github.com/python/cpython/blob/main/Lib/ssl.py + try: + certs = [cert for cert, encoding, trust in ssl.enum_certificates(storename) + if encoding == 'x509_asn' and ( + trust is True or ssl.Purpose.SERVER_AUTH.oid in trust)] + except PermissionError: + return + for cert in certs: try: - return YoutubeDLHTTPSHandler(params, context=context, **kwargs) - except TypeError: - # Python 2.7.8 - # (create_default_context present but HTTPSHandler has no context=) + ssl_context.load_verify_locations(cadata=cert) + except ssl.SSLError: pass - if sys.version_info < (3, 2): - return YoutubeDLHTTPSHandler(params, **kwargs) - else: # Python < 3.4 - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = (ssl.CERT_NONE - if opts_no_check_certificate - else ssl.CERT_REQUIRED) - context.set_default_verify_paths() - return YoutubeDLHTTPSHandler(params, context=context, **kwargs) - -def bug_reports_message(): +def make_HTTPS_handler(params, **kwargs): + opts_check_certificate = not params.get('nocheckcertificate') + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = opts_check_certificate + context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE + if opts_check_certificate: + try: + context.load_default_certs() + # Work around the issue in load_default_certs when there are bad certificates. See: + # https://github.com/hypervideo/hypervideo/issues/1060, + # https://bugs.python.org/issue35665, https://bugs.python.org/issue45312 + except ssl.SSLError: + # enum_certificates is not present in mingw python. See https://github.com/hypervideo/hypervideo/issues/1151 + if sys.platform == 'win32' and hasattr(ssl, 'enum_certificates'): + # Create a new context to discard any certificates that were already loaded + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname, context.verify_mode = True, ssl.CERT_REQUIRED + for storename in ('CA', 'ROOT'): + _ssl_load_windows_store_certs(context, storename) + context.set_default_verify_paths() + return YoutubeDLHTTPSHandler(params, context=context, **kwargs) + + +def bug_reports_message(before=';'): if ytdl_is_updateable(): update_cmd = 'type doas pacman -Sy hypervideo to update' else: - update_cmd = 'see https://yt-dl.org/update on how to update' - msg = '; please report this issue on https://yt-dl.org/bug .' + update_cmd = 'see https://git.conocimientoslibres.ga/software/hypervideo.git/about/#how-do-i-update-hypervideo' + msg = 'please report this issue on https://github.com/hypervideo/hypervideo .' msg += ' Make sure you are using the latest version; %s.' % update_cmd msg += ' Be sure to call hypervideo with the --verbose flag and include its complete output.' - return msg + + before = before.rstrip() + if not before or before.endswith(('.', '!', '?')): + msg = msg[0].title() + msg[1:] + + return (before + ' ' if before else '') + msg class YoutubeDLError(Exception): @@ -2332,28 +2411,36 @@ class YoutubeDLError(Exception): pass +network_exceptions = [compat_urllib_error.URLError, compat_http_client.HTTPException, socket.error] +if hasattr(ssl, 'CertificateError'): + network_exceptions.append(ssl.CertificateError) +network_exceptions = tuple(network_exceptions) + + class ExtractorError(YoutubeDLError): """Error during info extraction.""" - def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None): + def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None, ie=None): """ tb, if given, is the original traceback (so that it can be printed out). If expected is set, this is a normal error message and most likely not a bug in hypervideo. """ - - if sys.exc_info()[0] in (compat_urllib_error.URLError, socket.timeout, UnavailableVideoError): + if sys.exc_info()[0] in network_exceptions: expected = True - if video_id is not None: - msg = video_id + ': ' + msg - if cause: - msg += ' (caused by %r)' % cause - if not expected: - msg += bug_reports_message() - super(ExtractorError, self).__init__(msg) + self.msg = str(msg) self.traceback = tb - self.exc_info = sys.exc_info() # preserve original exception + self.expected = expected self.cause = cause self.video_id = video_id + self.ie = ie + self.exc_info = sys.exc_info() # preserve original exception + + super(ExtractorError, self).__init__(''.join(( + format_field(ie, template='[%s] '), + format_field(video_id, template='%s: '), + self.msg, + format_field(cause, template=' (caused by %r)'), + '' if expected else bug_reports_message()))) def format_traceback(self): if self.traceback is None: @@ -2379,6 +2466,7 @@ class GeoRestrictedError(ExtractorError): This exception may be thrown when a video is not available from your geographic location due to geographic restrictions imposed by a website. """ + def __init__(self, msg, countries=None): super(GeoRestrictedError, self).__init__(msg, expected=True) self.msg = msg @@ -2399,6 +2487,15 @@ class DownloadError(YoutubeDLError): self.exc_info = exc_info +class EntryNotInPlaylist(YoutubeDLError): + """Entry not in playlist exception. + + This exception will be thrown by YoutubeDL when a requested entry + is not found in the playlist info_dict + """ + pass + + class SameFileError(YoutubeDLError): """Same File exception. @@ -2420,6 +2517,21 @@ class PostProcessingError(YoutubeDLError): self.msg = msg +class ExistingVideoReached(YoutubeDLError): + """ --max-downloads limit has been reached. """ + pass + + +class RejectedVideoReached(YoutubeDLError): + """ --max-downloads limit has been reached. """ + pass + + +class ThrottledDownload(YoutubeDLError): + """ Download speed below --throttled-rate. """ + pass + + class MaxDownloadsReached(YoutubeDLError): """ --max-downloads limit has been reached. """ pass @@ -2582,6 +2694,8 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler): @staticmethod def deflate(data): + if not data: + return data try: return zlib.decompress(data, -zlib.MAX_WBITS) except zlib.error: @@ -2938,8 +3052,16 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler): def extract_timezone(date_str): m = re.search( - r'^.{8,}?(?P<tz>Z$| ?(?P<sign>\+|-)(?P<hours>[0-9]{2}):?(?P<minutes>[0-9]{2})$)', - date_str) + r'''(?x) + ^.{8,}? # >=8 char non-TZ prefix, if present + (?P<tz>Z| # just the UTC Z, or + (?:(?<=.\b\d{4}|\b\d{2}:\d\d)| # preceded by 4 digits or hh:mm or + (?<!.\b[a-zA-Z]{3}|[a-zA-Z]{4}|..\b\d\d)) # not preceded by 3 alpha word or >= 4 alpha or 2 digits + [ ]? # optional space + (?P<sign>\+|-) # +/- + (?P<hours>[0-9]{2}):?(?P<minutes>[0-9]{2}) # hh[:]mm + $) + ''', date_str) if not m: timezone = datetime.timedelta() else: @@ -3055,33 +3177,83 @@ def subtitles_filename(filename, sub_lang, sub_format, expected_real_ext=None): return replace_extension(filename, sub_lang + '.' + sub_format, expected_real_ext) -def date_from_str(date_str): +def datetime_from_str(date_str, precision='auto', format='%Y%m%d'): """ Return a datetime object from a string in the format YYYYMMDD or - (now|today)[+-][0-9](day|week|month|year)(s)?""" - today = datetime.date.today() + (now|today|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + + format: string date format used to return datetime object from + precision: round the time portion of a datetime object. + auto|microsecond|second|minute|hour|day. + auto: round to the unit provided in date_str (if applicable). + """ + auto_precision = False + if precision == 'auto': + auto_precision = True + precision = 'microsecond' + today = datetime_round(datetime.datetime.now(), precision) if date_str in ('now', 'today'): return today if date_str == 'yesterday': return today - datetime.timedelta(days=1) - match = re.match(r'(now|today)(?P<sign>[+-])(?P<time>\d+)(?P<unit>day|week|month|year)(s)?', date_str) + match = re.match( + r'(?P<start>.+)(?P<sign>[+-])(?P<time>\d+)(?P<unit>microsecond|second|minute|hour|day|week|month|year)(s)?', + date_str) if match is not None: - sign = match.group('sign') - time = int(match.group('time')) - if sign == '-': - time = -time + start_time = datetime_from_str(match.group('start'), precision, format) + time = int(match.group('time')) * (-1 if match.group('sign') == '-' else 1) unit = match.group('unit') - # A bad approximation? - if unit == 'month': + if unit == 'month' or unit == 'year': + new_date = datetime_add_months(start_time, time * 12 if unit == 'year' else time) unit = 'day' - time *= 30 - elif unit == 'year': - unit = 'day' - time *= 365 - unit += 's' - delta = datetime.timedelta(**{unit: time}) - return today + delta - return datetime.datetime.strptime(date_str, '%Y%m%d').date() + else: + if unit == 'week': + unit = 'day' + time *= 7 + delta = datetime.timedelta(**{unit + 's': time}) + new_date = start_time + delta + if auto_precision: + return datetime_round(new_date, unit) + return new_date + + return datetime_round(datetime.datetime.strptime(date_str, format), precision) + + +def date_from_str(date_str, format='%Y%m%d'): + """ + Return a datetime object from a string in the format YYYYMMDD or + (now|today|date)[+-][0-9](microsecond|second|minute|hour|day|week|month|year)(s)? + + format: string date format used to return datetime object from + """ + return datetime_from_str(date_str, precision='microsecond', format=format).date() + + +def datetime_add_months(dt, months): + """Increment/Decrement a datetime object by months.""" + month = dt.month + months - 1 + year = dt.year + month // 12 + month = month % 12 + 1 + day = min(dt.day, calendar.monthrange(year, month)[1]) + return dt.replace(year, month, day) + + +def datetime_round(dt, precision='day'): + """ + Round a datetime object's time to a specific precision + """ + if precision == 'microsecond': + return dt + + unit_seconds = { + 'day': 86400, + 'hour': 3600, + 'minute': 60, + 'second': 1, + } + roundto = lambda x, n: ((x + n / 2) // n) * n + timestamp = calendar.timegm(dt.timetuple()) + return datetime.datetime.utcfromtimestamp(roundto(timestamp, unit_seconds[precision])) def hyphenate_date(date_str): @@ -3135,6 +3307,14 @@ def platform_name(): return res +def get_windows_version(): + ''' Get Windows version. None if it's not running on Windows ''' + if compat_os_name == 'nt': + return version_tuple(platform.win32_ver()[1]) + else: + return None + + def _windows_write_string(s, out): """ Returns True if the string was written using special methods, False if it has yet to be written out.""" @@ -3607,6 +3787,11 @@ def remove_quotes(s): return s +def get_domain(url): + domain = re.match(r'(?:https?:\/\/)?(?:www\.)?(?P<domain>[^\n\/]+\.[^\n\/]+)(?:\/(.*))?', url) + return domain.group('domain') if domain else None + + def url_basename(url): path = compat_urlparse.urlparse(url).path return path.strip('/').split('/')[-1] @@ -3692,6 +3877,18 @@ def url_or_none(url): return url if re.match(r'^(?:(?:https?|rt(?:m(?:pt?[es]?|fp)|sp[su]?)|mms|ftps?):)?//', url) else None +def strftime_or_none(timestamp, date_format, default=None): + datetime_object = None + try: + if isinstance(timestamp, compat_numeric_types): # unix timestamp + datetime_object = datetime.datetime.utcfromtimestamp(timestamp) + elif isinstance(timestamp, compat_str): # assume YYYYMMDD + datetime_object = datetime.datetime.strptime(timestamp, '%Y%m%d') + return datetime_object.strftime(date_format) + except (ValueError, TypeError, AttributeError): + return default + + def parse_duration(s): if not isinstance(s, compat_basestring): return None @@ -3769,7 +3966,8 @@ def check_executable(exe, args=[]): """ Checks if the given binary is installed somewhere in PATH, and returns its name. args can be a list of arguments for a short output (like -version) """ try: - subprocess.Popen([exe] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate() + process_communicate_or_kill(subprocess.Popen( + [exe] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)) except OSError: return False return exe @@ -3783,10 +3981,10 @@ def get_exe_version(exe, args=['--version'], # STDIN should be redirected too. On UNIX-like systems, ffmpeg triggers # SIGTTOU if hypervideo is run in the background. # See https://github.com/ytdl-org/youtube-dl/issues/955#issuecomment-209789656 - out, _ = subprocess.Popen( + out, _ = process_communicate_or_kill(subprocess.Popen( [encodeArgument(exe)] + args, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate() + stdout=subprocess.PIPE, stderr=subprocess.STDOUT)) except OSError: return False if isinstance(out, bytes): # Python 2.x @@ -3805,49 +4003,144 @@ def detect_exe_version(output, version_re=None, unrecognized='present'): return unrecognized -class PagedList(object): +class LazyList(collections.abc.Sequence): + ''' Lazy immutable list from an iterable + Note that slices of a LazyList are lists and not LazyList''' + + class IndexError(IndexError): + pass + + def __init__(self, iterable): + self.__iterable = iter(iterable) + self.__cache = [] + self.__reversed = False + + def __iter__(self): + if self.__reversed: + # We need to consume the entire iterable to iterate in reverse + yield from self.exhaust() + return + yield from self.__cache + for item in self.__iterable: + self.__cache.append(item) + yield item + + def __exhaust(self): + self.__cache.extend(self.__iterable) + return self.__cache + + def exhaust(self): + ''' Evaluate the entire iterable ''' + return self.__exhaust()[::-1 if self.__reversed else 1] + + @staticmethod + def __reverse_index(x): + return None if x is None else -(x + 1) + + def __getitem__(self, idx): + if isinstance(idx, slice): + if self.__reversed: + idx = slice(self.__reverse_index(idx.start), self.__reverse_index(idx.stop), -(idx.step or 1)) + start, stop, step = idx.start, idx.stop, idx.step or 1 + elif isinstance(idx, int): + if self.__reversed: + idx = self.__reverse_index(idx) + start, stop, step = idx, idx, 0 + else: + raise TypeError('indices must be integers or slices') + if ((start or 0) < 0 or (stop or 0) < 0 + or (start is None and step < 0) + or (stop is None and step > 0)): + # We need to consume the entire iterable to be able to slice from the end + # Obviously, never use this with infinite iterables + self.__exhaust() + try: + return self.__cache[idx] + except IndexError as e: + raise self.IndexError(e) from e + n = max(start or 0, stop or 0) - len(self.__cache) + 1 + if n > 0: + self.__cache.extend(itertools.islice(self.__iterable, n)) + try: + return self.__cache[idx] + except IndexError as e: + raise self.IndexError(e) from e + + def __bool__(self): + try: + self[-1] if self.__reversed else self[0] + except self.IndexError: + return False + return True + + def __len__(self): + self.__exhaust() + return len(self.__cache) + + def reverse(self): + self.__reversed = not self.__reversed + return self + + def __repr__(self): + # repr and str should mimic a list. So we exhaust the iterable + return repr(self.exhaust()) + + def __str__(self): + return repr(self.exhaust()) + + +class PagedList: def __len__(self): # This is only useful for tests return len(self.getslice()) - -class OnDemandPagedList(PagedList): def __init__(self, pagefunc, pagesize, use_cache=True): self._pagefunc = pagefunc self._pagesize = pagesize self._use_cache = use_cache - if use_cache: - self._cache = {} + self._cache = {} + + def getpage(self, pagenum): + page_results = self._cache.get(pagenum) or list(self._pagefunc(pagenum)) + if self._use_cache: + self._cache[pagenum] = page_results + return page_results def getslice(self, start=0, end=None): - res = [] + return list(self._getslice(start, end)) + + def _getslice(self, start, end): + raise NotImplementedError('This method must be implemented by subclasses') + + def __getitem__(self, idx): + # NOTE: cache must be enabled if this is used + if not isinstance(idx, int) or idx < 0: + raise TypeError('indices must be non-negative integers') + entries = self.getslice(idx, idx + 1) + return entries[0] if entries else None + + +class OnDemandPagedList(PagedList): + def _getslice(self, start, end): for pagenum in itertools.count(start // self._pagesize): firstid = pagenum * self._pagesize nextfirstid = pagenum * self._pagesize + self._pagesize if start >= nextfirstid: continue - page_results = None - if self._use_cache: - page_results = self._cache.get(pagenum) - if page_results is None: - page_results = list(self._pagefunc(pagenum)) - if self._use_cache: - self._cache[pagenum] = page_results - startv = ( start % self._pagesize if firstid <= start < nextfirstid else 0) - endv = ( ((end - 1) % self._pagesize) + 1 if (end is not None and firstid <= end <= nextfirstid) else None) + page_results = self.getpage(pagenum) if startv != 0 or endv is not None: page_results = page_results[startv:endv] - res.extend(page_results) + yield from page_results # A little optimization - if current page is not "full", ie. does # not contain page_size videos then we can assume that this page @@ -3860,36 +4153,31 @@ class OnDemandPagedList(PagedList): # break out early as well if end == nextfirstid: break - return res class InAdvancePagedList(PagedList): def __init__(self, pagefunc, pagecount, pagesize): - self._pagefunc = pagefunc self._pagecount = pagecount - self._pagesize = pagesize + PagedList.__init__(self, pagefunc, pagesize, True) - def getslice(self, start=0, end=None): - res = [] + def _getslice(self, start, end): start_page = start // self._pagesize end_page = ( self._pagecount if end is None else (end // self._pagesize + 1)) skip_elems = start - start_page * self._pagesize only_more = None if end is None else end - start for pagenum in range(start_page, end_page): - page = list(self._pagefunc(pagenum)) + page_results = self.getpage(pagenum) if skip_elems: - page = page[skip_elems:] + page_results = page_results[skip_elems:] skip_elems = None if only_more is not None: - if len(page) < only_more: - only_more -= len(page) + if len(page_results) < only_more: + only_more -= len(page_results) else: - page = page[:only_more] - res.extend(page) + yield from page_results[:only_more] break - res.extend(page) - return res + yield from page_results def uppercase_escape(s): @@ -3927,17 +4215,24 @@ def escape_url(url): ).geturl() +def parse_qs(url): + return compat_parse_qs(compat_urllib_parse_urlparse(url).query) + + def read_batch_urls(batch_fd): def fixup(url): if not isinstance(url, compat_str): url = url.decode('utf-8', 'replace') - BOM_UTF8 = '\xef\xbb\xbf' - if url.startswith(BOM_UTF8): - url = url[len(BOM_UTF8):] - url = url.strip() - if url.startswith(('#', ';', ']')): + BOM_UTF8 = ('\xef\xbb\xbf', '\ufeff') + for bom in BOM_UTF8: + if url.startswith(bom): + url = url[len(bom):] + url = url.lstrip() + if not url or url.startswith(('#', ';', ']')): return False - return url + # "#" cannot be stripped out since it is part of the URI + # However, it can be safely stipped out if follwing a whitespace + return re.split(r'\s#', url, 1)[0].rstrip() with contextlib.closing(batch_fd) as fd: return [url for url in map(fixup, fd) if url] @@ -4040,9 +4335,7 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True): def try_get(src, getter, expected_type=None): - if not isinstance(getter, (list, tuple)): - getter = [getter] - for get in getter: + for get in variadic(getter): try: v = get(src) except (AttributeError, KeyError, TypeError, IndexError): @@ -4097,6 +4390,7 @@ def parse_age_limit(s): m = re.match(r'^(?P<age>\d{1,2})\+?$', s) if m: return int(m.group('age')) + s = s.upper() if s in US_RATINGS: return US_RATINGS[s] m = re.match(r'^TV[_-]?(%s)$' % '|'.join(k[3:] for k in TV_PARENTAL_GUIDELINES), s) @@ -4115,8 +4409,9 @@ def strip_jsonp(code): r'\g<callback_data>', code) -def js_to_json(code): - COMMENT_RE = r'/\*(?:(?!\*/).)*?\*/|//[^\n]*' +def js_to_json(code, vars={}): + # vars is a dict of var, val pairs to substitute + COMMENT_RE = r'/\*(?:(?!\*/).)*?\*/|//[^\n]*\n' SKIP_RE = r'\s*(?:{comment})?\s*'.format(comment=COMMENT_RE) INTEGER_TABLE = ( (r'(?s)^(0[xX][0-9a-fA-F]+){skip}:?$'.format(skip=SKIP_RE), 16), @@ -4127,6 +4422,8 @@ def js_to_json(code): v = m.group(0) if v in ('true', 'false', 'null'): return v + elif v in ('undefined', 'void 0'): + return 'null' elif v.startswith('/*') or v.startswith('//') or v.startswith('!') or v == ',': return "" @@ -4144,13 +4441,16 @@ def js_to_json(code): i = int(im.group(1), base) return '"%d":' % i if v.endswith(':') else '%d' % i + if v in vars: + return vars[v] + return '"%s"' % v return re.sub(r'''(?sx) "(?:[^"\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^"\\]*"| '(?:[^'\\]*(?:\\\\|\\['"nurtbfx/\n]))*[^'\\]*'| {comment}|,(?={skip}[\]}}])| - (?:(?<![0-9])[eE]|[a-df-zA-DF-Z_])[.a-zA-Z_0-9]*| + void\s0|(?:(?<![0-9])[eE]|[a-df-zA-DF-Z_$])[.a-zA-Z_$0-9]*| \b(?:0[xX][0-9a-fA-F]+|0+[0-7]+)(?:{skip}:)?| [0-9]+(?={skip}:)| !+ @@ -4167,7 +4467,40 @@ def qualities(quality_ids): return q -DEFAULT_OUTTMPL = '%(title)s-%(id)s.%(ext)s' +DEFAULT_OUTTMPL = { + 'default': '%(title)s [%(id)s].%(ext)s', + 'chapter': '%(title)s - %(section_number)03d %(section_title)s [%(id)s].%(ext)s', +} +OUTTMPL_TYPES = { + 'chapter': None, + 'subtitle': None, + 'thumbnail': None, + 'description': 'description', + 'annotation': 'annotations.xml', + 'infojson': 'info.json', + 'pl_thumbnail': None, + 'pl_description': 'description', + 'pl_infojson': 'info.json', +} + +# As of [1] format syntax is: +# %[mapping_key][conversion_flags][minimum_width][.precision][length_modifier]type +# 1. https://docs.python.org/2/library/stdtypes.html#string-formatting +STR_FORMAT_RE_TMPL = r'''(?x) + (?<!%)(?P<prefix>(?:%%)*) + % + (?P<has_key>\((?P<key>{0})\))? + (?P<format> + (?P<conversion>[#0\-+ ]+)? + (?P<min_width>\d+)? + (?P<precision>\.\d+)? + (?P<len_mod>[hlL])? # unused in python + {1} # conversion type + ) +''' + + +STR_FORMAT_TYPES = 'diouxXeEfFgGcrs' def limit_length(s, length): @@ -4195,9 +4528,10 @@ def is_outdated_version(version, limit, assume_new=True): def ytdl_is_updateable(): """ Returns if hypervideo can be updated with -U """ - from zipimport import zipimporter - return isinstance(globals().get('__loader__'), zipimporter) or hasattr(sys, 'frozen') + from .update import is_non_updateable + + return not is_non_updateable() def args_to_str(args): @@ -4218,19 +4552,24 @@ def mimetype2ext(mt): if mt is None: return None - ext = { + mt, _, params = mt.partition(';') + mt = mt.strip() + + FULL_MAP = { 'audio/mp4': 'm4a', # Per RFC 3003, audio/mpeg can be .mp1, .mp2 or .mp3. Here use .mp3 as # it's the most popular one 'audio/mpeg': 'mp3', - }.get(mt) + 'audio/x-wav': 'wav', + 'audio/wav': 'wav', + 'audio/wave': 'wav', + } + + ext = FULL_MAP.get(mt) if ext is not None: return ext - _, _, res = mt.rpartition('/') - res = res.split(';')[0].strip().lower() - - return { + SUBTYPE_MAP = { '3gpp': '3gp', 'smptett+xml': 'tt', 'ttaf+xml': 'dfxp', @@ -4249,7 +4588,28 @@ def mimetype2ext(mt): 'quicktime': 'mov', 'mp2t': 'ts', 'x-wav': 'wav', - }.get(res, res) + 'filmstrip+json': 'fs', + 'svg+xml': 'svg', + } + + _, _, subtype = mt.rpartition('/') + ext = SUBTYPE_MAP.get(subtype.lower()) + if ext is not None: + return ext + + SUFFIX_MAP = { + 'json': 'json', + 'xml': 'xml', + 'zip': 'zip', + 'gzip': 'gz', + } + + _, _, suffix = subtype.partition('+') + ext = SUFFIX_MAP.get(suffix) + if ext is not None: + return ext + + return subtype.replace('+', '.') def parse_codecs(codecs_str): @@ -4257,13 +4617,22 @@ def parse_codecs(codecs_str): if not codecs_str: return {} split_codecs = list(filter(None, map( - lambda str: str.strip(), codecs_str.strip().strip(',').split(',')))) - vcodec, acodec = None, None + str.strip, codecs_str.strip().strip(',').split(',')))) + vcodec, acodec, hdr = None, None, None for full_codec in split_codecs: codec = full_codec.split('.')[0] - if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', 'h263', 'h264', 'mp4v', 'hvc1', 'av01', 'theora'): + if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', 'h263', 'h264', 'mp4v', 'hvc1', 'av01', 'theora', 'dvh1', 'dvhe'): if not vcodec: vcodec = full_codec + if codec in ('dvh1', 'dvhe'): + hdr = 'DV' + elif codec == 'vp9' and vcodec.startswith('vp9.2'): + hdr = 'HDR10' + elif codec == 'av01': + parts = full_codec.split('.') + if len(parts) > 3 and parts[3] == '10': + hdr = 'HDR10' + vcodec = '.'.join(parts[:4]) elif codec in ('mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'): if not acodec: acodec = full_codec @@ -4279,6 +4648,7 @@ def parse_codecs(codecs_str): return { 'vcodec': vcodec or 'none', 'acodec': acodec or 'none', + 'dynamic_range': hdr, } return {} @@ -4353,66 +4723,85 @@ def determine_protocol(info_dict): return compat_urllib_parse_urlparse(url).scheme -def render_table(header_row, data): +def render_table(header_row, data, delim=False, extraGap=0, hideEmpty=False): """ Render a list of rows, each as a list of values """ + + def get_max_lens(table): + return [max(len(compat_str(v)) for v in col) for col in zip(*table)] + + def filter_using_list(row, filterArray): + return [col for (take, col) in zip(filterArray, row) if take] + + if hideEmpty: + max_lens = get_max_lens(data) + header_row = filter_using_list(header_row, max_lens) + data = [filter_using_list(row, max_lens) for row in data] + table = [header_row] + data - max_lens = [max(len(compat_str(v)) for v in col) for col in zip(*table)] - format_str = ' '.join('%-' + compat_str(ml + 1) + 's' for ml in max_lens[:-1]) + '%s' + max_lens = get_max_lens(table) + if delim: + table = [header_row] + [['-' * ml for ml in max_lens]] + data + format_str = ' '.join('%-' + compat_str(ml + extraGap) + 's' for ml in max_lens[:-1]) + ' %s' return '\n'.join(format_str % tuple(row) for row in table) -def _match_one(filter_part, dct): +def _match_one(filter_part, dct, incomplete): + # TODO: Generalize code with YoutubeDL._build_format_filter + STRING_OPERATORS = { + '*=': operator.contains, + '^=': lambda attr, value: attr.startswith(value), + '$=': lambda attr, value: attr.endswith(value), + '~=': lambda attr, value: re.search(value, attr), + } COMPARISON_OPERATORS = { + **STRING_OPERATORS, + '<=': operator.le, # "<=" must be defined above "<" '<': operator.lt, - '<=': operator.le, - '>': operator.gt, '>=': operator.ge, + '>': operator.gt, '=': operator.eq, - '!=': operator.ne, } + operator_rex = re.compile(r'''(?x)\s* (?P<key>[a-z_]+) - \s*(?P<op>%s)(?P<none_inclusive>\s*\?)?\s* + \s*(?P<negation>!\s*)?(?P<op>%s)(?P<none_inclusive>\s*\?)?\s* (?: - (?P<intval>[0-9.]+(?:[kKmMgGtTpPeEzZyY]i?[Bb]?)?)| - (?P<quote>["\'])(?P<quotedstrval>(?:\\.|(?!(?P=quote)|\\).)+?)(?P=quote)| - (?P<strval>(?![0-9.])[a-z0-9A-Z]*) + (?P<quote>["\'])(?P<quotedstrval>.+?)(?P=quote)| + (?P<strval>.+?) ) \s*$ ''' % '|'.join(map(re.escape, COMPARISON_OPERATORS.keys()))) m = operator_rex.search(filter_part) if m: - op = COMPARISON_OPERATORS[m.group('op')] - actual_value = dct.get(m.group('key')) - if (m.group('quotedstrval') is not None - or m.group('strval') is not None + m = m.groupdict() + unnegated_op = COMPARISON_OPERATORS[m['op']] + if m['negation']: + op = lambda attr, value: not unnegated_op(attr, value) + else: + op = unnegated_op + comparison_value = m['quotedstrval'] or m['strval'] or m['intval'] + if m['quote']: + comparison_value = comparison_value.replace(r'\%s' % m['quote'], m['quote']) + actual_value = dct.get(m['key']) + numeric_comparison = None + if isinstance(actual_value, compat_numeric_types): # If the original field is a string and matching comparisonvalue is # a number we should respect the origin of the original field # and process comparison value as a string (see - # https://github.com/ytdl-org/youtube-dl/issues/11082). - or actual_value is not None and m.group('intval') is not None - and isinstance(actual_value, compat_str)): - if m.group('op') not in ('=', '!='): - raise ValueError( - 'Operator %s does not support string values!' % m.group('op')) - comparison_value = m.group('quotedstrval') or m.group('strval') or m.group('intval') - quote = m.group('quote') - if quote is not None: - comparison_value = comparison_value.replace(r'\%s' % quote, quote) - else: + # https://github.com/ytdl-org/youtube-dl/issues/11082) try: - comparison_value = int(m.group('intval')) + numeric_comparison = int(comparison_value) except ValueError: - comparison_value = parse_filesize(m.group('intval')) - if comparison_value is None: - comparison_value = parse_filesize(m.group('intval') + 'B') - if comparison_value is None: - raise ValueError( - 'Invalid integer value %r in filter part %r' % ( - m.group('intval'), filter_part)) + numeric_comparison = parse_filesize(comparison_value) + if numeric_comparison is None: + numeric_comparison = parse_filesize(f'{comparison_value}B') + if numeric_comparison is None: + numeric_comparison = parse_duration(comparison_value) + if numeric_comparison is not None and m['op'] in STRING_OPERATORS: + raise ValueError('Operator %s only supports string values!' % m['op']) if actual_value is None: - return m.group('none_inclusive') - return op(actual_value, comparison_value) + return incomplete or m['none_inclusive'] + return op(actual_value, comparison_value if numeric_comparison is None else numeric_comparison) UNARY_OPERATORS = { '': lambda v: (v is True) if isinstance(v, bool) else (v is not None), @@ -4426,21 +4815,25 @@ def _match_one(filter_part, dct): if m: op = UNARY_OPERATORS[m.group('op')] actual_value = dct.get(m.group('key')) + if incomplete and actual_value is None: + return True return op(actual_value) raise ValueError('Invalid filter part %r' % filter_part) -def match_str(filter_str, dct): - """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false """ - +def match_str(filter_str, dct, incomplete=False): + """ Filter a dictionary with a simple string syntax. Returns True (=passes filter) or false + When incomplete, all conditions passes on missing fields + """ return all( - _match_one(filter_part, dct) for filter_part in filter_str.split('&')) + _match_one(filter_part.replace(r'\&', '&'), dct, incomplete) + for filter_part in re.split(r'(?<!\\)&', filter_str)) def match_filter_func(filter_str): - def _match_func(info_dict): - if match_str(filter_str, info_dict): + def _match_func(info_dict, *args, **kwargs): + if match_str(filter_str, info_dict, *args, **kwargs): return None else: video_title = info_dict.get('title', info_dict.get('id', 'video')) @@ -4651,12 +5044,37 @@ def cli_valueless_option(params, command_option, param, expected_value=True): return [command_option] if param == expected_value else [] -def cli_configuration_args(params, param, default=[]): - ex_args = params.get(param) - if ex_args is None: +def cli_configuration_args(argdict, keys, default=[], use_compat=True): + if isinstance(argdict, (list, tuple)): # for backward compatibility + if use_compat: + return argdict + else: + argdict = None + if argdict is None: return default - assert isinstance(ex_args, list) - return ex_args + assert isinstance(argdict, dict) + + assert isinstance(keys, (list, tuple)) + for key_list in keys: + arg_list = list(filter( + lambda x: x is not None, + [argdict.get(key.lower()) for key in variadic(key_list)])) + if arg_list: + return [arg for args in arg_list for arg in args] + return default + + +def _configuration_args(main_key, argdict, exe, keys=None, default=[], use_compat=True): + main_key, exe = main_key.lower(), exe.lower() + root_key = exe if main_key == exe else f'{main_key}+{exe}' + keys = [f'{root_key}{k}' for k in (keys or [''])] + if root_key in keys: + if main_key != exe: + keys.append((main_key, exe)) + keys.append('default') + else: + use_compat = False + return cli_configuration_args(argdict, keys, default, use_compat) class ISO639Utils(object): @@ -5725,7 +6143,7 @@ def write_xattr(path, key, value): cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) except EnvironmentError as e: raise XAttrMetadataError(e.errno, e.strerror) - stdout, stderr = p.communicate() + stdout, stderr = process_communicate_or_kill(p) stderr = stderr.decode('utf-8', 'replace') if p.returncode != 0: raise XAttrMetadataError(p.returncode, stderr) @@ -5757,6 +6175,95 @@ def random_birthday(year_field, month_field, day_field): } +# Templates for internet shortcut files, which are plain text files. +DOT_URL_LINK_TEMPLATE = ''' +[InternetShortcut] +URL=%(url)s +'''.lstrip() + +DOT_WEBLOC_LINK_TEMPLATE = ''' +<?xml version="1.0" encoding="UTF-8"?> +<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> +<plist version="1.0"> +<dict> +\t<key>URL</key> +\t<string>%(url)s</string> +</dict> +</plist> +'''.lstrip() + +DOT_DESKTOP_LINK_TEMPLATE = ''' +[Desktop Entry] +Encoding=UTF-8 +Name=%(filename)s +Type=Link +URL=%(url)s +Icon=text-html +'''.lstrip() + + +def iri_to_uri(iri): + """ + Converts an IRI (Internationalized Resource Identifier, allowing Unicode characters) to a URI (Uniform Resource Identifier, ASCII-only). + + The function doesn't add an additional layer of escaping; e.g., it doesn't escape `%3C` as `%253C`. Instead, it percent-escapes characters with an underlying UTF-8 encoding *besides* those already escaped, leaving the URI intact. + """ + + iri_parts = compat_urllib_parse_urlparse(iri) + + if '[' in iri_parts.netloc: + raise ValueError('IPv6 URIs are not, yet, supported.') + # Querying `.netloc`, when there's only one bracket, also raises a ValueError. + + # The `safe` argument values, that the following code uses, contain the characters that should not be percent-encoded. Everything else but letters, digits and '_.-' will be percent-encoded with an underlying UTF-8 encoding. Everything already percent-encoded will be left as is. + + net_location = '' + if iri_parts.username: + net_location += compat_urllib_parse_quote(iri_parts.username, safe=r"!$%&'()*+,~") + if iri_parts.password is not None: + net_location += ':' + compat_urllib_parse_quote(iri_parts.password, safe=r"!$%&'()*+,~") + net_location += '@' + + net_location += iri_parts.hostname.encode('idna').decode('utf-8') # Punycode for Unicode hostnames. + # The 'idna' encoding produces ASCII text. + if iri_parts.port is not None and iri_parts.port != 80: + net_location += ':' + str(iri_parts.port) + + return compat_urllib_parse_urlunparse( + (iri_parts.scheme, + net_location, + + compat_urllib_parse_quote_plus(iri_parts.path, safe=r"!$%&'()*+,/:;=@|~"), + + # Unsure about the `safe` argument, since this is a legacy way of handling parameters. + compat_urllib_parse_quote_plus(iri_parts.params, safe=r"!$%&'()*+,/:;=@|~"), + + # Not totally sure about the `safe` argument, since the source does not explicitly mention the query URI component. + compat_urllib_parse_quote_plus(iri_parts.query, safe=r"!$%&'()*+,/:;=?@{|}~"), + + compat_urllib_parse_quote_plus(iri_parts.fragment, safe=r"!#$%&'()*+,/:;=?@{|}~"))) + + # Source for `safe` arguments: https://url.spec.whatwg.org/#percent-encoded-bytes. + + +def to_high_limit_path(path): + if sys.platform in ['win32', 'cygwin']: + # Work around MAX_PATH limitation on Windows. The maximum allowed length for the individual path segments may still be quite limited. + return r'\\?\ '.rstrip() + os.path.abspath(path) + + return path + + +def format_field(obj, field=None, template='%s', ignore=(None, ''), default='', func=None): + if field is None: + val = obj if obj is not None else default + else: + val = obj.get(field, default) + if func and val not in ignore: + val = func(val) + return template % val if val not in ignore else default + + def clean_podcast_url(url): return re.sub(r'''(?x) (?: @@ -5772,3 +6279,203 @@ def clean_podcast_url(url): st\.fm # https://podsights.com/docs/ )/e )/''', '', url) + + +_HEX_TABLE = '0123456789abcdef' + + +def random_uuidv4(): + return re.sub(r'[xy]', lambda x: _HEX_TABLE[random.randint(0, 15)], 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx') + + +def make_dir(path, to_screen=None): + try: + dn = os.path.dirname(path) + if dn and not os.path.exists(dn): + os.makedirs(dn) + return True + except (OSError, IOError) as err: + if callable(to_screen) is not None: + to_screen('unable to create directory ' + error_to_compat_str(err)) + return False + + +def get_executable_path(): + from zipimport import zipimporter + if hasattr(sys, 'frozen'): # Running from PyInstaller + path = os.path.dirname(sys.executable) + elif isinstance(globals().get('__loader__'), zipimporter): # Running from ZIP + path = os.path.join(os.path.dirname(__file__), '../..') + else: + path = os.path.join(os.path.dirname(__file__), '..') + return os.path.abspath(path) + + +def load_plugins(name, suffix, namespace): + classes = {} + try: + plugins_spec = importlib.util.spec_from_file_location( + name, os.path.join(get_executable_path(), 'ytdlp_plugins', name, '__init__.py')) + plugins = importlib.util.module_from_spec(plugins_spec) + sys.modules[plugins_spec.name] = plugins + plugins_spec.loader.exec_module(plugins) + for name in dir(plugins): + if name in namespace: + continue + if not name.endswith(suffix): + continue + klass = getattr(plugins, name) + classes[name] = namespace[name] = klass + except FileNotFoundError: + pass + return classes + + +def traverse_obj( + obj, *path_list, default=None, expected_type=None, get_all=True, + casesense=True, is_user_input=False, traverse_string=False): + ''' Traverse nested list/dict/tuple + @param path_list A list of paths which are checked one by one. + Each path is a list of keys where each key is a string, + a function, a tuple of strings or "...". + When a fuction is given, it takes the key as argument and + returns whether the key matches or not. When a tuple is given, + all the keys given in the tuple are traversed, and + "..." traverses all the keys in the object + @param default Default value to return + @param expected_type Only accept final value of this type (Can also be any callable) + @param get_all Return all the values obtained from a path or only the first one + @param casesense Whether to consider dictionary keys as case sensitive + @param is_user_input Whether the keys are generated from user input. If True, + strings are converted to int/slice if necessary + @param traverse_string Whether to traverse inside strings. If True, any + non-compatible object will also be converted into a string + # TODO: Write tests + ''' + if not casesense: + _lower = lambda k: (k.lower() if isinstance(k, str) else k) + path_list = (map(_lower, variadic(path)) for path in path_list) + + def _traverse_obj(obj, path, _current_depth=0): + nonlocal depth + if obj is None: + return None + path = tuple(variadic(path)) + for i, key in enumerate(path): + if isinstance(key, (list, tuple)): + obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key] + key = ... + if key is ...: + obj = (obj.values() if isinstance(obj, dict) + else obj if isinstance(obj, (list, tuple, LazyList)) + else str(obj) if traverse_string else []) + _current_depth += 1 + depth = max(depth, _current_depth) + return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj] + elif callable(key): + if isinstance(obj, (list, tuple, LazyList)): + obj = enumerate(obj) + elif isinstance(obj, dict): + obj = obj.items() + else: + if not traverse_string: + return None + obj = str(obj) + _current_depth += 1 + depth = max(depth, _current_depth) + return [_traverse_obj(v, path[i + 1:], _current_depth) for k, v in obj if key(k)] + elif isinstance(obj, dict) and not (is_user_input and key == ':'): + obj = (obj.get(key) if casesense or (key in obj) + else next((v for k, v in obj.items() if _lower(k) == key), None)) + else: + if is_user_input: + key = (int_or_none(key) if ':' not in key + else slice(*map(int_or_none, key.split(':')))) + if key == slice(None): + return _traverse_obj(obj, (..., *path[i + 1:]), _current_depth) + if not isinstance(key, (int, slice)): + return None + if not isinstance(obj, (list, tuple, LazyList)): + if not traverse_string: + return None + obj = str(obj) + try: + obj = obj[key] + except IndexError: + return None + return obj + + if isinstance(expected_type, type): + type_test = lambda val: val if isinstance(val, expected_type) else None + elif expected_type is not None: + type_test = expected_type + else: + type_test = lambda val: val + + for path in path_list: + depth = 0 + val = _traverse_obj(obj, path) + if val is not None: + if depth: + for _ in range(depth - 1): + val = itertools.chain.from_iterable(v for v in val if v is not None) + val = [v for v in map(type_test, val) if v is not None] + if val: + return val if get_all else val[0] + else: + val = type_test(val) + if val is not None: + return val + return default + + +def traverse_dict(dictn, keys, casesense=True): + ''' For backward compatibility. Do not use ''' + return traverse_obj(dictn, keys, casesense=casesense, + is_user_input=True, traverse_string=True) + + +def variadic(x, allowed_types=(str, bytes)): + return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,) + + +# create a JSON Web Signature (jws) with HS256 algorithm +# the resulting format is in JWS Compact Serialization +# implemented following JWT https://www.rfc-editor.org/rfc/rfc7519.html +# implemented following JWS https://www.rfc-editor.org/rfc/rfc7515.html +def jwt_encode_hs256(payload_data, key, headers={}): + header_data = { + 'alg': 'HS256', + 'typ': 'JWT', + } + if headers: + header_data.update(headers) + header_b64 = base64.b64encode(json.dumps(header_data).encode('utf-8')) + payload_b64 = base64.b64encode(json.dumps(payload_data).encode('utf-8')) + h = hmac.new(key.encode('utf-8'), header_b64 + b'.' + payload_b64, hashlib.sha256) + signature_b64 = base64.b64encode(h.digest()) + token = header_b64 + b'.' + payload_b64 + b'.' + signature_b64 + return token + + +def supports_terminal_sequences(stream): + if compat_os_name == 'nt': + if get_windows_version() < (10, 0, 10586): + return False + elif not os.getenv('TERM'): + return False + try: + return stream.isatty() + except BaseException: + return False + + +TERMINAL_SEQUENCES = { + 'DOWN': '\n', + 'UP': '\x1b[A', + 'ERASE_LINE': '\x1b[K', + 'RED': '\033[0;31m', + 'YELLOW': '\033[0;33m', + 'BLUE': '\033[0;34m', + 'RESET_STYLE': '\033[0m', +} |