aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--yt_dlp/compat.py11
-rw-r--r--yt_dlp/utils.py69
2 files changed, 80 insertions, 0 deletions
diff --git a/yt_dlp/compat.py b/yt_dlp/compat.py
index b97d4512e..2bc6a6b7f 100644
--- a/yt_dlp/compat.py
+++ b/yt_dlp/compat.py
@@ -134,6 +134,16 @@ except AttributeError:
asyncio.run = compat_asyncio_run
+try: # >= 3.7
+ asyncio.tasks.all_tasks
+except AttributeError:
+ asyncio.tasks.all_tasks = asyncio.tasks.Task.all_tasks
+
+try:
+ import websockets as compat_websockets
+except ImportError:
+ compat_websockets = None
+
# Python 3.8+ does not honor %HOME% on windows, but this breaks compatibility with youtube-dl
# See https://github.com/yt-dlp/yt-dlp/issues/792
# https://docs.python.org/3/library/os.path.html#os.path.expanduser
@@ -303,6 +313,7 @@ __all__ = [
'compat_urllib_response',
'compat_urlparse',
'compat_urlretrieve',
+ 'compat_websockets',
'compat_xml_parse_error',
'compat_xpath',
'compat_zip',
diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py
index bb8d65cad..c5489d494 100644
--- a/yt_dlp/utils.py
+++ b/yt_dlp/utils.py
@@ -3,6 +3,7 @@
from __future__ import unicode_literals
+import asyncio
import base64
import binascii
import calendar
@@ -73,6 +74,7 @@ from .compat import (
compat_urllib_parse_unquote_plus,
compat_urllib_request,
compat_urlparse,
+ compat_websockets,
compat_xpath,
)
@@ -5311,3 +5313,70 @@ class Config:
def parse_args(self):
return self._parser.parse_args(list(self.all_args))
+
+
+class WebSocketsWrapper():
+ """Wraps websockets module to use in non-async scopes"""
+
+ def __init__(self, url, headers=None):
+ self.loop = asyncio.events.new_event_loop()
+ self.conn = compat_websockets.connect(
+ url, extra_headers=headers, ping_interval=None,
+ close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
+
+ def __enter__(self):
+ self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
+ return self
+
+ def send(self, *args):
+ self.run_with_loop(self.pool.send(*args), self.loop)
+
+ def recv(self, *args):
+ return self.run_with_loop(self.pool.recv(*args), self.loop)
+
+ def __exit__(self, type, value, traceback):
+ try:
+ return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
+ finally:
+ self.loop.close()
+ self.r_cancel_all_tasks(self.loop)
+
+ # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
+ # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
+ @staticmethod
+ def run_with_loop(main, loop):
+ if not asyncio.coroutines.iscoroutine(main):
+ raise ValueError(f'a coroutine was expected, got {main!r}')
+
+ try:
+ return loop.run_until_complete(main)
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ if hasattr(loop, 'shutdown_default_executor'):
+ loop.run_until_complete(loop.shutdown_default_executor())
+
+ @staticmethod
+ def _cancel_all_tasks(loop):
+ to_cancel = asyncio.tasks.all_tasks(loop)
+
+ if not to_cancel:
+ return
+
+ for task in to_cancel:
+ task.cancel()
+
+ loop.run_until_complete(
+ asyncio.tasks.gather(*to_cancel, loop=loop, return_exceptions=True))
+
+ for task in to_cancel:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ loop.call_exception_handler({
+ 'message': 'unhandled exception during asyncio.run() shutdown',
+ 'exception': task.exception(),
+ 'task': task,
+ })
+
+
+has_websockets = bool(compat_websockets)