Source code for meerkat.columns.tensor.torch

from __future__ import annotations

import abc
import functools
import logging
import os
from typing import TYPE_CHECKING, Callable, List, Mapping, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import pyarrow as pa
from yaml.representer import Representer

from meerkat.block.abstract import BlockView
from meerkat.block.torch_block import TorchBlock
from meerkat.mixins.cloneable import CloneableMixin
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.writers.concat_writer import ConcatWriter
from meerkat.writers.numpy_writer import NumpyMemmapWriter

from ..abstract import Column
from .abstract import TensorColumn

torch = LazyLoader("torch")

if TYPE_CHECKING:
    import torch

Representer.add_representer(abc.ABCMeta, Representer.represent_name)

Columnable = Union[Sequence, np.ndarray, pd.Series, "torch.Tensor"]

logger = logging.getLogger(__name__)


def getattr_decorator(fn: Callable):
    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        out = fn(*args, **kwargs)
        if isinstance(out, torch.Tensor):
            if out.ndim == 0:
                return out.clone().detach()
            return TorchTensorColumn(out)
        else:
            return out

    return wrapper


def _as_tensor(data: Union["torch.Tensor", np.ndarray, pd.Series]) -> "torch.Tensor":
    """Overloaded as_tensor function to support other data types."""
    if not isinstance(data, (np.ndarray, torch.Tensor)):
        data = np.asarray(data)
    return torch.as_tensor(data)


