Source code for meerkat.ops.complete

import os
import re

import meerkat as mk
import meerkat.tools.docs as docs
from meerkat.ops.map import _SHARED_DOCS_


[docs]@docs.doc(source=_SHARED_DOCS_) def complete( df: mk.DataFrame, prompt: str, engine: str, batch_size: int = 1, use_ray: bool = False, num_blocks: int = 100, blocks_per_window: int = 10, pbar: bool = False, client_connection: str = None, cache_connection: str = "~/.manifest/cache.sqlite", ) -> mk.ScalarColumn: """Apply a generative language model to each row in a DataFrame. Args: df (DataFrame): The :class:`DataFrame` to which the function will be applied. prompt (str): engine (str): ${batch_size} ${materialize} ${use_ray} ${num_blocks} ${blocks_per_window} ${pbar} client_connection: The connection string for the client. This is typically the key (e.g. OPENAI). If it is not provided, it will be inferred from the engine. cache_connection: The sqlite connection string for the cache. Returns: Union[Column]: A :class:`DeferredColumn` or a :class:`DataFrame` containing :class:`DeferredColumn` representing the deferred map. """ from manifest import Manifest input_engine = engine client_name, engine = engine.split("/") if client_connection is None: if client_name == "openai": client_connection = os.environ["OPENAI_API_KEY"] else: raise ValueError( f"Cannot infer client connection from engine {input_engine}." ) cache_connection = os.path.abspath(os.path.expanduser(cache_connection)) os.makedirs(os.path.dirname(cache_connection), exist_ok=True) manifest = Manifest( client_name=client_name, client_connection=client_connection, engine=engine, temperature=0, max_tokens=1, cache_name="sqlite", cache_connection=cache_connection, ) def _run_manifest(rows: mk.DataFrame): out = manifest.run([prompt.format(**row) for row in rows.iterrows()]) return out keys = re.findall(r"{(.*?)}", prompt) output = mk.map( df[keys], function=_run_manifest, inputs="row", is_batched_fn=True, batch_size=batch_size, pbar=pbar, use_ray=use_ray, num_blocks=num_blocks, blocks_per_window=blocks_per_window, ) return output