Source code for meerkat.interactive.endpoint

from __future__ import annotations

import inspect
import logging
import typing
from functools import partial, wraps
from typing import Any, Callable, Generic, Union

from fastapi import APIRouter, Body
from pydantic import BaseModel, create_model

from meerkat.interactive.graph import Store, trigger, unmarked
from meerkat.interactive.graph.store import _unpack_stores_from_object
from meerkat.interactive.node import Node, NodeMixin
from meerkat.interactive.types import T
from meerkat.mixins.identifiable import IdentifiableMixin, is_meerkat_id
from meerkat.state import state
from meerkat.tools.utils import get_type_hint_args, get_type_hint_origin, has_var_args

logger = logging.getLogger(__name__)

# KG: must declare this dynamically defined model here,
# otherwise we get a FastAPI error
# when only declaring this inside the Endpoint class.
FnPydanticModel = None


class SingletonRouter(type):
    """A metaclass that ensures that only one instance of a router is created,

    *for a given prefix*.

    A prefix is a string that is used to identify a router. For example,
    the prefix for the router that handles endpoints is "/endpoint". We
    want to ensure that only one router is created for each prefix.
    """

    _instances = {}

    def __call__(cls, *args, **kwargs):
        prefix = kwargs["prefix"]
        # Look up if this (cls, prefix) pair has been created before
        if (cls, prefix) not in cls._instances:
            # If not, we let a new instance be created
            cls._instances[(cls, prefix)] = super(SingletonRouter, cls).__call__(
                *args, **kwargs
            )
        return cls._instances[(cls, prefix)]


class SimpleRouter(IdentifiableMixin, APIRouter):  # , metaclass=SingletonRouter):
    # KG: using the SingletonRouter metaclass causes a bug.
    # app.include_router() inside Endpoint is called multiple times
    # for the same router. This causes an error because some
    # endpoints are registered multiple times because the FastAPI
    # class doesn't check if an endpoint is already registered.
    # As a patch, we're generating one router per Endpoint object
    # (this could generate multiple routers for the same prefix, but
    # that's not been a problem).
    """A very simple FastAPI router.

    This router allows you to pass in arbitrary keyword arguments that are
    passed to the FastAPI router, and sets sensible defaults for the
    prefix, tags, and responses.

    Note that if you create two routers with the same prefix, they will
    not be the same object.

    Attributes:
        prefix (str): The prefix for this router.
        **kwargs: Arbitrary keyword arguments that are passed to the FastAPI
            router.
    """

    _self_identifiable_group: str = "routers"

    def __init__(self, prefix: str, **kwargs):
        super().__init__(
            prefix=prefix,
            tags=[prefix.strip("/").replace("/", "-")],
            responses={404: {"description": "Not found"}},
            id=self.prepend_meerkat_id_prefix(prefix),
            **kwargs,
        )


class EndpointFrontend(BaseModel):
    """A schema for sending an endpoint to the frontend."""

    endpointId: Union[str, None]


