diff options
Diffstat (limited to 'test/test_utils.py')
-rw-r--r-- | test/test_utils.py | 468 |
1 files changed, 391 insertions, 77 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 039900c..acb913a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,89 +1,108 @@ #!/usr/bin/env python3 -# coding: utf-8 - -from __future__ import unicode_literals # Allow direct execution import os +import re import sys import unittest + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -# Various small unit tests +import contextlib import io import itertools import json import xml.etree.ElementTree +from hypervideo_dl.compat import ( + compat_etree_fromstring, + compat_HTMLParseError, + compat_os_name, +) from hypervideo_dl.utils import ( + Config, + DateRange, + ExtractorError, + InAdvancePagedList, + LazyList, + OnDemandPagedList, age_restricted, args_to_str, - encode_base_n, + base_url, caesar, clean_html, clean_podcast_url, - Config, + cli_bool_option, + cli_option, + cli_valueless_option, date_from_str, datetime_from_str, - DateRange, detect_exe_version, determine_ext, + determine_file_encoding, + dfxp2srt, dict_get, + encode_base_n, encode_compat_str, encodeFilename, escape_rfc3986, escape_url, + expand_path, extract_attributes, - ExtractorError, find_xpath_attr, fix_xml_ampersands, - format_bytes, float_or_none, - get_element_by_class, + format_bytes, + get_compatible_ext, get_element_by_attribute, - get_elements_by_class, - get_elements_by_attribute, - get_element_html_by_class, + get_element_by_class, get_element_html_by_attribute, - get_elements_html_by_class, + get_element_html_by_class, + get_element_text_and_html_by_tag, + get_elements_by_attribute, + get_elements_by_class, get_elements_html_by_attribute, + get_elements_html_by_class, get_elements_text_and_html_by_attribute, - get_element_text_and_html_by_tag, - InAdvancePagedList, int_or_none, intlist_to_bytes, + iri_to_uri, is_html, js_to_json, limit_length, + locked_file, + lowercase_escape, + match_str, merge_dicts, mimetype2ext, month_by_name, multipart_encode, ohdave_rsa_encrypt, - OnDemandPagedList, orderedSet, parse_age_limit, + parse_bitrate, + parse_codecs, + parse_count, + parse_dfxp_time_expr, parse_duration, parse_filesize, - parse_count, parse_iso8601, - parse_resolution, - parse_bitrate, parse_qs, + parse_resolution, pkcs1pad, + prepend_extension, read_batch_urls, + remove_end, + remove_quotes, + remove_start, + render_table, + replace_extension, + rot47, sanitize_filename, sanitize_path, sanitize_url, sanitized_Request, - expand_path, - prepend_extension, - replace_extension, - remove_start, - remove_end, - remove_quotes, - rot47, shell_quote, smuggle_url, str_to_int, @@ -91,42 +110,23 @@ from hypervideo_dl.utils import ( strip_or_none, subtitles_filename, timeconvert, + traverse_obj, unescapeHTML, unified_strdate, unified_timestamp, unsmuggle_url, + update_url_query, uppercase_escape, - lowercase_escape, url_basename, url_or_none, - base_url, - urljoin, urlencode_postdata, + urljoin, urshift, - update_url_query, version_tuple, - xpath_with_ns, + xpath_attr, xpath_element, xpath_text, - xpath_attr, - render_table, - match_str, - parse_dfxp_time_expr, - dfxp2srt, - cli_option, - cli_valueless_option, - cli_bool_option, - parse_codecs, - iri_to_uri, - LazyList, -) -from hypervideo_dl.compat import ( - compat_chr, - compat_etree_fromstring, - compat_getenv, - compat_HTMLParseError, - compat_os_name, - compat_setenv, + xpath_with_ns, ) @@ -142,13 +142,13 @@ class TestUtil(unittest.TestCase): self.assertEqual(sanitize_filename('123'), '123') - self.assertEqual('abc_de', sanitize_filename('abc/de')) + self.assertEqual('abc⧸de', sanitize_filename('abc/de')) self.assertFalse('/' in sanitize_filename('abc/de///')) - self.assertEqual('abc_de', sanitize_filename('abc/<>\\*|de')) - self.assertEqual('xxx', sanitize_filename('xxx/<>\\*|')) - self.assertEqual('yes no', sanitize_filename('yes? no')) - self.assertEqual('this - that', sanitize_filename('this: that')) + self.assertEqual('abc_de', sanitize_filename('abc/<>\\*|de', is_id=False)) + self.assertEqual('xxx', sanitize_filename('xxx/<>\\*|', is_id=False)) + self.assertEqual('yes no', sanitize_filename('yes? no', is_id=False)) + self.assertEqual('this - that', sanitize_filename('this: that', is_id=False)) self.assertEqual(sanitize_filename('AT&T'), 'AT&T') aumlaut = 'ä' @@ -265,15 +265,22 @@ class TestUtil(unittest.TestCase): def test_expand_path(self): def env(var): - return '%{0}%'.format(var) if sys.platform == 'win32' else '${0}'.format(var) + return f'%{var}%' if sys.platform == 'win32' else f'${var}' - compat_setenv('hypervideo_dl_EXPATH_PATH', 'expanded') + os.environ['hypervideo_dl_EXPATH_PATH'] = 'expanded' self.assertEqual(expand_path(env('hypervideo_dl_EXPATH_PATH')), 'expanded') - self.assertEqual(expand_path(env('HOME')), compat_getenv('HOME')) - self.assertEqual(expand_path('~'), compat_getenv('HOME')) - self.assertEqual( - expand_path('~/%s' % env('hypervideo_dl_EXPATH_PATH')), - '%s/expanded' % compat_getenv('HOME')) + + old_home = os.environ.get('HOME') + test_str = R'C:\Documents and Settings\тест\Application Data' + try: + os.environ['HOME'] = test_str + self.assertEqual(expand_path(env('HOME')), os.getenv('HOME')) + self.assertEqual(expand_path('~'), os.getenv('HOME')) + self.assertEqual( + expand_path('~/%s' % env('hypervideo_dl_EXPATH_PATH')), + '%s/expanded' % os.getenv('HOME')) + finally: + os.environ['HOME'] = old_home or '' def test_prepend_extension(self): self.assertEqual(prepend_extension('abc.ext', 'temp'), 'abc.temp.ext') @@ -364,6 +371,7 @@ class TestUtil(unittest.TestCase): self.assertEqual(unified_strdate('2012/10/11 01:56:38 +0000'), '20121011') self.assertEqual(unified_strdate('1968 12 10'), '19681210') self.assertEqual(unified_strdate('1968-12-10'), '19681210') + self.assertEqual(unified_strdate('31-07-2022 20:00'), '20220731') self.assertEqual(unified_strdate('28/01/2014 21:00:00 +0100'), '20140128') self.assertEqual( unified_strdate('11/26/2014 11:30:00 AM PST', day_first=False), @@ -407,6 +415,10 @@ class TestUtil(unittest.TestCase): self.assertEqual(unified_timestamp('December 15, 2017 at 7:49 am'), 1513324140) self.assertEqual(unified_timestamp('2018-03-14T08:32:43.1493874+00:00'), 1521016363) + self.assertEqual(unified_timestamp('December 31 1969 20:00:01 EDT'), 1) + self.assertEqual(unified_timestamp('Wednesday 31 December 1969 18:01:26 MDT'), 86) + self.assertEqual(unified_timestamp('12/31/1969 20:01:18 EDT', False), 78) + def test_determine_ext(self): self.assertEqual(determine_ext('http://example.com/foo/bar.mp4/?download'), 'mp4') self.assertEqual(determine_ext('http://example.com/foo/bar/?download', None), None) @@ -537,9 +549,6 @@ class TestUtil(unittest.TestCase): self.assertEqual(str_to_int('123,456'), 123456) self.assertEqual(str_to_int('123.456'), 123456) self.assertEqual(str_to_int(523), 523) - # Python 3 has no long - if sys.version_info < (3, 0): - eval('self.assertEqual(str_to_int(123456L), 123456)') self.assertEqual(str_to_int('noninteger'), None) self.assertEqual(str_to_int([]), None) @@ -559,6 +568,7 @@ class TestUtil(unittest.TestCase): self.assertEqual(base_url('http://foo.de/bar/'), 'http://foo.de/bar/') self.assertEqual(base_url('http://foo.de/bar/baz'), 'http://foo.de/bar/') self.assertEqual(base_url('http://foo.de/bar/baz?x=z/x/c'), 'http://foo.de/bar/') + self.assertEqual(base_url('http://foo.de/bar/baz&x=z&w=y/x/c'), 'http://foo.de/bar/baz&x=z&w=y/x/') def test_urljoin(self): self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt') @@ -668,8 +678,7 @@ class TestUtil(unittest.TestCase): def get_page(pagenum): firstid = pagenum * pagesize upto = min(size, pagenum * pagesize + pagesize) - for i in range(firstid, upto): - yield i + yield from range(firstid, upto) pl = OnDemandPagedList(get_page, pagesize) got = pl.getslice(*sliceargs) @@ -738,7 +747,7 @@ class TestUtil(unittest.TestCase): multipart_encode({b'field': b'value'}, boundary='AAAAAA')[0], b'--AAAAAA\r\nContent-Disposition: form-data; name="field"\r\n\r\nvalue\r\n--AAAAAA--\r\n') self.assertEqual( - multipart_encode({'欄位'.encode('utf-8'): '值'.encode('utf-8')}, boundary='AAAAAA')[0], + multipart_encode({'欄位'.encode(): '值'.encode()}, boundary='AAAAAA')[0], b'--AAAAAA\r\nContent-Disposition: form-data; name="\xe6\xac\x84\xe4\xbd\x8d"\r\n\r\n\xe5\x80\xbc\r\n--AAAAAA--\r\n') self.assertRaises( ValueError, multipart_encode, {b'field': b'value'}, boundary='value') @@ -896,7 +905,7 @@ class TestUtil(unittest.TestCase): 'dynamic_range': 'HDR10', }) self.assertEqual(parse_codecs('av01.0.12M.10.0.110.09.16.09.0'), { - 'vcodec': 'av01.0.12M.10', + 'vcodec': 'av01.0.12M.10.0.110.09.16.09.0', 'acodec': 'none', 'dynamic_range': 'HDR10', }) @@ -1091,6 +1100,12 @@ class TestUtil(unittest.TestCase): on = js_to_json('[1,//{},\n2]') self.assertEqual(json.loads(on), [1, 2]) + on = js_to_json(R'"\^\$\#"') + self.assertEqual(json.loads(on), R'^$#', msg='Unnecessary escapes should be stripped') + + on = js_to_json('\'"\\""\'') + self.assertEqual(json.loads(on), '"""', msg='Unnecessary quote escape should be escaped') + def test_js_to_json_malformed(self): self.assertEqual(js_to_json('42a1'), '42"a1"') self.assertEqual(js_to_json('42a-1'), '42"a"-1') @@ -1126,7 +1141,7 @@ class TestUtil(unittest.TestCase): self.assertEqual(extract_attributes('<e x="décomposé">'), {'x': 'décompose\u0301'}) # "Narrow" Python builds don't support unicode code points outside BMP. try: - compat_chr(0x10000) + chr(0x10000) supports_outside_bmp = True except ValueError: supports_outside_bmp = False @@ -1399,7 +1414,7 @@ ffmpeg version 2.4.4 Copyright (c) 2000-2014 the FFmpeg ...'''), '2.4.4') <p begin="3" dur="-1">Ignored, three</p> </div> </body> - </tt>'''.encode('utf-8') + </tt>'''.encode() srt_data = '''1 00:00:00,000 --> 00:00:01,000 The following line contains Chinese characters and special symbols @@ -1417,14 +1432,14 @@ Line ''' self.assertEqual(dfxp2srt(dfxp_data), srt_data) - dfxp_data_no_default_namespace = '''<?xml version="1.0" encoding="UTF-8"?> + dfxp_data_no_default_namespace = b'''<?xml version="1.0" encoding="UTF-8"?> <tt xml:lang="en" xmlns:tts="http://www.w3.org/ns/ttml#parameter"> <body> <div xml:lang="en"> <p begin="0" end="1">The first line</p> </div> </body> - </tt>'''.encode('utf-8') + </tt>''' srt_data = '''1 00:00:00,000 --> 00:00:01,000 The first line @@ -1432,7 +1447,7 @@ The first line ''' self.assertEqual(dfxp2srt(dfxp_data_no_default_namespace), srt_data) - dfxp_data_with_style = '''<?xml version="1.0" encoding="utf-8"?> + dfxp_data_with_style = b'''<?xml version="1.0" encoding="utf-8"?> <tt xmlns="http://www.w3.org/2006/10/ttaf1" xmlns:ttp="http://www.w3.org/2006/10/ttaf1#parameter" ttp:timeBase="media" xmlns:tts="http://www.w3.org/2006/10/ttaf1#style" xml:lang="en" xmlns:ttm="http://www.w3.org/2006/10/ttaf1#metadata"> <head> <styling> @@ -1450,7 +1465,7 @@ The first line <p style="s1" tts:textDecoration="underline" begin="00:00:09.56" id="p2" end="00:00:12.36"><span style="s2" tts:color="lime">inner<br /> </span>style</p> </div> </body> -</tt>'''.encode('utf-8') +</tt>''' srt_data = '''1 00:00:02,080 --> 00:00:05,840 <font color="white" face="sansSerif" size="16">default style<font color="red">custom style</font></font> @@ -1670,6 +1685,9 @@ Line 1 self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'foo', html)), []) self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'no-such-foo', html)), []) + self.assertEqual(list(get_elements_text_and_html_by_attribute( + 'class', 'foo', '<a class="foo">nice</a><span class="foo">nice</span>', tag='a')), [('nice', '<a class="foo">nice</a>')]) + GET_ELEMENT_BY_TAG_TEST_STRING = ''' random text lorem ipsum</p> <div> @@ -1757,7 +1775,7 @@ Line 1 def test(ll, idx, val, cache): self.assertEqual(ll[idx], val) - self.assertEqual(getattr(ll, '_LazyList__cache'), list(cache)) + self.assertEqual(ll._cache, list(cache)) ll = LazyList(range(10)) test(ll, 0, 0, range(1)) @@ -1795,6 +1813,302 @@ Line 1 self.assertEqual(Config.hide_login_info(['--username=foo']), ['--username=PRIVATE']) + def test_locked_file(self): + TEXT = 'test_locked_file\n' + FILE = 'test_locked_file.ytdl' + MODES = 'war' # Order is important + + try: + for lock_mode in MODES: + with locked_file(FILE, lock_mode, False) as f: + if lock_mode == 'r': + self.assertEqual(f.read(), TEXT * 2, 'Wrong file content') + else: + f.write(TEXT) + for test_mode in MODES: + testing_write = test_mode != 'r' + try: + with locked_file(FILE, test_mode, False): + pass + except (BlockingIOError, PermissionError): + if not testing_write: # FIXME + print(f'Known issue: Exclusive lock ({lock_mode}) blocks read access ({test_mode})') + continue + self.assertTrue(testing_write, f'{test_mode} is blocked by {lock_mode}') + else: + self.assertFalse(testing_write, f'{test_mode} is not blocked by {lock_mode}') + finally: + with contextlib.suppress(OSError): + os.remove(FILE) + + def test_determine_file_encoding(self): + self.assertEqual(determine_file_encoding(b''), (None, 0)) + self.assertEqual(determine_file_encoding(b'--verbose -x --audio-format mkv\n'), (None, 0)) + + self.assertEqual(determine_file_encoding(b'\xef\xbb\xbf'), ('utf-8', 3)) + self.assertEqual(determine_file_encoding(b'\x00\x00\xfe\xff'), ('utf-32-be', 4)) + self.assertEqual(determine_file_encoding(b'\xff\xfe'), ('utf-16-le', 2)) + + self.assertEqual(determine_file_encoding(b'\xff\xfe# coding: utf-8\n--verbose'), ('utf-16-le', 2)) + + self.assertEqual(determine_file_encoding(b'# coding: utf-8\n--verbose'), ('utf-8', 0)) + self.assertEqual(determine_file_encoding(b'# coding: someencodinghere-12345\n--verbose'), ('someencodinghere-12345', 0)) + + self.assertEqual(determine_file_encoding(b'#coding:utf-8\n--verbose'), ('utf-8', 0)) + self.assertEqual(determine_file_encoding(b'# coding: utf-8 \r\n--verbose'), ('utf-8', 0)) + + self.assertEqual(determine_file_encoding('# coding: utf-32-be'.encode('utf-32-be')), ('utf-32-be', 0)) + self.assertEqual(determine_file_encoding('# coding: utf-16-le'.encode('utf-16-le')), ('utf-16-le', 0)) + + def test_get_compatible_ext(self): + self.assertEqual(get_compatible_ext( + vcodecs=[None], acodecs=[None, None], vexts=['mp4'], aexts=['m4a', 'm4a']), 'mkv') + self.assertEqual(get_compatible_ext( + vcodecs=[None], acodecs=[None], vexts=['flv'], aexts=['flv']), 'flv') + + self.assertEqual(get_compatible_ext( + vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['m4a']), 'mp4') + self.assertEqual(get_compatible_ext( + vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['webm']), 'mkv') + self.assertEqual(get_compatible_ext( + vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['m4a']), 'mkv') + self.assertEqual(get_compatible_ext( + vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['webm']), 'webm') + + self.assertEqual(get_compatible_ext( + vcodecs=['h264'], acodecs=['mp4a'], vexts=['mov'], aexts=['m4a']), 'mp4') + self.assertEqual(get_compatible_ext( + vcodecs=['av01.0.12M.08'], acodecs=['opus'], vexts=['mp4'], aexts=['webm']), 'webm') + + self.assertEqual(get_compatible_ext( + vcodecs=['vp9'], acodecs=['opus'], vexts=['webm'], aexts=['webm'], preferences=['flv', 'mp4']), 'mp4') + self.assertEqual(get_compatible_ext( + vcodecs=['av1'], acodecs=['mp4a'], vexts=['webm'], aexts=['m4a'], preferences=('webm', 'mkv')), 'mkv') + + def test_traverse_obj(self): + _TEST_DATA = { + 100: 100, + 1.2: 1.2, + 'str': 'str', + 'None': None, + '...': ..., + 'urls': [ + {'index': 0, 'url': 'https://www.example.com/0'}, + {'index': 1, 'url': 'https://www.example.com/1'}, + ], + 'data': ( + {'index': 2}, + {'index': 3}, + ), + 'dict': {}, + } + + # Test base functionality + self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', + msg='allow tuple path') + self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str', + msg='allow list path') + self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str', + msg='allow iterable path') + self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str', + msg='single items should be treated as a path') + self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA) + self.assertEqual(traverse_obj(_TEST_DATA, 100), 100) + self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2) + + # Test Ellipsis behavior + self.assertCountEqual(traverse_obj(_TEST_DATA, ...), + (item for item in _TEST_DATA.values() if item is not None), + msg='`...` should give all values except `None`') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(), + msg='`...` selection for dicts should select all values') + self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='nested `...` queries should work') + self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), + msg='`...` query result should be flattened') + + # Test function as key + self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), + [_TEST_DATA['urls']], + msg='function as query key should perform a filter based on (key, value)') + self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, + msg='exceptions in the query function should be catched') + + # Test alternative paths + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', + msg='multiple `paths` should be treated as alternative paths') + self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str', + msg='alternatives should exit early') + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None, + msg='alternatives should return `default` if exhausted') + self.assertEqual(traverse_obj(_TEST_DATA, (..., 'fail'), 100), 100, + msg='alternatives should track their own branching return') + self.assertEqual(traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)), list(_TEST_DATA['data']), + msg='alternatives on empty objects should search further') + + # Test branch and path nesting + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'], + msg='tuple as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'], + msg='list as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'], + msg='double nesting in path should be treated as paths') + self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1], + msg='do not fail early on branching') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='tripple nesting in path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='ellipsis as branch path start gets flattened') + + # Test dictionary as key + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2}, + msg='dict key should result in a dict with the same keys') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}), + {0: 'https://www.example.com/0'}, + msg='dict key should allow paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}), + {0: ['https://www.example.com/0']}, + msg='tuple in dict path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}), + {0: ['https://www.example.com/0']}, + msg='double nesting in dict path should be treated as paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}), + {0: ['https://www.example.com/1', 'https://www.example.com/0']}, + msg='tripple nesting in dict path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, + msg='remove `None` values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...}, + msg='do not remove `None` values if `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, + msg='do not remove empty values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}}, + msg='do not remove empty values when dict key and a default') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []}, + msg='if branch in dict key not successful, return `[]`') + + # Testing default parameter behavior + _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None, + msg='default value should be `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...), ..., + msg='chained fails should result in default') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0, + msg='should not short cirquit on `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1, + msg='invalid dict key should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1, + msg='`None` is a deliberate sentinel and should become `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None, + msg='`IndexError` should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1, + msg='if branched but not successful return `default` if defined, not `[]`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None), None, + msg='if branched but not successful return `default` even if `default` is `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail')), [], + msg='if branched but not successful return `[]`, not `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [], + msg='if branched but object is empty return `[]`, not `default`') + + # Testing expected_type behavior + _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str', + msg='accept matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, + msg='reject non matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0', + msg='transform type using type function') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', + expected_type=lambda _: 1 / 0), None, + msg='wrap expected_type fuction in try_call') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], + msg='eliminate items that expected_type fails on') + + # Test get_all behavior + _GET_ALL_DATA = {'key': [0, 1, 2]} + self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False), 0, + msg='if not `get_all`, return only first matching value') + self.assertEqual(traverse_obj(_GET_ALL_DATA, ..., get_all=False), [0, 1, 2], + msg='do not overflatten if not `get_all`') + + # Test casesense behavior + _CASESENSE_DATA = { + 'KeY': 'value0', + 0: { + 'KeY': 'value1', + 0: {'KeY': 'value2'}, + }, + } + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None, + msg='dict keys should be case sensitive unless `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY', + casesense=False), 'value0', + msg='allow non matching key case if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)), + casesense=False), ['value1'], + msg='allow non matching key case in branch if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)), + casesense=False), ['value2'], + msg='allow non matching key case in branch path if `casesense`') + + # Test traverse_string behavior + _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2} + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None, + msg='do not traverse into string if not `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0), + traverse_string=True), 's', + msg='traverse into string if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1), + traverse_string=True), '.', + msg='traverse into converted data if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...), + traverse_string=True), list('str'), + msg='`...` branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), + traverse_string=True), ['s', 'r'], + msg='branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), + traverse_string=True), list('str'), + msg='function branching into string should result in list') + + # Test is_user_input behavior + _IS_USER_INPUT_DATA = {'range8': list(range(8))} + self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), + is_user_input=True), 3, + msg='allow for string indexing if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), + is_user_input=True), tuple(range(8))[3:], + msg='allow for string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), + is_user_input=True), tuple(range(8))[:4:2], + msg='allow step in string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), + is_user_input=True), range(8), + msg='`:` should be treated as `...` if `is_user_input`') + with self.assertRaises(TypeError, msg='too many params should result in error'): + traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True) + + # Test re.Match as input obj + mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123') + self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None], + msg='`...` on a `re.Match` should give its `groups()`') + self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'], + msg='function on a `re.Match` should give groupno, value starting at 0') + self.assertEqual(traverse_obj(mobj, 'group'), '3', + msg='str key on a `re.Match` should give group with that name') + self.assertEqual(traverse_obj(mobj, 2), '3', + msg='int key on a `re.Match` should give group with that name') + self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3', + msg='str key on a `re.Match` should respect casesense') + self.assertEqual(traverse_obj(mobj, 'fail'), None, + msg='failing str key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None, + msg='failing str key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, 8), None, + msg='failing int key on a `re.Match` should return `default`') + if __name__ == '__main__': unittest.main() |