aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorpukkandan <pukkandan.ytdlp@gmail.com>2021-07-21 11:17:27 +0530
committerpukkandan <pukkandan.ytdlp@gmail.com>2021-07-21 11:30:06 +0530
commit352d63fdb52452f6e99d5603757c54c3f5c186d7 (patch)
tree80f265bad020115da9756f64becf3b872e273240
parent11f9be09122882b6308a396ff50e2dc141450316 (diff)
downloadhypervideo-pre-352d63fdb52452f6e99d5603757c54c3f5c186d7.tar.lz
hypervideo-pre-352d63fdb52452f6e99d5603757c54c3f5c186d7.tar.xz
hypervideo-pre-352d63fdb52452f6e99d5603757c54c3f5c186d7.zip
[utils] Improve `traverse_obj`
-rw-r--r--yt_dlp/extractor/youtube.py10
-rw-r--r--yt_dlp/utils.py21
2 files changed, 20 insertions, 11 deletions
diff --git a/yt_dlp/extractor/youtube.py b/yt_dlp/extractor/youtube.py
index aa0421a72..afe31a12d 100644
--- a/yt_dlp/extractor/youtube.py
+++ b/yt_dlp/extractor/youtube.py
@@ -1929,10 +1929,11 @@ class YoutubeIE(YoutubeBaseInfoExtractor):
return sts
def _mark_watched(self, video_id, player_responses):
- playback_url = url_or_none((traverse_obj(
- player_responses, ('playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'),
- expected_type=str) or [None])[0])
+ playback_url = traverse_obj(
+ player_responses, (..., 'playbackTracking', 'videostatsPlaybackUrl', 'baseUrl'),
+ expected_type=url_or_none, get_all=False)
if not playback_url:
+ self.report_warning('Unable to mark watched')
return
parsed_playback_url = compat_urlparse.urlparse(playback_url)
qs = compat_urlparse.parse_qs(parsed_playback_url.query)
@@ -2606,8 +2607,7 @@ class YoutubeIE(YoutubeBaseInfoExtractor):
self._get_requested_clients(url, smuggled_data),
video_id, webpage, master_ytcfg, player_url, identity_token))
- get_first = lambda obj, keys, **kwargs: (
- traverse_obj(obj, (..., *variadic(keys)), **kwargs) or [None])[0]
+ get_first = lambda obj, keys, **kwargs: traverse_obj(obj, (..., *variadic(keys)), **kwargs, get_all=False)
playability_statuses = traverse_obj(
player_responses, (..., 'playabilityStatus'), expected_type=dict, default=[])
diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py
index 4d3cbc7b4..4d12c0a8e 100644
--- a/yt_dlp/utils.py
+++ b/yt_dlp/utils.py
@@ -6225,7 +6225,7 @@ def load_plugins(name, suffix, namespace):
def traverse_obj(
- obj, *path_list, default=None, expected_type=None,
+ obj, *path_list, default=None, expected_type=None, get_all=True,
casesense=True, is_user_input=False, traverse_string=False):
''' Traverse nested list/dict/tuple
@param path_list A list of paths which are checked one by one.
@@ -6234,7 +6234,8 @@ def traverse_obj(
all the keys given in the tuple are traversed, and
"..." traverses all the keys in the object
@param default Default value to return
- @param expected_type Only accept final value of this type
+ @param expected_type Only accept final value of this type (Can also be any callable)
+ @param get_all Return all the values obtained from a path or only the first one
@param casesense Whether to consider dictionary keys as case sensitive
@param is_user_input Whether the keys are generated from user input. If True,
strings are converted to int/slice if necessary
@@ -6281,6 +6282,13 @@ def traverse_obj(
return None
return obj
+ if isinstance(expected_type, type):
+ type_test = lambda val: val if isinstance(val, expected_type) else None
+ elif expected_type is not None:
+ type_test = expected_type
+ else:
+ type_test = lambda val: val
+
for path in path_list:
depth = 0
val = _traverse_obj(obj, path)
@@ -6288,12 +6296,13 @@ def traverse_obj(
if depth:
for _ in range(depth - 1):
val = itertools.chain.from_iterable(v for v in val if v is not None)
- val = ([v for v in val if v is not None] if expected_type is None
- else [v for v in val if isinstance(v, expected_type)])
+ val = [v for v in map(type_test, val) if v is not None]
if val:
+ return val if get_all else val[0]
+ else:
+ val = type_test(val)
+ if val is not None:
return val
- elif expected_type is None or isinstance(val, expected_type):
- return val
return default