diff options
author | Simon Sawicki <37424085+Grub4K@users.noreply.github.com> | 2022-09-25 23:03:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-26 02:33:19 +0530 |
commit | ab029d7e9200a273d7204be68c0735b16971ff44 (patch) | |
tree | bcbe5a128b7bb86dace4970a841b7ee25cbdd8cf /test | |
parent | 0bd5a039ea234374821510ac0371e03e87a6a57f (diff) | |
download | hypervideo-pre-ab029d7e9200a273d7204be68c0735b16971ff44.tar.lz hypervideo-pre-ab029d7e9200a273d7204be68c0735b16971ff44.tar.xz hypervideo-pre-ab029d7e9200a273d7204be68c0735b16971ff44.zip |
[utils] `traverse_obj`: Rewrite, document and add tests (#5024)
Authored by: Grub4K
Diffstat (limited to 'test')
-rw-r--r-- | test/test_utils.py | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 96477c53f..69313564a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -109,6 +109,7 @@ from yt_dlp.utils import ( strip_or_none, subtitles_filename, timeconvert, + traverse_obj, unescapeHTML, unified_strdate, unified_timestamp, @@ -1874,6 +1875,192 @@ Line 1 self.assertEqual(get_compatible_ext( vcodecs=['av1'], acodecs=['mp4a'], vexts=['webm'], aexts=['m4a'], preferences=('webm', 'mkv')), 'mkv') + def test_traverse_obj(self): + _TEST_DATA = { + 100: 100, + 1.2: 1.2, + 'str': 'str', + 'None': None, + '...': ..., + 'urls': [ + {'index': 0, 'url': 'https://www.example.com/0'}, + {'index': 1, 'url': 'https://www.example.com/1'}, + ], + 'data': ( + {'index': 2}, + {'index': 3}, + ), + } + + # Test base functionality + self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', + msg='allow tuple path') + self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str', + msg='allow list path') + self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str', + msg='allow iterable path') + self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str', + msg='single items should be treated as a path') + self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA) + self.assertEqual(traverse_obj(_TEST_DATA, 100), 100) + self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2) + + # Test Ellipsis behavior + self.assertCountEqual(traverse_obj(_TEST_DATA, ...), + (item for item in _TEST_DATA.values() if item is not None), + msg='`...` should give all values except `None`') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(), + msg='`...` selection for dicts should select all values') + self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='nested `...` queries should work') + self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), + msg='`...` query result should be flattened') + + # Test function as key + self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), + [_TEST_DATA['urls']], + msg='function as query key should perform a filter based on (key, value)') + self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, + msg='exceptions in the query function should be catched') + + # Test alternative paths + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', + msg='multiple `path_list` should be treated as alternative paths') + self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str', + msg='alternatives should exit early') + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None, + msg='alternatives should return `default` if exhausted') + + # Test branch and path nesting + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'], + msg='tuple as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'], + msg='list as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'], + msg='double nesting in path should be treated as paths') + self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1], + msg='do not fail early on branching') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='tripple nesting in path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='ellipsis as branch path start gets flattened') + + # Test dictionary as key + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2}, + msg='dict key should result in a dict with the same keys') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}), + {0: 'https://www.example.com/0'}, + msg='dict key should allow paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}), + {0: ['https://www.example.com/0']}, + msg='tuple in dict path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}), + {0: ['https://www.example.com/0']}, + msg='double nesting in dict path should be treated as paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}), + {0: ['https://www.example.com/1', 'https://www.example.com/0']}, + msg='tripple nesting in dict path should be treated as branches') + self.assertEqual(traverse_obj({}, {0: 1}, default=...), {0: ...}, + msg='do not remove `None` values when dict key') + + # Testing default parameter behavior + _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None, + msg='default value should be `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...), ..., + msg='chained fails should result in default') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0, + msg='should not short cirquit on `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1, + msg='invalid dict key should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1, + msg='`None` is a deliberate sentinel and should become `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None, + msg='`IndexError` should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1, + msg='if branched but not successfull return `default`, not `[]`') + + # Testing expected_type behavior + _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str', + msg='accept matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, + msg='reject non matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0', + msg='transform type using type function') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', + expected_type=lambda _: 1 / 0), None, + msg='wrap expected_type fuction in try_call') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], + msg='eliminate items that expected_type fails on') + + # Test get_all behavior + _GET_ALL_DATA = {'key': [0, 1, 2]} + self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False), 0, + msg='if not `get_all`, return only first matching value') + self.assertEqual(traverse_obj(_GET_ALL_DATA, ..., get_all=False), [0, 1, 2], + msg='do not overflatten if not `get_all`') + + # Test casesense behavior + _CASESENSE_DATA = { + 'KeY': 'value0', + 0: { + 'KeY': 'value1', + 0: {'KeY': 'value2'}, + }, + } + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None, + msg='dict keys should be case sensitive unless `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY', + casesense=False), 'value0', + msg='allow non matching key case if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)), + casesense=False), ['value1'], + msg='allow non matching key case in branch if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)), + casesense=False), ['value2'], + msg='allow non matching key case in branch path if `casesense`') + + # Test traverse_string behavior + _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2} + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None, + msg='do not traverse into string if not `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0), + traverse_string=True), 's', + msg='traverse into string if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1), + traverse_string=True), '.', + msg='traverse into converted data if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...), + traverse_string=True), list('str'), + msg='`...` branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), + traverse_string=True), ['s', 'r'], + msg='branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), + traverse_string=True), list('str'), + msg='function branching into string should result in list') + + # Test is_user_input behavior + _IS_USER_INPUT_DATA = {'range8': list(range(8))} + self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), + is_user_input=True), 3, + msg='allow for string indexing if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), + is_user_input=True), tuple(range(8))[3:], + msg='allow for string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), + is_user_input=True), tuple(range(8))[:4:2], + msg='allow step in string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), + is_user_input=True), range(8), + msg='`:` should be treated as `...` if `is_user_input`') + with self.assertRaises(TypeError, msg='too many params should result in error'): + traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True) + if __name__ == '__main__': unittest.main() |