# TODO: technically Endpoint doesn't need to be NodeMixin (probably)
class Endpoint(IdentifiableMixin, NodeMixin, Generic[T]):
    EmbeddedBody = partial(Body, embed=True)
    _self_identifiable_group: str = "endpoints"

    def __init__(
        self,
        fn: Callable = None,
        prefix: Union[str, APIRouter] = None,
        route: str = None,
    ):
        """Create an endpoint from a function in Meerkat.

        Typically, you will not need to call this class directly, but
        instead use the `endpoint` decorator.

        Attributes:
            fn (Callable): The function to create an endpoint from.
            prefix (str): The prefix for this endpoint.
            route (str): The route for this endpoint.

        Note:
        All endpoints can be hit with a POST request at
        /{endpoint_id}/dispatch/
        The request needs a JSON body with the following keys:
            - kwargs: a dictionary of keyword arguments to be
                passed to the endpoint function `fn`
            - payload: additional payload, if any

        Optionally, the user can customize how endpoints are
        organized by specifying a prefix and a route. The prefix
        is a string that is used to identify a router. For example,
        the prefix for the router that handles endpoints is "/endpoint".
        The route is a string that is used to identify an endpoint
        within a router. For example, the route for the endpoint
        that handles the `get` function could be "/get".

        If only a prefix is specified, then the route will be the
        name of the function e.g. "my_endpoint". If both a prefix
        and a route are specified, then the route will be the
        specified route e.g. "/specific/route/".

        Refer to the FastAPI documentation for more information
        on how to create routers and endpoints.
        """
        super().__init__()
        if fn is None:
            self.id = None
        self.fn = fn
        self._validate_fn()

        if prefix is None:
            # No prefix, no router
            self.router = None
        else:
            # Make the router
            if isinstance(prefix, APIRouter):
                self.router = prefix
            else:
                self.router = SimpleRouter(prefix=prefix)

        self.prefix = prefix
        self.route = route

    def __repr__(self) -> str:
        if hasattr(self.fn, "__name__"):
            name = self.fn.__name__
        elif hasattr(self.fn, "func"):
            name = self.fn.func.__name__
        else:
            name = None
        return (
            f"Endpoint(id={self.id}, name={name}, prefix={self.prefix}, "
            f"route={self.route})"
        )

    def _validate_fn(self):
        """Validate the function `fn`."""
        if not callable(self.fn):
            raise TypeError(f"Endpoint function {self.fn} is not callable.")

        # Disallow *args
        if has_var_args(self.fn):
            raise TypeError(
                f"Endpoint function {self.fn} has a `*args` parameter."
                " Please use keyword arguments instead."
            )

        # Do we allow lambdas?

    @property
    def frontend(self):
        return EndpointFrontend(
            endpointId=self.id,
        )

    def to_json(self):
        return {"endpointId": self._self_id}

    def run(self, *args, **kwargs) -> Any:
        """Actually run the endpoint function `fn`.

        Args:
            *args: Positional arguments to pass to `fn`.
            **kwargs: Keyword arguments to pass to `fn`.

        Returns:
            The return value of `fn`.
        """
        logger.debug(f"Running endpoint {self}.")

        # Apply a partial function to ingest the additional arguments
        # that are passed in
        partial_fn = partial(self.fn, *args, **kwargs)

        # Check if the partial_fn has any arguments left to be filled
        spec = inspect.getfullargspec(partial_fn)
        # Check if spec has no args: if it does have args,
        # it means that we can't call the function without filling them in
        no_args = len(spec.args) == 0
        # Check if all the kwonlyargs are in the keywords: if yes, we've
        # bound all the keyword arguments
        no_kwonlyargs = all([arg in partial_fn.keywords for arg in spec.kwonlyargs])

        # Get the signature
        signature = inspect.signature(partial_fn)
        # Check if any parameters are unfilled args
        no_unfilled_args = all(
            [
                param.default is not param.empty
                for param in signature.parameters.values()
            ]
        )

        if not (no_args and no_kwonlyargs and no_unfilled_args):
            # Find the missing keyword arguments
            missing_args = [
                arg for arg in spec.kwonlyargs if arg not in partial_fn.keywords
            ] + [
                param.name
                for param in signature.parameters.values()
                if param.default == param.empty
            ]
            raise ValueError(
                f"Endpoint {self.id} still has arguments left to be \
                filled (args: {spec.args}, kwargs: {missing_args}). \
                    Ensure that all keyword arguments \
                    are passed in when calling `.run()` on this endpoint."
            )

        # Clear the modification queue before running the function
        # This is an invariant: there should be no pending modifications
        # when running an endpoint, so that only the modifications
        # that are made by the endpoint are applied
        state.modification_queue.clear()

        # Ready the ModificationQueue so that it can be used to track
        # modifications made by the endpoint
        state.modification_queue.ready()
        state.progress_queue.add(
            self.fn.func.__name__ if isinstance(self.fn, partial) else self.fn.__name__
        )

        try:
            # The function should not add any operations to the graph.
            with unmarked():
                result = partial_fn()
        except Exception as e:
            # Unready the modification queue
            state.modification_queue.unready()
            raise e

        with unmarked():
            modifications = trigger()

        # End the progress bar
        state.progress_queue.add(None)

        return result, modifications

    def partial(self, *args, **kwargs) -> Endpoint:
        # Any NodeMixin objects that are passed in as arguments
        # should have this Endpoint as a non-triggering child
        if not self.has_inode():
            node = self.create_inode()
            self.attach_to_inode(node)

        for arg in list(args) + list(kwargs.values()):
            if isinstance(arg, NodeMixin):
                if not arg.has_inode():
                    inode_id = None if not isinstance(arg, Store) else arg.id
                    node = arg.create_inode(inode_id=inode_id)
                    arg.attach_to_inode(node)

                arg.inode.add_child(self.inode, triggers=False)

        # TODO (sabri): make this work for derived dataframes
        # There's a subtle issue with partial that we should figure out. I spent an
        # hour or so on it, but am gonna table it til after the deadline tomorrow
        # because I have a hacky workaround. Basically, if we create an endpoint
        # partial passing a "derived" dataframe, when the endpoint is called, we
        # should expect that the current value of the dataframe will be passed.
        # Currently, the original value of the dataframe is passed. It makes sense to
        # me why this is happening, but the right fix is eluding me.

        # All NodeMixin objects need to be replaced by their node id.
        # This ensures that we can resolve the correct object at runtime
        # even if the object is a result of a reactive function
        # (i.e. not a root of the graph).
        def _get_node_id_or_arg(arg):
            if isinstance(arg, NodeMixin):
                assert arg.has_inode()
                return arg.inode.id
            return arg

        args = [_get_node_id_or_arg(arg) for arg in args]
        kwargs = {key: _get_node_id_or_arg(val) for key, val in kwargs.items()}

        fn = partial(self.fn, *args, **kwargs)
        fn.__name__ = self.fn.__name__
        return Endpoint(
            fn=fn,
            prefix=None,
            route=None,
        )

    def compose(self, fn: Union[Endpoint, Callable]) -> Endpoint:
        """Create a new Endpoint that applies `fn` to the return value of this
        Endpoint. Effectively equivalent to `fn(self.fn(*args, **kwargs))`.

        If the return value is None and `fn` doesn't take any inputs, then
        `fn` will be called with no arguments.

        Args:
            fn (Endpoint, callable): An Endpoint or a callable function that accepts
                a single argument of the same type as the return of this Endpoint
                (i.e. self).

        Return:
            Endpoint: The new composed Endpoint.
        """
        if not isinstance(fn, Endpoint):
            fn = Endpoint(fn=fn)

        # `fn` may not take any inputs.
        # FIXME: Should this logic be in ``compose``? or some other function?
        sig = get_signature(fn)
        pipe_return = len(sig.parameters) > 0

        @wraps(self.fn)
        def composed(*args, **kwargs):
            out = self.fn(*args, **kwargs)
            return fn.fn(out) if pipe_return else fn.fn()

        composed.__name__ = f"composed({str(self)} | {str(fn)})"

        return Endpoint(
            fn=composed,
            prefix=self.prefix,
            route=self.route,
        )

    def add_route(self, method: str = "POST") -> None:
        """Add a FastAPI route for this endpoint to the router. This function
        will not do anything if the router is None (i.e. no prefix was
        specified).

        This function is called automatically when the endpoint is
        created using the `endpoint` decorator.
        """
        if self.router is None:
            return

        if self.route is None:
            # The route will be postfixed with the fn name
            self.route = f"/{self.fn.__name__}/"

            # Analyze the function signature of `fn` to
            # construct a dictionary, mapping argument names
            # to their types and default values for creating a
            # Pydantic model.

            # During this we also
            # - make sure that args are either type-hinted or
            #   annotated with a default value (can't create
            #  a Pydantic model without a type hint or default)
            # - replace arguments that have type-hints which
            #   are subclasses of `IdentifiableMixin` with
            #   strings (i.e. the id of the Identifiable)
            #   (e.g. `Store` -> `str`)
            signature = inspect.signature(self.fn)

            pydantic_model_params = {}
            for p in signature.parameters:
                annot = signature.parameters[p].annotation
                default = signature.parameters[p].default
                has_default = default is not inspect._empty

                if annot is inspect.Parameter.empty:
                    if p == "kwargs":
                        # Allow arbitrary keyword arguments
                        pydantic_model_params[p] = (dict, ...)
                        continue

                    if not has_default:
                        raise ValueError(
                            f"Parameter {p} must have a type annotation or "
                            "a default value."
                        )
                elif isinstance(annot, type) and issubclass(annot, IdentifiableMixin):
                    # e.g. Stores must be referred to by str ids when
                    # passed into the API
                    pydantic_model_params[p] = (str, ...)
                else:
                    pydantic_model_params[p] = (
                        (annot, default) if has_default else (annot, ...)
                    )

            # Allow arbitrary types in the Pydantic model
            class Config:
                arbitrary_types_allowed = True

            # Create the Pydantic model, named `{fn_name}Model`
            global FnPydanticModel
            FnPydanticModel = create_model(
                f"{self.fn.__name__.capitalize()}{self.prefix.replace('/', '').capitalize()}Model",  # noqa: E501
                __config__=Config,
                **pydantic_model_params,
            )

            # Create a wrapper function, with kwargs that conform to the
            # Pydantic model, and a return annotation that matches `fn`
            def _fn(
                kwargs: FnPydanticModel = Endpoint.EmbeddedBody(),
            ):  # -> signature.return_annotation:
                return self.fn(**kwargs.dict())

            # from inspect import Parameter, Signature
            # params = []
            # for p, (annot, default) in pydantic_model_params.items():
            #     params.append(
            #         Parameter(
            #             p,
            #             kind=Parameter.POSITIONAL_OR_KEYWORD,
            #             annotation=annot,
            #             default=default,
            #         )
            #     )
            # _fn.__signature__ = Signature(params)

            # Name the wrapper function the same as `fn`, so it looks nice
            # in the docs
            _fn.__name__ = self.fn.__name__
        else:
            # If the user specifies a route manually, then they're responsible for
            # everything, including type-hints and default values.
            signature = inspect.signature(self.fn)
            for p in signature.parameters:
                annot = signature.parameters[p].annotation

                # If annot is a subclass of `IdentifiableMixin`, replace
                # it with the `str` type (i.e. the id of the Identifiable)
                # (e.g. `Store` -> `str`)
                if isinstance(annot, type) and issubclass(annot, IdentifiableMixin):
                    self.fn.__annotations__[p] = str

            _fn = self.fn

        # Make FastAPI endpoint for POST requests
        self.router.add_api_route(
            self.route + "/" if not self.route.endswith("/") else self.route,
            _fn,
            methods=[method],
        )

        # Must add the router to the app again, everytime a new route is added
        # otherwise, the new route does not show up in the docs
        from meerkat.interactive.api.main import app

        app.include_router(self.router)

    def __call__(self, *args, __fn_only=False, **kwargs):
        """Calling the endpoint will just call .run(...) by default.

        If `__fn_only=True` is specified, it will call the raw function
        underlying this endpoint.
        """
        if __fn_only:
            # FIXME(Sabri): This isn't working for some reason. The '__fn_only' arg
            # is for some reason being put in the kwargs dict. Workaround is to just
            # use self.fn directly.
            return self.fn(*args, **kwargs)
        return self.run(*args, **kwargs)

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        if not isinstance(v, cls):
            return make_endpoint(v)
        return v


