Source code for ebonite.ext.ext_loader

import importlib
import sys
from types import ModuleType
from typing import Dict, List, Union

from ebonite.config import Core
from ebonite.utils.classproperty import classproperty
from ebonite.utils.importing import import_module, module_importable, module_imported
from ebonite.utils.log import logger


[docs]class Extension: """ Extension descriptor :param module: main extension module :param reqs: list of extension dependencies :param force: if True, disable lazy loading for this extension :param validator: boolean predicate which should evaluate to True for this extension to be loaded """ def __init__(self, module, reqs: List[str], force=True, validator=None): self.force = force self.reqs = reqs self.module = module self.validator = validator def __str__(self): return f'<Extension {self.module}>' def __repr__(self): return str(self) def __eq__(self, other): return self.module == other.module def __hash__(self): return hash(self.module)
[docs]class ExtensionDict(dict): """ :class:`_Extension` container """ def __init__(self, *extensions: Extension): super().__init__() for e in extensions: self[e.module] = e
def __tensorflow_major_version(): import tensorflow as tf return tf.__version__.split('.')[0] is_tf_v1, is_tf_v2 = lambda: __tensorflow_major_version() == '1', lambda: __tensorflow_major_version() == '2'
[docs]class ExtensionLoader: """ Class that tracks and loads extensions. """ builtin_extensions: Dict[str, Extension] = ExtensionDict( Extension('ebonite.ext.numpy', ['numpy'], False), Extension('ebonite.ext.pandas', ['pandas'], False), Extension('ebonite.ext.sklearn', ['sklearn'], False), Extension('ebonite.ext.tensorflow', ['tensorflow'], False, is_tf_v1), Extension('ebonite.ext.tensorflow_v2', ['tensorflow'], False, is_tf_v2), Extension('ebonite.ext.torch', ['torch'], False), Extension('ebonite.ext.catboost', ['catboost'], False), Extension('ebonite.ext.aiohttp', ['aiohttp', 'aiohttp_swagger']), Extension('ebonite.ext.flask', ['flask', 'flasgger'], False), Extension('ebonite.ext.sqlalchemy', ['sqlalchemy']), Extension('ebonite.ext.s3', ['boto3']), Extension('ebonite.ext.imageio', ['imageio']), Extension('ebonite.ext.lightgbm', ['lightgbm'], False), Extension('ebonite.ext.xgboost', ['xgboost'], False), Extension('ebonite.ext.docker', ['docker'], False) ) _loaded_extensions: Dict[Extension, ModuleType] = {} @classproperty def loaded_extensions(cls) -> Dict[Extension, ModuleType]: """ :return: List of loaded extensions """ return cls._loaded_extensions @classmethod def _setup_import_hook(cls, extensions: List[Extension]): """ Add import hook to sys.meta_path that will load extensions when their dependencies are imported :param extensions: list of :class:`.Extension` """ if len(extensions) == 0: return existing = [h for h in sys.meta_path if isinstance(h, _ImportLoadExtInterceptor)] if len(existing) > 0: hook = existing[0] hook.module_to_extension.update({req: e for e in extensions for req in e.reqs}) else: hook = _ImportLoadExtInterceptor( module_to_extension={req: e for e in extensions for req in e.reqs} ) sys.meta_path.insert(0, hook)
[docs] @classmethod def load_all(cls, try_lazy=True): """ Load all (builtin and additional) extensions :param try_lazy: if `False`, use force load for all builtin extensions """ for_hook = [] for ext in cls.builtin_extensions.values(): if not try_lazy or hasattr(sys, 'frozen') or ext.force: if all(module_importable(r) for r in ext.reqs): cls.load(ext) else: if all(module_imported(r) for r in ext.reqs): cls.load(ext) else: for_hook.append(ext) cls._setup_import_hook(for_hook) for mod in Core.ADDITIONAL_EXTENSIONS: cls.load(mod)
[docs] @classmethod def load(cls, extension: Union[str, Extension]): """ Load single extension :param extension: str of :class:`.Extension` instance to load """ if isinstance(extension, str): extension = Extension(extension, [], force=True) if extension not in cls._loaded_extensions and not module_imported(extension.module) and \ (extension.validator is None or extension.validator()): logger.debug('Importing extension module %s', extension.module) cls._loaded_extensions[extension] = import_module(extension.module)
class _ImportLoadExtRegisterer(importlib.abc.PathEntryFinder): """A hook that registers all modules that are being imported""" def __init__(self): self.imported = [] def find_module(self, fullname, path=None): self.imported.append(fullname) return None class _ImportLoadExtInterceptor(importlib.abc.Loader, importlib.abc.PathEntryFinder): """ Import hook implementation to load extensions on dependency import :param module_to_extension: dict requirement -> :class:`.Extension` """ def __init__(self, module_to_extension: Dict[str, Extension]): self.module_to_extension = module_to_extension def find_module(self, fullname, path=None): # hijack importing machinery return self def load_module(self, fullname): # change this hook to registering hook reg = _ImportLoadExtRegisterer() sys.meta_path = [reg] + [x for x in sys.meta_path if x is not self] try: # fallback to ordinary importing module = importlib.import_module(fullname) finally: # put this hook back sys.meta_path = [self] + [x for x in sys.meta_path if x is not reg] # check all that was imported and import all extensions that are ready for imported in reg.imported: if not module_imported(imported): continue extension = self.module_to_extension.get(imported) if extension is None: continue if all(module_imported(m) for m in extension.reqs): ExtensionLoader.load(extension) return module
[docs]def load_extensions(*exts: str): """ Load extensions :param exts: list of extension main modules """ for ext in exts: ExtensionLoader.load(ext)