Source code for meerkat.columns.abstract

from __future__ import annotations

import abc
import logging
import pathlib
import reprlib
from ast import Dict
from copy import copy
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    List,
    Mapping,
    Optional,
    Sequence,
    Type,
    Union,
)

import numpy as np
import pandas as pd
import pyarrow as pa

import meerkat.config
from meerkat.errors import ConversionError
from meerkat.interactive.graph.marking import unmarked
from meerkat.interactive.node import NodeMixin
from meerkat.mixins.aggregate import AggregateMixin
from meerkat.mixins.blockable import BlockableMixin
from meerkat.mixins.cloneable import CloneableMixin
from meerkat.mixins.collate import CollateMixin
from meerkat.mixins.deferable import DeferrableMixin
from meerkat.mixins.identifiable import IdentifiableMixin
from meerkat.mixins.indexing import MaterializationMixin
from meerkat.mixins.inspect_fn import FunctionInspectorMixin
from meerkat.mixins.io import ColumnIOMixin
from meerkat.mixins.reactifiable import ReactifiableMixin
from meerkat.provenance import ProvenanceMixin, capture_provenance
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.tools.utils import convert_to_batch_column_fn, translate_index

if TYPE_CHECKING:
    import torch

    from meerkat.interactive.formatter.base import FormatterGroup

torch = LazyLoader("torch")  # noqa: F811

logger = logging.getLogger(__name__)


