From 88426d9446758c707fb511408f2d6f56de952db4 Mon Sep 17 00:00:00 2001 From: pukkandan Date: Wed, 8 Feb 2023 08:14:36 +0530 Subject: [compat_utils] Improve `passthrough_module` --- yt_dlp/compat/compat_utils.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) (limited to 'yt_dlp/compat/compat_utils.py') diff --git a/yt_dlp/compat/compat_utils.py b/yt_dlp/compat/compat_utils.py index 373389a46..f8679c98e 100644 --- a/yt_dlp/compat/compat_utils.py +++ b/yt_dlp/compat/compat_utils.py @@ -1,5 +1,6 @@ import collections import contextlib +import functools import importlib import sys import types @@ -22,6 +23,10 @@ def _is_package(module): return '__path__' in vars(module) +def _is_dunder(name): + return name.startswith('__') and name.endswith('__') + + class EnhancedModule(types.ModuleType): def __new__(cls, name, *args, **kwargs): if name not in sys.modules: @@ -44,7 +49,7 @@ class EnhancedModule(types.ModuleType): try: ret = super().__getattribute__(attr) except AttributeError: - if attr.startswith('__') and attr.endswith('__'): + if _is_dunder(attr): raise getter = getattr(self, '__getattr__', None) if not getter: @@ -53,7 +58,7 @@ class EnhancedModule(types.ModuleType): return ret.fget() if isinstance(ret, property) else ret -def passthrough_module(parent, child, allowed_attributes=None, *, callback=lambda _: None): +def passthrough_module(parent, child, allowed_attributes=(..., ), *, callback=lambda _: None): """Passthrough parent module into a child module, creating the parent if necessary""" parent = EnhancedModule(parent) @@ -68,24 +73,23 @@ def passthrough_module(parent, child, allowed_attributes=None, *, callback=lambd callback(attr) return ret + @functools.lru_cache(maxsize=None) def from_child(attr): nonlocal child - - if allowed_attributes is None: - if attr.startswith('__') and attr.endswith('__'): + if attr not in allowed_attributes: + if ... not in allowed_attributes or _is_dunder(attr): return _NO_ATTRIBUTE - elif attr not in allowed_attributes: - return _NO_ATTRIBUTE if isinstance(child, str): child = importlib.import_module(child, parent.__name__) - with contextlib.suppress(AttributeError): - return getattr(child, attr) - if _is_package(child): with contextlib.suppress(ImportError): - return importlib.import_module(f'.{attr}', child.__name__) + return passthrough_module(f'{parent.__name__}.{attr}', + importlib.import_module(f'.{attr}', child.__name__)) + + with contextlib.suppress(AttributeError): + return getattr(child, attr) return _NO_ATTRIBUTE -- cgit v1.2.3