Source code for meerkat.ops.embed

from typing import Callable, Union

import PIL

import meerkat as mk
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.tools.utils import choose_device

from .clip import clip
from .encoder import Encoder
from .registry import encoders
from .robust import robust
from .transformers import transformers

bit = LazyLoader(".bit")

torch = LazyLoader("torch")

__all__ = ["clip", "bit", "transformers", "robust", "embed"]


def infer_modality(col: mk.Column):
    if isinstance(col, mk.ImageColumn):
        return "image"
    elif isinstance(col, (mk.ScalarColumn, str)):
        return "text"
    elif isinstance(col, mk.ArrowScalarColumn):
        import pyarrow

        if isinstance(col[0], pyarrow.lib.StringScalar):
            return "text"
    else:
        raise ValueError(
            f"Cannot infer modality \
            from column of type {type(col)}. \
            Please pass in the modality argument explicitly \
            with `modality=text` or `modality=image`."
        )


# @cache(params=["encoder", "modality", ""])
[docs]def embed( data: Union[mk.DataFrame, mk.Column, str, PIL.Image.Image], input: str = None, encoder: Union[str, Encoder] = "clip", modality: str = None, out_col: str = None, device: Union[int, str] = "auto", mmap_dir: str = None, num_workers: int = 0, batch_size: int = 128, pbar: bool = True, **kwargs, ) -> Union[mk.DataFrame, mk.Column]: """Embed a column of data with an encoder from the encoder registry. Examples -------- Suppose you have an Image dataset (e.g. Imagenette, CIFAR-10) loaded into a `Meerkat DataFrame <https://github.com/robustness-gym/meerkat>`_. You can embed the images in the dataset with CLIP using a code snippet like: .. code-block:: python import meerkat as mk df = mk.datasets.get("imagenette") df = mk.embed( data=df, input_col="img", encoder="clip" ) Args: data (Union[mk.DataFrame, mk.AbstractColumn]): A dataframe or column containing the data to embed. input_col (str, optional): If ``data`` is a dataframe, the name of the column to embed. If ``data`` is a column, then the parameter is ignored. Defaults to None. encoder (Union[str, Encoder], optional): Name of the encoder to use. List supported encoders with ``domino.encoders``. Defaults to "clip". Alternatively, pass an :class:`~domino._embed.encoder.Encoder` object containing a custom encoder. modality (str, optional): The modality of the data to be embedded. Defaults to None, in which case the modality is inferred from the type of the input column. out_col (str, optional): The name of the column where the embeddings are stored. Defaults to None, in which case it is ``"{encoder}({input_col})"``. device (Union[int, str], optional): The device on which. Defaults to "cpu". mmap_dir (str, optional): The path to directory where a memory-mapped file containing the embeddings will be written. Defaults to None, in which case the embeddings are not memmapped. num_workers (int, optional): Number of worker processes used to load the data from disk. Defaults to 4. batch_size (int, optional): Size of the batches to used . Defaults to 128. **kwargs: Additional keyword arguments are passed to the encoder. To see supported arguments for each encoder, see the encoder documentation (e.g. :func:`~domino._embed.clip`). Returns: mk.DataFrame: A view of ``data`` with a new column containing the embeddings. This column will be named according to the ``out_col`` parameter. """ col = data if isinstance(data, mk.Column) else data[input] if len(data) == 0: return data device = choose_device(device) if out_col is None: out_col = f"{encoder}({input})" if modality is None: modality = infer_modality(col=col) # TODO(karan): a hacky way to handle error with processing # pyarrow.lib.StringScalars in a mk.ArrowArrayColumn if modality == "text" and isinstance(col, mk.ArrowScalarColumn): col = mk.ScalarColumn(col.to_pandas()) if isinstance(encoder, str): encoder = encoders.get(encoder, device=device, **kwargs) if isinstance(encoder, dict): if modality not in encoder: raise ValueError( f'Encoder "{encoder}" does not support modality "{modality}".' ) encoder = encoder[modality] out = _embed( col=col, encode=encoder.encode, preprocess=encoder.preprocess, collate=encoder.collate, device=device, mmap_dir=mmap_dir, num_workers=num_workers, batch_size=batch_size, pbar=pbar, ) if isinstance(data, mk.DataFrame): data[out_col] = out return data else: return out
def _embed( col: mk.Column, encode: Callable, preprocess: Callable, collate: Callable, device: int = None, mmap_dir: str = None, num_workers: int = 0, batch_size: int = 128, pbar: bool = True, ): def _encode(x): out = encode(_prepare_input(x)) if torch.is_tensor(out): out = out.cpu().detach().numpy() return out if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if preprocess is not None: embed_input = col.defer(preprocess) else: embed_input = col if collate is not None: embed_input.collate_fn = collate def _prepare_input(x): if isinstance(x, mk.Column): x = x.data if torch.is_tensor(x): x = x.to(device) return x with torch.no_grad(): out = embed_input.map( _encode, pbar=pbar, is_batched_fn=True, batch_size=batch_size, # num_workers=num_workers, # mmap=mmap_dir is not None, # mmap_path=None # if mmap_dir is None # else os.path.join(mmap_dir, "emb_mmap.npy"), # flush_size=128, ) return out