Source code for meerkat.columns.scalar.abstract

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Any, List, Set, Tuple, Union

import numpy as np
import pandas as pd
import pyarrow as pa
from pandas.core.accessor import CachedAccessor

from meerkat.block.abstract import BlockView
from meerkat.block.arrow_block import ArrowBlock
from meerkat.block.pandas_block import PandasBlock
from meerkat.columns.tensor.abstract import TensorColumn
from meerkat.tools.lazy_loader import LazyLoader

from ..abstract import Column

torch = LazyLoader("torch")

if TYPE_CHECKING:
    import torch

    from meerkat.dataframe import DataFrame

ScalarColumnTypes = Union[np.ndarray, "torch.TensorType", pd.Series, List]


class StringMethods:
    def __init__(self, data: Column):
        self.column = data

    def len(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function("len", _namespace="str", **kwargs)

    # predicate str methods ScalarColumn of bools
    def isalnum(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "isalnum", _namespace="str", **kwargs
        )

    def isalpha(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "isalpha", _namespace="str", **kwargs
        )

    def isdecimal(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "isdecimal", _namespace="str", **kwargs
        )

    def isdigit(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "isdigit", _namespace="str", **kwargs
        )

    def islower(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "islower", _namespace="str", **kwargs
        )

    def isupper(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "isupper", _namespace="str", **kwargs
        )

    def isnumeric(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "isnumeric", _namespace="str", **kwargs
        )

    def isspace(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "isspace", _namespace="str", **kwargs
        )

    def istitle(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "istitle", _namespace="str", **kwargs
        )

    def center(self, width: int, fillchar: str = " ", **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "center", _namespace="str", width=width, fillchar=fillchar, **kwargs
        )

    # transform str methods
    def capitalize(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "capitalize", _namespace="str", **kwargs
        )

    def lower(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function("lower", _namespace="str", **kwargs)

    def upper(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function("upper", _namespace="str", **kwargs)

    def swapcase(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "swapcase", _namespace="str", **kwargs
        )

    def strip(self, to_strip: str = None, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "strip", _namespace="str", to_strip=to_strip, **kwargs
        )

    def lstrip(self, to_strip: str = None, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "lstrip", _namespace="str", to_strip=to_strip, **kwargs
        )

    def rstrip(self, to_strip: str = None, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "rstrip", _namespace="str", to_strip=to_strip, **kwargs
        )

    def replace(
        self, pat: str, repl: str, n: int = -1, regex: bool = False, **kwargs
    ) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "replace", _namespace="str", pat=pat, repl=repl, n=n, regex=regex, **kwargs
        )

    def title(self, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function("title", _namespace="str", **kwargs)

    def split(
        self, pat: str = None, n: int = -1, regex: bool = False, **kwargs
    ) -> "DataFrame":
        raise NotImplementedError()

    def rsplit(
        self, pat: str = None, n: int = -1, regex: bool = False, **kwargs
    ) -> "DataFrame":
        raise NotImplementedError()

    def startswith(self, pat: str, **kwargs) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "startswith", _namespace="str", pat=pat, **kwargs
        )

    def contains(self, pat: str, case: bool = True, regex: bool = True) -> ScalarColumn:
        return self.column._dispatch_unary_function(
            "contains", _namespace="str", pat=pat, case=case, regex=regex
        )

    def extract(self, pat: str, **kwargs) -> "DataFrame":
        return self.column._dispatch_unary_function(
            "extract", _namespace="str", pat=pat, **kwargs
        )


[docs]class ScalarColumn(Column): str = CachedAccessor("str", StringMethods) def __new__(cls, data: ScalarColumnTypes = None, backend: str = None): from .arrow import ArrowScalarColumn from .pandas import PandasScalarColumn if (cls is not ScalarColumn) or (data is None): return super().__new__(cls) backends = {"arrow": ArrowScalarColumn, "pandas": PandasScalarColumn} if backend is not None: if backend not in backends: raise ValueError( f"Cannot create `ScalarColumn` with backend '{backend}'. " f"Expected one of {list(backends.keys())}" ) else: return super().__new__(backends[backend]) if isinstance(data, BlockView): if isinstance(data.block, PandasBlock): return super().__new__(PandasScalarColumn) elif isinstance(data.block, ArrowBlock): return super().__new__(ArrowScalarColumn) else: raise ValueError( f"Cannot create `ScalarColumn` from object of type {type(data)}." ) if isinstance(data, (np.ndarray, torch.TensorType, pd.Series, List, Tuple)): return super().__new__(PandasScalarColumn) elif isinstance(data, pa.Array): return super().__new__(ArrowScalarColumn) elif isinstance(data, TensorColumn) and len(data.shape) == 1: return super().__new__(PandasScalarColumn) elif isinstance(data, ScalarColumn): return data else: raise ValueError( f"Cannot create `ScalarColumn` from object of type {type(data)}." ) def _dispatch_unary_function( self, compute_fn: str, _namespace: str = None, **kwargs ): raise NotImplementedError() @property def dtype(self, **kwargs) -> Union[pa.DataType, np.dtype]: raise NotImplementedError() # aggregation functions @abstractmethod def _dispatch_aggregation_function(self, compute_fn: str, **kwargs): raise NotImplementedError() def mean(self, skipna: bool = True, **kwargs) -> float: return self._dispatch_aggregation_function("mean", skipna=skipna, **kwargs) def median(self, skipna: bool = True, **kwargs) -> Any: return self._dispatch_aggregation_function("median", skipna=skipna, **kwargs) def mode(self, **kwargs) -> ScalarColumn: return self._dispatch_aggregation_function("mode", **kwargs) def var(self, ddof: int = 1, **kwargs) -> ScalarColumn: return self._dispatch_aggregation_function("var", ddof=ddof, **kwargs) def std(self, ddof: int = 1, **kwargs) -> ScalarColumn: return self._dispatch_aggregation_function("std", ddof=ddof, **kwargs) def min(self, skipna: bool = True, **kwargs) -> ScalarColumn: return self._dispatch_aggregation_function("min", skipna=skipna, **kwargs) def max(self, skipna: bool = True, **kwargs) -> ScalarColumn: return self._dispatch_aggregation_function("max", skipna=skipna, **kwargs) def sum(self, skipna: bool = True, **kwargs) -> Any: return self._dispatch_aggregation_function("sum", skipna=skipna, **kwargs) def product(self, skipna: bool = True, **kwargs) -> Any: return self._dispatch_aggregation_function("product", skipna=skipna, **kwargs) def any(self, skipna: bool = True, **kwargs) -> Any: return self._dispatch_aggregation_function("any", skipna=skipna, **kwargs) def all(self, skipna: bool = True, **kwargs) -> Any: return self._dispatch_aggregation_function("all", skipna=skipna, **kwargs) def unique(self, **kwargs) -> ScalarColumn: return self._dispatch_unary_function("unique", **kwargs) # arithmetic functions def _dispatch_arithmetic_function( self, other, compute_fn: str, right: bool, **kwargs ): raise NotImplementedError() def __add__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "add", right=False) def __radd__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "add", right=True) def __sub__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "sub", right=False) def __rsub__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "sub", right=True) def __mul__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "mul", right=False) def __rmul__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "mul", right=True) def __truediv__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "truediv", right=False) def __rtruediv__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "truediv", right=True) def __floordiv__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "floordiv", right=False) def __rfloordiv__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "floordiv", right=True) def __mod__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "mod", right=False) def __rmod__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "mod", right=True) def __pow__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "pow", right=False) def __rpow__(self, other: ScalarColumn): return self._dispatch_arithmetic_function(other, "pow", right=True) # comparison functions def _dispatch_comparison_function(self, other, compute_fn: str, **kwargs): raise NotImplementedError() def __eq__(self, other: ScalarColumn): return self._dispatch_comparison_function(other, "eq") def __ne__(self, other: ScalarColumn): return self._dispatch_comparison_function(other, "ne") def __lt__(self, other: ScalarColumn): return self._dispatch_comparison_function(other, "lt") def __le__(self, other: ScalarColumn): return self._dispatch_comparison_function(other, "le") def __gt__(self, other: ScalarColumn): return self._dispatch_comparison_function(other, "gt") def __ge__(self, other: ScalarColumn): return self._dispatch_comparison_function(other, "ge") # logical functions def _dispatch_logical_function(self, other, compute_fn: str, **kwargs): raise NotImplementedError() def __and__(self, other: ScalarColumn): return self._dispatch_logical_function(other, "and") def __or__(self, other: ScalarColumn): return self._dispatch_logical_function(other, "or") def __invert__(self): return self._dispatch_logical_function(None, "invert") def __xor__(self, other: ScalarColumn): return self._dispatch_logical_function(other, "xor") # containment functions def isin(self, values: Union[List, Set], **kwargs) -> ScalarColumn: raise NotImplementedError() def isna(self, **kwargs) -> ScalarColumn: return self._dispatch_unary_function("isna", **kwargs) def isnull(self, **kwargs) -> ScalarColumn: return self._dispatch_unary_function("isnull", **kwargs)