aboutsummaryrefslogtreecommitdiffstats
path: root/yt_dlp/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'yt_dlp/utils.py')
-rw-r--r--yt_dlp/utils.py141
1 files changed, 81 insertions, 60 deletions
diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py
index e1e0f7b25..878b2b6a8 100644
--- a/yt_dlp/utils.py
+++ b/yt_dlp/utils.py
@@ -5420,7 +5420,7 @@ def traverse_obj(
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
- A value of None is treated as the absence of a value.
+ Unhelpful values (`[]`, `{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@@ -5446,6 +5446,8 @@ def traverse_obj(
@params paths Paths which to traverse by.
@param default Value to return if the paths do not match.
+ If the last key in the path is a `dict`, it will apply to each value inside
+ the dict instead, depth first. Try to avoid if using nested `dict` keys.
@param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside
@@ -5460,12 +5462,15 @@ def traverse_obj(
@param traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be
converted into a string and then traversed into.
+ The return value of that path will be a string instead,
+ not respecting any further branching.
@returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead.
- A list is always returned if the last path branches and no `default` is given.
+ If no `default` is given and the last path branches, a `list` of results
+ is always returned. If a path ends on a `dict` that result will always be a `dict`.
"""
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: k.casefold() if isinstance(k, str) else k
@@ -5475,87 +5480,94 @@ def traverse_obj(
else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
- def apply_key(key, test_type, obj):
+ def apply_key(key, obj, is_last):
+ branching = False
+ result = None
+
if obj is None:
- return
+ pass
elif key is None:
- yield obj
+ result = obj
elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key))
if isinstance(item, type):
if isinstance(obj, item):
- yield obj
+ result = obj
else:
- yield try_call(item, args=(obj,))
+ result = try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)):
- for branch in key:
- _, result = apply_path(obj, branch, test_type)
- yield from result
+ branching = True
+ result = itertools.chain.from_iterable(
+ apply_path(obj, branch, is_last)[0] for branch in key)
elif key is ...:
+ branching = True
if isinstance(obj, collections.abc.Mapping):
- yield from obj.values()
+ result = obj.values()
elif is_sequence(obj):
- yield from obj
+ result = obj
elif isinstance(obj, re.Match):
- yield from obj.groups()
+ result = obj.groups()
elif traverse_string:
- yield from str(obj)
+ branching = False
+ result = str(obj)
+ else:
+ result = ()
elif callable(key):
- if is_sequence(obj):
- iter_obj = enumerate(obj)
- elif isinstance(obj, collections.abc.Mapping):
+ branching = True
+ if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
+ elif is_sequence(obj):
+ iter_obj = enumerate(obj)
elif isinstance(obj, re.Match):
iter_obj = itertools.chain(
enumerate((obj.group(), *obj.groups())),
obj.groupdict().items())
elif traverse_string:
+ branching = False
iter_obj = enumerate(str(obj))
else:
- return
- yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))
+ iter_obj = ()
+
+ result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
+ if not branching: # string traversal
+ result = ''.join(result)
elif isinstance(key, dict):
- iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items())
- yield {k: v if v is not None else default for k, v in iter_obj
- if v is not None or default is not NO_DEFAULT}
+ iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
+ result = {
+ k: v if v is not None else default for k, v in iter_obj
+ if v is not None or default is not NO_DEFAULT
+ } or None
elif isinstance(obj, collections.abc.Mapping):
- yield (obj.get(key) if casesense or (key in obj)
- else next((v for k, v in obj.items() if casefold(k) == key), None))
+ result = (obj.get(key) if casesense or (key in obj) else
+ next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, re.Match):
if isinstance(key, int) or casesense:
with contextlib.suppress(IndexError):
- yield obj.group(key)
- return
+ result = obj.group(key)
- if not isinstance(key, str):
- return
-
- yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
-
- else:
- if is_user_input:
- key = (int_or_none(key) if ':' not in key
- else slice(*map(int_or_none, key.split(':'))))
-
- if not isinstance(key, (int, slice)):
- return
+ elif isinstance(key, str):
+ result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
+ elif isinstance(key, (int, slice)):
if not is_sequence(obj):
- if not traverse_string:
- return
- obj = str(obj)
+ if traverse_string:
+ with contextlib.suppress(IndexError):
+ result = str(obj)[key]
+ else:
+ branching = isinstance(key, slice)
+ with contextlib.suppress(IndexError):
+ result = obj[key]
- with contextlib.suppress(IndexError):
- yield obj[key]
+ return branching, result if branching else (result,)
def lazy_last(iterable):
iterator = iter(iterable)
@@ -5569,45 +5581,54 @@ def traverse_obj(
yield True, prev
- def apply_path(start_obj, path, test_type=False):
+ def apply_path(start_obj, path, test_type):
objs = (start_obj,)
has_branched = False
key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
- if is_user_input and key == ':':
- key = ...
+ if is_user_input and isinstance(key, str):
+ if key == ':':
+ key = ...
+ elif ':' in key:
+ key = slice(*map(int_or_none, key.split(':')))
+ elif int_or_none(key) is not None:
+ key = int(key)
if not casesense and isinstance(key, str):
key = key.casefold()
- if key is ... or isinstance(key, (list, tuple)) or callable(key):
- has_branched = True
-
if __debug__ and callable(key):
# Verify function signature
inspect.signature(key).bind(None, None)
- key_func = functools.partial(apply_key, key, last)
- objs = itertools.chain.from_iterable(map(key_func, objs))
+ new_objs = []
+ for obj in objs:
+ branching, results = apply_key(key, obj, last)
+ has_branched |= branching
+ new_objs.append(results)
+
+ objs = itertools.chain.from_iterable(new_objs)
if test_type and not isinstance(key, (dict, list, tuple)):
objs = map(type_test, objs)
- return has_branched, objs
-
- def _traverse_obj(obj, path, use_list=True, test_type=True):
- has_branched, results = apply_path(obj, path, test_type)
- results = LazyList(x for x in results if x is not None)
+ return objs, has_branched, isinstance(key, dict)
+ def _traverse_obj(obj, path, allow_empty, test_type):
+ results, has_branched, is_dict = apply_path(obj, path, test_type)
+ results = LazyList(item for item in results if item not in (None, [], {}))
if get_all and has_branched:
- return results.exhaust() if results or use_list else None
+ if results:
+ return results.exhaust()
+ if allow_empty:
+ return [] if default is NO_DEFAULT else default
+ return None
- return results[0] if results else None
+ return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1):
- use_list = default is NO_DEFAULT and index == len(paths)
- result = _traverse_obj(obj, path, use_list)
+ result = _traverse_obj(obj, path, index == len(paths), True)
if result is not None:
return result