import inspect
import sys
import types
import typing
import warnings
import weakref
from collections import defaultdict
from collections.abc import Mapping
from functools import reduce, wraps
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import dill
import numpy as np
import pandas as pd
import yaml
from yaml.constructor import ConstructorError
from meerkat import env
from meerkat.tools.lazy_loader import LazyLoader
torch = LazyLoader("torch")
def is_subclass(v, cls):
"""Check if `v` is a subclass of `cls`, with guard for TypeError."""
try:
_is_subclass = issubclass(v, cls)
except TypeError:
_is_subclass = False
return _is_subclass
def has_var_kwargs(fn: Callable) -> bool:
"""Check if a function has variable keyword arguments e.g. **kwargs.
Args:
fn: The function to check.
Returns:
True if the function has variable keyword arguments, False otherwise.
"""
sig = inspect.signature(fn)
params = sig.parameters.values()
return any([True for p in params if p.kind == p.VAR_KEYWORD])
def has_var_args(fn: Callable) -> bool:
"""Check if a function has variable positional arguments e.g. *args.
Args:
fn: The function to check.
Returns:
True if the function has variable positional arguments, False otherwise.
"""
sig = inspect.signature(fn)
params = sig.parameters.values()
return any([True for p in params if p.kind == p.VAR_POSITIONAL])
def get_type_hint_args(type_hint):
"""Get the arguments of a type hint."""
if sys.version_info >= (3, 8):
# Python > 3.8
return typing.get_args(type_hint)
else:
return type_hint.__args__
def get_type_hint_origin(type_hint):
"""Get the origin of a type hint."""
if sys.version_info >= (3, 8):
# Python > 3.8
return typing.get_origin(type_hint)
else:
return type_hint.__origin__
[docs]class classproperty(property):
"""Taken from https://stackoverflow.com/a/13624858.
The behavior of class properties using the @classmethod and
@property decorators has changed across Python versions. This class
(should) provide consistent behavior across Python versions. See
https://stackoverflow.com/a/1800999 for more information.
"""
def __get__(self, owner_self, owner_cls):
return self.fget(owner_cls)
def deprecated(replacement: Optional[str] = None):
"""This is a decorator which can be used to mark functions as deprecated.
It will result in a warning being emitted when the function is used.
"""
def _decorator(func):
@wraps(func)
def new_func(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning) # turn off filter
warnings.warn(
"Call to deprecated function {}.".format(func.__name__) + ""
if new_func is None
else " Use {} instead.".format(replacement),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning) # reset filter
return func(*args, **kwargs)
return new_func
return _decorator
def requires(*packages):
"""Use this decorator to identify which packages must be installed.
It will raise an error if the function is called and these packages
are not available.
"""
def _decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
for package in packages:
fn_str = f"{func.__qualname__}()"
if not env.package_available(package):
raise ImportError(
f"Missing package `{package}` which is required for {fn_str}."
)
return func(*args, **kwargs)
return wrapped
return _decorator
class WeakMapping(Mapping):
def __init__(self):
self.refs: Dict[Any, weakref.ReferenceType] = {}
def __getitem__(self, key: str):
ref = self.refs[key]
obj = ref()
if obj is None:
raise KeyError(f"Object with key {key} no longer exists")
return obj
def __setitem__(self, key: str, value: Any):
self.refs[key] = weakref.ref(value)
def __delitem__(self, key: str):
del self.refs[key]
def __iter__(self):
return iter(self.refs)
def __len__(self):
return len(self.refs)
def nested_getattr(obj, attr, *args):
"""Get a nested property from an object.
# noqa: E501
Source: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties
"""
return reduce(lambda o, a: getattr(o, a, *args), [obj] + attr.split("."))
def nested_apply(obj: object, fn: callable, base_types: Tuple[type] = ()):
if isinstance(obj, base_types):
return fn(obj)
elif isinstance(obj, list):
return [nested_apply(v, fn=fn, base_types=base_types) for v in obj]
elif isinstance(obj, tuple):
return tuple(nested_apply(v, fn=fn, base_types=base_types) for v in obj)
elif isinstance(obj, dict):
return {
k: nested_apply(v, fn=fn, base_types=base_types) for k, v in obj.items()
}
else:
return fn(obj)
BACKWARDS_COMPAT_REPLACEMENTS = [
("meerkat.ml", "meerkat.nn"),
("meerkat.columns.numpy_column", "meerkat.columns.tensor.numpy"),
("NumpyArrayColumn", "NumPyTensorColumn"),
("meerkat.columns.tensor_column", "meerkat.columns.tensor.torch"),
("meerkat.columns.pandas_column", "meerkat.columns.scalar.pandas"),
("meerkat.columns.arrow_column", "meerkat.columns.scalar.arrow"),
("meerkat.columns.image_column", "meerkat.columns.deferred.image"),
("meerkat.columns.file_column", "meerkat.columns.deferred.file"),
("meerkat.columns.list_column", "meerkat.columns.object.base"),
("meerkat.block.lambda_block", "meerkat.block.deferred_block"),
(
"meerkat.interactive.app.src.lib.component.filter",
"meerkat.interactive.app.src.lib.component.core.filter",
),
("ListColumn", "ObjectColumn"),
("LambdaBlock", "DeferredBlock"),
("NumpyBlock", "NumPyBlock"),
]
class MeerkatDumper(yaml.Dumper):
@staticmethod
def _pickled_object_representer(dumper, data):
return dumper.represent_mapping(
"!PickledObject", {"class": data.__class__, "pickle": dill.dumps(data)}
)
@staticmethod
def _function_representer(dumper, data):
if data.__name__ == "<lambda>":
return dumper.represent_mapping(
"!Lambda",
{"code": inspect.getsource(data), "pickle": dill.dumps(data)},
)
if "<locals>" in data.__qualname__:
return dumper.represent_mapping(
"!NestedFunction",
{"code": inspect.getsource(data), "pickle": dill.dumps(data)},
)
return dumper.represent_name(data)
MeerkatDumper.add_multi_representer(object, MeerkatDumper._pickled_object_representer)
MeerkatDumper.add_representer(types.FunctionType, MeerkatDumper._function_representer)
class MeerkatLoader(yaml.FullLoader):
"""PyYaml does not load unimported modules for safety reasons.
We want to allow importing only meerkat modules
"""
def find_python_module(self, name: str, mark, unsafe=False):
try:
return super().find_python_module(name=name, mark=mark, unsafe=unsafe)
except ConstructorError as e:
if name.startswith("meerkat."):
__import__(name)
else:
raise e
return super().find_python_module(name=name, mark=mark, unsafe=unsafe)
def find_python_name(self, name: str, mark, unsafe=False):
for old, new in BACKWARDS_COMPAT_REPLACEMENTS:
if old in name:
name = name.replace(old, new)
if "." in name:
module_name, _ = name.rsplit(".", 1)
else:
module_name = "builtins"
try:
return super().find_python_name(name=name, mark=mark, unsafe=unsafe)
except ConstructorError as e:
if name.startswith("meerkat."):
__import__(module_name)
else:
raise e
return super().find_python_name(name=name, mark=mark, unsafe=unsafe)
@staticmethod
def _pickled_object_constructor(loader, node):
data = loader.construct_mapping(node)
return dill.loads(data["pickle"])
@staticmethod
def _function_constructor(loader, node):
data = loader.construct_mapping(node)
return dill.loads(data["pickle"])
MeerkatLoader.add_constructor(
"!PickledObject", MeerkatLoader._pickled_object_constructor
)
MeerkatLoader.add_constructor("!Lambda", MeerkatLoader._function_constructor)
def dump_yaml(obj: Any, path: str, **kwargs):
with open(path, "w") as f:
yaml.dump(obj, f, Dumper=MeerkatDumper, **kwargs)
def load_yaml(path: str, **kwargs):
with open(path, "r") as f:
return yaml.load(f, Loader=MeerkatLoader, **kwargs)
class MeerkatUnpickler(dill.Unpickler):
def find_class(self, module, name):
try:
return super().find_class(module, name)
except Exception:
for old, new in BACKWARDS_COMPAT_REPLACEMENTS:
if old in module:
module = module.replace(old, new)
return super().find_class(module, name)
def meerkat_dill_load(path: str):
"""Load dill file with backwards compatibility for old column names."""
return MeerkatUnpickler(open(path, "rb")).load()
def convert_to_batch_column_fn(
function: Callable, with_indices: bool, materialize: bool = True, **kwargs
):
"""Batch a function that applies to an example."""
def _function(batch: Sequence, indices: Optional[List[int]], *args, **kwargs):
# Pull out the batch size
batch_size = len(batch)
# Iterate and apply the function
outputs = None
for i in range(batch_size):
# Apply the unbatched function
if with_indices:
output = function(
batch[i] if materialize else batch[i],
indices[i],
*args,
**kwargs,
)
else:
output = function(
batch[i] if materialize else batch[i],
*args,
**kwargs,
)
if i == 0:
# Create an empty dict or list for the outputs
outputs = defaultdict(list) if isinstance(output, dict) else []
# Append the output
if isinstance(output, dict):
for k in output.keys():
outputs[k].append(output[k])
else:
outputs.append(output)
if isinstance(outputs, dict):
return dict(outputs)
return outputs
if with_indices:
# Just return the function as is
return _function
else:
# Wrap in a lambda to apply the indices argument
return lambda batch, *args, **kwargs: _function(batch, None, *args, **kwargs)
def convert_to_batch_fn(
function: Callable, with_indices: bool, materialize: bool = True, **kwargs
):
"""Batch a function that applies to an example."""
def _function(
batch: Dict[str, List], indices: Optional[List[int]], *args, **kwargs
):
# Pull out the batch size
batch_size = len(batch[list(batch.keys())[0]])
# Iterate and apply the function
outputs = None
for i in range(batch_size):
# Apply the unbatched function
if with_indices:
output = function(
{k: v[i] if materialize else v[i] for k, v in batch.items()},
indices[i],
*args,
**kwargs,
)
else:
output = function(
{k: v[i] if materialize else v[i] for k, v in batch.items()},
*args,
**kwargs,
)
if i == 0:
# Create an empty dict or list for the outputs
outputs = defaultdict(list) if isinstance(output, dict) else []
# Append the output
if isinstance(output, dict):
for k in output.keys():
outputs[k].append(output[k])
else:
outputs.append(output)
if isinstance(outputs, dict):
return dict(outputs)
return outputs
if with_indices:
# Just return the function as is
return _function
else:
# Wrap in a lambda to apply the indices argument
return lambda batch, *args, **kwargs: _function(batch, None, *args, **kwargs)
def convert_to_python(obj: Any):
"""Utility for converting NumPy and torch dtypes to native python types.
Useful when sending objects to frontend.
"""
import torch
if torch.is_tensor(obj):
obj = obj.numpy()
if isinstance(obj, np.generic):
obj = obj.item()
return obj
def translate_index(index, length: int):
def _is_batch_index(index):
# np.ndarray indexed with a tuple of length 1 does not return an np.ndarray
# so we match this behavior
return not (
isinstance(index, int) or (isinstance(index, tuple) and len(index) == 1)
)
# `index` should return a single element
if not _is_batch_index(index):
return index
from ..columns.scalar.abstract import ScalarColumn
from ..columns.tensor.abstract import TensorColumn
if isinstance(index, pd.Series):
index = index.values
if torch.is_tensor(index):
index = index.numpy()
if isinstance(index, tuple) or isinstance(index, list):
index = np.array(index)
if isinstance(index, ScalarColumn):
index = index.to_numpy()
if isinstance(index, TensorColumn):
if len(index.shape) == 1:
index = index.to_numpy()
else:
raise TypeError(
"`TensorColumn` index must have 1 axis, not {}".format(len(index.shape))
)
# `index` should return a batch
if isinstance(index, slice):
# int or slice index => standard list slicing
indices = np.arange(*index.indices(length))
elif isinstance(index, np.ndarray):
if len(index.shape) != 1:
raise TypeError(
"`np.ndarray` index must have 1 axis, not {}".format(len(index.shape))
)
if index.dtype == bool:
indices = np.where(index)[0]
else:
return index
else:
raise TypeError("Object of type {} is not a valid index".format(type(index)))
return indices
def choose_device(device: str = "auto"):
"""Choose the device to use for a Meerkat operation."""
from meerkat.config import config
if not config.system.use_gpu:
return "cpu"
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
return device