aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--setup.cfg1
-rw-r--r--yt_dlp/YoutubeDL.py2
-rw-r--r--yt_dlp/utils/__init__.py14
-rw-r--r--yt_dlp/utils/_deprecated.py30
-rw-r--r--yt_dlp/utils/_legacy.py163
-rw-r--r--yt_dlp/utils/_utils.py (renamed from yt_dlp/utils.py)458
-rw-r--r--yt_dlp/utils/traversal.py254
8 files changed, 480 insertions, 444 deletions
diff --git a/Makefile b/Makefile
index d5d47629b..f03fe2052 100644
--- a/Makefile
+++ b/Makefile
@@ -74,7 +74,7 @@ offlinetest: codetest
$(PYTHON) -m pytest -k "not download"
# XXX: This is hard to maintain
-CODE_FOLDERS = yt_dlp yt_dlp/downloader yt_dlp/extractor yt_dlp/postprocessor yt_dlp/compat yt_dlp/dependencies
+CODE_FOLDERS = yt_dlp yt_dlp/downloader yt_dlp/extractor yt_dlp/postprocessor yt_dlp/compat yt_dlp/utils yt_dlp/dependencies
yt-dlp: yt_dlp/*.py yt_dlp/*/*.py
mkdir -p zip
for d in $(CODE_FOLDERS) ; do \
diff --git a/setup.cfg b/setup.cfg
index 6deaa7971..68d9e516d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -8,6 +8,7 @@ ignore = E402,E501,E731,E741,W503
max_line_length = 120
per_file_ignores =
devscripts/lazy_load_template.py: F401
+ yt_dlp/utils/__init__.py: F401, F403
[autoflake]
diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py
index 91aec1fe6..b8f1a05a0 100644
--- a/yt_dlp/YoutubeDL.py
+++ b/yt_dlp/YoutubeDL.py
@@ -124,7 +124,6 @@ from .utils import (
parse_filesize,
preferredencoding,
prepend_extension,
- register_socks_protocols,
remove_terminal_sequences,
render_table,
replace_extension,
@@ -739,7 +738,6 @@ class YoutubeDL:
when=when)
self._setup_opener()
- register_socks_protocols()
def preload_download_archive(fn):
"""Preload the archive, if any is specified"""
diff --git a/yt_dlp/utils/__init__.py b/yt_dlp/utils/__init__.py
new file mode 100644
index 000000000..74b39e2c7
--- /dev/null
+++ b/yt_dlp/utils/__init__.py
@@ -0,0 +1,14 @@
+import warnings
+
+from ..compat.compat_utils import passthrough_module
+
+# XXX: Implement this the same way as other DeprecationWarnings without circular import
+passthrough_module(__name__, '._legacy', callback=lambda attr: warnings.warn(
+ DeprecationWarning(f'{__name__}.{attr} is deprecated'), stacklevel=5))
+del passthrough_module
+
+# isort: off
+from .traversal import *
+from ._utils import *
+from ._utils import _configuration_args, _get_exe_version_output
+from ._deprecated import *
diff --git a/yt_dlp/utils/_deprecated.py b/yt_dlp/utils/_deprecated.py
new file mode 100644
index 000000000..4454d84a7
--- /dev/null
+++ b/yt_dlp/utils/_deprecated.py
@@ -0,0 +1,30 @@
+"""Deprecated - New code should avoid these"""
+
+from ._utils import preferredencoding
+
+
+def encodeFilename(s, for_subprocess=False):
+ assert isinstance(s, str)
+ return s
+
+
+def decodeFilename(b, for_subprocess=False):
+ return b
+
+
+def decodeArgument(b):
+ return b
+
+
+def decodeOption(optval):
+ if optval is None:
+ return optval
+ if isinstance(optval, bytes):
+ optval = optval.decode(preferredencoding())
+
+ assert isinstance(optval, str)
+ return optval
+
+
+def error_to_compat_str(err):
+ return str(err)
diff --git a/yt_dlp/utils/_legacy.py b/yt_dlp/utils/_legacy.py
new file mode 100644
index 000000000..cd009b504
--- /dev/null
+++ b/yt_dlp/utils/_legacy.py
@@ -0,0 +1,163 @@
+"""No longer used and new code should not use. Exists only for API compat."""
+
+import platform
+import struct
+import sys
+import urllib.parse
+import zlib
+
+from ._utils import decode_base_n, preferredencoding
+from .traversal import traverse_obj
+from ..dependencies import certifi, websockets
+
+has_certifi = bool(certifi)
+has_websockets = bool(websockets)
+
+
+def load_plugins(name, suffix, namespace):
+ from ..plugins import load_plugins
+ ret = load_plugins(name, suffix)
+ namespace.update(ret)
+ return ret
+
+
+def traverse_dict(dictn, keys, casesense=True):
+ return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True)
+
+
+def decode_base(value, digits):
+ return decode_base_n(value, table=digits)
+
+
+def platform_name():
+ """ Returns the platform name as a str """
+ return platform.platform()
+
+
+def get_subprocess_encoding():
+ if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
+ # For subprocess calls, encode with locale encoding
+ # Refer to http://stackoverflow.com/a/9951851/35070
+ encoding = preferredencoding()
+ else:
+ encoding = sys.getfilesystemencoding()
+ if encoding is None:
+ encoding = 'utf-8'
+ return encoding
+
+
+# UNUSED
+# Based on png2str() written by @gdkchan and improved by @yokrysty
+# Originally posted at https://github.com/ytdl-org/youtube-dl/issues/9706
+def decode_png(png_data):
+ # Reference: https://www.w3.org/TR/PNG/
+ header = png_data[8:]
+
+ if png_data[:8] != b'\x89PNG\x0d\x0a\x1a\x0a' or header[4:8] != b'IHDR':
+ raise OSError('Not a valid PNG file.')
+
+ int_map = {1: '>B', 2: '>H', 4: '>I'}
+ unpack_integer = lambda x: struct.unpack(int_map[len(x)], x)[0]
+
+ chunks = []
+
+ while header:
+ length = unpack_integer(header[:4])
+ header = header[4:]
+
+ chunk_type = header[:4]
+ header = header[4:]
+
+ chunk_data = header[:length]
+ header = header[length:]
+
+ header = header[4:] # Skip CRC
+
+ chunks.append({
+ 'type': chunk_type,
+ 'length': length,
+ 'data': chunk_data
+ })
+
+ ihdr = chunks[0]['data']
+
+ width = unpack_integer(ihdr[:4])
+ height = unpack_integer(ihdr[4:8])
+
+ idat = b''
+
+ for chunk in chunks:
+ if chunk['type'] == b'IDAT':
+ idat += chunk['data']
+
+ if not idat:
+ raise OSError('Unable to read PNG data.')
+
+ decompressed_data = bytearray(zlib.decompress(idat))
+
+ stride = width * 3
+ pixels = []
+
+ def _get_pixel(idx):
+ x = idx % stride
+ y = idx // stride
+ return pixels[y][x]
+
+ for y in range(height):
+ basePos = y * (1 + stride)
+ filter_type = decompressed_data[basePos]
+
+ current_row = []
+
+ pixels.append(current_row)
+
+ for x in range(stride):
+ color = decompressed_data[1 + basePos + x]
+ basex = y * stride + x
+ left = 0
+ up = 0
+
+ if x > 2:
+ left = _get_pixel(basex - 3)
+ if y > 0:
+ up = _get_pixel(basex - stride)
+
+ if filter_type == 1: # Sub
+ color = (color + left) & 0xff
+ elif filter_type == 2: # Up
+ color = (color + up) & 0xff
+ elif filter_type == 3: # Average
+ color = (color + ((left + up) >> 1)) & 0xff
+ elif filter_type == 4: # Paeth
+ a = left
+ b = up
+ c = 0
+
+ if x > 2 and y > 0:
+ c = _get_pixel(basex - stride - 3)
+
+ p = a + b - c
+
+ pa = abs(p - a)
+ pb = abs(p - b)
+ pc = abs(p - c)
+
+ if pa <= pb and pa <= pc:
+ color = (color + a) & 0xff
+ elif pb <= pc:
+ color = (color + b) & 0xff
+ else:
+ color = (color + c) & 0xff
+
+ current_row.append(color)
+
+ return width, height, pixels
+
+
+def register_socks_protocols():
+ # "Register" SOCKS protocols
+ # In Python < 2.6.5, urlsplit() suffers from bug https://bugs.python.org/issue7904
+ # URLs with protocols not in urlparse.uses_netloc are not handled correctly
+ for scheme in ('socks', 'socks4', 'socks4a', 'socks5'):
+ if scheme not in urllib.parse.uses_netloc:
+ urllib.parse.uses_netloc.append(scheme)
diff --git a/yt_dlp/utils.py b/yt_dlp/utils/_utils.py
index 190af1b7d..f032af901 100644
--- a/yt_dlp/utils.py
+++ b/yt_dlp/utils/_utils.py
@@ -47,26 +47,18 @@ import urllib.request
import xml.etree.ElementTree
import zlib
-from .compat import functools # isort: split
-from .compat import (
+from . import traversal
+
+from ..compat import functools # isort: split
+from ..compat import (
compat_etree_fromstring,
compat_expanduser,
compat_HTMLParseError,
compat_os_name,
compat_shlex_quote,
)
-from .dependencies import brotli, certifi, websockets, xattr
-from .socks import ProxyType, sockssocket
-
-
-def register_socks_protocols():
- # "Register" SOCKS protocols
- # In Python < 2.6.5, urlsplit() suffers from bug https://bugs.python.org/issue7904
- # URLs with protocols not in urlparse.uses_netloc are not handled correctly
- for scheme in ('socks', 'socks4', 'socks4a', 'socks5'):
- if scheme not in urllib.parse.uses_netloc:
- urllib.parse.uses_netloc.append(scheme)
-
+from ..dependencies import brotli, certifi, websockets, xattr
+from ..socks import ProxyType, sockssocket
# This is not clearly defined otherwise
compiled_regex_type = type(re.compile(''))
@@ -928,27 +920,6 @@ class Popen(subprocess.Popen):
return stdout or default, stderr or default, proc.returncode
-def get_subprocess_encoding():
- if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
- # For subprocess calls, encode with locale encoding
- # Refer to http://stackoverflow.com/a/9951851/35070
- encoding = preferredencoding()
- else:
- encoding = sys.getfilesystemencoding()
- if encoding is None:
- encoding = 'utf-8'
- return encoding
-
-
-def encodeFilename(s, for_subprocess=False):
- assert isinstance(s, str)
- return s
-
-
-def decodeFilename(b, for_subprocess=False):
- return b
-
-
def encodeArgument(s):
# Legacy code that uses byte strings
# Uncomment the following line after fixing all post processors
@@ -956,20 +927,6 @@ def encodeArgument(s):
return s if isinstance(s, str) else s.decode('ascii')
-def decodeArgument(b):
- return b
-
-
-def decodeOption(optval):
- if optval is None:
- return optval
- if isinstance(optval, bytes):
- optval = optval.decode(preferredencoding())
-
- assert isinstance(optval, str)
- return optval
-
-
_timetuple = collections.namedtuple('Time', ('hours', 'minutes', 'seconds', 'milliseconds'))
@@ -1034,7 +991,7 @@ def make_HTTPS_handler(params, **kwargs):
context.verify_mode = ssl.CERT_REQUIRED if opts_check_certificate else ssl.CERT_NONE
if opts_check_certificate:
- if has_certifi and 'no-certifi' not in params.get('compat_opts', []):
+ if certifi and 'no-certifi' not in params.get('compat_opts', []):
context.load_verify_locations(cafile=certifi.where())
else:
try:
@@ -1068,7 +1025,7 @@ def make_HTTPS_handler(params, **kwargs):
def bug_reports_message(before=';'):
- from .update import REPOSITORY
+ from ..update import REPOSITORY
msg = (f'please report this issue on https://github.com/{REPOSITORY}/issues?q= , '
'filling out the appropriate issue template. Confirm you are on the latest version using yt-dlp -U')
@@ -2019,12 +1976,6 @@ class DateRange:
and self.start == other.start and self.end == other.end)
-def platform_name():
- """ Returns the platform name as a str """
- deprecation_warning(f'"{__name__}.platform_name" is deprecated, use "platform.platform" instead')
- return platform.platform()
-
-
@functools.cache
def system_identifier():
python_implementation = platform.python_implementation()
@@ -2076,7 +2027,7 @@ def write_string(s, out=None, encoding=None):
def deprecation_warning(msg, *, printer=None, stacklevel=0, **kwargs):
- from . import _IN_CLI
+ from .. import _IN_CLI
if _IN_CLI:
if msg in deprecation_warning._cache:
return
@@ -3284,13 +3235,6 @@ def variadic(x, allowed_types=NO_DEFAULT):
return x if is_iterable_like(x, blocked_types=allowed_types) else (x, )
-def dict_get(d, key_or_keys, default=None, skip_false_values=True):
- for val in map(d.get, variadic(key_or_keys)):
- if val is not None and (val or not skip_false_values):
- return val
- return default
-
-
def try_call(*funcs, expected_type=None, args=[], kwargs={}):
for f in funcs:
try:
@@ -3528,7 +3472,7 @@ def is_outdated_version(version, limit, assume_new=True):
def ytdl_is_updateable():
""" Returns if yt-dlp can be updated with -U """
- from .update import is_non_updateable
+ from ..update import is_non_updateable
return not is_non_updateable()
@@ -3538,10 +3482,6 @@ def args_to_str(args):
return ' '.join(compat_shlex_quote(a) for a in args)
-def error_to_compat_str(err):
- return str(err)
-
-
def error_to_str(err):
return f'{type(err).__name__}: {err}'
@@ -3628,7 +3568,7 @@ def mimetype2ext(mt, default=NO_DEFAULT):
mimetype = mt.partition(';')[0].strip().lower()
_, _, subtype = mimetype.rpartition('/')
- ext = traverse_obj(MAP, mimetype, subtype, subtype.rsplit('+')[-1])
+ ext = traversal.traverse_obj(MAP, mimetype, subtype, subtype.rsplit('+')[-1])
if ext:
return ext
elif default is not NO_DEFAULT:
@@ -3660,7 +3600,7 @@ def parse_codecs(codecs_str):
vcodec = full_codec
if parts[0] in ('dvh1', 'dvhe'):
hdr = 'DV'
- elif parts[0] == 'av1' and traverse_obj(parts, 3) == '10':
+ elif parts[0] == 'av1' and traversal.traverse_obj(parts, 3) == '10':
hdr = 'HDR10'
elif parts[:2] == ['vp9', '2']:
hdr = 'HDR10'
@@ -3706,8 +3646,7 @@ def get_compatible_ext(*, vcodecs, acodecs, vexts, aexts, preferences=None):
},
}
- sanitize_codec = functools.partial(
- try_get, getter=lambda x: x[0].split('.')[0].replace('0', '').lower())
+ sanitize_codec = functools.partial(try_get, getter=lambda x: x[0].split('.')[0].replace('0', ''))
vcodec, acodec = sanitize_codec(vcodecs), sanitize_codec(acodecs)
for ext in preferences or COMPATIBLE_CODECS.keys():
@@ -5088,12 +5027,6 @@ def decode_base_n(string, n=None, table=None):
return result
-def decode_base(value, digits):
- deprecation_warning(f'{__name__}.decode_base is deprecated and may be removed '
- f'in a future version. Use {__name__}.decode_base_n instead')
- return decode_base_n(value, table=digits)
-
-
def decode_packed_codes(code):
mobj = re.search(PACKED_CODES_RE, code)
obfuscated_code, base, count, symbols = mobj.groups()
@@ -5138,113 +5071,6 @@ def urshift(val, n):
return val >> n if val >= 0 else (val + 0x100000000) >> n
-# Based on png2str() written by @gdkchan and improved by @yokrysty
-# Originally posted at https://github.com/ytdl-org/youtube-dl/issues/9706
-def decode_png(png_data):
- # Reference: https://www.w3.org/TR/PNG/
- header = png_data[8:]
-
- if png_data[:8] != b'\x89PNG\x0d\x0a\x1a\x0a' or header[4:8] != b'IHDR':
- raise OSError('Not a valid PNG file.')
-
- int_map = {1: '>B', 2: '>H', 4: '>I'}
- unpack_integer = lambda x: struct.unpack(int_map[len(x)], x)[0]
-
- chunks = []
-
- while header:
- length = unpack_integer(header[:4])
- header = header[4:]
-
- chunk_type = header[:4]
- header = header[4:]
-
- chunk_data = header[:length]
- header = header[length:]
-
- header = header[4:] # Skip CRC
-
- chunks.append({
- 'type': chunk_type,
- 'length': length,
- 'data': chunk_data
- })
-
- ihdr = chunks[0]['data']
-
- width = unpack_integer(ihdr[:4])
- height = unpack_integer(ihdr[4:8])
-
- idat = b''
-
- for chunk in chunks:
- if chunk['type'] == b'IDAT':
- idat += chunk['data']
-
- if not idat:
- raise OSError('Unable to read PNG data.')
-
- decompressed_data = bytearray(zlib.decompress(idat))
-
- stride = width * 3
- pixels = []
-
- def _get_pixel(idx):
- x = idx % stride
- y = idx // stride
- return pixels[y][x]
-
- for y in range(height):
- basePos = y * (1 + stride)
- filter_type = decompressed_data[basePos]
-
- current_row = []
-
- pixels.append(current_row)
-
- for x in range(stride):
- color = decompressed_data[1 + basePos + x]
- basex = y * stride + x
- left = 0
- up = 0
-
- if x > 2:
- left = _get_pixel(basex - 3)
- if y > 0:
- up = _get_pixel(basex - stride)
-
- if filter_type == 1: # Sub
- color = (color + left) & 0xff
- elif filter_type == 2: # Up
- color = (color + up) & 0xff
- elif filter_type == 3: # Average
- color = (color + ((left + up) >> 1)) & 0xff
- elif filter_type == 4: # Paeth
- a = left
- b = up
- c = 0
-
- if x > 2 and y > 0:
- c = _get_pixel(basex - stride - 3)
-
- p = a + b - c
-
- pa = abs(p - a)
- pb = abs(p - b)
- pc = abs(p - c)
-
- if pa <= pb and pa <= pc:
- color = (color + a) & 0xff
- elif pb <= pc:
- color = (color + b) & 0xff
- else:
- color = (color + c) & 0xff
-
- current_row.append(color)
-
- return width, height, pixels
-
-
def write_xattr(path, key, value):
# Windows: Write xattrs to NTFS Alternate Data Streams:
# http://en.wikipedia.org/wiki/NTFS#Alternate_data_streams_.28ADS.29
@@ -5403,7 +5229,7 @@ def to_high_limit_path(path):
def format_field(obj, field=None, template='%s', ignore=NO_DEFAULT, default='', func=IDENTITY):
- val = traverse_obj(obj, *variadic(field))
+ val = traversal.traverse_obj(obj, *variadic(field))
if not val if ignore is NO_DEFAULT else val in variadic(ignore):
return default
return template % func(val)
@@ -5441,12 +5267,12 @@ def make_dir(path, to_screen=None):
return True
except OSError as err:
if callable(to_screen) is not None:
- to_screen('unable to create directory ' + error_to_compat_str(err))
+ to_screen(f'unable to create directory {err}')
return False
def get_executable_path():
- from .update import _get_variant_and_executable_path
+ from ..update import _get_variant_and_executable_path
return os.path.dirname(os.path.abspath(_get_variant_and_executable_path()[1]))
@@ -5470,244 +5296,6 @@ def get_system_config_dirs(package_name):
yield os.path.join('/etc', package_name)
-def traverse_obj(
- obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
- casesense=True, is_user_input=False, traverse_string=False):
- """
- Safely traverse nested `dict`s and `Iterable`s
-
- >>> obj = [{}, {"key": "value"}]
- >>> traverse_obj(obj, (1, "key"))
- "value"
-
- Each of the provided `paths` is tested and the first producing a valid result will be returned.
- The next path will also be tested if the path branched but no results could be found.
- Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
- Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
-
- The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
-
- The keys in the path can be one of:
- - `None`: Return the current object.
- - `set`: Requires the only item in the set to be a type or function,
- like `{type}`/`{func}`. If a `type`, returns only values
- of this type. If a function, returns `func(obj)`.
- - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- - `slice`: Branch out and return all values in `obj[key]`.
- - `Ellipsis`: Branch out and return a list of all values.
- - `tuple`/`list`: Branch out and return a list of all matching values.
- Read as: `[traverse_obj(obj, branch) for branch in branches]`.
- - `function`: Branch out and return values filtered by the function.
- Read as: `[value for key, value in obj if function(key, value)]`.
- For `Iterable`s, `key` is the index of the value.
- For `re.Match`es, `key` is the group number (0 = full match)
- as well as additionally any group names, if given.
- - `dict` Transform the current object and return a matching dict.
- Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
-
- `tuple`, `list`, and `dict` all support nested paths and branches.
-
- @params paths Paths which to traverse by.
- @param default Value to return if the paths do not match.
- If the last key in the path is a `dict`, it will apply to each value inside
- the dict instead, depth first. Try to avoid if using nested `dict` keys.
- @param expected_type If a `type`, only accept final values of this type.
- If any other callable, try to call the function on each result.
- If the last key in the path is a `dict`, it will apply to each value inside
- the dict instead, recursively. This does respect branching paths.
- @param get_all If `False`, return the first matching result, otherwise all matching ones.
- @param casesense If `False`, consider string dictionary keys as case insensitive.
-
- The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
-
- @param is_user_input Whether the keys are generated from user input.
- If `True` strings get converted to `int`/`slice` if needed.
- @param traverse_string Whether to traverse into objects as strings.
- If `True`, any non-compatible object will first be
- converted into a string and then traversed into.
- The return value of that path will be a string instead,
- not respecting any further branching.
-
-
- @returns The result of the object traversal.
- If successful, `get_all=True`, and the path branches at least once,
- then a list of results is returned instead.
- If no `default` is given and the last path branches, a `list` of results
- is always returned. If a path ends on a `dict` that result will always be a `dict`.
- """
- casefold = lambda k: k.casefold() if isinstance(k, str) else k
-
- if isinstance(expected_type, type):
- type_test = lambda val: val if isinstance(val, expected_type) else None
- else:
- type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
-
- def apply_key(key, obj, is_last):
- branching = False
- result = None
-
- if obj is None and traverse_string:
- if key is ... or callable(key) or isinstance(key, slice):
- branching = True
- result = ()
-
- elif key is None:
- result = obj
-
- elif isinstance(key, set):
- assert len(key) == 1, 'Set should only be used to wrap a single item'
- item = next(iter(key))
- if isinstance(item, type):
- if isinstance(obj, item):
- result = obj
- else:
- result = try_call(item, args=(obj,))
-
- elif isinstance(key, (list, tuple)):
- branching = True
- result = itertools.chain.from_iterable(
- apply_path(obj, branch, is_last)[0] for branch in key)
-
- elif key is ...:
- branching = True
- if isinstance(obj, collections.abc.Mapping):
- result = obj.values()
- elif is_iterable_like(obj):
- result = obj
- elif isinstance(obj, re.Match):
- result = obj.groups()
- elif traverse_string:
- branching = False
- result = str(obj)
- else:
- result = ()
-
- elif callable(key):
- branching = True
- if isinstance(obj, collections.abc.Mapping):
- iter_obj = obj.items()
- elif is_iterable_like(obj):
- iter_obj = enumerate(obj)
- elif isinstance(obj, re.Match):
- iter_obj = itertools.chain(
- enumerate((obj.group(), *obj.groups())),
- obj.groupdict().items())
- elif traverse_string:
- branching = False
- iter_obj = enumerate(str(obj))
- else:
- iter_obj = ()
-
- result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
- if not branching: # string traversal
- result = ''.join(result)
-
- elif isinstance(key, dict):
- iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
- result = {
- k: v if v is not None else default for k, v in iter_obj
- if v is not None or default is not NO_DEFAULT
- } or None
-
- elif isinstance(obj, collections.abc.Mapping):
- result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
- next((v for k, v in obj.items() if casefold(k) == key), None))
-
- elif isinstance(obj, re.Match):
- if isinstance(key, int) or casesense:
- with contextlib.suppress(IndexError):
- result = obj.group(key)
-
- elif isinstance(key, str):
- result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
-
- elif isinstance(key, (int, slice)):
- if is_iterable_like(obj, collections.abc.Sequence):
- branching = isinstance(key, slice)
- with contextlib.suppress(IndexError):
- result = obj[key]
- elif traverse_string:
- with contextlib.suppress(IndexError):
- result = str(obj)[key]
-
- return branching, result if branching else (result,)
-
- def lazy_last(iterable):
- iterator = iter(iterable)
- prev = next(iterator, NO_DEFAULT)
- if prev is NO_DEFAULT:
- return
-
- for item in iterator:
- yield False, prev
- prev = item
-
- yield True, prev
-
- def apply_path(start_obj, path, test_type):
- objs = (start_obj,)
- has_branched = False
-
- key = None
- for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
- if is_user_input and isinstance(key, str):
- if key == ':':
- key = ...
- elif ':' in key:
- key = slice(*map(int_or_none, key.split(':')))
- elif int_or_none(key) is not None:
- key = int(key)
-
- if not casesense and isinstance(key, str):
- key = key.casefold()
-
- if __debug__ and callable(key):
- # Verify function signature
- inspect.signature(key).bind(None, None)
-
- new_objs = []
- for obj in objs:
- branching, results = apply_key(key, obj, last)
- has_branched |= branching
- new_objs.append(results)
-
- objs = itertools.chain.from_iterable(new_objs)
-
- if test_type and not isinstance(key, (dict, list, tuple)):
- objs = map(type_test, objs)
-
- return objs, has_branched, isinstance(key, dict)
-
- def _traverse_obj(obj, path, allow_empty, test_type):
- results, has_branched, is_dict = apply_path(obj, path, test_type)
- results = LazyList(item for item in results if item not in (None, {}))
- if get_all and has_branched:
- if results:
- return results.exhaust()
- if allow_empty:
- return [] if default is NO_DEFAULT else default
- return None
-
- return results[0] if results else {} if allow_empty and is_dict else None
-
- for index, path in enumerate(paths, 1):
- result = _traverse_obj(obj, path, index == len(paths), True)
- if result is not None:
- return result
-
- return None if default is NO_DEFAULT else default
-
-
-def traverse_dict(dictn, keys, casesense=True):
- deprecation_warning(f'"{__name__}.traverse_dict" is deprecated and may be removed '
- f'in a future version. Use "{__name__}.traverse_obj" instead')
- return traverse_obj(dictn, keys, casesense=casesense, is_user_input=True, traverse_string=True)
-
-
-def get_first(obj, *paths, **kwargs):
- return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
-
-
def time_seconds(**kwargs):
"""
Returns TZ-aware time in seconds since the epoch (1970-01-01T00:00:00Z)
@@ -5803,7 +5391,7 @@ def number_of_digits(number):
def join_nonempty(*values, delim='-', from_dict=None):
if from_dict is not None:
- values = (traverse_obj(from_dict, variadic(v)) for v in values)
+ values = (traversal.traverse_obj(from_dict, variadic(v)) for v in values)
return delim.join(map(str, filter(None, values)))
@@ -6514,15 +6102,3 @@ class FormatSorter:
format['abr'] = format.get('tbr') - format.get('vbr', 0)
return tuple(self._calculate_field_preference(format, field) for field in self._order)
-
-
-# Deprecated
-has_certifi = bool(certifi)
-has_websockets = bool(websockets)
-
-
-def load_plugins(name, suffix, namespace):
- from .plugins import load_plugins
- ret = load_plugins(name, suffix)
- namespace.update(ret)
- return ret
diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py
new file mode 100644
index 000000000..462c3ba5d
--- /dev/null
+++ b/yt_dlp/utils/traversal.py
@@ -0,0 +1,254 @@
+import collections.abc
+import contextlib
+import inspect
+import itertools
+import re
+
+from ._utils import (
+ IDENTITY,
+ NO_DEFAULT,
+ LazyList,
+ int_or_none,
+ is_iterable_like,
+ try_call,
+ variadic,
+)
+
+
+def traverse_obj(
+ obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
+ casesense=True, is_user_input=False, traverse_string=False):
+ """
+ Safely traverse nested `dict`s and `Iterable`s
+
+ >>> obj = [{}, {"key": "value"}]
+ >>> traverse_obj(obj, (1, "key"))
+ "value"
+
+ Each of the provided `paths` is tested and the first producing a valid result will be returned.
+ The next path will also be tested if the path branched but no results could be found.
+ Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
+ Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
+
+ The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
+
+ The keys in the path can be one of:
+ - `None`: Return the current object.
+ - `set`: Requires the only item in the set to be a type or function,
+ like `{type}`/`{func}`. If a `type`, returns only values
+ of this type. If a function, returns `func(obj)`.
+ - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
+ - `slice`: Branch out and return all values in `obj[key]`.
+ - `Ellipsis`: Branch out and return a list of all values.
+ - `tuple`/`list`: Branch out and return a list of all matching values.
+ Read as: `[traverse_obj(obj, branch) for branch in branches]`.
+ - `function`: Branch out and return values filtered by the function.
+ Read as: `[value for key, value in obj if function(key, value)]`.
+ For `Iterable`s, `key` is the index of the value.
+ For `re.Match`es, `key` is the group number (0 = full match)
+ as well as additionally any group names, if given.
+ - `dict` Transform the current object and return a matching dict.
+ Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
+
+ `tuple`, `list`, and `dict` all support nested paths and branches.
+
+ @params paths Paths which to traverse by.
+ @param default Value to return if the paths do not match.
+ If the last key in the path is a `dict`, it will apply to each value inside
+ the dict instead, depth first. Try to avoid if using nested `dict` keys.
+ @param expected_type If a `type`, only accept final values of this type.
+ If any other callable, try to call the function on each result.
+ If the last key in the path is a `dict`, it will apply to each value inside
+ the dict instead, recursively. This does respect branching paths.
+ @param get_all If `False`, return the first matching result, otherwise all matching ones.
+ @param casesense If `False`, consider string dictionary keys as case insensitive.
+
+ The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
+
+ @param is_user_input Whether the keys are generated from user input.
+ If `True` strings get converted to `int`/`slice` if needed.
+ @param traverse_string Whether to traverse into objects as strings.
+ If `True`, any non-compatible object will first be
+ converted into a string and then traversed into.
+ The return value of that path will be a string instead,
+ not respecting any further branching.
+
+
+ @returns The result of the object traversal.
+ If successful, `get_all=True`, and the path branches at least once,
+ then a list of results is returned instead.
+ If no `default` is given and the last path branches, a `list` of results
+ is always returned. If a path ends on a `dict` that result will always be a `dict`.
+ """
+ casefold = lambda k: k.casefold() if isinstance(k, str) else k
+
+ if isinstance(expected_type, type):
+ type_test = lambda val: val if isinstance(val, expected_type) else None
+ else:
+ type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
+
+ def apply_key(key, obj, is_last):
+ branching = False
+ result = None
+
+ if obj is None and traverse_string:
+ if key is ... or callable(key) or isinstance(key, slice):
+ branching = True
+ result = ()
+
+ elif key is None:
+ result = obj
+
+ elif isinstance(key, set):
+ assert len(key) == 1, 'Set should only be used to wrap a single item'
+ item = next(iter(key))
+ if isinstance(item, type):
+ if isinstance(obj, item):
+ result = obj
+ else:
+ result = try_call(item, args=(obj,))
+
+ elif isinstance(key, (list, tuple)):
+ branching = True
+ result = itertools.chain.from_iterable(
+ apply_path(obj, branch, is_last)[0] for branch in key)
+
+ elif key is ...:
+ branching = True
+ if isinstance(obj, collections.abc.Mapping):
+ result = obj.values()
+ elif is_iterable_like(obj):
+ result = obj
+ elif isinstance(obj, re.Match):
+ result = obj.groups()
+ elif traverse_string:
+ branching = False
+ result = str(obj)
+ else:
+ result = ()
+
+ elif callable(key):
+ branching = True
+ if isinstance(obj, collections.abc.Mapping):
+ iter_obj = obj.items()
+ elif is_iterable_like(obj):
+ iter_obj = enumerate(obj)
+ elif isinstance(obj, re.Match):
+ iter_obj = itertools.chain(
+ enumerate((obj.group(), *obj.groups())),
+ obj.groupdict().items())
+ elif traverse_string:
+ branching = False
+ iter_obj = enumerate(str(obj))
+ else:
+ iter_obj = ()
+
+ result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
+ if not branching: # string traversal
+ result = ''.join(result)
+
+ elif isinstance(key, dict):
+ iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
+ result = {
+ k: v if v is not None else default for k, v in iter_obj
+ if v is not None or default is not NO_DEFAULT
+ } or None
+
+ elif isinstance(obj, collections.abc.Mapping):
+ result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
+ next((v for k, v in obj.items() if casefold(k) == key), None))
+
+ elif isinstance(obj, re.Match):
+ if isinstance(key, int) or casesense:
+ with contextlib.suppress(IndexError):
+ result = obj.group(key)
+
+ elif isinstance(key, str):
+ result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
+
+ elif isinstance(key, (int, slice)):
+ if is_iterable_like(obj, collections.abc.Sequence):
+ branching = isinstance(key, slice)
+ with contextlib.suppress(IndexError):
+ result = obj[key]
+ elif traverse_string:
+ with contextlib.suppress(IndexError):
+ result = str(obj)[key]
+
+ return branching, result if branching else (result,)
+
+ def lazy_last(iterable):
+ iterator = iter(iterable)
+ prev = next(iterator, NO_DEFAULT)
+ if prev is NO_DEFAULT:
+ return
+
+ for item in iterator:
+ yield False, prev
+ prev = item
+
+ yield True, prev
+
+ def apply_path(start_obj, path, test_type):
+ objs = (start_obj,)
+ has_branched = False
+
+ key = None
+ for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
+ if is_user_input and isinstance(key, str):
+ if key == ':':
+ key = ...
+ elif ':' in key:
+ key = slice(*map(int_or_none, key.split(':')))
+ elif int_or_none(key) is not None:
+ key = int(key)
+
+ if not casesense and isinstance(key, str):
+ key = key.casefold()
+
+ if __debug__ and callable(key):
+ # Verify function signature
+ inspect.signature(key).bind(None, None)
+
+ new_objs = []
+ for obj in objs:
+ branching, results = apply_key(key, obj, last)
+ has_branched |= branching
+ new_objs.append(results)
+
+ objs = itertools.chain.from_iterable(new_objs)
+
+ if test_type and not isinstance(key, (dict, list, tuple)):
+ objs = map(type_test, objs)
+
+ return objs, has_branched, isinstance(key, dict)
+
+ def _traverse_obj(obj, path, allow_empty, test_type):
+ results, has_branched, is_dict = apply_path(obj, path, test_type)
+ results = LazyList(item for item in results if item not in (None, {}))
+ if get_all and has_branched:
+ if results:
+ return results.exhaust()
+ if allow_empty:
+ return [] if default is NO_DEFAULT else default
+ return None
+
+ return results[0] if results else {} if allow_empty and is_dict else None
+
+ for index, path in enumerate(paths, 1):
+ result = _traverse_obj(obj, path, index == len(paths), True)
+ if result is not None:
+ return result
+
+ return None if default is NO_DEFAULT else default
+
+
+def get_first(obj, *paths, **kwargs):
+ return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
+
+
+def dict_get(d, key_or_keys, default=None, skip_false_values=True):
+ for val in map(d.get, variadic(key_or_keys)):
+ if val is not None and (val or not skip_false_values):
+ return val
+ return default