[docs]class Column( AggregateMixin, BlockableMixin, CloneableMixin, CollateMixin, ColumnIOMixin, FunctionInspectorMixin, IdentifiableMixin, DeferrableMixin, MaterializationMixin, NodeMixin, ProvenanceMixin, ReactifiableMixin, abc.ABC, ): """An abstract class for Meerkat columns.""" _data: Sequence = None # Path to a log directory logdir: pathlib.Path = pathlib.Path.home() / "meerkat/" # Create a directory logdir.mkdir(parents=True, exist_ok=True) _self_identifiable_group: str = "columns" def __init__( self, data: Sequence = None, collate_fn: Callable = None, formatters: FormatterGroup = None, *args, **kwargs, ): """ Args: data (Sequence, optional): [description]. Defaults to None. collate_fn (Callable, optional): [description]. Defaults to None. formatter (Callable, optional): . Defaults to None. """ # Assign to data self._set_data(data) super(Column, self).__init__( collate_fn=collate_fn, *args, **kwargs, ) self._formatters = ( formatters if formatters is not None else self._get_default_formatters() ) @unmarked() def __repr__(self): return ( f"column({reprlib.repr([x for x in self[:10]])}, " f"backend={type(self).__name__}" ) @unmarked() def __str__(self): return reprlib.repr([x for x in self[:10]]) def streamlit(self): return self._repr_pandas_() def _set_data(self, data): if self.is_blockable(): data = self._unpack_block_view(data) self._data = data def _is_valid_primary_key(self): """Subclasses should implement checks for ensuring that the column could be used as a valid primary key. Specifically, the check should ensure that the values in the column are unique. If the check does not pass, returns False. If the subclass has not implemented this method. """ return False def _keyidx_to_posidx(self, keyidx: Any) -> int: """Get the posidx of the first occurrence of the given keyidx. Raise a key error if the keyidx is not found. Args: keyidx: The keyidx to search for. Returns: The posidx of the first occurrence of the given keyidx. """ raise NotImplementedError() def _keyidxs_to_posidxs(self, keyidxs: Sequence[Any]) -> np.ndarray: """Get the posidxs of the given keyidxs. Raise a key error if any of the keyidxs are not found. Args: keyidxs: The keyidxs to search for. Returns: The posidxs of the given keyidxs. """ raise NotImplementedError() @property def data(self): """Get the underlying data.""" return self._data @data.setter def data(self, value): self._set_data(value) @property def metadata(self): return {} @classmethod def _state_keys(cls) -> set: """List of attributes that describe the state of the object.""" return {"_collate_fn", "_formatters"} def _get_cell(self, index: int, materialize: bool = True) -> Any: """Get a single cell from the column. Args: index (int): This is an index into the ALL rows, not just visible rows. In other words, we assume that the index passed in has already been remapped via `_remap_index`, if `self.visible_rows` is not `None`. materialize (bool, optional): Materialize and return the object. This argument is used by subclasses of `Column` that hold data in an unmaterialized format. Defaults to False. """ return self._data[index] def _get_batch(self, indices: np.ndarray, materialize: bool = True) -> Column: """Get a batch of cells from the column. Args: index (int): This is an index into the ALL rows, not just visible rows. In other words, we assume that the index passed in has already been remapped via `_remap_index`, if `self.visible_rows` is not `None`. materialize (bool, optional): Materialize and return the object. This argument is used by subclasses of `Column` that hold data in an unmaterialized format. Defaults to False. """ if materialize: return self.collate( [self._get_cell(int(i), materialize=materialize) for i in indices] ) else: return self.collate( [self._get_cell(int(i), materialize=materialize) for i in indices] ) def _get(self, index, materialize: bool = True, _data: np.ndarray = None): index = self._translate_index(index) if isinstance(index, int): if _data is None: _data = self._get_cell(index, materialize=materialize) return _data elif isinstance(index, np.ndarray): # support for blocks if _data is None: _data = self._get_batch(index, materialize=materialize) return self._clone(data=_data) def __getitem__(self, index): return self._get(index, materialize=False) def _set_cell(self, index, value): self._data[index] = value def _set_batch(self, indices: np.ndarray, values): for index, value in zip(indices, values): self._set_cell(int(index), value) def _set(self, index, value): index = self._translate_index(index) if isinstance(index, int): self._set_cell(index, value) elif isinstance(index, Sequence) or isinstance(index, np.ndarray): self._set_batch(index, value) else: raise ValueError def __setitem__(self, index, value): self._set(index, value) def _is_batch_index(self, index): # np.ndarray indexed with a tuple of length 1 does not return an np.ndarray # so we match this behavior return not ( isinstance(index, int) or (isinstance(index, tuple) and len(index) == 1) ) def _translate_index(self, index): return translate_index(index, length=len(self)) @staticmethod def _convert_to_batch_fn( function: Callable, with_indices: bool, materialize: bool = True, **kwargs ) -> callable: return convert_to_batch_column_fn( function=function, with_indices=with_indices, materialize=materialize, **kwargs, ) @unmarked() def __len__(self): self._reactive_warning("len", "col") return self.full_length() def full_length(self): if self._data is None: return 0 return len(self._data) @unmarked() def _repr_cell_(self, index) -> object: raise NotImplementedError def _get_default_formatters(self) -> "FormatterGroup": from meerkat.interactive.formatter import TextFormatter from meerkat.interactive.formatter.base import FormatterGroup # by default all object should have a `str` representation return FormatterGroup(base=TextFormatter()) @property def formatters(self) -> "FormatterGroup": return self._formatters @formatters.setter def formatters(self, formatters: Union["FormatterGroup", Dict]): if isinstance(formatters, dict): formatters = FormatterGroup(**dict) self._formatters = formatters def format(self, formatters: "FormatterGroup") -> Column: new_col = self.view() new_col.formatters = self.formatters.copy() new_col.formatters.update(formatters) return new_col @unmarked() def _repr_pandas_(self, max_rows: int = None) -> pd.Series: if max_rows is None: max_rows = meerkat.config.display.max_rows if len(self) > max_rows: col = pd.Series( [self._repr_cell(idx) for idx in range(max_rows // 2)] + [self._repr_cell(0)] + [ self._repr_cell(idx) for idx in range(len(self) - max_rows // 2, len(self)) ] ) else: col = pd.Series([self._repr_cell(idx) for idx in range(len(self))]) # TODO: if the objects have a _repr_html_ method, we should be able to use # that instead of explicitly relying on the column having a formatter. return ( col, self.formatters["base"] if self.formatters["base"] is None else self.formatters["base"].html, ) @unmarked() def _repr_html_(self, max_rows: int = None): # pd.Series objects do not implement _repr_html_ if max_rows is None: max_rows = meerkat.config.display.max_rows if len(self) > max_rows: pd_index = np.concatenate( ( np.arange(max_rows // 2), np.zeros(1), np.arange(len(self) - max_rows // 2, len(self)), ), ) else: pd_index = np.arange(len(self)) col_name = f"({self.__class__.__name__})" col, formatter = self._repr_pandas_(max_rows=max_rows) df = col.to_frame(name=col_name) df = df.set_index(pd_index.astype(int)) return df.to_html( max_rows=max_rows, formatters={col_name: formatter}, escape=False, ) def map( self, function: Callable, is_batched_fn: bool = False, batch_size: int = 1, inputs: Union[Mapping[str, str], Sequence[str]] = None, outputs: Union[Mapping[any, str], Sequence[str]] = None, output_type: Union[Mapping[str, Type["Column"]], Type["Column"]] = None, materialize: bool = True, **kwargs, ) -> Optional[Union[Dict, List, Column]]: from meerkat.ops.map import map return map( data=self, function=function, is_batched_fn=is_batched_fn, batch_size=batch_size, inputs=inputs, outputs=outputs, output_type=output_type, materialize=materialize, **kwargs, )
[docs] @capture_provenance() def filter( self, function: Callable, with_indices=False, input_columns: Optional[Union[str, List[str]]] = None, is_batched_fn: bool = False, batch_size: Optional[int] = 1, drop_last_batch: bool = False, num_workers: Optional[int] = 0, materialize: bool = True, # pbar: bool = False, **kwargs, ) -> Optional[Column]: """Filter the elements of the column using a function.""" # Return if `self` has no examples if not len(self): logger.info("Dataset empty, returning it .") return self # Get some information about the function function_properties = self._inspect_function( function, with_indices, is_batched_fn=is_batched_fn, materialize=materialize, **kwargs, ) assert function_properties.bool_output, "function must return boolean." # Map to get the boolean outputs and indices logger.info("Running `filter`, a new dataset will be returned.") outputs = self.map( function=function, with_indices=with_indices, # input_columns=input_columns, is_batched_fn=is_batched_fn, batch_size=batch_size, drop_last_batch=drop_last_batch, num_workers=num_workers, materialize=materialize, # pbar=pbar, **kwargs, ) indices = np.where(outputs)[0] return self[indices]
[docs] def sort( self, ascending: Union[bool, List[bool]] = True, kind: str = "quicksort" ) -> Column: """Return a sorted view of the column. Args: ascending (Union[bool, List[bool]]): Whether to sort in ascending or descending order. If a list, must be the same length as `by`. Defaults to True. kind (str): The kind of sort to use. Defaults to 'quicksort'. Options include 'quicksort', 'mergesort', 'heapsort', 'stable'. Return: Column: A view of the column with the sorted data. """ raise NotImplementedError
[docs] def argsort( self, ascending: Union[bool, List[bool]] = True, kind: str = "quicksort" ) -> Column: """Return indices that would sorted the column. Args: ascending (Union[bool, List[bool]]): Whether to sort in ascending or descending order. If a list, must be the same length as `by`. Defaults to True. kind (str): The kind of sort to use. Defaults to 'quicksort'. Options include 'quicksort', 'mergesort', 'heapsort', 'stable'. Return: Column: A view of the column with the sorted data. """ raise NotImplementedError
[docs] def sample( self, n: int = None, frac: float = None, replace: bool = False, weights: Union[str, np.ndarray] = None, random_state: Union[int, np.random.RandomState] = None, ) -> Column: """Select a random sample of rows from Column. Roughly equivalent to ``sample`` in Pandas https://pandas.pydata.org/docs/reference/api/panda s.DataFrame.sample.html. Args: n (int): Number of samples to draw. If `frac` is specified, this parameter should not be passed. Defaults to 1 if `frac` is not passed. frac (float): Fraction of rows to sample. If `n` is specified, this parameter should not be passed. replace (bool): Sample with or without replacement. Defaults to False. weights (np.ndarray): Weights to use for sampling. If `None` (default), the rows will be sampled uniformly. If a numpy array, the sample will be weighted accordingly. If weights do not sum to 1 they will be normalized to sum to 1. random_state (Union[int, np.random.RandomState]): Random state or seed to use for sampling. Return: Column: A random sample of rows from the DataFrame. """ from meerkat import sample return sample( data=self, n=n, frac=frac, replace=replace, weights=weights, random_state=random_state, )
def append(self, column: Column) -> None: # TODO(Sabri): implement a naive `ComposedColumn` for generic append and # implement specific ones for ListColumn, NumpyColumn etc. raise NotImplementedError @staticmethod def concat(columns: Sequence[Column]) -> None: # TODO(Sabri): implement a naive `ComposedColumn` for generic append and # implement specific ones for ListColumn, NumpyColumn etc. raise NotImplementedError
[docs] def is_equal(self, other: Column) -> bool: """Tests whether two columns. Args: other (Column): [description] """ raise NotImplementedError()
[docs] def batch( self, batch_size: int = 1, drop_last_batch: bool = False, collate: bool = True, num_workers: int = 0, materialize: bool = True, *args, **kwargs, ): """Batch the column. Args: batch_size: integer batch size drop_last_batch: drop the last batch if its smaller than batch_size collate: whether to collate the returned batches Returns: batches of data """ if ( self._get_batch.__func__ == Column._get_batch and self._get.__func__ == Column._get ): return torch.utils.data.DataLoader( self.mz if materialize else self, batch_size=batch_size, collate_fn=self.collate if collate else lambda x: x, drop_last=drop_last_batch, num_workers=num_workers, *args, **kwargs, ) else: batch_indices = [] indices = np.arange(len(self)) for i in range(0, len(self), batch_size): if drop_last_batch and i + batch_size > len(self): continue batch_indices.append(indices[i : i + batch_size]) return torch.utils.data.DataLoader( self.mz if materialize else self, sampler=batch_indices, batch_size=None, batch_sampler=None, drop_last=drop_last_batch, num_workers=num_workers, *args, **kwargs, )
@classmethod def get_writer(cls, mmap: bool = False, template: Column = None): from meerkat.writers.concat_writer import ConcatWriter if mmap: raise ValueError("Memmapping not supported with this column type.") else: return ConcatWriter(output_type=cls, template=template) Columnable = Union[Sequence, np.ndarray, pd.Series, "torch.Tensor"]
[docs] @classmethod # @capture_provenance() def from_data(cls, data: Union[Columnable, Column]): """Convert data to a meerkat column using the appropriate Column type.""" return column(data)
[docs] def head(self, n: int = 5) -> Column: """Get the first `n` examples of the column.""" return self[:n]
[docs] def tail(self, n: int = 5) -> Column: """Get the last `n` examples of the column.""" return self[-n:]
[docs] def to_pandas(self, allow_objects: bool = False) -> pd.Series: """Convert the column to a Pandas Series. If the column cannot be converted to a Pandas Series, this method will raise a `~meerkat.errors.ConversionError`. Returns: pd.Series: The column as a Pandas Series. """ raise ConversionError( f"Cannot convert column of type {type(self)} to Pandas Series." )
[docs] def to_arrow(self) -> pa.Array: """Convert the column to an Arrow Array. If the column cannot be converted to an Arrow Array, this method will raise a `~meerkat.errors.ConversionError`. Returns: pa.Array: The column as an Arrow Array. """ raise ConversionError( f"Cannot convert column of type {type(self)} to Arrow Array." )
[docs] def to_torch(self) -> "torch.Tensor": """Convert the column to a PyTorch Tensor. If the column cannot be converted to a PyTorch Tensor, this method will raise a `~meerkat.errors.ConversionError`. Returns: torch.Tensor: The column as a PyTorch Tensor. """ raise ConversionError( f"Cannot convert column of type {type(self)} to PyTorch Tensor." )
[docs] def to_numpy(self) -> np.ndarray: """Convert the column to a Numpy array. If the column cannot be converted to a Numpy array, this method will raise a `~meerkat.errors.ConversionError`. Returns: np.ndarray: The column as a Numpy array. """ raise ConversionError( f"Cannot convert column of type {type(self)} to Numpy array." )
def __array__(self) -> np.ndarray: """Convert the data to a numpy array.""" return self.to_numpy()
[docs] def to_json(self) -> dict: """Convert the column to a JSON object. If the column cannot be converted to a JSON object, this method will raise a `~meerkat.errors.ConversionError`. Returns: dict: The column as a JSON object. """ raise ConversionError( f"Cannot convert column of type {type(self)} to JSON object." )
def _copy_data(self) -> object: return copy(self._data) def _view_data(self) -> object: return self._data @property def is_mmap(self): return False
def infer_column_type(data: Sequence) -> Type[Column]: if isinstance(data, Column): return type(data) from .scalar.abstract import ScalarColumn if isinstance(data, pd.Series): return ScalarColumn if isinstance(data, (pa.Array, pa.ChunkedArray)): from .scalar.arrow import ArrowScalarColumn return ArrowScalarColumn if torch.is_tensor(data): from .tensor.torch import TorchTensorColumn # FIXME: Once we have a torch scalar column we should use that here # if len(data.shape) == 1: # return ScalarColumn(data.cpu().detach().numpy()) return TorchTensorColumn if isinstance(data, np.ndarray): if len(data.shape) == 1: from .scalar.pandas import ScalarColumn return ScalarColumn from .tensor.numpy import NumPyTensorColumn return NumPyTensorColumn if isinstance(data, Sequence): from .tensor.numpy import NumPyTensorColumn if len(data) != 0 and (isinstance(data[0], (np.ndarray, NumPyTensorColumn))): return NumPyTensorColumn from .tensor.torch import TorchTensorColumn if len(data) != 0 and ( isinstance(data[0], TorchTensorColumn) or torch.is_tensor(data[0]) ): return TorchTensorColumn if len(data) != 0 and isinstance(data[0], (str, int, float, bool, np.generic)): from .scalar.pandas import ScalarColumn return ScalarColumn from .object.base import ObjectColumn return ObjectColumn else: raise ValueError(f"Cannot create column out of data of type {type(data)}")
[docs]def column(data: Sequence, scalar_backend: str = None) -> Column: """Create a Meerkat column from data. The Meerkat column type is inferred from the type and structure of the data passed in. """ if isinstance(data, Column): # TODO: Need ton make this view but should decide where to do it exactly return data # .view() from .scalar.abstract import ScalarColumn column_type = infer_column_type(data) if column_type == ScalarColumn: return ScalarColumn(data, backend=scalar_backend) return column_type(data)