aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py468
1 files changed, 391 insertions, 77 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 039900c..acb913a 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,89 +1,108 @@
#!/usr/bin/env python3
-# coding: utf-8
-
-from __future__ import unicode_literals
# Allow direct execution
import os
+import re
import sys
import unittest
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-# Various small unit tests
+import contextlib
import io
import itertools
import json
import xml.etree.ElementTree
+from hypervideo_dl.compat import (
+ compat_etree_fromstring,
+ compat_HTMLParseError,
+ compat_os_name,
+)
from hypervideo_dl.utils import (
+ Config,
+ DateRange,
+ ExtractorError,
+ InAdvancePagedList,
+ LazyList,
+ OnDemandPagedList,
age_restricted,
args_to_str,
- encode_base_n,
+ base_url,
caesar,
clean_html,
clean_podcast_url,
- Config,
+ cli_bool_option,
+ cli_option,
+ cli_valueless_option,
date_from_str,
datetime_from_str,
- DateRange,
detect_exe_version,
determine_ext,
+ determine_file_encoding,
+ dfxp2srt,
dict_get,
+ encode_base_n,
encode_compat_str,
encodeFilename,
escape_rfc3986,
escape_url,
+ expand_path,
extract_attributes,
- ExtractorError,
find_xpath_attr,
fix_xml_ampersands,
- format_bytes,
float_or_none,
- get_element_by_class,
+ format_bytes,
+ get_compatible_ext,
get_element_by_attribute,
- get_elements_by_class,
- get_elements_by_attribute,
- get_element_html_by_class,
+ get_element_by_class,
get_element_html_by_attribute,
- get_elements_html_by_class,
+ get_element_html_by_class,
+ get_element_text_and_html_by_tag,
+ get_elements_by_attribute,
+ get_elements_by_class,
get_elements_html_by_attribute,
+ get_elements_html_by_class,
get_elements_text_and_html_by_attribute,
- get_element_text_and_html_by_tag,
- InAdvancePagedList,
int_or_none,
intlist_to_bytes,
+ iri_to_uri,
is_html,
js_to_json,
limit_length,
+ locked_file,
+ lowercase_escape,
+ match_str,
merge_dicts,
mimetype2ext,
month_by_name,
multipart_encode,
ohdave_rsa_encrypt,
- OnDemandPagedList,
orderedSet,
parse_age_limit,
+ parse_bitrate,
+ parse_codecs,
+ parse_count,
+ parse_dfxp_time_expr,
parse_duration,
parse_filesize,
- parse_count,
parse_iso8601,
- parse_resolution,
- parse_bitrate,
parse_qs,
+ parse_resolution,
pkcs1pad,
+ prepend_extension,
read_batch_urls,
+ remove_end,
+ remove_quotes,
+ remove_start,
+ render_table,
+ replace_extension,
+ rot47,
sanitize_filename,
sanitize_path,
sanitize_url,
sanitized_Request,
- expand_path,
- prepend_extension,
- replace_extension,
- remove_start,
- remove_end,
- remove_quotes,
- rot47,
shell_quote,
smuggle_url,
str_to_int,
@@ -91,42 +110,23 @@ from hypervideo_dl.utils import (
strip_or_none,
subtitles_filename,
timeconvert,
+ traverse_obj,
unescapeHTML,
unified_strdate,
unified_timestamp,
unsmuggle_url,
+ update_url_query,
uppercase_escape,
- lowercase_escape,
url_basename,
url_or_none,
- base_url,
- urljoin,
urlencode_postdata,
+ urljoin,
urshift,
- update_url_query,
version_tuple,
- xpath_with_ns,
+ xpath_attr,
xpath_element,
xpath_text,
- xpath_attr,
- render_table,
- match_str,
- parse_dfxp_time_expr,
- dfxp2srt,
- cli_option,
- cli_valueless_option,
- cli_bool_option,
- parse_codecs,
- iri_to_uri,
- LazyList,
-)
-from hypervideo_dl.compat import (
- compat_chr,
- compat_etree_fromstring,
- compat_getenv,
- compat_HTMLParseError,
- compat_os_name,
- compat_setenv,
+ xpath_with_ns,
)
@@ -142,13 +142,13 @@ class TestUtil(unittest.TestCase):
self.assertEqual(sanitize_filename('123'), '123')
- self.assertEqual('abc_de', sanitize_filename('abc/de'))
+ self.assertEqual('abc⧸de', sanitize_filename('abc/de'))
self.assertFalse('/' in sanitize_filename('abc/de///'))
- self.assertEqual('abc_de', sanitize_filename('abc/<>\\*|de'))
- self.assertEqual('xxx', sanitize_filename('xxx/<>\\*|'))
- self.assertEqual('yes no', sanitize_filename('yes? no'))
- self.assertEqual('this - that', sanitize_filename('this: that'))
+ self.assertEqual('abc_de', sanitize_filename('abc/<>\\*|de', is_id=False))
+ self.assertEqual('xxx', sanitize_filename('xxx/<>\\*|', is_id=False))
+ self.assertEqual('yes no', sanitize_filename('yes? no', is_id=False))
+ self.assertEqual('this - that', sanitize_filename('this: that', is_id=False))
self.assertEqual(sanitize_filename('AT&T'), 'AT&T')
aumlaut = 'ä'
@@ -265,15 +265,22 @@ class TestUtil(unittest.TestCase):
def test_expand_path(self):
def env(var):
- return '%{0}%'.format(var) if sys.platform == 'win32' else '${0}'.format(var)
+ return f'%{var}%' if sys.platform == 'win32' else f'${var}'
- compat_setenv('hypervideo_dl_EXPATH_PATH', 'expanded')
+ os.environ['hypervideo_dl_EXPATH_PATH'] = 'expanded'
self.assertEqual(expand_path(env('hypervideo_dl_EXPATH_PATH')), 'expanded')
- self.assertEqual(expand_path(env('HOME')), compat_getenv('HOME'))
- self.assertEqual(expand_path('~'), compat_getenv('HOME'))
- self.assertEqual(
- expand_path('~/%s' % env('hypervideo_dl_EXPATH_PATH')),
- '%s/expanded' % compat_getenv('HOME'))
+
+ old_home = os.environ.get('HOME')
+ test_str = R'C:\Documents and Settings\тест\Application Data'
+ try:
+ os.environ['HOME'] = test_str
+ self.assertEqual(expand_path(env('HOME')), os.getenv('HOME'))
+ self.assertEqual(expand_path('~'), os.getenv('HOME'))
+ self.assertEqual(
+ expand_path('~/%s' % env('hypervideo_dl_EXPATH_PATH')),
+ '%s/expanded' % os.getenv('HOME'))
+ finally:
+ os.environ['HOME'] = old_home or ''
def test_prepend_extension(self):
self.assertEqual(prepend_extension('abc.ext', 'temp'), 'abc.temp.ext')
@@ -364,6 +371,7 @@ class TestUtil(unittest.TestCase):
self.assertEqual(unified_strdate('2012/10/11 01:56:38 +0000'), '20121011')
self.assertEqual(unified_strdate('1968 12 10'), '19681210')
self.assertEqual(unified_strdate('1968-12-10'), '19681210')
+ self.assertEqual(unified_strdate('31-07-2022 20:00'), '20220731')
self.assertEqual(unified_strdate('28/01/2014 21:00:00 +0100'), '20140128')
self.assertEqual(
unified_strdate('11/26/2014 11:30:00 AM PST', day_first=False),
@@ -407,6 +415,10 @@ class TestUtil(unittest.TestCase):
self.assertEqual(unified_timestamp('December 15, 2017 at 7:49 am'), 1513324140)
self.assertEqual(unified_timestamp('2018-03-14T08:32:43.1493874+00:00'), 1521016363)
+ self.assertEqual(unified_timestamp('December 31 1969 20:00:01 EDT'), 1)
+ self.assertEqual(unified_timestamp('Wednesday 31 December 1969 18:01:26 MDT'), 86)
+ self.assertEqual(unified_timestamp('12/31/1969 20:01:18 EDT', False), 78)
+
def test_determine_ext(self):
self.assertEqual(determine_ext('http://example.com/foo/bar.mp4/?download'), 'mp4')
self.assertEqual(determine_ext('http://example.com/foo/bar/?download', None), None)
@@ -537,9 +549,6 @@ class TestUtil(unittest.TestCase):
self.assertEqual(str_to_int('123,456'), 123456)
self.assertEqual(str_to_int('123.456'), 123456)
self.assertEqual(str_to_int(523), 523)
- # Python 3 has no long
- if sys.version_info < (3, 0):
- eval('self.assertEqual(str_to_int(123456L), 123456)')
self.assertEqual(str_to_int('noninteger'), None)
self.assertEqual(str_to_int([]), None)
@@ -559,6 +568,7 @@ class TestUtil(unittest.TestCase):
self.assertEqual(base_url('http://foo.de/bar/'), 'http://foo.de/bar/')
self.assertEqual(base_url('http://foo.de/bar/baz'), 'http://foo.de/bar/')
self.assertEqual(base_url('http://foo.de/bar/baz?x=z/x/c'), 'http://foo.de/bar/')
+ self.assertEqual(base_url('http://foo.de/bar/baz&x=z&w=y/x/c'), 'http://foo.de/bar/baz&x=z&w=y/x/')
def test_urljoin(self):
self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
@@ -668,8 +678,7 @@ class TestUtil(unittest.TestCase):
def get_page(pagenum):
firstid = pagenum * pagesize
upto = min(size, pagenum * pagesize + pagesize)
- for i in range(firstid, upto):
- yield i
+ yield from range(firstid, upto)
pl = OnDemandPagedList(get_page, pagesize)
got = pl.getslice(*sliceargs)
@@ -738,7 +747,7 @@ class TestUtil(unittest.TestCase):
multipart_encode({b'field': b'value'}, boundary='AAAAAA')[0],
b'--AAAAAA\r\nContent-Disposition: form-data; name="field"\r\n\r\nvalue\r\n--AAAAAA--\r\n')
self.assertEqual(
- multipart_encode({'欄位'.encode('utf-8'): '值'.encode('utf-8')}, boundary='AAAAAA')[0],
+ multipart_encode({'欄位'.encode(): '值'.encode()}, boundary='AAAAAA')[0],
b'--AAAAAA\r\nContent-Disposition: form-data; name="\xe6\xac\x84\xe4\xbd\x8d"\r\n\r\n\xe5\x80\xbc\r\n--AAAAAA--\r\n')
self.assertRaises(
ValueError, multipart_encode, {b'field': b'value'}, boundary='value')
@@ -896,7 +905,7 @@ class TestUtil(unittest.TestCase):
'dynamic_range': 'HDR10',
})
self.assertEqual(parse_codecs('av01.0.12M.10.0.110.09.16.09.0'), {
- 'vcodec': 'av01.0.12M.10',
+ 'vcodec': 'av01.0.12M.10.0.110.09.16.09.0',
'acodec': 'none',
'dynamic_range': 'HDR10',
})
@@ -1091,6 +1100,12 @@ class TestUtil(unittest.TestCase):
on = js_to_json('[1,//{},\n2]')
self.assertEqual(json.loads(on), [1, 2])
+ on = js_to_json(R'"\^\$\#"')
+ self.assertEqual(json.loads(on), R'^$#', msg='Unnecessary escapes should be stripped')
+
+ on = js_to_json('\'"\\""\'')
+ self.assertEqual(json.loads(on), '"""', msg='Unnecessary quote escape should be escaped')
+
def test_js_to_json_malformed(self):
self.assertEqual(js_to_json('42a1'), '42"a1"')
self.assertEqual(js_to_json('42a-1'), '42"a"-1')
@@ -1126,7 +1141,7 @@ class TestUtil(unittest.TestCase):
self.assertEqual(extract_attributes('<e x="décompose&#769;">'), {'x': 'décompose\u0301'})
# "Narrow" Python builds don't support unicode code points outside BMP.
try:
- compat_chr(0x10000)
+ chr(0x10000)
supports_outside_bmp = True
except ValueError:
supports_outside_bmp = False
@@ -1399,7 +1414,7 @@ ffmpeg version 2.4.4 Copyright (c) 2000-2014 the FFmpeg ...'''), '2.4.4')
<p begin="3" dur="-1">Ignored, three</p>
</div>
</body>
- </tt>'''.encode('utf-8')
+ </tt>'''.encode()
srt_data = '''1
00:00:00,000 --> 00:00:01,000
The following line contains Chinese characters and special symbols
@@ -1417,14 +1432,14 @@ Line
'''
self.assertEqual(dfxp2srt(dfxp_data), srt_data)
- dfxp_data_no_default_namespace = '''<?xml version="1.0" encoding="UTF-8"?>
+ dfxp_data_no_default_namespace = b'''<?xml version="1.0" encoding="UTF-8"?>
<tt xml:lang="en" xmlns:tts="http://www.w3.org/ns/ttml#parameter">
<body>
<div xml:lang="en">
<p begin="0" end="1">The first line</p>
</div>
</body>
- </tt>'''.encode('utf-8')
+ </tt>'''
srt_data = '''1
00:00:00,000 --> 00:00:01,000
The first line
@@ -1432,7 +1447,7 @@ The first line
'''
self.assertEqual(dfxp2srt(dfxp_data_no_default_namespace), srt_data)
- dfxp_data_with_style = '''<?xml version="1.0" encoding="utf-8"?>
+ dfxp_data_with_style = b'''<?xml version="1.0" encoding="utf-8"?>
<tt xmlns="http://www.w3.org/2006/10/ttaf1" xmlns:ttp="http://www.w3.org/2006/10/ttaf1#parameter" ttp:timeBase="media" xmlns:tts="http://www.w3.org/2006/10/ttaf1#style" xml:lang="en" xmlns:ttm="http://www.w3.org/2006/10/ttaf1#metadata">
<head>
<styling>
@@ -1450,7 +1465,7 @@ The first line
<p style="s1" tts:textDecoration="underline" begin="00:00:09.56" id="p2" end="00:00:12.36"><span style="s2" tts:color="lime">inner<br /> </span>style</p>
</div>
</body>
-</tt>'''.encode('utf-8')
+</tt>'''
srt_data = '''1
00:00:02,080 --> 00:00:05,840
<font color="white" face="sansSerif" size="16">default style<font color="red">custom style</font></font>
@@ -1670,6 +1685,9 @@ Line 1
self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'foo', html)), [])
self.assertEqual(list(get_elements_text_and_html_by_attribute('class', 'no-such-foo', html)), [])
+ self.assertEqual(list(get_elements_text_and_html_by_attribute(
+ 'class', 'foo', '<a class="foo">nice</a><span class="foo">nice</span>', tag='a')), [('nice', '<a class="foo">nice</a>')])
+
GET_ELEMENT_BY_TAG_TEST_STRING = '''
random text lorem ipsum</p>
<div>
@@ -1757,7 +1775,7 @@ Line 1
def test(ll, idx, val, cache):
self.assertEqual(ll[idx], val)
- self.assertEqual(getattr(ll, '_LazyList__cache'), list(cache))
+ self.assertEqual(ll._cache, list(cache))
ll = LazyList(range(10))
test(ll, 0, 0, range(1))
@@ -1795,6 +1813,302 @@ Line 1
self.assertEqual(Config.hide_login_info(['--username=foo']),
['--username=PRIVATE'])
+ def test_locked_file(self):
+ TEXT = 'test_locked_file\n'
+ FILE = 'test_locked_file.ytdl'
+ MODES = 'war' # Order is important
+
+ try:
+ for lock_mode in MODES:
+ with locked_file(FILE, lock_mode, False) as f:
+ if lock_mode == 'r':
+ self.assertEqual(f.read(), TEXT * 2, 'Wrong file content')
+ else:
+ f.write(TEXT)
+ for test_mode in MODES:
+ testing_write = test_mode != 'r'
+ try:
+ with locked_file(FILE, test_mode, False):
+ pass
+ except (BlockingIOError, PermissionError):
+ if not testing_write: # FIXME
+ print(f'Known issue: Exclusive lock ({lock_mode}) blocks read access ({test_mode})')
+ continue
+ self.assertTrue(testing_write, f'{test_mode} is blocked by {lock_mode}')
+ else:
+ self.assertFalse(testing_write, f'{test_mode} is not blocked by {lock_mode}')
+ finally:
+ with contextlib.suppress(OSError):
+ os.remove(FILE)
+
+ def test_determine_file_encoding(self):
+ self.assertEqual(determine_file_encoding(b''), (None, 0))
+ self.assertEqual(determine_file_encoding(b'--verbose -x --audio-format mkv\n'), (None, 0))
+
+ self.assertEqual(determine_file_encoding(b'\xef\xbb\xbf'), ('utf-8', 3))
+ self.assertEqual(determine_file_encoding(b'\x00\x00\xfe\xff'), ('utf-32-be', 4))
+ self.assertEqual(determine_file_encoding(b'\xff\xfe'), ('utf-16-le', 2))
+
+ self.assertEqual(determine_file_encoding(b'\xff\xfe# coding: utf-8\n--verbose'), ('utf-16-le', 2))
+
+ self.assertEqual(determine_file_encoding(b'# coding: utf-8\n--verbose'), ('utf-8', 0))
+ self.assertEqual(determine_file_encoding(b'# coding: someencodinghere-12345\n--verbose'), ('someencodinghere-12345', 0))
+
+ self.assertEqual(determine_file_encoding(b'#coding:utf-8\n--verbose'), ('utf-8', 0))
+ self.assertEqual(determine_file_encoding(b'# coding: utf-8 \r\n--verbose'), ('utf-8', 0))
+
+ self.assertEqual(determine_file_encoding('# coding: utf-32-be'.encode('utf-32-be')), ('utf-32-be', 0))
+ self.assertEqual(determine_file_encoding('# coding: utf-16-le'.encode('utf-16-le')), ('utf-16-le', 0))
+
+ def test_get_compatible_ext(self):
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None, None], vexts=['mp4'], aexts=['m4a', 'm4a']), 'mkv')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['flv'], aexts=['flv']), 'flv')
+
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['m4a']), 'mp4')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['webm']), 'mkv')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['m4a']), 'mkv')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['webm']), 'webm')
+
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['h264'], acodecs=['mp4a'], vexts=['mov'], aexts=['m4a']), 'mp4')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['av01.0.12M.08'], acodecs=['opus'], vexts=['mp4'], aexts=['webm']), 'webm')
+
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['vp9'], acodecs=['opus'], vexts=['webm'], aexts=['webm'], preferences=['flv', 'mp4']), 'mp4')
+ self.assertEqual(get_compatible_ext(
+ vcodecs=['av1'], acodecs=['mp4a'], vexts=['webm'], aexts=['m4a'], preferences=('webm', 'mkv')), 'mkv')
+
+ def test_traverse_obj(self):
+ _TEST_DATA = {
+ 100: 100,
+ 1.2: 1.2,
+ 'str': 'str',
+ 'None': None,
+ '...': ...,
+ 'urls': [
+ {'index': 0, 'url': 'https://www.example.com/0'},
+ {'index': 1, 'url': 'https://www.example.com/1'},
+ ],
+ 'data': (
+ {'index': 2},
+ {'index': 3},
+ ),
+ 'dict': {},
+ }
+
+ # Test base functionality
+ self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
+ msg='allow tuple path')
+ self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str',
+ msg='allow list path')
+ self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str',
+ msg='allow iterable path')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str',
+ msg='single items should be treated as a path')
+ self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA)
+ self.assertEqual(traverse_obj(_TEST_DATA, 100), 100)
+ self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2)
+
+ # Test Ellipsis behavior
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
+ (item for item in _TEST_DATA.values() if item is not None),
+ msg='`...` should give all values except `None`')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
+ msg='`...` selection for dicts should select all values')
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='nested `...` queries should work')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
+ msg='`...` query result should be flattened')
+
+ # Test function as key
+ self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
+ [_TEST_DATA['urls']],
+ msg='function as query key should perform a filter based on (key, value)')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
+ msg='exceptions in the query function should be catched')
+
+ # Test alternative paths
+ self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
+ msg='multiple `paths` should be treated as alternative paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
+ msg='alternatives should exit early')
+ self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
+ msg='alternatives should return `default` if exhausted')
+ self.assertEqual(traverse_obj(_TEST_DATA, (..., 'fail'), 100), 100,
+ msg='alternatives should track their own branching return')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)), list(_TEST_DATA['data']),
+ msg='alternatives on empty objects should search further')
+
+ # Test branch and path nesting
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
+ msg='tuple as key should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'],
+ msg='list as key should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'],
+ msg='double nesting in path should be treated as paths')
+ self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1],
+ msg='do not fail early on branching')
+ self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='tripple nesting in path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))),
+ ['https://www.example.com/0', 'https://www.example.com/1'],
+ msg='ellipsis as branch path start gets flattened')
+
+ # Test dictionary as key
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2},
+ msg='dict key should result in a dict with the same keys')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}),
+ {0: 'https://www.example.com/0'},
+ msg='dict key should allow paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}),
+ {0: ['https://www.example.com/0']},
+ msg='tuple in dict path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}),
+ {0: ['https://www.example.com/0']},
+ msg='double nesting in dict path should be treated as paths')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
+ {0: ['https://www.example.com/1', 'https://www.example.com/0']},
+ msg='tripple nesting in dict path should be treated as branches')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
+ msg='remove `None` values when dict key')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
+ msg='do not remove `None` values if `default`')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
+ msg='do not remove empty values when dict key')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}},
+ msg='do not remove empty values when dict key and a default')
+ self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []},
+ msg='if branch in dict key not successful, return `[]`')
+
+ # Testing default parameter behavior
+ _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None,
+ msg='default value should be `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...), ...,
+ msg='chained fails should result in default')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0,
+ msg='should not short cirquit on `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1,
+ msg='invalid dict key should result in `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1,
+ msg='`None` is a deliberate sentinel and should become `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
+ msg='`IndexError` should result in `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1,
+ msg='if branched but not successful return `default` if defined, not `[]`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None), None,
+ msg='if branched but not successful return `default` even if `default` is `None`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail')), [],
+ msg='if branched but not successful return `[]`, not `default`')
+ self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [],
+ msg='if branched but object is empty return `[]`, not `default`')
+
+ # Testing expected_type behavior
+ _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str',
+ msg='accept matching `expected_type` type')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None,
+ msg='reject non matching `expected_type` type')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0',
+ msg='transform type using type function')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str',
+ expected_type=lambda _: 1 / 0), None,
+ msg='wrap expected_type fuction in try_call')
+ self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'],
+ msg='eliminate items that expected_type fails on')
+
+ # Test get_all behavior
+ _GET_ALL_DATA = {'key': [0, 1, 2]}
+ self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False), 0,
+ msg='if not `get_all`, return only first matching value')
+ self.assertEqual(traverse_obj(_GET_ALL_DATA, ..., get_all=False), [0, 1, 2],
+ msg='do not overflatten if not `get_all`')
+
+ # Test casesense behavior
+ _CASESENSE_DATA = {
+ 'KeY': 'value0',
+ 0: {
+ 'KeY': 'value1',
+ 0: {'KeY': 'value2'},
+ },
+ }
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None,
+ msg='dict keys should be case sensitive unless `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY',
+ casesense=False), 'value0',
+ msg='allow non matching key case if `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)),
+ casesense=False), ['value1'],
+ msg='allow non matching key case in branch if `casesense`')
+ self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)),
+ casesense=False), ['value2'],
+ msg='allow non matching key case in branch path if `casesense`')
+
+ # Test traverse_string behavior
+ _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2}
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None,
+ msg='do not traverse into string if not `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0),
+ traverse_string=True), 's',
+ msg='traverse into string if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1),
+ traverse_string=True), '.',
+ msg='traverse into converted data if `traverse_string`')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
+ traverse_string=True), list('str'),
+ msg='`...` branching into string should result in list')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
+ traverse_string=True), ['s', 'r'],
+ msg='branching into string should result in list')
+ self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
+ traverse_string=True), list('str'),
+ msg='function branching into string should result in list')
+
+ # Test is_user_input behavior
+ _IS_USER_INPUT_DATA = {'range8': list(range(8))}
+ self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
+ is_user_input=True), 3,
+ msg='allow for string indexing if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
+ is_user_input=True), tuple(range(8))[3:],
+ msg='allow for string slice if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
+ is_user_input=True), tuple(range(8))[:4:2],
+ msg='allow step in string slice if `is_user_input`')
+ self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
+ is_user_input=True), range(8),
+ msg='`:` should be treated as `...` if `is_user_input`')
+ with self.assertRaises(TypeError, msg='too many params should result in error'):
+ traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
+
+ # Test re.Match as input obj
+ mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123')
+ self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None],
+ msg='`...` on a `re.Match` should give its `groups()`')
+ self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'],
+ msg='function on a `re.Match` should give groupno, value starting at 0')
+ self.assertEqual(traverse_obj(mobj, 'group'), '3',
+ msg='str key on a `re.Match` should give group with that name')
+ self.assertEqual(traverse_obj(mobj, 2), '3',
+ msg='int key on a `re.Match` should give group with that name')
+ self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3',
+ msg='str key on a `re.Match` should respect casesense')
+ self.assertEqual(traverse_obj(mobj, 'fail'), None,
+ msg='failing str key on a `re.Match` should return `default`')
+ self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None,
+ msg='failing str key on a `re.Match` should return `default`')
+ self.assertEqual(traverse_obj(mobj, 8), None,
+ msg='failing int key on a `re.Match` should return `default`')
+
if __name__ == '__main__':
unittest.main()