Source code for meerkat.columns.deferred.base

from __future__ import annotations

import logging
import os
import warnings
from typing import Callable, Collection, Sequence, Type, Union

import dill
import numpy as np
import yaml

from meerkat.block.abstract import BlockView
from meerkat.block.deferred_block import DeferredBlock, DeferredCellOp, DeferredOp
from meerkat.cells.abstract import AbstractCell
from meerkat.columns.abstract import Column
from meerkat.errors import ConcatWarning, ImmutableError
from meerkat.tools.lazy_loader import LazyLoader

Image = LazyLoader("PIL.Image")


logger = logging.getLogger(__name__)


[docs]class DeferredCell(AbstractCell): def __init__(self, data: DeferredCellOp): self._data = data @property def data(self) -> object: """Get the data associated with this cell.""" return self._data
[docs] def get(self, *args, **kwargs): return self.data._get()
def __eq__(self, other): return (other.__class__ == self.__class__) and (self.data == other.data) def __repr__(self): name = getattr(self.data.fn, "__qualname__", repr(self.data.fn)) return f"{self.__class__.__qualname__}(fn={name})" def __call__(self): return self.data._get()
[docs]class DeferredColumn(Column): block_class: type = DeferredBlock def __init__( self, data: Union[DeferredOp, BlockView], output_type: Type["Column"] = None, *args, **kwargs, ): self._output_type = output_type super(DeferredColumn, self).__init__(data, *args, **kwargs) def __call__( self, use_ray: bool = False, pbar: bool = False, num_blocks: int = None, blocks_per_window: int = None, batch_size: int = 1, ): from meerkat.ops.map import _materialize return _materialize( self, use_ray=use_ray, pbar=pbar, num_blocks=num_blocks, blocks_per_window=blocks_per_window, batch_size=batch_size, ) def _set(self, index, value): raise ImmutableError("LambdaColumn is immutable.") @property def fn(self) -> Callable: """Subclasses like `ImageColumn` should be able to implement their own version.""" return self.data.fn def _create_cell(self, data: object) -> DeferredCell: return DeferredCell(data=data) def _get(self, index, materialize: bool = False, _data: np.ndarray = None): index = self._translate_index(index) data = self.data._get(index=index, materialize=materialize) if isinstance(index, int): if materialize: return data else: return self._create_cell(data=data) elif isinstance(index, np.ndarray): # support for blocks if materialize: # materialize could change the data in unknown ways, cannot clone return self.convert_to_output_type(data=self.collate(data)) else: return self._clone(data=data) @classmethod def _state_keys(cls) -> Collection: return super()._state_keys() | {"_output_type"} @staticmethod def concat(columns: Sequence[DeferredColumn]): for c in columns: if c.fn != columns[0].fn: warnings.warn( ConcatWarning("Concatenating LambdaColumns with different `fn`.") ) break return columns[0]._clone(data=DeferredOp.concat([c.data for c in columns])) def _write_data(self, path): return self.data.write(os.path.join(path, "data"))
[docs] def is_equal(self, other: Column) -> bool: if other.__class__ != self.__class__: return False return self.data.is_equal(other.data)
@staticmethod def _read_data(path: str): try: return DeferredOp.read(path=os.path.join(path, "data")) except KeyError: # TODO(Sabri): Remove this in a future version, once we no longer need to # support old DataFrames. warnings.warn( "Reading a LambdaColumn stored in a format that will soon be" " deprecated. Please re-write the column to the new format." ) meta = yaml.load( open(os.path.join(path, "data", "meta.yaml")), Loader=yaml.FullLoader, ) if issubclass(meta["dtype"], Column): col = Column.read(os.path.join(path, "data")) else: raise ValueError( "Support for LambdaColumns based on a DataFrame is deprecated." ) state = dill.load(open(os.path.join(path, "state.dill"), "rb")) return DeferredOp( args=[col], kwargs={}, fn=state["fn"], is_batched_fn=False, batch_size=1, ) def _get_default_formatters(self) -> Callable: # materialize a sample into a column from meerkat.interactive.formatter.base import deferred_formatter_group col = self._get(index=slice(0, 1, 1), materialize=True) return deferred_formatter_group(col.formatters) def _repr_cell(self, idx): return self[idx] def convert_to_output_type(self, data: any): if self._output_type is None: from meerkat import column return column(data) return self._output_type(data)