diff options
Diffstat (limited to 'yt_dlp/utils.py')
-rw-r--r-- | yt_dlp/utils.py | 58 |
1 files changed, 48 insertions, 10 deletions
diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 7d51fe472..55e1c4415 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5424,6 +5424,9 @@ def traverse_obj( The keys in the path can be one of: - `None`: Return the current object. + - `set`: Requires the only item in the set to be a type or function, + like `{type}`/`{func}`. If a `type`, returns only values + of this type. If a function, returns `func(obj)`. - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. @@ -5432,6 +5435,8 @@ def traverse_obj( - `function`: Branch out and return values filtered by the function. Read as: `[value for key, value in obj if function(key, value)]`. For `Sequence`s, `key` is the index of the value. + For `re.Match`es, `key` is the group number (0 = full match) + as well as additionally any group names, if given. - `dict` Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. @@ -5441,6 +5446,8 @@ def traverse_obj( @param default Value to return if the paths do not match. @param expected_type If a `type`, only accept final values of this type. If any other callable, try to call the function on each result. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, recursively. This does respect branching paths. @param get_all If `False`, return the first matching result, otherwise all matching ones. @param casesense If `False`, consider string dictionary keys as case insensitive. @@ -5466,16 +5473,25 @@ def traverse_obj( else: type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) - def apply_key(key, obj): + def apply_key(key, test_type, obj): if obj is None: return elif key is None: yield obj + elif isinstance(key, set): + assert len(key) == 1, 'Set should only be used to wrap a single item' + item = next(iter(key)) + if isinstance(item, type): + if isinstance(obj, item): + yield obj + else: + yield try_call(item, args=(obj,)) + elif isinstance(key, (list, tuple)): for branch in key: - _, result = apply_path(obj, branch) + _, result = apply_path(obj, branch, test_type) yield from result elif key is ...: @@ -5494,7 +5510,9 @@ def traverse_obj( elif isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() elif isinstance(obj, re.Match): - iter_obj = enumerate((obj.group(), *obj.groups())) + iter_obj = itertools.chain( + enumerate((obj.group(), *obj.groups())), + obj.groupdict().items()) elif traverse_string: iter_obj = enumerate(str(obj)) else: @@ -5502,7 +5520,7 @@ def traverse_obj( yield from (v for k, v in iter_obj if try_call(key, args=(k, v))) elif isinstance(key, dict): - iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) + iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items()) yield {k: v if v is not None else default for k, v in iter_obj if v is not None or default is not NO_DEFAULT} @@ -5537,11 +5555,24 @@ def traverse_obj( with contextlib.suppress(IndexError): yield obj[key] - def apply_path(start_obj, path): + def lazy_last(iterable): + iterator = iter(iterable) + prev = next(iterator, NO_DEFAULT) + if prev is NO_DEFAULT: + return + + for item in iterator: + yield False, prev + prev = item + + yield True, prev + + def apply_path(start_obj, path, test_type=False): objs = (start_obj,) has_branched = False - for key in variadic(path): + key = None + for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): if is_user_input and key == ':': key = ... @@ -5551,14 +5582,21 @@ def traverse_obj( if key is ... or isinstance(key, (list, tuple)) or callable(key): has_branched = True - key_func = functools.partial(apply_key, key) + if __debug__ and callable(key): + # Verify function signature + inspect.signature(key).bind(None, None) + + key_func = functools.partial(apply_key, key, last) objs = itertools.chain.from_iterable(map(key_func, objs)) + if test_type and not isinstance(key, (dict, list, tuple)): + objs = map(type_test, objs) + return has_branched, objs - def _traverse_obj(obj, path, use_list=True): - has_branched, results = apply_path(obj, path) - results = LazyList(x for x in map(type_test, results) if x is not None) + def _traverse_obj(obj, path, use_list=True, test_type=True): + has_branched, results = apply_path(obj, path, test_type) + results = LazyList(x for x in results if x is not None) if get_all and has_branched: return results.exhaust() if results or use_list else None |