Source code for meerkat.interactive.app.src.lib.component.core.match

import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union

import numpy as np
from fastapi import HTTPException

from meerkat.dataframe import DataFrame
from meerkat.interactive.app.src.lib.component.abstract import Component
from meerkat.interactive.endpoint import Endpoint, EndpointProperty, endpoint
from meerkat.interactive.event import EventInterface
from meerkat.interactive.graph import Store, reactive

if TYPE_CHECKING:
    from meerkat.ops.embed.encoder import Encoder

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

_SUPPORTED_BIN_OPS = {
    "Add": lambda x, y: x + y,
    "Sub": lambda x, y: x - y,
    "Mult": lambda x, y: x * y,
    "Div": lambda x, y: x / y,
    "Pow": lambda x, y: x**y,
}

_SUPPORTED_CALLS = {
    "concat": lambda *args: np.concatenate(args, axis=1),
}


def parse_query(query: str, encoder: Union[str, "Encoder"] = "clip"):
    return _parse_query(ast.parse(query, mode="eval").body, encoder=encoder)


def _parse_query(node: ast.AST, encoder: Union[str, "Encoder"]):
    import meerkat as mk

    if isinstance(node, ast.BinOp):
        return _SUPPORTED_BIN_OPS[node.op.__class__.__name__](
            _parse_query(node.left, encoder=encoder),
            _parse_query(node.right, encoder=encoder),
        )
    elif isinstance(node, ast.Call):
        return _SUPPORTED_CALLS[node.func.id](
            *[_parse_query(arg, encoder=encoder) for arg in node.args]
        )
    elif isinstance(node, ast.Constant):
        return mk.embed(
            data=mk.column([node.value]),
            encoder=encoder,
            num_workers=0,
            pbar=False,
        )
    else:
        node_repr = node.id if hasattr(node, "id") else node
        if isinstance(node_repr, str):
            node_repr = f"'{node_repr}'"
        raise ValueError(f"Unsupported query {node_repr}")


@endpoint()
def get_match_schema(df: DataFrame):
    import meerkat as mk
    from meerkat.interactive.api.routers.dataframe import (
        SchemaResponse,
        _get_column_infos,
    )

    columns = [
        k
        for k, v in df.items()
        if isinstance(v, mk.TensorColumn) and len(v.shape) == 2
        # TODO: We should know the provenance of embeddings and where they came from,
        # to explicitly check whether the encoder will match it in size.
    ]
    return SchemaResponse(
        id=df.id,
        columns=_get_column_infos(df, columns),
        nrows=len(df),
    )


def _calc_image_query(df: DataFrame, locs: list, against: str):
    """Calculate the negative samples for a match."""
    return df.loc[locs][against].mean(axis=0)


@endpoint()
def set_criterion(
    df: DataFrame,
    query: str,
    against: str,
    criterion: Store,
    positives: list = None,
    negatives: list = None,
    encoder: Union[str, "Encoder"] = None,
):
    """Match a query string against a DataFrame column.

    The `dataframe_id` remains the same as the original request.
    """
    if not isinstance(df, DataFrame):
        raise HTTPException(
            status_code=400, detail="`match` expects a ref containing a dataframe"
        )

    try:
        if not query and not negatives and not positives:
            return criterion

        query_embedding = 0.0
        if query:
            query_embedding = parse_query(query, encoder=encoder)
        if negatives:
            query_embedding = query_embedding - 0.25 * _calc_image_query(
                df, negatives, against
            )
        if positives:
            query_embedding = query_embedding + _calc_image_query(
                df, positives, against
            )

        match_criterion = MatchCriterion(
            query=query,
            against=against,
            query_embedding=query_embedding,
            name=f"match({against}, {query})",
            positives=positives,
            negatives=negatives,
        )
        criterion.set(match_criterion)

        if not (criterion.value is None or criterion.against is None):
            data_embedding = df[criterion.against]
            scores = (data_embedding @ criterion.query_embedding.T).squeeze()
            df[criterion.name] = scores
            df.set(df)

    except Exception as e:
        raise e

    return criterion


@dataclass
class MatchCriterion:
    against: str
    query: str
    name: str
    query_embedding: np.ndarray = None
    positives: list = None
    negatives: list = None


class OnGetMatchSchemaMatch(EventInterface):
    pass


class OnMatchMatch(EventInterface):
    criterion: MatchCriterion


_get_match_schema = get_match_schema


