Source code for meerkat.columns.tensor.abstract

from typing import TYPE_CHECKING, List, Union

import numpy as np

from meerkat.block.abstract import BlockView
from meerkat.block.numpy_block import NumPyBlock
from meerkat.block.torch_block import TorchBlock
from meerkat.tools.lazy_loader import LazyLoader

from ..abstract import Column

torch = LazyLoader("torch")

if TYPE_CHECKING:
    from torch import TensorType

    TensorColumnTypes = Union[np.ndarray, TensorType]


[docs]class TensorColumn(Column): def __new__(cls, data: "TensorColumnTypes" = None, backend: str = None): from .numpy import NumPyTensorColumn from .torch import TorchTensorColumn backends = {"torch": TorchTensorColumn, "numpy": NumPyTensorColumn} if backend is not None: if backend not in backends: raise ValueError( f"Backend {backend} not supported. " f"Expected one of {list(backends.keys())}" ) else: return super().__new__(backends[backend]) if isinstance(data, BlockView): if isinstance(data.block, TorchBlock): backend = TorchTensorColumn elif isinstance(data.block, NumPyBlock): backend = NumPyTensorColumn if (cls is not TensorColumn) or (data is None): return super().__new__(cls) if isinstance(data, BlockView): if isinstance(data.block, TorchBlock): return super().__new__(TorchTensorColumn) elif isinstance(data.block, NumPyBlock): return super().__new__(NumPyTensorColumn) if isinstance(data, np.ndarray): return super().__new__(NumPyTensorColumn) elif torch.is_tensor(data): return super().__new__(TorchTensorColumn) elif isinstance(data, List): if len(data) == 0: raise ValueError( "Cannot create `TensorColumn` from empty list of tensors." ) elif torch.is_tensor(data[0]): return super().__new__(TorchTensorColumn) else: return super().__new__(NumPyTensorColumn) else: raise ValueError( f"Cannot create `TensorColumn` from object of type {type(data)}." )
# def _get_default_formatters(self): # from meerkat.interactive.formatter import TensorFormatterGroup # return TensorFormatterGroup()