Source code for meerkat.ops.sample

from typing import Union

import numpy as np

from meerkat import Column, DataFrame
from meerkat.interactive.graph.reactivity import reactive


[docs]@reactive def sample( data: Union[DataFrame, Column], n: int = None, frac: float = None, replace: bool = False, weights: Union[str, np.ndarray] = None, random_state: Union[int, np.random.RandomState] = None, ) -> Union[DataFrame, Column]: """Select a random sample of rows from DataFrame or Column. Roughly equivalent to ``sample`` in Pandas https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.sample.html. Args: data (Union[DataFrame, AbstractColumn]): DataFrame or Column to sample from. n (int): Number of samples to draw. If `frac` is specified, this parameter should not be passed. Defaults to 1 if `frac` is not passed. frac (float): Fraction of rows to sample. If `n` is specified, this parameter should not be passed. replace (bool): Sample with or without replacement. Defaults to False. weights (Union[str, np.ndarray]): Weights to use for sampling. If `None` (default), the rows will be sampled uniformly. If a numpy array, the sample will be weighted accordingly. If a string and `data` is a DataFrame, the sampled_df will be applied to the rows based on the column with the name specified. If weights do not sum to 1 they will be normalized to sum to 1. random_state (Union[int, np.random.RandomState]): Random state or seed to use for sampling. Return: Union[DataFrame, AbstractColumn]: A random sample of rows from DataFrame or Column. """ import pandas.core.common as com from pandas.core.sample import process_sampling_size from pandas.core.sample import sample as _sample if isinstance(weights, str): if isinstance(data, Column): raise ValueError( "Weights passed to `sample` must be a numpy array if data is a Column." ) weights = data[weights].to_numpy() rs = com.random_state(random_state) n = process_sampling_size(n=n, frac=frac, replace=replace) if frac is not None: n = round(frac * len(data)) sampled_indices = _sample( obj_len=len(data), size=n, replace=replace, weights=weights, random_state=rs, ) return data[sampled_indices]