[docs]class Match(Component): df: DataFrame against: str text: str = "" title: str = "Match" enable_selection: bool = False reset_criterion: bool = False # TODO: Revisit this, how to deal with endpoint interfaces when there is composition # and positional arguments on_match: EndpointProperty[OnMatchMatch] = None get_match_schema: EndpointProperty[OnGetMatchSchemaMatch] = None on_clickminus: Endpoint = None on_unclickminus: Endpoint = None on_clickplus: Endpoint = None on_unclickplus: Endpoint = None on_reset: Endpoint = None
[docs] def __init__( self, df: DataFrame = None, *, against: str, text: str = "", encoder: Union[str, "Encoder"] = "clip", title: str = "Match", enable_selection: bool = False, reset_criterion: bool = False, on_match: EndpointProperty = None, get_match_schema: EndpointProperty = None, on_clickminus: Endpoint = None, on_unclickminus: Endpoint = None, on_clickplus: Endpoint = None, on_unclickplus: Endpoint = None, on_reset: Endpoint = None, ): """ Args: df: The DataFrame. against: The column to match against. text: The query text. encoder: The encoder to use. title: The title of the component. enable_selection: Whether to enable selection for image-based matching. reset_criterion: Whether to reset the criterion when on_reset is called. on_match: The endpoint to call when the match button is clicked. This endpoint will be called after ``self.criterion`` is set. """ super().__init__( df=df, against=against, text=text, title=title, enable_selection=enable_selection, reset_criterion=reset_criterion, on_match=on_match, get_match_schema=get_match_schema, on_clickminus=on_clickminus, on_unclickminus=on_unclickminus, on_clickplus=on_clickplus, on_unclickplus=on_unclickplus, on_reset=on_reset, ) # we do not add the against or the query to the partial, because we don't # want them to be maintained on the backend # if they are maintained on the backend, then a store update dispatch will # run on every key stroke self.get_match_schema = _get_match_schema.partial(df=self.df) self._criterion: MatchCriterion = Store( MatchCriterion(against=None, query=None, name=None), backend_only=True, ) self.negative_selection = Store([], backend_only=True) self.positive_selection = Store([], backend_only=True) self._mode: Store[ Literal[ "set_negative_selection", "set_positive_selection", "", ] ] = Store("") on_match = set_criterion.partial( df=self.df, encoder=encoder, criterion=self._criterion, positives=self.positive_selection, negatives=self.negative_selection, ) if self.on_match is not None: on_match = on_match.compose(self.on_match) self.on_match = on_match
def set_selection(self, selection: Store[list]): self.external_selection = selection self.enable_selection.set(True) self._positive_selection = Store([], backend_only=True) self._negative_selection = Store([], backend_only=True) self.on_clickminus = self.on_set_negative_selection.partial(self) self.on_clickplus = self.on_set_positive_selection.partial(self) self.on_unclickminus = self.on_unset_negative_selection.partial(self) self.on_unclickplus = self.on_unset_positive_selection.partial(self) self.on_reset = self.on_reset_selection.partial(self) self.on_external_selection_change(self.external_selection) @endpoint() def on_reset_selection(self): """Reset all the selections.""" self.negative_selection.set([]) self.positive_selection.set([]) self.external_selection.set([]) self._mode.set("") self._positive_selection.set([]) self._negative_selection.set([]) if self.reset_criterion: self._criterion.set(MatchCriterion(against=None, query=None, name=None)) @reactive() def on_external_selection_change(self, external_selection): if self._mode == "set_negative_selection": self.negative_selection.set(external_selection) elif self._mode == "set_positive_selection": self.positive_selection.set(external_selection) @endpoint() def on_set_negative_selection(self): if self._mode == "set_positive_selection": self._positive_selection.set(self.external_selection.value) self._mode.set("set_negative_selection") self.external_selection.set(self._negative_selection.value) @endpoint() def on_unset_negative_selection(self): self._negative_selection.set(self.external_selection.value) self._mode.set("") self.external_selection.set([]) @endpoint() def on_set_positive_selection(self): if self._mode == "set_negative_selection": self._negative_selection.set(self.external_selection.value) self._mode.set("set_positive_selection") self.external_selection.set(self._positive_selection.value) @endpoint() def on_unset_positive_selection(self): self._positive_selection.set(self.external_selection.value) self._mode.set("") self.external_selection.set([]) @property def criterion(self) -> MatchCriterion: return self._criterion