aboutsummaryrefslogtreecommitdiffstats
path: root/yt_dlp/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'yt_dlp/utils.py')
-rw-r--r--yt_dlp/utils.py448
1 files changed, 378 insertions, 70 deletions
diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py
index 90502dbc0..6ec8da11b 100644
--- a/yt_dlp/utils.py
+++ b/yt_dlp/utils.py
@@ -3,6 +3,8 @@
from __future__ import unicode_literals
+import asyncio
+import atexit
import base64
import binascii
import calendar
@@ -58,6 +60,7 @@ from .compat import (
compat_kwargs,
compat_os_name,
compat_parse_qs,
+ compat_shlex_split,
compat_shlex_quote,
compat_str,
compat_struct_pack,
@@ -72,6 +75,7 @@ from .compat import (
compat_urllib_parse_unquote_plus,
compat_urllib_request,
compat_urlparse,
+ compat_websockets,
compat_xpath,
)
@@ -144,6 +148,7 @@ std_headers = {
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'Accept-Encoding': 'gzip, deflate',
'Accept-Language': 'en-us,en;q=0.5',
+ 'Sec-Fetch-Mode': 'navigate',
}
@@ -415,17 +420,33 @@ def get_element_by_id(id, html):
return get_element_by_attribute('id', id, html)
+def get_element_html_by_id(id, html):
+ """Return the html of the tag with the specified ID in the passed HTML document"""
+ return get_element_html_by_attribute('id', id, html)
+
+
def get_element_by_class(class_name, html):
"""Return the content of the first tag with the specified class in the passed HTML document"""
retval = get_elements_by_class(class_name, html)
return retval[0] if retval else None
+def get_element_html_by_class(class_name, html):
+ """Return the html of the first tag with the specified class in the passed HTML document"""
+ retval = get_elements_html_by_class(class_name, html)
+ return retval[0] if retval else None
+
+
def get_element_by_attribute(attribute, value, html, escape_value=True):
retval = get_elements_by_attribute(attribute, value, html, escape_value)
return retval[0] if retval else None
+def get_element_html_by_attribute(attribute, value, html, escape_value=True):
+ retval = get_elements_html_by_attribute(attribute, value, html, escape_value)
+ return retval[0] if retval else None
+
+
def get_elements_by_class(class_name, html):
"""Return the content of all tags with the specified class in the passed HTML document as a list"""
return get_elements_by_attribute(
@@ -433,29 +454,123 @@ def get_elements_by_class(class_name, html):
html, escape_value=False)
-def get_elements_by_attribute(attribute, value, html, escape_value=True):
+def get_elements_html_by_class(class_name, html):
+ """Return the html of all tags with the specified class in the passed HTML document as a list"""
+ return get_elements_html_by_attribute(
+ 'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name),
+ html, escape_value=False)
+
+
+def get_elements_by_attribute(*args, **kwargs):
"""Return the content of the tag with the specified attribute in the passed HTML document"""
+ return [content for content, _ in get_elements_text_and_html_by_attribute(*args, **kwargs)]
+
+
+def get_elements_html_by_attribute(*args, **kwargs):
+ """Return the html of the tag with the specified attribute in the passed HTML document"""
+ return [whole for _, whole in get_elements_text_and_html_by_attribute(*args, **kwargs)]
+
+
+def get_elements_text_and_html_by_attribute(attribute, value, html, escape_value=True):
+ """
+ Return the text (content) and the html (whole) of the tag with the specified
+ attribute in the passed HTML document
+ """
+
+ value_quote_optional = '' if re.match(r'''[\s"'`=<>]''', value) else '?'
value = re.escape(value) if escape_value else value
- retlist = []
- for m in re.finditer(r'''(?xs)
- <([a-zA-Z0-9:._-]+)
- (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'|))*?
- \s+%s=['"]?%s['"]?
- (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'|))*?
- \s*>
- (?P<content>.*?)
- </\1>
- ''' % (re.escape(attribute), value), html):
- res = m.group('content')
+ partial_element_re = r'''(?x)
+ <(?P<tag>[a-zA-Z0-9:._-]+)
+ (?:\s(?:[^>"']|"[^"]*"|'[^']*')*)?
+ \s%(attribute)s\s*=\s*(?P<_q>['"]%(vqo)s)(?-x:%(value)s)(?P=_q)
+ ''' % {'attribute': re.escape(attribute), 'value': value, 'vqo': value_quote_optional}
+
+ for m in re.finditer(partial_element_re, html):
+ content, whole = get_element_text_and_html_by_tag(m.group('tag'), html[m.start():])
+
+ yield (
+ unescapeHTML(re.sub(r'^(?P<q>["\'])(?P<content>.*)(?P=q)$', r'\g<content>', content, flags=re.DOTALL)),
+ whole
+ )
+
+
+class HTMLBreakOnClosingTagParser(compat_HTMLParser):
+ """
+ HTML parser which raises HTMLBreakOnClosingTagException upon reaching the
+ closing tag for the first opening tag it has encountered, and can be used
+ as a context manager
+ """
+
+ class HTMLBreakOnClosingTagException(Exception):
+ pass
+
+ def __init__(self):
+ self.tagstack = collections.deque()
+ compat_HTMLParser.__init__(self)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *_):
+ self.close()
- if res.startswith('"') or res.startswith("'"):
- res = res[1:-1]
+ def close(self):
+ # handle_endtag does not return upon raising HTMLBreakOnClosingTagException,
+ # so data remains buffered; we no longer have any interest in it, thus
+ # override this method to discard it
+ pass
- retlist.append(unescapeHTML(res))
+ def handle_starttag(self, tag, _):
+ self.tagstack.append(tag)
- return retlist
+ def handle_endtag(self, tag):
+ if not self.tagstack:
+ raise compat_HTMLParseError('no tags in the stack')
+ while self.tagstack:
+ inner_tag = self.tagstack.pop()
+ if inner_tag == tag:
+ break
+ else:
+ raise compat_HTMLParseError(f'matching opening tag for closing {tag} tag not found')
+ if not self.tagstack:
+ raise self.HTMLBreakOnClosingTagException()
+
+
+def get_element_text_and_html_by_tag(tag, html):
+ """
+ For the first element with the specified tag in the passed HTML document
+ return its' content (text) and the whole element (html)
+ """
+ def find_or_raise(haystack, needle, exc):
+ try:
+ return haystack.index(needle)
+ except ValueError:
+ raise exc
+ closing_tag = f'</{tag}>'
+ whole_start = find_or_raise(
+ html, f'<{tag}', compat_HTMLParseError(f'opening {tag} tag not found'))
+ content_start = find_or_raise(
+ html[whole_start:], '>', compat_HTMLParseError(f'malformed opening {tag} tag'))
+ content_start += whole_start + 1
+ with HTMLBreakOnClosingTagParser() as parser:
+ parser.feed(html[whole_start:content_start])
+ if not parser.tagstack or parser.tagstack[0] != tag:
+ raise compat_HTMLParseError(f'parser did not match opening {tag} tag')
+ offset = content_start
+ while offset < len(html):
+ next_closing_tag_start = find_or_raise(
+ html[offset:], closing_tag,
+ compat_HTMLParseError(f'closing {tag} tag not found'))
+ next_closing_tag_end = next_closing_tag_start + len(closing_tag)
+ try:
+ parser.feed(html[offset:offset + next_closing_tag_end])
+ offset += next_closing_tag_end
+ except HTMLBreakOnClosingTagParser.HTMLBreakOnClosingTagException:
+ return html[content_start:offset + next_closing_tag_start], \
+ html[whole_start:offset + next_closing_tag_end]
+ raise compat_HTMLParseError('unexpected end of html')
class HTMLAttributeParser(compat_HTMLParser):
@@ -527,10 +642,9 @@ def clean_html(html):
if html is None: # Convenience for sanitizing descriptions etc.
return html
- # Newline vs <br />
- html = html.replace('\n', ' ')
- html = re.sub(r'(?u)\s*<\s*br\s*/?\s*>\s*', '\n', html)
- html = re.sub(r'(?u)<\s*/\s*p\s*>\s*<\s*p[^>]*>', '\n', html)
+ html = re.sub(r'\s+', ' ', html)
+ html = re.sub(r'(?u)\s?<\s?br\s?/?\s?>\s?', '\n', html)
+ html = re.sub(r'(?u)<\s?/\s?p\s?>\s?<\s?p[^>]*>', '\n', html)
# Strip html tags
html = re.sub('<.*?>', '', html)
# Replace html entities
@@ -554,7 +668,7 @@ def sanitize_open(filename, open_mode):
import msvcrt
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
return (sys.stdout.buffer if hasattr(sys.stdout, 'buffer') else sys.stdout, filename)
- stream = open(encodeFilename(filename), open_mode)
+ stream = locked_file(filename, open_mode, block=False).open()
return (stream, filename)
except (IOError, OSError) as err:
if err.errno in (errno.EACCES,):
@@ -566,7 +680,7 @@ def sanitize_open(filename, open_mode):
raise
else:
# An exception here should be caught in the caller
- stream = open(encodeFilename(alt_filename), open_mode)
+ stream = locked_file(filename, open_mode, block=False).open()
return (stream, alt_filename)
@@ -885,6 +999,8 @@ def make_HTTPS_handler(params, **kwargs):
opts_check_certificate = not params.get('nocheckcertificate')
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = opts_check_certificate
+ if params.get('legacyserverconnect'):
+ context.options |= 4 # SSL_OP_LEGACY_SERVER_CONNECT
context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE
if opts_check_certificate:
try:
@@ -905,13 +1021,9 @@ def make_HTTPS_handler(params, **kwargs):
def bug_reports_message(before=';'):
- if ytdl_is_updateable():
- update_cmd = 'type doas pacman -Sy hypervideo to update'
- else:
- update_cmd = 'see https://git.conocimientoslibres.ga/software/hypervideo.git/about/#how-do-i-update-hypervideo'
- msg = 'please report this issue on https://github.com/yt-dlp/yt-dlp .'
- msg += ' Make sure you are using the latest version; %s.' % update_cmd
- msg += ' Be sure to call yt-dlp with the --verbose flag and include its complete output.'
+ msg = ('please report this issue on https://github.com/yt-dlp/yt-dlp , '
+ 'filling out the "Broken site" issue template properly. '
+ 'Confirm you are on the latest version using -U')
before = before.rstrip()
if not before or before.endswith(('.', '!', '?')):
@@ -1734,7 +1846,7 @@ def datetime_from_str(date_str, precision='auto', format='%Y%m%d'):
if precision == 'auto':
auto_precision = True
precision = 'microsecond'
- today = datetime_round(datetime.datetime.now(), precision)
+ today = datetime_round(datetime.datetime.utcnow(), precision)
if date_str in ('now', 'today'):
return today
if date_str == 'yesterday':
@@ -2010,7 +2122,7 @@ if sys.platform == 'win32':
whole_low = 0xffffffff
whole_high = 0x7fffffff
- def _lock_file(f, exclusive):
+ def _lock_file(f, exclusive, block): # todo: block unused on win32
overlapped = OVERLAPPED()
overlapped.Offset = 0
overlapped.OffsetHigh = 0
@@ -2033,15 +2145,19 @@ else:
try:
import fcntl
- def _lock_file(f, exclusive):
- fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)
+ def _lock_file(f, exclusive, block):
+ fcntl.flock(f,
+ fcntl.LOCK_SH if not exclusive
+ else fcntl.LOCK_EX if block
+ else fcntl.LOCK_EX | fcntl.LOCK_NB)
def _unlock_file(f):
fcntl.flock(f, fcntl.LOCK_UN)
+
except ImportError:
UNSUPPORTED_MSG = 'file locking is not supported on this platform'
- def _lock_file(f, exclusive):
+ def _lock_file(f, exclusive, block):
raise IOError(UNSUPPORTED_MSG)
def _unlock_file(f):
@@ -2049,15 +2165,16 @@ else:
class locked_file(object):
- def __init__(self, filename, mode, encoding=None):
- assert mode in ['r', 'a', 'w']
+ def __init__(self, filename, mode, block=True, encoding=None):
+ assert mode in ['r', 'rb', 'a', 'ab', 'w', 'wb']
self.f = io.open(filename, mode, encoding=encoding)
self.mode = mode
+ self.block = block
def __enter__(self):
- exclusive = self.mode != 'r'
+ exclusive = 'r' not in self.mode
try:
- _lock_file(self.f, exclusive)
+ _lock_file(self.f, exclusive, self.block)
except IOError:
self.f.close()
raise
@@ -2078,6 +2195,15 @@ class locked_file(object):
def read(self, *args):
return self.f.read(*args)
+ def flush(self):
+ self.f.flush()
+
+ def open(self):
+ return self.__enter__()
+
+ def close(self, *args):
+ self.__exit__(self, *args, value=False, traceback=False)
+
def get_filesystem_encoding():
encoding = sys.getfilesystemencoding()
@@ -2120,9 +2246,11 @@ def format_decimal_suffix(num, fmt='%d%s', *, factor=1000):
if num is None:
return None
exponent = 0 if num == 0 else int(math.log(num, factor))
- suffix = ['', *'KMGTPEZY'][exponent]
+ suffix = ['', *'kMGTPEZY'][exponent]
+ if factor == 1024:
+ suffix = {'k': 'Ki', '': ''}.get(suffix, f'{suffix}i')
converted = num / (factor ** exponent)
- return fmt % (converted, f'{suffix}i' if suffix and factor == 1024 else suffix)
+ return fmt % (converted, suffix)
def format_bytes(bytes):
@@ -2382,13 +2510,8 @@ class PUTRequest(compat_urllib_request.Request):
def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1):
- if get_attr:
- if v is not None:
- v = getattr(v, get_attr, None)
- if v == '':
- v = None
- if v is None:
- return default
+ if get_attr and v is not None:
+ v = getattr(v, get_attr, None)
try:
return int(v) * invscale // scale
except (ValueError, TypeError, OverflowError):
@@ -2432,6 +2555,13 @@ def url_or_none(url):
return url if re.match(r'^(?:(?:https?|rt(?:m(?:pt?[es]?|fp)|sp[su]?)|mms|ftps?):)?//', url) else None
+def request_to_url(req):
+ if isinstance(req, compat_urllib_request.Request):
+ return req.get_full_url()
+ else:
+ return req
+
+
def strftime_or_none(timestamp, date_format, default=None):
datetime_object = None
try:
@@ -2452,9 +2582,14 @@ def parse_duration(s):
return None
days, hours, mins, secs, ms = [None] * 5
- m = re.match(r'(?:(?:(?:(?P<days>[0-9]+):)?(?P<hours>[0-9]+):)?(?P<mins>[0-9]+):)?(?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?Z?$', s)
+ m = re.match(r'''(?x)
+ (?P<before_secs>
+ (?:(?:(?P<days>[0-9]+):)?(?P<hours>[0-9]+):)?(?P<mins>[0-9]+):)?
+ (?P<secs>(?(before_secs)[0-9]{1,2}|[0-9]+))
+ (?P<ms>[.:][0-9]+)?Z?$
+ ''', s)
if m:
- days, hours, mins, secs, ms = m.groups()
+ days, hours, mins, secs, ms = m.group('days', 'hours', 'mins', 'secs', 'ms')
else:
m = re.match(
r'''(?ix)(?:P?
@@ -2499,7 +2634,7 @@ def parse_duration(s):
if days:
duration += float(days) * 24 * 60 * 60
if ms:
- duration += float(ms)
+ duration += float(ms.replace(':', '.'))
return duration
@@ -2733,8 +2868,7 @@ class InAdvancePagedList(PagedList):
def _getslice(self, start, end):
start_page = start // self._pagesize
- end_page = (
- self._pagecount if end is None else (end // self._pagesize + 1))
+ end_page = self._pagecount if end is None else min(self._pagecount, end // self._pagesize + 1)
skip_elems = start - start_page * self._pagesize
only_more = None if end is None else end - start
for pagenum in range(start_page, end_page):
@@ -3055,6 +3189,7 @@ OUTTMPL_TYPES = {
'annotation': 'annotations.xml',
'infojson': 'info.json',
'link': None,
+ 'pl_video': None,
'pl_thumbnail': None,
'pl_description': 'description',
'pl_infojson': 'info.json',
@@ -3203,7 +3338,7 @@ def parse_codecs(codecs_str):
return {}
split_codecs = list(filter(None, map(
str.strip, codecs_str.strip().strip(',').split(','))))
- vcodec, acodec, hdr = None, None, None
+ vcodec, acodec, tcodec, hdr = None, None, None, None
for full_codec in split_codecs:
parts = full_codec.split('.')
codec = parts[0].replace('0', '')
@@ -3220,13 +3355,17 @@ def parse_codecs(codecs_str):
elif codec in ('flac', 'mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'):
if not acodec:
acodec = full_codec
+ elif codec in ('stpp', 'wvtt',):
+ if not tcodec:
+ tcodec = full_codec
else:
write_string('WARNING: Unknown codec %s\n' % full_codec, sys.stderr)
- if vcodec or acodec:
+ if vcodec or acodec or tcodec:
return {
'vcodec': vcodec or 'none',
'acodec': acodec or 'none',
'dynamic_range': hdr,
+ **({'tcodec': tcodec} if tcodec is not None else {}),
}
elif len(split_codecs) == 2:
return {
@@ -3316,12 +3455,11 @@ def render_table(header_row, data, delim=False, extra_gap=0, hide_empty=False):
return [max(width(str(v)) for v in col) for col in zip(*table)]
def filter_using_list(row, filterArray):
- return [col for (take, col) in zip(filterArray, row) if take]
+ return [col for take, col in itertools.zip_longest(filterArray, row, fillvalue=True) if take]
- if hide_empty:
- max_lens = get_max_lens(data)
- header_row = filter_using_list(header_row, max_lens)
- data = [filter_using_list(row, max_lens) for row in data]
+ max_lens = get_max_lens(data) if hide_empty else []
+ header_row = filter_using_list(header_row, max_lens)
+ data = [filter_using_list(row, max_lens) for row in data]
table = [header_row] + data
max_lens = get_max_lens(table)
@@ -4860,13 +4998,10 @@ def to_high_limit_path(path):
def format_field(obj, field=None, template='%s', ignore=(None, ''), default='', func=None):
- if field is None:
- val = obj if obj is not None else default
- else:
- val = obj.get(field, default)
- if func and val not in ignore:
- val = func(val)
- return template % val if val not in ignore else default
+ val = traverse_obj(obj, *variadic(field))
+ if val in ignore:
+ return default
+ return template % (func(val) if func else val)
def clean_podcast_url(url):
@@ -4942,11 +5077,12 @@ def traverse_obj(
''' Traverse nested list/dict/tuple
@param path_list A list of paths which are checked one by one.
Each path is a list of keys where each key is a string,
- a function, a tuple of strings or "...".
+ a function, a tuple of strings/None or "...".
When a fuction is given, it takes the key as argument and
returns whether the key matches or not. When a tuple is given,
all the keys given in the tuple are traversed, and
"..." traverses all the keys in the object
+ "None" returns the object without traversal
@param default Default value to return
@param expected_type Only accept final value of this type (Can also be any callable)
@param get_all Return all the values obtained from a path or only the first one
@@ -4965,8 +5101,8 @@ def traverse_obj(
nonlocal depth
path = tuple(variadic(path))
for i, key in enumerate(path):
- if obj is None:
- return None
+ if None in (key, obj):
+ return obj
if isinstance(key, (list, tuple)):
obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key]
key = ...
@@ -5034,7 +5170,6 @@ def traverse_obj(
return default
-# Deprecated
def traverse_dict(dictn, keys, casesense=True):
write_string('DeprecationWarning: yt_dlp.utils.traverse_dict is deprecated '
'and may be removed in a future version. Use yt_dlp.utils.traverse_obj instead')
@@ -5045,6 +5180,22 @@ def variadic(x, allowed_types=(str, bytes, dict)):
return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,)
+def decode_base(value, digits):
+ # This will convert given base-x string to scalar (long or int)
+ table = {char: index for index, char in enumerate(digits)}
+ result = 0
+ base = len(digits)
+ for chr in value:
+ result *= base
+ result += table[chr]
+ return result
+
+
+def time_seconds(**kwargs):
+ t = datetime.datetime.now(datetime.timezone(datetime.timedelta(**kwargs)))
+ return t.timestamp()
+
+
# create a JSON Web Signature (jws) with HS256 algorithm
# the resulting format is in JWS Compact Serialization
# implemented following JWT https://www.rfc-editor.org/rfc/rfc7519.html
@@ -5099,3 +5250,160 @@ def join_nonempty(*values, delim='-', from_dict=None):
if from_dict is not None:
values = map(from_dict.get, values)
return delim.join(map(str, filter(None, values)))
+
+
+class Config:
+ own_args = None
+ filename = None
+ __initialized = False
+
+ def __init__(self, parser, label=None):
+ self._parser, self.label = parser, label
+ self._loaded_paths, self.configs = set(), []
+
+ def init(self, args=None, filename=None):
+ assert not self.__initialized
+ directory = ''
+ if filename:
+ location = os.path.realpath(filename)
+ directory = os.path.dirname(location)
+ if location in self._loaded_paths:
+ return False
+ self._loaded_paths.add(location)
+
+ self.__initialized = True
+ self.own_args, self.filename = args, filename
+ for location in self._parser.parse_args(args)[0].config_locations or []:
+ location = os.path.join(directory, expand_path(location))
+ if os.path.isdir(location):
+ location = os.path.join(location, 'yt-dlp.conf')
+ if not os.path.exists(location):
+ self._parser.error(f'config location {location} does not exist')
+ self.append_config(self.read_file(location), location)
+ return True
+
+ def __str__(self):
+ label = join_nonempty(
+ self.label, 'config', f'"{self.filename}"' if self.filename else '',
+ delim=' ')
+ return join_nonempty(
+ self.own_args is not None and f'{label[0].upper()}{label[1:]}: {self.hide_login_info(self.own_args)}',
+ *(f'\n{c}'.replace('\n', '\n| ')[1:] for c in self.configs),
+ delim='\n')
+
+ @staticmethod
+ def read_file(filename, default=[]):
+ try:
+ optionf = open(filename)
+ except IOError:
+ return default # silently skip if file is not present
+ try:
+ # FIXME: https://github.com/ytdl-org/youtube-dl/commit/dfe5fa49aed02cf36ba9f743b11b0903554b5e56
+ contents = optionf.read()
+ if sys.version_info < (3,):
+ contents = contents.decode(preferredencoding())
+ res = compat_shlex_split(contents, comments=True)
+ finally:
+ optionf.close()
+ return res
+
+ @staticmethod
+ def hide_login_info(opts):
+ PRIVATE_OPTS = set(['-p', '--password', '-u', '--username', '--video-password', '--ap-password', '--ap-username'])
+ eqre = re.compile('^(?P<key>' + ('|'.join(re.escape(po) for po in PRIVATE_OPTS)) + ')=.+$')
+
+ def _scrub_eq(o):
+ m = eqre.match(o)
+ if m:
+ return m.group('key') + '=PRIVATE'
+ else:
+ return o
+
+ opts = list(map(_scrub_eq, opts))
+ for idx, opt in enumerate(opts):
+ if opt in PRIVATE_OPTS and idx + 1 < len(opts):
+ opts[idx + 1] = 'PRIVATE'
+ return opts
+
+ def append_config(self, *args, label=None):
+ config = type(self)(self._parser, label)
+ config._loaded_paths = self._loaded_paths
+ if config.init(*args):
+ self.configs.append(config)
+
+ @property
+ def all_args(self):
+ for config in reversed(self.configs):
+ yield from config.all_args
+ yield from self.own_args or []
+
+ def parse_args(self):
+ return self._parser.parse_args(list(self.all_args))
+
+
+class WebSocketsWrapper():
+ """Wraps websockets module to use in non-async scopes"""
+
+ def __init__(self, url, headers=None):
+ self.loop = asyncio.events.new_event_loop()
+ self.conn = compat_websockets.connect(
+ url, extra_headers=headers, ping_interval=None,
+ close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
+ atexit.register(self.__exit__, None, None, None)
+
+ def __enter__(self):
+ self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
+ return self
+
+ def send(self, *args):
+ self.run_with_loop(self.pool.send(*args), self.loop)
+
+ def recv(self, *args):
+ return self.run_with_loop(self.pool.recv(*args), self.loop)
+
+ def __exit__(self, type, value, traceback):
+ try:
+ return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
+ finally:
+ self.loop.close()
+ self._cancel_all_tasks(self.loop)
+
+ # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
+ # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
+ @staticmethod
+ def run_with_loop(main, loop):
+ if not asyncio.coroutines.iscoroutine(main):
+ raise ValueError(f'a coroutine was expected, got {main!r}')
+
+ try:
+ return loop.run_until_complete(main)
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ if hasattr(loop, 'shutdown_default_executor'):
+ loop.run_until_complete(loop.shutdown_default_executor())
+
+ @staticmethod
+ def _cancel_all_tasks(loop):
+ to_cancel = asyncio.tasks.all_tasks(loop)
+
+ if not to_cancel:
+ return
+
+ for task in to_cancel:
+ task.cancel()
+
+ loop.run_until_complete(
+ asyncio.tasks.gather(*to_cancel, loop=loop, return_exceptions=True))
+
+ for task in to_cancel:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ loop.call_exception_handler({
+ 'message': 'unhandled exception during asyncio.run() shutdown',
+ 'exception': task.exception(),
+ 'task': task,
+ })
+
+
+has_websockets = bool(compat_websockets)