from abc import abstractmethod
from ebonite.core.analyzer.base import Hook, analyzer_class
from ebonite.core.objects.dataset_type import (PRIMITIVES, BytesDatasetType, DatasetType, DictDatasetType,
ListDatasetType, PrimitiveDatasetType, TupleDatasetType,
TupleLikeListDatasetType)
[docs]class DatasetHook(Hook):
"""
Base hook type for :class:`DatasetAnalyzer`.
Analysis result is an instance of :class:`~ebonite.core.objects.DatasetType`
"""
[docs] @abstractmethod
def process(self, obj, **kwargs) -> DatasetType:
"""
Analyzes obj and returns result. Result type is determined by specific Hook class sub-hierarchy
:param obj: object to analyze
:param kwargs: additional information to be used for analysis
:return: analysis result
"""
pass # pragma: no cover
DatasetAnalyzer = analyzer_class(DatasetHook, DatasetType)
[docs]class PrimitivesHook(DatasetHook):
"""
Hook for primitive data, for example when you model outputs just one int
"""
[docs] def can_process(self, obj):
if type(obj) in PRIMITIVES:
return True
[docs] def must_process(self, obj):
return False
[docs] def process(self, obj, **kwargs) -> DatasetType:
return PrimitiveDatasetType(type(obj).__name__)
[docs]class OrderedCollectionHookDelegator(DatasetHook):
"""
Hook for list/tuple data
"""
[docs] def can_process(self, obj) -> bool:
return isinstance(obj, (list, tuple))
[docs] def must_process(self, obj) -> bool:
return False
[docs] def process(self, obj, **kwargs) -> DatasetType:
if isinstance(obj, tuple):
return TupleDatasetType([DatasetAnalyzer.analyze(o) for o in obj])
py_types = {type(o) for o in obj}
if len(obj) <= 1 or len(py_types) > 1:
return TupleLikeListDatasetType([DatasetAnalyzer.analyze(o) for o in obj])
if not py_types.intersection(PRIMITIVES): # py_types is guaranteed to be singleton set here
return TupleLikeListDatasetType([DatasetAnalyzer.analyze(o) for o in obj])
# optimization for large lists of same primitive type elements
return ListDatasetType(DatasetAnalyzer.analyze(obj[0]), len(obj))
[docs]class DictHookDelegator(DatasetHook):
"""
Hook for dict data
"""
[docs] def can_process(self, obj) -> bool:
return isinstance(obj, dict)
[docs] def must_process(self, obj) -> bool:
return False
[docs] def process(self, obj, **kwargs) -> DatasetType:
try:
items = {k: DatasetAnalyzer.analyze(o) for k, o in obj.items()}
except ValueError:
raise ValueError(f"Cant process {obj} with DictHookDelegator")
return DictDatasetType(items)
[docs]class BytesDatasetHook(DatasetHook):
"""
Hook for bytes objects
"""
[docs] def process(self, obj, **kwargs) -> DatasetType:
return BytesDatasetType()
[docs] def can_process(self, obj) -> bool:
return isinstance(obj, bytes)
[docs] def must_process(self, obj) -> bool:
return False