aboutsummaryrefslogtreecommitdiffstats
path: root/yt_dlp/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'yt_dlp/utils.py')
-rw-r--r--yt_dlp/utils.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py
index f69311462..2f5e66720 100644
--- a/yt_dlp/utils.py
+++ b/yt_dlp/utils.py
@@ -3273,8 +3273,14 @@ def multipart_encode(data, boundary=None):
return out, content_type
-def variadic(x, allowed_types=(str, bytes, dict)):
- return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,)
+def is_iterable_like(x, allowed_types=collections.abc.Iterable, blocked_types=NO_DEFAULT):
+ if blocked_types is NO_DEFAULT:
+ blocked_types = (str, bytes, collections.abc.Mapping)
+ return isinstance(x, allowed_types) and not isinstance(x, blocked_types)
+
+
+def variadic(x, allowed_types=NO_DEFAULT):
+ return x if is_iterable_like(x, blocked_types=allowed_types) else (x,)
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
@@ -5467,7 +5473,7 @@ def traverse_obj(
obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
casesense=True, is_user_input=False, traverse_string=False):
"""
- Safely traverse nested `dict`s and `Sequence`s
+ Safely traverse nested `dict`s and `Iterable`s
>>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key"))
@@ -5475,7 +5481,7 @@ def traverse_obj(
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
- Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
+ Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@@ -5492,7 +5498,7 @@ def traverse_obj(
Read as: `[traverse_obj(obj, branch) for branch in branches]`.
- `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`.
- For `Sequence`s, `key` is the index of the value.
+ For `Iterable`s, `key` is the index of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict.
@@ -5540,7 +5546,9 @@ def traverse_obj(
result = None
if obj is None and traverse_string:
- pass
+ if key is ... or callable(key) or isinstance(key, slice):
+ branching = True
+ result = ()
elif key is None:
result = obj
@@ -5563,7 +5571,7 @@ def traverse_obj(
branching = True
if isinstance(obj, collections.abc.Mapping):
result = obj.values()
- elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
+ elif is_iterable_like(obj):
result = obj
elif isinstance(obj, re.Match):
result = obj.groups()
@@ -5577,7 +5585,7 @@ def traverse_obj(
branching = True
if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
- elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
+ elif is_iterable_like(obj):
iter_obj = enumerate(obj)
elif isinstance(obj, re.Match):
iter_obj = itertools.chain(
@@ -5601,7 +5609,7 @@ def traverse_obj(
} or None
elif isinstance(obj, collections.abc.Mapping):
- result = (obj.get(key) if casesense or (key in obj) else
+ result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, re.Match):
@@ -5613,7 +5621,7 @@ def traverse_obj(
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
elif isinstance(key, (int, slice)):
- if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)):
+ if is_iterable_like(obj, collections.abc.Sequence):
branching = isinstance(key, slice)
with contextlib.suppress(IndexError):
result = obj[key]