diff options
| -rw-r--r-- | test/test_utils.py | 40 | ||||
| -rw-r--r-- | yt_dlp/utils.py | 58 | 
2 files changed, 88 insertions, 10 deletions
| diff --git a/test/test_utils.py b/test/test_utils.py index 3d5a6ea6b..ffe1b729f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -105,6 +105,7 @@ from yt_dlp.utils import (      sanitized_Request,      shell_quote,      smuggle_url, +    str_or_none,      str_to_int,      strip_jsonp,      strip_or_none, @@ -2015,6 +2016,29 @@ Line 1                           msg='function as query key should perform a filter based on (key, value)')          self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},                                msg='exceptions in the query function should be catched') +        if __debug__: +            with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): +                traverse_obj(_TEST_DATA, lambda a: ...) +            with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): +                traverse_obj(_TEST_DATA, lambda a, b, c: ...) + +        # Test set as key (transformation/type, like `expected_type`) +        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper}, )), ['STR'], +                         msg='Function in set should be a transformation') +        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str})), ['str'], +                         msg='Type in set should be a type filter') +        self.assertEqual(traverse_obj(_TEST_DATA, {dict}), _TEST_DATA, +                         msg='A single set should be wrapped into a path') +        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper})), ['STR'], +                         msg='Transformation function should not raise') +        self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})), +                         [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], +                         msg='Function in set should be a transformation') +        if __debug__: +            with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): +                traverse_obj(_TEST_DATA, set()) +            with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): +                traverse_obj(_TEST_DATA, {str.upper, str})          # Test alternative paths          self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', @@ -2106,6 +2130,20 @@ Line 1                           msg='wrap expected_type fuction in try_call')          self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'],                           msg='eliminate items that expected_type fails on') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), {0: 100}, +                         msg='type as expected_type should filter dict values') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'}, +                         msg='function as expected_type should transform dict values') +        self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), 1, +                         msg='expected_type should not filter non final dict values') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}}, +                         msg='expected_type should transform deep dict values') +        self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), [{0: ...}, {0: ...}], +                         msg='expected_type should transform branched dict values') +        self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), [4], +                         msg='expected_type regression for type matching in tuple branching') +        self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), [], +                         msg='expected_type regression for type matching in dict result')          # Test get_all behavior          _GET_ALL_DATA = {'key': [0, 1, 2]} @@ -2189,6 +2227,8 @@ Line 1                           msg='failing str key on a `re.Match` should return `default`')          self.assertEqual(traverse_obj(mobj, 8), None,                           msg='failing int key on a `re.Match` should return `default`') +        self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], +                         msg='function on a `re.Match` should give group name as well')  if __name__ == '__main__': diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 7d51fe472..55e1c4415 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5424,6 +5424,9 @@ def traverse_obj(      The keys in the path can be one of:          - `None`:           Return the current object. +        - `set`:            Requires the only item in the set to be a type or function, +                            like `{type}`/`{func}`. If a `type`, returns only values +                            of this type. If a function, returns `func(obj)`.          - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`.          - `slice`:          Branch out and return all values in `obj[key]`.          - `Ellipsis`:       Branch out and return a list of all values. @@ -5432,6 +5435,8 @@ def traverse_obj(          - `function`:       Branch out and return values filtered by the function.                              Read as: `[value for key, value in obj if function(key, value)]`.                              For `Sequence`s, `key` is the index of the value. +                            For `re.Match`es, `key` is the group number (0 = full match) +                            as well as additionally any group names, if given.          - `dict`            Transform the current object and return a matching dict.                              Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. @@ -5441,6 +5446,8 @@ def traverse_obj(      @param default          Value to return if the paths do not match.      @param expected_type    If a `type`, only accept final values of this type.                              If any other callable, try to call the function on each result. +                            If the last key in the path is a `dict`, it will apply to each value inside +                            the dict instead, recursively. This does respect branching paths.      @param get_all          If `False`, return the first matching result, otherwise all matching ones.      @param casesense        If `False`, consider string dictionary keys as case insensitive. @@ -5466,16 +5473,25 @@ def traverse_obj(      else:          type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) -    def apply_key(key, obj): +    def apply_key(key, test_type, obj):          if obj is None:              return          elif key is None:              yield obj +        elif isinstance(key, set): +            assert len(key) == 1, 'Set should only be used to wrap a single item' +            item = next(iter(key)) +            if isinstance(item, type): +                if isinstance(obj, item): +                    yield obj +            else: +                yield try_call(item, args=(obj,)) +          elif isinstance(key, (list, tuple)):              for branch in key: -                _, result = apply_path(obj, branch) +                _, result = apply_path(obj, branch, test_type)                  yield from result          elif key is ...: @@ -5494,7 +5510,9 @@ def traverse_obj(              elif isinstance(obj, collections.abc.Mapping):                  iter_obj = obj.items()              elif isinstance(obj, re.Match): -                iter_obj = enumerate((obj.group(), *obj.groups())) +                iter_obj = itertools.chain( +                    enumerate((obj.group(), *obj.groups())), +                    obj.groupdict().items())              elif traverse_string:                  iter_obj = enumerate(str(obj))              else: @@ -5502,7 +5520,7 @@ def traverse_obj(              yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))          elif isinstance(key, dict): -            iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) +            iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items())              yield {k: v if v is not None else default for k, v in iter_obj                     if v is not None or default is not NO_DEFAULT} @@ -5537,11 +5555,24 @@ def traverse_obj(              with contextlib.suppress(IndexError):                  yield obj[key] -    def apply_path(start_obj, path): +    def lazy_last(iterable): +        iterator = iter(iterable) +        prev = next(iterator, NO_DEFAULT) +        if prev is NO_DEFAULT: +            return + +        for item in iterator: +            yield False, prev +            prev = item + +        yield True, prev + +    def apply_path(start_obj, path, test_type=False):          objs = (start_obj,)          has_branched = False -        for key in variadic(path): +        key = None +        for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):              if is_user_input and key == ':':                  key = ... @@ -5551,14 +5582,21 @@ def traverse_obj(              if key is ... or isinstance(key, (list, tuple)) or callable(key):                  has_branched = True -            key_func = functools.partial(apply_key, key) +            if __debug__ and callable(key): +                # Verify function signature +                inspect.signature(key).bind(None, None) + +            key_func = functools.partial(apply_key, key, last)              objs = itertools.chain.from_iterable(map(key_func, objs)) +        if test_type and not isinstance(key, (dict, list, tuple)): +            objs = map(type_test, objs) +          return has_branched, objs -    def _traverse_obj(obj, path, use_list=True): -        has_branched, results = apply_path(obj, path) -        results = LazyList(x for x in map(type_test, results) if x is not None) +    def _traverse_obj(obj, path, use_list=True, test_type=True): +        has_branched, results = apply_path(obj, path, test_type) +        results = LazyList(x for x in results if x is not None)          if get_all and has_branched:              return results.exhaust() if results or use_list else None | 
