Source code for ebonite.build.provider.ml_model_multi

import os
import warnings
from typing import List, Optional, Union

from pyjackson import dumps
from pyjackson.decorators import cached_property

from ebonite.build.provider.utils import BuildableWithServer
from ebonite.core.analyzer import CanIsAMustHookMixin
from ebonite.core.analyzer.buildable import BuildableHook
from ebonite.build.provider.ml_model import MLModelProvider, read
from ebonite.core.objects import ArtifactCollection, Model, Requirements, Task
from ebonite.core.objects.artifacts import _RelativePathWrapper, CompositeArtifactCollection
from ebonite.core.objects.core import WithMetadataRepository
from ebonite.runtime.interface.ml_model import MODEL_BIN_PATH, MODELS_META_PATH
from ebonite.runtime.server import Server
from ebonite.utils.module import get_object_requirements
from ebonite.utils.log import logger


[docs]class MLModelMultiProvider(MLModelProvider): """Provider to put multiple models in one service :param models: List of Model instances :param server: Server instance to build with :param debug: Debug for instance""" def __init__(self, models: List[Model], server: Server, debug: bool = False): from ebonite.runtime.interface.ml_model import MultiModelLoader super(MLModelProvider, self).__init__(server, MultiModelLoader(), debug) self.models: List[Model] = models @cached_property def _requirements(self) -> Requirements: """Union of server, loader and all models requirements""" return (get_object_requirements(self.server) + get_object_requirements(self.loader) + sum((model.requirements for model in self.models), Requirements()))
[docs] def get_requirements(self): """Returns union of server, loader and all models requirements""" return self._requirements
[docs] def get_sources(self): """Returns models meta file and custom requirements""" return { MODELS_META_PATH: dumps([model.without_artifacts() for model in self.models]), **self._get_sources(), **{os.path.basename(f): read(f) for f in self.server.additional_sources} }
[docs] def get_artifacts(self) -> ArtifactCollection: """Returns binaries of models artifacts""" # TODO additional server binaries return CompositeArtifactCollection([ _RelativePathWrapper(m.artifact_any, os.path.join(MODEL_BIN_PATH, str(i))) for i, m in enumerate(self.models) ])
[docs] def get_python_version(self): versions = [model.params.get(Model.PYTHON_VERSION) for model in self.models] if len(set(versions)) > 1: logger.warn('Models in MultModelProvider have varying python versions in requirements') return max(versions)
[docs]class MultiModelBuildable(BuildableWithServer, WithMetadataRepository): def __init__(self, model_ids: List[Union[int, Model]], server_type: str, debug: bool = False): if len(model_ids) == 0: raise ValueError('model_ids must contain at least one model') super().__init__(server_type) if isinstance(model_ids[0], int): self.model_ids = model_ids self.models_cache = None else: self.models_cache = model_ids self.model_ids = [m.id for m in model_ids] self.bind_meta_repo(self.models_cache[0]._meta) self.debug = debug @property def task(self) -> Optional[Task]: tasks = [m.task for m in self.models] if len(set(t.id for t in tasks)) != 1: warnings.warn(f'Ambiguious task for buildable {self}') return tasks[0] @property def models(self) -> List[Model]: if self.models_cache is None: self._check_meta(False) self.models_cache = [self._meta.get_model_by_id(mid) for mid in self.model_ids] return self.models_cache
[docs] def get_provider(self) -> MLModelMultiProvider: return MLModelMultiProvider(self.models, self.server, self.debug)
[docs]class BuildableMultiModelHook(BuildableHook, CanIsAMustHookMixin):
[docs] def must_process(self, obj) -> bool: return isinstance(obj, list) and all(isinstance(o, Model) for o in obj)
[docs] def process(self, obj, **kwargs): server = kwargs.get('server') # TODO ??? debug = kwargs.get('debug', False) return MultiModelBuildable(obj, server.type, debug)