aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--test/test_networking.py29
-rw-r--r--yt_dlp/YoutubeDL.py7
-rw-r--r--yt_dlp/networking/common.py40
3 files changed, 65 insertions, 11 deletions
diff --git a/test/test_networking.py b/test/test_networking.py
index 9c33b0d4c..2622d24da 100644
--- a/test/test_networking.py
+++ b/test/test_networking.py
@@ -1035,17 +1035,17 @@ class TestRequestDirector:
assert isinstance(director.send(Request('http://')), FakeResponse)
def test_unsupported_handlers(self):
- director = RequestDirector(logger=FakeLogger())
- director.add_handler(FakeRH(logger=FakeLogger()))
-
class SupportedRH(RequestHandler):
_SUPPORTED_URL_SCHEMES = ['http']
def _send(self, request: Request):
return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url)
- # This handler should by default take preference over FakeRH
+ director = RequestDirector(logger=FakeLogger())
director.add_handler(SupportedRH(logger=FakeLogger()))
+ director.add_handler(FakeRH(logger=FakeLogger()))
+
+ # First should take preference
assert director.send(Request('http://')).read() == b'supported'
assert director.send(Request('any://')).read() == b''
@@ -1072,6 +1072,27 @@ class TestRequestDirector:
director.add_handler(UnexpectedRH(logger=FakeLogger))
assert director.send(Request('any://'))
+ def test_preference(self):
+ director = RequestDirector(logger=FakeLogger())
+ director.add_handler(FakeRH(logger=FakeLogger()))
+
+ class SomeRH(RequestHandler):
+ _SUPPORTED_URL_SCHEMES = ['http']
+
+ def _send(self, request: Request):
+ return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url)
+
+ def some_preference(rh, request):
+ return (0 if not isinstance(rh, SomeRH)
+ else 100 if 'prefer' in request.headers
+ else -1)
+
+ director.add_handler(SomeRH(logger=FakeLogger()))
+ director.preferences.add(some_preference)
+
+ assert director.send(Request('http://')).read() == b''
+ assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported'
+
# XXX: do we want to move this to test_YoutubeDL.py?
class TestYoutubeDLNetworking:
diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py
index 87bca5bbe..666d89b46 100644
--- a/yt_dlp/YoutubeDL.py
+++ b/yt_dlp/YoutubeDL.py
@@ -34,7 +34,7 @@ from .extractor.common import UnsupportedURLIE
from .extractor.openload import PhantomJSwrapper
from .minicurses import format_text
from .networking import HEADRequest, Request, RequestDirector
-from .networking.common import _REQUEST_HANDLERS
+from .networking.common import _REQUEST_HANDLERS, _RH_PREFERENCES
from .networking.exceptions import (
HTTPError,
NoSupportingHandlers,
@@ -683,7 +683,7 @@ class YoutubeDL:
self.params['http_headers'] = HTTPHeaderDict(std_headers, self.params.get('http_headers'))
self._load_cookies(self.params['http_headers'].get('Cookie')) # compat
self.params['http_headers'].pop('Cookie', None)
- self._request_director = self.build_request_director(_REQUEST_HANDLERS.values())
+ self._request_director = self.build_request_director(_REQUEST_HANDLERS.values(), _RH_PREFERENCES)
if auto_init and auto_init != 'no_verbose_header':
self.print_debug_header()
@@ -4077,7 +4077,7 @@ class YoutubeDL:
except HTTPError as e: # TODO: Remove in a future release
raise _CompatHTTPError(e) from e
- def build_request_director(self, handlers):
+ def build_request_director(self, handlers, preferences=None):
logger = _YDLLogger(self)
headers = self.params['http_headers'].copy()
proxies = self.proxies.copy()
@@ -4106,6 +4106,7 @@ class YoutubeDL:
},
}),
))
+ director.preferences.update(preferences or [])
return director
def encode(self, s):
diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py
index 8fba8c1c5..584c7bb4d 100644
--- a/yt_dlp/networking/common.py
+++ b/yt_dlp/networking/common.py
@@ -31,8 +31,19 @@ from ..utils import (
)
from ..utils.networking import HTTPHeaderDict, normalize_url
-if typing.TYPE_CHECKING:
- RequestData = bytes | Iterable[bytes] | typing.IO | None
+
+def register_preference(*handlers: type[RequestHandler]):
+ assert all(issubclass(handler, RequestHandler) for handler in handlers)
+
+ def outer(preference: Preference):
+ @functools.wraps(preference)
+ def inner(handler, *args, **kwargs):
+ if not handlers or isinstance(handler, handlers):
+ return preference(handler, *args, **kwargs)
+ return 0
+ _RH_PREFERENCES.add(inner)
+ return inner
+ return outer
class RequestDirector:
@@ -40,12 +51,17 @@ class RequestDirector:
Helper class that, when given a request, forward it to a RequestHandler that supports it.
+ Preference functions in the form of func(handler, request) -> int
+ can be registered into the `preferences` set. These are used to sort handlers
+ in order of preference.
+
@param logger: Logger instance.
@param verbose: Print debug request information to stdout.
"""
def __init__(self, logger, verbose=False):
self.handlers: dict[str, RequestHandler] = {}
+ self.preferences: set[Preference] = set()
self.logger = logger # TODO(Grub4k): default logger
self.verbose = verbose
@@ -58,6 +74,16 @@ class RequestDirector:
assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'
self.handlers[handler.RH_KEY] = handler
+ def _get_handlers(self, request: Request) -> list[RequestHandler]:
+ """Sorts handlers by preference, given a request"""
+ preferences = {
+ rh: sum(pref(rh, request) for pref in self.preferences)
+ for rh in self.handlers.values()
+ }
+ self._print_verbose('Handler preferences for this request: %s' % ', '.join(
+ f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items()))
+ return sorted(self.handlers.values(), key=preferences.get, reverse=True)
+
def _print_verbose(self, msg):
if self.verbose:
self.logger.stdout(f'director: {msg}')
@@ -73,8 +99,7 @@ class RequestDirector:
unexpected_errors = []
unsupported_errors = []
- # TODO (future): add a per-request preference system
- for handler in reversed(list(self.handlers.values())):
+ for handler in self._get_handlers(request):
self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')
try:
handler.validate(request)
@@ -530,3 +555,10 @@ class Response(io.IOBase):
def getheader(self, name, default=None):
deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)
return self.get_header(name, default)
+
+
+if typing.TYPE_CHECKING:
+ RequestData = bytes | Iterable[bytes] | typing.IO | None
+ Preference = typing.Callable[[RequestHandler, Request], int]
+
+_RH_PREFERENCES: set[Preference] = set()