[docs]class TorchTensorColumn( np.lib.mixins.NDArrayOperatorsMixin, TensorColumn, ): block_class: type = TorchBlock def __init__( self, data: Sequence = None, *args, **kwargs, ): if isinstance(data, BlockView): if not isinstance(data.block, TorchBlock): raise ValueError( "Cannot create `TensorColumn` from a `BlockView` not " "referencing a `TensorBlock`." ) elif data is not None and not isinstance(data, TorchTensorColumn): if isinstance(data, Sequence) and len(data) > 0: # TODO: We need to apply this check and do proper conversion of every # element in the sequence. # e.g. a list of mixed ndarrays and torch tensors # [np.array, torch.Tensor] should work. if torch.is_tensor(data[0]): # np.asarray supports a list of numpy arrays (it simply stacks them # before putting them into an array) but torch.as_tensor does not. # we want to support this for consistency and because it is # important for map data = torch.stack(data) else: data = np.asarray(data) data = _as_tensor(data) super(TorchTensorColumn, self).__init__(data=data, *args, **kwargs) def __torch_function__(self, func, types, args=(), kwargs=None): def _process_arg(arg): if isinstance(arg, type(self)): return arg.data elif isinstance(arg, (List, Tuple)): # Specifically use list and tuple because these are # expected types for arguments in torch operations. return type(arg)([_process_arg(_a) for _a in arg]) elif isinstance(arg, Mapping): # All mappings can be converted to dictionaries # when processed by torch operations. return {_k: _process_arg(_a) for _k, _a in arg.items()} else: return arg def _process_ret(ret): # This function may need to be refactored into an instance method # because the from_data implementation is different for each # class. if isinstance(ret, torch.Tensor): if ret.ndim == 0: return ret.clone().detach() return self.from_data(ret) elif isinstance(ret, (List, Tuple)): return type(ret)([_process_arg(_a) for _a in ret]) elif isinstance(ret, Mapping): return {_k: _process_arg(_a) for _k, _a in ret.items()} else: return ret if kwargs is None: kwargs = {} args = [_process_arg(a) for a in args] ret = func(*args, **kwargs) return _process_ret(ret) def __getattr__(self, name): try: out = getattr(object.__getattribute__(self, "data"), name) if isinstance(out, Callable): return getattr_decorator(out) else: return out except AttributeError: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) def _get(self, index, materialize: bool = True): index = self.block_class._convert_index(index) data = self._data[index] if self._is_batch_index(index): # only create a numpy array column return self._clone(data=data) else: return data def _set_batch(self, indices, values): self._data[indices] = values @classmethod def concat(cls, columns: Sequence[TorchTensorColumn]): data = torch.cat([c.data for c in columns]) if issubclass(cls, CloneableMixin): return columns[0]._clone(data=data) return cls(data) @classmethod def get_writer(cls, mmap: bool = False, template: Column = None): if mmap: return NumpyMemmapWriter() else: return ConcatWriter(template=template, output_type=TorchTensorColumn) def _repr_cell(self, index) -> object: if len(self.shape) > 1: if len(self.shape) == 2 and self.shape[1] < 5: return self[index] return f"torch.Tensor(shape={self.shape[1:]})" else: return self[index] def _get_default_formatters(self) -> Callable: from meerkat.interactive.formatter import ( NumberFormatterGroup, TensorFormatterGroup, TextFormatterGroup, ) if len(self) == 0: return NumberFormatterGroup() if len(self.shape) > 1: return TensorFormatterGroup(dtype=str(self.dtype)) cell = self.data[0] if isinstance(cell, np.generic): return NumberFormatterGroup(dtype=type(cell.item()).__name__) return TextFormatterGroup()
[docs] @classmethod def from_data(cls, data: Union[Columnable, Column]): """Convert data to an EmbeddingColumn.""" if torch.is_tensor(data): return cls(data) else: return super(TorchTensorColumn, cls).from_data(data)
def _copy_data(self) -> "torch.Tensor": return self._data.clone() def _view_data(self) -> object: return self._data def _write_data(self, path: str) -> None: # Saving all cell data in a single pickle file torch.save(self.data, os.path.join(path, "data.pt")) @staticmethod def _read_data(path: str) -> "torch.Tensor": return torch.load(os.path.join(path, "data.pt"))
[docs] def sort( self, ascending: Union[bool, List[bool]] = True, kind: str = "quicksort" ) -> TorchTensorColumn: """Return a sorted view of the column. Args: ascending (Union[bool, List[bool]]): Whether to sort in ascending or descending order. If a list, must be the same length as `by`. Defaults to True. kind (str): The kind of sort to use. Defaults to 'quicksort'. Options include 'quicksort', 'mergesort', 'heapsort', 'stable'. Return: Column: A view of the column with the sorted data. """ # calls argsort() function to retrieve ordered indices sorted_index = self.argsort(ascending=ascending, kind=kind) return self[sorted_index]
[docs] def argsort( self, ascending: Union[bool, List[bool]] = True, kind: str = "quicksort" ) -> TorchTensorColumn: """Return indices that would sorted the column. Args: ascending (Union[bool, List[bool]]): Whether to sort in ascending or descending order. If a list, must be the same length as `by`. Defaults to True. kind (str): The kind of sort to use. Defaults to 'quicksort'. Options include 'quicksort', 'mergesort', 'heapsort', 'stable'. Return: TensorColumn: A view of the column with the sorted data. For now! Raises error when shape of input array is more than one error. """ try: self.size()[1] except IndexError: # Case 1: The array only has one column # returns indices of descending order of array if not ascending: return torch.argsort(self.data, dim=-1, descending=True) # returns indices of ascending order of array return torch.argsort(self.data, dim=-1, descending=False) else: # Case 2: The array has more than one column, raise error. raise Exception("No implementation for array with more than one column.")
[docs] def is_equal(self, other: Column) -> bool: return (other.__class__ == self.__class__) and (self.data == other.data).all()
def to_tensor(self) -> "torch.Tensor": return self.data
[docs] def to_pandas(self, allow_objects: bool = True) -> pd.Series: if len(self.shape) == 1: return pd.Series(self.to_numpy()) elif allow_objects: # can only create a 1-D series data = self.to_numpy() return pd.Series([data[int(idx)] for idx in range(len(self))]) else: # can only create a 1-D series return super().to_pandas()
[docs] def to_numpy(self) -> pd.Series: return self.data.detach().cpu().numpy()
[docs] def to_arrow(self) -> pa.Array: if len(self.shape) == 1: return pa.array(self.to_numpy()) else: return super().to_arrow()
def mean( self, dim: int = None, keepdim: bool = False, *args, **kwargs ) -> "torch.Tensor": # torch only supports mean for floating point dtypes if self.data.dtype not in [ torch.float, torch.double, torch.cfloat, torch.cdouble, torch.half, torch.bfloat16, ]: data = self.data.float() else: data = self.data if dim is not None: return data.mean(*args, dim=dim, keepdim=keepdim, **kwargs) else: return data.mean(*args, **kwargs).numpy().item()