class EndpointProperty(Endpoint, Generic[T]):
    pass


def make_endpoint(endpoint_or_fn: Union[Callable, Endpoint, None]) -> Endpoint:
    """Make an Endpoint."""
    return (
        endpoint_or_fn
        if isinstance(endpoint_or_fn, Endpoint)
        else Endpoint(endpoint_or_fn)
    )


[docs]def endpoint( fn: Callable = None, prefix: Union[str, APIRouter] = None, route: str = None, method: str = "POST", ) -> Endpoint: """Decorator to mark a function as an endpoint. An endpoint is a function that can be called to - update the value of a Store (e.g. incrementing a counter) - update a DataFrame (e.g. adding a new row) - run a computation and return its result to the frontend - run a function in response to a frontend event (e.g. button click) Endpoints differ from reactive functions in that they are not automatically triggered by changes in their inputs. Instead, they are triggered by explicit calls to the endpoint function. The Store and DataFrame objects that are modified inside the endpoint function will automatically trigger reactive functions that depend on them. .. code-block:: python @endpoint def increment(count: Store, step: int = 1): count.set(count + step) # ^ update the count Store, which will trigger operations # that depend on it # Create a button that calls the increment endpoint counter = Store(0) button = Button(on_click=increment(counter)) # ^ read this as: call the increment endpoint with the `counter` # Store when the button is clicked Args: fn: The function to decorate. prefix: The prefix to add to the route. If a string, it will be prepended to the route. If an APIRouter, the route will be added to the router. route: The route to add to the endpoint. If not specified, the route will be the name of the function. method: The HTTP method to use for the endpoint. Defaults to "POST". Returns: The decorated function, as an Endpoint object. """ if fn is None: return partial(endpoint, prefix=prefix, route=route, method=method) @wraps(fn) def _endpoint(fn: Callable): # Gather up # 1. all the arguments that are hinted as Stores # 2. the hinted arguments that subclass IdentifiableMixin # e.g. Store, Endpoint, Page, etc. stores = set() identifiables = {} for name, annot in inspect.getfullargspec(fn).annotations.items(): is_annotation_store = _is_annotation_store(annot) if is_annotation_store: stores.add(name) # TODO: See if we can remove this in the future. if is_annotation_store or ( isinstance(annot, type) and issubclass(annot, IdentifiableMixin) ): # This will also include `Store`, so it will be a superset # of `stores` identifiables[name] = annot @wraps(fn) def wrapper(*args, **kwargs): # Keep the arguments that were not annotated to be `Stores` fn_signature = inspect.signature(fn) fn_bound_arguments = fn_signature.bind(*args, **kwargs).arguments # `Identifiables` that are passed into the function # may be passed in as a string id, or as the object itself # If they are passed in as a string id, we need to get the object # from the registry _args, _kwargs = [], {} for k, v in fn_bound_arguments.items(): if k in identifiables: # Dereference the argument if it was passed in as a string id if not isinstance(v, str): # Not a string id, so just use the object _kwargs[k] = v else: if isinstance(v, IdentifiableMixin): # v is a string, but it is also an IdentifiableMixin # e.g. Store("foo"), so just use v as is _kwargs[k] = v else: # v is a string id try: # Directly try to look up the string id in the # registry of the annotated type _kwargs[k] = identifiables[k].from_id(v) except Exception: # If that fails, try to look up the string id in # the Node registry, and then get the object # from the Node try: _kwargs[k] = Node.from_id(v).obj except Exception as e: # If that fails and the object is a non-id string, # then just use the string as is. # We have to do this check here rather than above # because we want to make sure we check for all # identifiable and nodes before checking if the # string is just a string. # this is required for compatibility with # IdentifiableMixin objects that do not start with # the meerkat id prefix. if isinstance(v, str) and not is_meerkat_id(v): _kwargs[k] = v else: raise e else: if k == "args": # These are *args under the `args` key # These are the only arguments that will be passed in as # *args to the fn v = [_resolve_id_to_obj(_value) for _value in v] _args, _ = _unpack_stores_from_object(v) elif k == "kwargs": # These are **kwargs under the `kwargs` key v = {_k: _resolve_id_to_obj(_value) for _k, _value in v.items()} v, _ = _unpack_stores_from_object(v) _kwargs = {**_kwargs, **v} else: # All other positional arguments that were not *args were # bound, so they become kwargs v, _ = _unpack_stores_from_object(_resolve_id_to_obj(v)) _kwargs[k] = v try: with unmarked(): # Run the function result = fn(*_args, **_kwargs) except Exception as e: # If the function raises an exception, log it and return # the exception # In case the exception is about .set() being missing, add # a more helpful error message if "no attribute 'set'" in str(e): # Get the name of the object that was passed in # as a Store, but did not have a .set() method obj_name = str(e).split("'")[1].strip("'") # Update the error message to be more helpful e = AttributeError( f"Exception raised in endpoint `{fn.__name__}`. " f"The object of type `{obj_name}` that you called to " "update with `.set()` " "is not a `Store`. You probably forgot to " "annotate this object's typehint in the signature of " f"`{fn.__name__}` as a `Store` i.e. \n\n" "@endpoint\n" f"def {fn.__name__}(..., parameter: Store, ...):\n\n" "Remember that without this type annotation, the object " "will be automatically unpacked by Meerkat inside the endpoint " "if it is a `Store`." ) logger.exception(e) raise e # Return the result of the function return result # Register the endpoint and return it endpoint = Endpoint( fn=wrapper, prefix=prefix, route=route, ) endpoint.add_route(method) return endpoint return _endpoint(fn)
def endpoints(cls: type = None, prefix: str = None): """Decorator to mark a class as containing a collection of endpoints. All instance methods in the marked class will be converted to endpoints. This decorator is useful when you want to create a class that contains some logical state variables (e.g. a Counter class), along with methods to manipulate the values of those variables (e.g. increment or decrement the counter). """ if cls is None: return partial(endpoints, prefix=prefix) _ids = {} _max_ids = {} if cls not in _ids: _ids[cls] = {} _max_ids[cls] = 1 def _endpoints(cls): class EndpointClass: def __init__(self, *args, **kwargs): self.instance = cls(*args, **kwargs) self.endpoints = {} # Access all the user-defined attributes of the instance # to create endpoints for attrib in dir(self.instance): if attrib.startswith("__"): continue obj = self.instance.__getattribute__(attrib) if callable(obj): if attrib not in self.endpoints: self.endpoints[attrib] = endpoint( obj, prefix=prefix + f"/{_ids[cls][self]}" ) def __getattribute__(self, attrib): if self not in _ids[cls]: _ids[cls][self] = _max_ids[cls] _max_ids[cls] += 1 try: obj = super().__getattribute__(attrib) return obj except AttributeError: pass obj = self.instance.__getattribute__(attrib) if callable(obj): if attrib not in self.endpoints: return obj return self.endpoints[attrib] else: return obj return EndpointClass return _endpoints(cls) def get_signature(fn: Union[Callable, Endpoint]) -> inspect.Signature: """Get the signature of a function or endpoint. Args: fn: The function or endpoint to get the signature of. Returns: The signature of the function or endpoint. """ if isinstance(fn, Endpoint): fn = fn.fn return inspect.signature(fn) def _resolve_id_to_obj(value): if isinstance(value, str) and is_meerkat_id(value): # This is a string that corresponds to a meerkat id, # so look it up. return Node.from_id(value).obj return value def _is_annotation_store(type_hint) -> bool: """Check if a type hint is a Store or a Union of Stores. Returns True if: - The type hint is a Store - The type hint is a Union of Store and other non-Store values. - The type hint is a generic store Store[T] or Union[Store[T], ...] """ if isinstance(type_hint, type) and issubclass(type_hint, Store): return True if isinstance(type_hint, typing._GenericAlias): origin = get_type_hint_origin(type_hint) args = get_type_hint_args(type_hint) if origin == typing.Union: return any(_is_annotation_store(arg) for arg in args) elif issubclass(origin, Store): return True return False