Source code for meerkat.columns.object.base

from __future__ import annotations

import abc
import logging
from typing import Sequence

import cytoolz as tz
import numpy as np
import pandas as pd
from PIL.Image import Image
from yaml.representer import Representer

from meerkat.columns.abstract import Column
from meerkat.mixins.cloneable import CloneableMixin

Representer.add_representer(abc.ABCMeta, Representer.represent_name)


logger = logging.getLogger(__name__)


[docs]class ObjectColumn(Column): def __init__( self, data: Sequence = None, *args, **kwargs, ): if data is not None: data = list(data) super(ObjectColumn, self).__init__(data=data, *args, **kwargs) @classmethod def from_list(cls, data: Sequence): return cls(data=data)
[docs] def batch( self, batch_size: int = 1, drop_last_batch: bool = False, collate: bool = True, *args, **kwargs, ): for i in range(0, len(self), batch_size): if drop_last_batch and i + batch_size > len(self): continue if collate: yield self.collate(self[i : i + batch_size]) else: yield self[i : i + batch_size]
@classmethod def concat(cls, columns: Sequence[ObjectColumn]): data = list(tz.concat([c.data for c in columns])) if issubclass(cls, CloneableMixin): return columns[0]._clone(data=data) return cls.from_list(data)
[docs] def is_equal(self, other: Column) -> bool: return (self.__class__ == other.__class__) and self.data == other.data
def _repr_cell(self, index) -> object: return self[index] def _get_default_formatters(self): from meerkat.interactive.formatter.image import ImageFormatterGroup sample = self[0] if isinstance(sample, Image): return ImageFormatterGroup() return super()._get_default_formatters()
[docs] def to_pandas(self, allow_objects: bool = False) -> pd.Series: return pd.Series([self[int(idx)] for idx in range(len(self))])
[docs] def to_numpy(self): return np.array(self.data)