diff options
| -rw-r--r-- | test/test_networking.py | 29 | ||||
| -rw-r--r-- | yt_dlp/YoutubeDL.py | 7 | ||||
| -rw-r--r-- | yt_dlp/networking/common.py | 40 | 
3 files changed, 65 insertions, 11 deletions
| diff --git a/test/test_networking.py b/test/test_networking.py index 9c33b0d4c..2622d24da 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -1035,17 +1035,17 @@ class TestRequestDirector:          assert isinstance(director.send(Request('http://')), FakeResponse)      def test_unsupported_handlers(self): -        director = RequestDirector(logger=FakeLogger()) -        director.add_handler(FakeRH(logger=FakeLogger())) -          class SupportedRH(RequestHandler):              _SUPPORTED_URL_SCHEMES = ['http']              def _send(self, request: Request):                  return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url) -        # This handler should by default take preference over FakeRH +        director = RequestDirector(logger=FakeLogger())          director.add_handler(SupportedRH(logger=FakeLogger())) +        director.add_handler(FakeRH(logger=FakeLogger())) + +        # First should take preference          assert director.send(Request('http://')).read() == b'supported'          assert director.send(Request('any://')).read() == b'' @@ -1072,6 +1072,27 @@ class TestRequestDirector:          director.add_handler(UnexpectedRH(logger=FakeLogger))          assert director.send(Request('any://')) +    def test_preference(self): +        director = RequestDirector(logger=FakeLogger()) +        director.add_handler(FakeRH(logger=FakeLogger())) + +        class SomeRH(RequestHandler): +            _SUPPORTED_URL_SCHEMES = ['http'] + +            def _send(self, request: Request): +                return Response(fp=io.BytesIO(b'supported'), headers={}, url=request.url) + +        def some_preference(rh, request): +            return (0 if not isinstance(rh, SomeRH) +                    else 100 if 'prefer' in request.headers +                    else -1) + +        director.add_handler(SomeRH(logger=FakeLogger())) +        director.preferences.add(some_preference) + +        assert director.send(Request('http://')).read() == b'' +        assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported' +  # XXX: do we want to move this to test_YoutubeDL.py?  class TestYoutubeDLNetworking: diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 87bca5bbe..666d89b46 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -34,7 +34,7 @@ from .extractor.common import UnsupportedURLIE  from .extractor.openload import PhantomJSwrapper  from .minicurses import format_text  from .networking import HEADRequest, Request, RequestDirector -from .networking.common import _REQUEST_HANDLERS +from .networking.common import _REQUEST_HANDLERS, _RH_PREFERENCES  from .networking.exceptions import (      HTTPError,      NoSupportingHandlers, @@ -683,7 +683,7 @@ class YoutubeDL:          self.params['http_headers'] = HTTPHeaderDict(std_headers, self.params.get('http_headers'))          self._load_cookies(self.params['http_headers'].get('Cookie'))  # compat          self.params['http_headers'].pop('Cookie', None) -        self._request_director = self.build_request_director(_REQUEST_HANDLERS.values()) +        self._request_director = self.build_request_director(_REQUEST_HANDLERS.values(), _RH_PREFERENCES)          if auto_init and auto_init != 'no_verbose_header':              self.print_debug_header() @@ -4077,7 +4077,7 @@ class YoutubeDL:          except HTTPError as e:  # TODO: Remove in a future release              raise _CompatHTTPError(e) from e -    def build_request_director(self, handlers): +    def build_request_director(self, handlers, preferences=None):          logger = _YDLLogger(self)          headers = self.params['http_headers'].copy()          proxies = self.proxies.copy() @@ -4106,6 +4106,7 @@ class YoutubeDL:                      },                  }),              )) +        director.preferences.update(preferences or [])          return director      def encode(self, s): diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py index 8fba8c1c5..584c7bb4d 100644 --- a/yt_dlp/networking/common.py +++ b/yt_dlp/networking/common.py @@ -31,8 +31,19 @@ from ..utils import (  )  from ..utils.networking import HTTPHeaderDict, normalize_url -if typing.TYPE_CHECKING: -    RequestData = bytes | Iterable[bytes] | typing.IO | None + +def register_preference(*handlers: type[RequestHandler]): +    assert all(issubclass(handler, RequestHandler) for handler in handlers) + +    def outer(preference: Preference): +        @functools.wraps(preference) +        def inner(handler, *args, **kwargs): +            if not handlers or isinstance(handler, handlers): +                return preference(handler, *args, **kwargs) +            return 0 +        _RH_PREFERENCES.add(inner) +        return inner +    return outer  class RequestDirector: @@ -40,12 +51,17 @@ class RequestDirector:      Helper class that, when given a request, forward it to a RequestHandler that supports it. +    Preference functions in the form of func(handler, request) -> int +    can be registered into the `preferences` set. These are used to sort handlers +    in order of preference. +      @param logger: Logger instance.      @param verbose: Print debug request information to stdout.      """      def __init__(self, logger, verbose=False):          self.handlers: dict[str, RequestHandler] = {} +        self.preferences: set[Preference] = set()          self.logger = logger  # TODO(Grub4k): default logger          self.verbose = verbose @@ -58,6 +74,16 @@ class RequestDirector:          assert isinstance(handler, RequestHandler), 'handler must be a RequestHandler'          self.handlers[handler.RH_KEY] = handler +    def _get_handlers(self, request: Request) -> list[RequestHandler]: +        """Sorts handlers by preference, given a request""" +        preferences = { +            rh: sum(pref(rh, request) for pref in self.preferences) +            for rh in self.handlers.values() +        } +        self._print_verbose('Handler preferences for this request: %s' % ', '.join( +            f'{rh.RH_NAME}={pref}' for rh, pref in preferences.items())) +        return sorted(self.handlers.values(), key=preferences.get, reverse=True) +      def _print_verbose(self, msg):          if self.verbose:              self.logger.stdout(f'director: {msg}') @@ -73,8 +99,7 @@ class RequestDirector:          unexpected_errors = []          unsupported_errors = [] -        # TODO (future): add a per-request preference system -        for handler in reversed(list(self.handlers.values())): +        for handler in self._get_handlers(request):              self._print_verbose(f'Checking if "{handler.RH_NAME}" supports this request.')              try:                  handler.validate(request) @@ -530,3 +555,10 @@ class Response(io.IOBase):      def getheader(self, name, default=None):          deprecation_warning('Response.getheader() is deprecated, use Response.get_header', stacklevel=2)          return self.get_header(name, default) + + +if typing.TYPE_CHECKING: +    RequestData = bytes | Iterable[bytes] | typing.IO | None +    Preference = typing.Callable[[RequestHandler, Request], int] + +_RH_PREFERENCES: set[Preference] = set() | 
