Source code for meerkat.columns.scalar.pandas

from __future__ import annotations

import abc
import functools
import logging
import numbers
import os
from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Union

import numpy as np
import pandas as pd
import pyarrow as pa
from pandas.core.accessor import CachedAccessor
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.dtypes.common import (
    is_categorical_dtype,
    is_datetime64_dtype,
    is_datetime64tz_dtype,
    is_period_dtype,
    is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.indexes.accessors import (
    CombinedDatetimelikeProperties,
    DatetimeProperties,
    PeriodProperties,
    TimedeltaProperties,
)
from yaml.representer import Representer

from meerkat.block.abstract import BlockView
from meerkat.block.pandas_block import PandasBlock
from meerkat.columns.abstract import Column
from meerkat.interactive.formatter.base import BaseFormatter
from meerkat.mixins.aggregate import AggregationError
from meerkat.tools.lazy_loader import LazyLoader

from .abstract import ScalarColumn, StringMethods

torch = LazyLoader("torch")

if TYPE_CHECKING:
    import torch

    from meerkat.dataframe import DataFrame

Representer.add_representer(abc.ABCMeta, Representer.represent_name)

logger = logging.getLogger(__name__)


def getattr_decorator(fn: Callable):
    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        out = fn(*args, **kwargs)
        if isinstance(out, pd.Series):
            return PandasScalarColumn(out)
        elif isinstance(out, pd.DataFrame):
            from meerkat import DataFrame

            # column names must be str in meerkat
            out = out.rename(mapper=str, axis="columns")
            return DataFrame.from_pandas(out)
        else:
            return out

    return wrapper


class _ReturnColumnMixin:
    def __getattribute__(self, name):
        if name == "__class__":
            # This is needed to avoid _pickle.PicklingError: args[0] from __newobj__
            # args has the wrong class when pickling
            return super().__getattribute__(name)
        try:
            attr = super().__getattribute__(name)
            if isinstance(attr, Callable):
                return getattr_decorator(attr)
            elif isinstance(attr, pd.Series):
                return PandasScalarColumn(attr)
            elif isinstance(attr, pd.DataFrame):
                from meerkat import DataFrame

                return DataFrame.from_pandas(attr)
            else:
                return attr
        except AttributeError:
            raise AttributeError(f"object has no attribute '{name}'")


# class _MeerkatStringMethods(_ReturnColumnMixin, StringMethods):
#     def __init__(self, data: Column):
#         super().__init__(data.data)


class PandasStringMethods(StringMethods):
    def split(
        self, pat: str = None, n: int = -1, regex: bool = False, **kwargs
    ) -> "DataFrame":
        from meerkat import DataFrame

        return DataFrame(
            {
                str(name): self.column._clone(data=col)
                for name, col in self.column.data.str.split(
                    " ", n=n, regex=regex, expand=True
                ).items()
            }
        )

    def rsplit(
        self, pat: str = None, n: int = -1, regex: bool = False, **kwargs
    ) -> "DataFrame":
        from meerkat import DataFrame

        if regex is True:
            raise NotImplementedError("regex=True is not supported for rsplit")

        return DataFrame(
            {
                str(name): self.column._clone(data=col)
                for name, col in self.column.data.str.rsplit(
                    " ", n=n, expand=True
                ).items()
            }
        )

    def extract(self, pat: str, **kwargs) -> "DataFrame":
        from meerkat import DataFrame

        return DataFrame(
            {
                str(name): self.column._clone(data=col)
                for name, col in self.column.data.str.extract(
                    pat, expand=True, **kwargs
                ).items()
            }
        )


class _MeerkatDatetimeProperties(_ReturnColumnMixin, DatetimeProperties):
    pass


class _MeerkatTimedeltaProperties(_ReturnColumnMixin, TimedeltaProperties):
    pass


class _MeerkatPeriodProperties(_ReturnColumnMixin, PeriodProperties):
    pass


class _MeerkatCategoricalAccessor(_ReturnColumnMixin, CategoricalAccessor):
    pass


class _MeerkatCombinedDatetimelikeProperties(CombinedDatetimelikeProperties):
    def __new__(cls, data: pd.Series):
        # CombinedDatetimelikeProperties isn't really instantiated. Instead
        # we need to choose which parent (datetime or timedelta) is
        # appropriate. Since we're checking the dtypes anyway, we'll just
        # do all the validation here.

        if not isinstance(data, ABCSeries):
            raise TypeError(
                f"cannot convert an object of type {type(data)} to a datetimelike index"
            )

        orig = data if is_categorical_dtype(data.dtype) else None
        if orig is not None:
            data = data._constructor(
                orig.array,
                name=orig.name,
                copy=False,
                dtype=orig._values.categories.dtype,
            )

        if is_datetime64_dtype(data.dtype):
            obj = _MeerkatDatetimeProperties(data, orig)
        elif is_datetime64tz_dtype(data.dtype):
            obj = _MeerkatDatetimeProperties(data, orig)
        elif is_timedelta64_dtype(data.dtype):
            obj = _MeerkatTimedeltaProperties(data, orig)
        elif is_period_dtype(data.dtype):
            obj = _MeerkatPeriodProperties(data, orig)
        else:
            raise AttributeError("Can only use .dt accessor with datetimelike values")

        return obj


[docs]class PandasScalarColumn( ScalarColumn, np.lib.mixins.NDArrayOperatorsMixin, ): block_class: type = PandasBlock _HANDLED_TYPES = (np.ndarray, numbers.Number, str) dt = CachedAccessor("dt", _MeerkatCombinedDatetimelikeProperties) cat = CachedAccessor("cat", _MeerkatCategoricalAccessor) str = CachedAccessor("str", PandasStringMethods) # str = CachedAccessor("str", _MeerkatStringMethods) # plot = CachedAccessor("plot", pandas.plotting.PlotAccessor) # sparse = CachedAccessor("sparse", SparseAccessor) def _set_data(self, data: object): if isinstance(data, PandasScalarColumn): # unpack series if it is a PandasScalarColumn data = data.data if isinstance(data, BlockView): if not isinstance(data.block, PandasBlock): raise ValueError( "Cannot create `PandasSeriesColumn` from a `BlockView` not " "referencing a `PandasBlock`." ) elif isinstance(data, pd.Series): # Force the index to be contiguous so that comparisons between different # pandas series columns are always possible. data = data.reset_index(drop=True) else: data = pd.Series(data) super(PandasScalarColumn, self)._set_data(data) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): out = kwargs.get("out", ()) for x in inputs + out: # Only support operations with instances of _HANDLED_TYPES. # Use ArrayLike instead of type(self) for isinstance to # allow subclasses that don't override __array_ufunc__ to # handle ArrayLike objects. if not isinstance(x, self._HANDLED_TYPES + (PandasScalarColumn,)): return NotImplemented # Defer to the implementation of the ufunc on unwrapped values. inputs = tuple( x.data if isinstance(x, PandasScalarColumn) else x for x in inputs ) if out: kwargs["out"] = tuple( x.data if isinstance(x, PandasScalarColumn) else x for x in out ) result = getattr(ufunc, method)(*inputs, **kwargs) if type(result) is tuple: # multiple return values return tuple(type(self)(x) for x in result) # pragma: no cover elif method == "at": # no return value return None # pragma: no cover else: # one return value return type(self)(result) def __getattr__(self, name): if name == "__getstate__" or name == "__setstate__": # for pickle, it's important to raise an attribute error if __getstate__ # or __setstate__ is called. Without this, pickle will use the __setstate__ # and __getstate__ of the underlying pandas Series raise AttributeError() try: out = getattr(object.__getattribute__(self, "data"), name) if isinstance(out, Callable): return getattr_decorator(out) else: return out except AttributeError: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) @classmethod def from_array(cls, data: np.ndarray, *args, **kwargs): return cls(data=data, *args, **kwargs) def _get(self, index, materialize: bool = True): index = self.block_class._convert_index(index) data = self._data.iloc[index] if self._is_batch_index(index): # only create a numpy array column return self._clone(data=data) else: return data def _set_cell(self, index, value): self._data.iloc[index] = value def _set_batch(self, indices, values): self._data.iloc[indices] = values @classmethod def concat(cls, columns: Sequence[PandasScalarColumn]): data = pd.concat([c.data for c in columns]) return columns[0]._clone(data=data) def _write_data(self, path: str) -> None: data_path = os.path.join(path, "data.pd") self.data.to_pickle(data_path) @staticmethod def _read_data( path: str, ): data_path = os.path.join(path, "data.pd") # Load in the data return pd.read_pickle(data_path) def _repr_cell(self, index) -> object: return self[index] def _get_default_formatters(self) -> BaseFormatter: # can't implement this as a class level property because then it will treat # the formatter as a method from meerkat.interactive.formatter import ( BooleanFormatterGroup, NumberFormatterGroup, TextFormatterGroup, ) if len(self) == 0: return super()._get_default_formatters() if self.dtype == object: return TextFormatterGroup() if self.dtype == pd.StringDtype: return TextFormatterGroup() cell = self[0] if isinstance(cell, np.generic): if isinstance(cell, np.bool_): return BooleanFormatterGroup() return NumberFormatterGroup(dtype=type(cell.item()).__name__) return super()._get_default_formatters() def _is_valid_primary_key(self): return self.data.is_unique def _keyidx_to_posidx(self, keyidx: Any) -> int: # TODO(sabri): when we implement indices, we should use them here if we have # one where_result = np.where(self.data == keyidx) if len(where_result[0]) == 0: raise KeyError(f"keyidx {keyidx} not found in column.") posidx = where_result[0][0] return int(posidx) def _keyidxs_to_posidxs(self, keyidxs: Sequence[Any]) -> np.ndarray: # FIXME: this implementation is very slow. This should be done with indices return np.array([self._keyidx_to_posidx(keyidx) for keyidx in keyidxs])
[docs] def sort( self, ascending: Union[bool, List[bool]] = True, kind: str = "quicksort" ) -> PandasScalarColumn: """Return a sorted view of the column. Args: ascending (Union[bool, List[bool]]): Whether to sort in ascending or descending order. If a list, must be the same length as `by`. Defaults to True. kind (str): The kind of sort to use. Defaults to 'quicksort'. Options include 'quicksort', 'mergesort', 'heapsort', 'stable'. Return: AbstractColumn: A view of the column with the sorted data. """ # calls argsort() function to retrieve ordered indices sorted_index = self.argsort(ascending, kind) return self[sorted_index]
[docs] def argsort( self, ascending: bool = True, kind: str = "quicksort" ) -> PandasScalarColumn: """Return indices that would sorted the column. Args: ascending (Union[bool, List[bool]]): Whether to sort in ascending or descending order. If a list, must be the same length as `by`. Defaults to True. kind (str): The kind of sort to use. Defaults to 'quicksort'. Options include 'quicksort', 'mergesort', 'heapsort', 'stable'. Return: PandasSeriesColumn: A view of the column with the sorted data. For now! Raises error when shape of input array is more than one error. """ num_columns = len(self.shape) # Raise error if array has more than one column if num_columns > 1: raise Exception("No implementation for array with more than one column.") # returns indices of descending order of array if not ascending: return (-1 * self.data).argsort(kind=kind) # returns indices of ascending order of array return self.data.argsort(kind=kind)
[docs] def to_tensor(self) -> "torch.Tensor": """Use `column.to_tensor()` instead of `torch.tensor(column)`, which is very slow.""" dtype = self.data.values.dtype if not np.issubdtype(dtype, np.number): raise ValueError( f"Cannot convert `PandasSeriesColumn` with dtype={dtype} to tensor." ) # TODO (Sabri): understand why `torch.tensor(column)` is so slow return torch.tensor(self.data.values)
[docs] def to_numpy(self) -> "torch.Tensor": return self.values
[docs] def to_pandas(self, allow_objects: bool = False) -> pd.Series: return self.data.reset_index(drop=True)
[docs] def to_arrow(self) -> pa.Array: return pa.array(self.data.values)
[docs] def is_equal(self, other: Column) -> bool: if other.__class__ != self.__class__: return False return (self.data.values == other.data.values).all()
[docs] def to_json(self) -> List[Any]: return self.data.tolist()
@property def dtype(self) -> Any: return self.data.dtype def equals(self, other: Column) -> bool: if other.__class__ != self.__class__: return False return self.data.equals(other.data) def _dispatch_aggregation_function(self, compute_fn: str, **kwargs): return getattr(self.data, compute_fn)(**kwargs) def mean(self, skipna: bool = True, **kwargs): try: return self.data.mean(skipna=skipna, **kwargs) except TypeError: raise AggregationError( "Cannot apply mean aggregation to Pandas Series with " f" dtype '{self.data.dtype}'." ) def _dispatch_arithmetic_function( self, other: ScalarColumn, compute_fn: str, right: bool, **kwargs ): if isinstance(other, Column): assert isinstance(other, PandasScalarColumn) other = other.data if right: compute_fn = f"r{compute_fn}" return self._clone( data=getattr(self.data, f"__{compute_fn}__")(other, **kwargs) ) def _dispatch_comparison_function( self, other: ScalarColumn, compute_fn: str, **kwargs ): if isinstance(other, Column): assert isinstance(other, PandasScalarColumn) other = other.data return self._clone( data=getattr(self.data, f"__{compute_fn}__")(other, **kwargs) ) def _dispatch_logical_function( self, other: ScalarColumn, compute_fn: str, **kwargs ): if isinstance(other, Column): assert isinstance(other, PandasScalarColumn) other = other.data if other is None: return self._clone(data=getattr(self.data, f"__{compute_fn}__")(**kwargs)) return self._clone( data=getattr(self.data, f"__{compute_fn}__")(other, **kwargs) ) def isin(self, values: Sequence[Any]) -> "PandasScalarColumn": return self._clone(data=self.data.isin(values)) def _dispatch_unary_function( self, compute_fn: str, _namespace: str = None, **kwargs ): if _namespace is not None: obj = getattr(self.data, _namespace) else: obj = self.data return self._clone(data=getattr(obj, compute_fn)(**kwargs))
PandasSeriesColumn = PandasScalarColumn