"""DataFrame class."""
from __future__ import annotations
import logging
import os
import pathlib
import shutil
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
import cytoolz as tz
import dill
import numpy as np
import pandas as pd
import pyarrow as pa
from pandas._libs import lib
import meerkat
from meerkat.block.manager import BlockManager
from meerkat.columns.abstract import Column
from meerkat.columns.scalar.abstract import ScalarColumn
from meerkat.columns.scalar.arrow import ArrowScalarColumn
from meerkat.errors import ConversionError
from meerkat.interactive.graph.marking import is_unmarked_context, unmarked
from meerkat.interactive.graph.reactivity import reactive
from meerkat.interactive.graph.store import Store
from meerkat.interactive.modification import DataFrameModification
from meerkat.interactive.node import NodeMixin
from meerkat.mixins.cloneable import CloneableMixin
from meerkat.mixins.deferable import DeferrableMixin
from meerkat.mixins.identifiable import IdentifiableMixin
from meerkat.mixins.indexing import IndexerMixin, MaterializationMixin
from meerkat.mixins.inspect_fn import FunctionInspectorMixin
from meerkat.mixins.reactifiable import ReactifiableMixin
from meerkat.provenance import ProvenanceMixin
from meerkat.row import Row
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.tools.utils import convert_to_batch_fn, dump_yaml, load_yaml
torch = LazyLoader("torch")
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
if TYPE_CHECKING:
from meerkat.interactive.formatter.base import FormatterGroup
logger = logging.getLogger(__name__)
Example = Dict
Batch = Dict[str, Union[List, Column]]
BatchOrDataset = Union[Batch, "DataFrame"]
[docs]class DataFrame(
CloneableMixin,
FunctionInspectorMixin,
IdentifiableMixin,
NodeMixin,
DeferrableMixin,
MaterializationMixin,
IndexerMixin,
ProvenanceMixin,
ReactifiableMixin,
):
"""A collection of equal length columns.
Args:
data (Union[dict, list]): A dictionary of columns or a list of dictionaries.
primary_key (str, optional): The name of the primary key column.
Defaults to ``None``.
"""
_self_identifiable_group: str = "dataframes"
# Path to a log directory
logdir: pathlib.Path = pathlib.Path.home() / "meerkat/"
# Create a directory
logdir.mkdir(parents=True, exist_ok=True)
def __init__(
self,
data: Union[dict, list] = None,
primary_key: Union[str, bool] = True,
*args,
**kwargs,
):
super(DataFrame, self).__init__(
*args,
**kwargs,
)
self._primary_key = None if primary_key is False else primary_key
self.data = data
if primary_key is True:
self._infer_primary_key(create=True)
def _react(self):
"""Converts the object to a reactive object in-place."""
self._self_marked = True
for col in self.columns:
self[col].mark()
return self
def _no_react(self):
"""Converts the object to a non-reactive object in-place."""
self._self_marked = False
for col in self.columns:
self[col].unmark()
return self
@property
def gui(self):
from meerkat.interactive.gui import DataFrameGUI
return DataFrameGUI(self)
@unmarked()
def _repr_pandas_(self, max_rows: int = None):
if max_rows is None:
max_rows = meerkat.config.display.max_rows
df, formatters = self.data._repr_pandas_(max_rows=max_rows)
rename = {k: f"{k}" for k, v in self.items()}
return (
df[self.columns].rename(columns=rename),
{rename[k]: v for k, v in formatters.items()},
)
@unmarked()
def _ipython_display_(self):
from IPython.display import HTML, display
from meerkat.state import state
if state.frontend_info is None:
max_rows = meerkat.config.display.max_rows
df, formatters = self._repr_pandas_(max_rows=max_rows)
return display(
HTML(df.to_html(formatters=formatters, max_rows=max_rows, escape=False))
)
return self.gui.table()._ipython_display_()
@unmarked()
def __repr__(self):
return (
f"{self.__class__.__name__}" f"(nrows: {self.nrows}, ncols: {self.ncols})"
)
def __len__(self):
# __len__ is required to return an int. It cannot return a Store[int].
# As such, it cannot be wrapped in `@reactive` decorator.
# To get the length of a DataFrame in a reactive way, use `df.nrows`.
self._reactive_warning("len", placeholder="df")
with unmarked():
_len = self.nrows
return _len
def __contains__(self, item):
# __contains__ is called when using the `in` keyword.
# `in` casts the output to a boolean (i.e. `bool(self.__contains__(item))`).
# Store.__bool__ is not reactive because __bool__ has to return
# a bool (not a Store). Thus, we cannot wrap __contains__ in
# `@reactive` decorator.
if not is_unmarked_context() and self.marked:
warnings.warn(
"The `in` operator is not reactive. Use `df.contains(...)`:\n"
"\t>>> df.contains(item)"
)
return self.columns.__contains__(item)
@reactive()
def contains(self, item):
return self.__contains__(item)
@property
def data(self) -> BlockManager:
"""Get the underlying data (excluding invisible rows).
To access underlying data with invisible rows, use `_data`.
"""
return self._data
def _set_data(self, value: Union[BlockManager, Mapping] = None):
if isinstance(value, BlockManager):
self._data = value
elif isinstance(value, Mapping):
self._data = BlockManager.from_dict(value)
elif isinstance(value, Sequence):
if not isinstance(value[0], Mapping):
raise ValueError(
"Cannot set DataFrame `data` to a Sequence containing object of "
f" type {type(value[0])}. Must be a Sequence of Mapping."
)
gen = (list(x.keys()) for x in value)
columns = lib.fast_unique_multiple_list_gen(gen)
data = lib.dicts_to_array(value, columns=columns)
# Assert all columns are the same length
data = {column: list(data[:, idx]) for idx, column in enumerate(columns)}
self._data = BlockManager.from_dict(data)
elif value is None:
self._data = BlockManager()
else:
raise ValueError(
f"Cannot set DataFrame `data` to object of type {type(value)}."
)
# check that the primary key is still valid after data reset
if self._primary_key is not None:
if (
self._primary_key not in self
or not self.primary_key._is_valid_primary_key()
):
self.set_primary_key(None)
@data.setter
def data(self, value):
self._set_data(value)
@property
def columns(self):
"""Column names in the DataFrame."""
return list(self.data.keys())
@property
def primary_key(self) -> ScalarColumn:
"""The column acting as the primary key."""
if self._primary_key is None:
return None
return self[self._primary_key]
@property
def primary_key_name(self) -> str:
"""The name of the column acting as the primary key."""
return self._primary_key
[docs] def set_primary_key(self, column: str, inplace: bool = False) -> DataFrame:
"""Set the DataFrame's primary key using an existing column. This is an
out-of-place operation. For more information on primary keys, see the
User Guide.
Args:
column (str): The name of an existing column to set as the primary key.
"""
if column is not None:
if column not in self.columns:
raise ValueError(
"Must pass the name of an existing column to `set_primary_key`."
f'"{column}" not in {self.columns}'
)
if not self[column]._is_valid_primary_key():
raise ValueError(
f'Column "{column}" cannot be used as a primary key. Ensure that '
"columns of this type can be used as primary keys and \
that the values in the column are unique."
)
if inplace:
self._primary_key = column
return self
else:
return self._clone(_primary_key=column)
[docs] def create_primary_key(self, column: str):
"""Create a primary key of contiguous integers.
Args:
column (str): The name of the column to create.
"""
self[column] = np.arange(self.nrows)
self.set_primary_key(column, inplace=True)
def _infer_primary_key(self, create: bool = False) -> str:
"""Infer the primary key from the data."""
for column in self.columns:
if self[column]._is_valid_primary_key():
self.set_primary_key(column, inplace=True)
return column
if create:
self.create_primary_key("pkey")
return "pkey"
return None
@property
def nrows(self):
"""Number of rows in the DataFrame."""
if self.ncols == 0:
return 0
return self._data.nrows
@property
def ncols(self):
"""Number of rows in the DataFrame."""
return self.data.ncols
@property
def shape(self):
"""Shape of the DataFrame (num_rows, num_columns)."""
return self.nrows, self.ncols
[docs] @reactive(nested_return=False)
def size(self):
"""Shape of the DataFrame (num_rows, num_columns)."""
return self.shape
[docs] def add_column(self, name: str, data: Column.Columnable, overwrite=False) -> None:
"""Add a column to the DataFrame."""
assert isinstance(
name, str
), f"Column name must of type `str`, not `{type(name)}`."
assert (name not in self.columns) or overwrite, (
f"Column with name `{name}` already exists, "
f"set `overwrite=True` to overwrite."
)
if not is_listlike(data):
data = [data] * self.nrows
if name in self.columns:
self.remove_column(name)
column = Column.from_data(data)
assert len(column) == len(self) or len(self.columns) == 0, (
f"`add_column` failed. "
f"Values length {len(column)} != dataset length {len(self)}."
)
# Add the column
self.data[name] = column
logger.info(f"Added column `{name}` with length `{len(column)}`.")
[docs] def remove_column(self, column: str) -> None:
"""Remove a column from the dataset."""
assert column in self.columns, f"Column `{column}` does not exist."
# Remove the column
del self.data[column]
if column == self._primary_key:
# need to reset the primary key if we remove the column
self.set_primary_key(None, inplace=True)
logger.info(f"Removed column `{column}`.")
# @capture_provenance(capture_args=["axis"])
[docs] @reactive()
def append(
self,
df: DataFrame,
axis: Union[str, int] = "rows",
suffixes: Tuple[str] = None,
overwrite: bool = False,
) -> DataFrame:
"""Append a batch of data to the dataset.
`example_or_batch` must have the same columns as the dataset
(regardless of what columns are visible).
"""
return meerkat.concat(
[self, df], axis=axis, suffixes=suffixes, overwrite=overwrite
)
[docs] @reactive()
def head(self, n: int = 5) -> DataFrame:
"""Get the first `n` examples of the DataFrame."""
return self[:n]
[docs] @reactive()
def tail(self, n: int = 5) -> DataFrame:
"""Get the last `n` examples of the DataFrame."""
return self[-n:]
def _get_loc(self, keyidx, materialize: bool = False):
if self.primary_key is None:
raise ValueError(
"Cannot use `loc` without a primary key. Set a primary key using "
"`set_primary_key`."
)
if isinstance(
keyidx, (np.ndarray, list, tuple, pd.Series, torch.Tensor, Column)
):
posidxs = self.primary_key._keyidxs_to_posidxs(keyidx)
return self._clone(
data=self.data.apply("_get", index=posidxs, materialize=materialize)
)
else:
posidx = self.primary_key._keyidx_to_posidx(keyidx)
row = self.data.apply("_get", index=posidx, materialize=materialize)
return {k: row[k] for k in self.columns}
def _set_loc(self, keyidx: Union[str, int], column: str, value: any):
if self.primary_key is None:
raise ValueError(
"Cannot use `loc` without a primary key. Set a primary key using "
"`set_primary_key`."
)
if isinstance(
keyidx, (np.ndarray, list, tuple, pd.Series, torch.Tensor, Column)
):
raise NotImplementedError("")
else:
posidx = self.primary_key._keyidx_to_posidx(keyidx)
self[column][posidx] = value
if self.has_inode():
mod = DataFrameModification(id=self.inode.id, scope=self.columns)
mod.add_to_queue()
def _get(self, posidx, materialize: bool = False):
if isinstance(posidx, str):
# str index => column selection (AbstractColumn)
if posidx in self.columns:
return self.data[posidx]
raise KeyError(f"Column `{posidx}` does not exist.")
elif isinstance(posidx, int):
# int index => single row (dict)
row = self.data.apply("_get", index=posidx, materialize=materialize)
return Row({k: row[k] for k in self.columns})
# cases where `index` returns a dataframe
index_type = None
if isinstance(posidx, slice):
# slice index => multiple row selection (DataFrame)
index_type = "row"
elif (isinstance(posidx, tuple) or isinstance(posidx, list)) and len(posidx):
# tuple or list index => multiple row selection (DataFrame)
if isinstance(posidx[0], str):
index_type = "column"
else:
index_type = "row"
elif isinstance(posidx, np.ndarray):
if len(posidx.shape) != 1:
raise ValueError(
"Index must have 1 axis, not {}".format(len(posidx.shape))
)
# numpy array index => multiple row selection (DataFrame)
index_type = "row"
elif torch.is_tensor(posidx):
if len(posidx.shape) != 1:
raise ValueError(
"Index must have 1 axis, not {}".format(len(posidx.shape))
)
# torch tensor index => multiple row selection (DataFrame)
index_type = "row"
elif isinstance(posidx, pd.Series):
index_type = "row"
elif isinstance(posidx, Column):
# column index => multiple row selection (DataFrame)
index_type = "row"
else:
raise TypeError("Invalid index type: {}".format(type(posidx)))
if index_type == "column":
if not set(posidx).issubset(self.columns):
missing_cols = set(posidx) - set(self.columns)
raise KeyError(f"DataFrame does not have columns {missing_cols}")
df = self._clone(data=self.data[posidx])
return df
elif index_type == "row": # pragma: no cover
return self._clone(
data=self.data.apply("_get", index=posidx, materialize=materialize)
)
# @capture_provenance(capture_args=[])
@reactive()
def __getitem__(self, posidx):
return self._get(posidx, materialize=False)
def __setitem__(self, posidx, value):
self.add_column(name=posidx, data=value, overwrite=True)
# This condition will issue modifications even if we're outside an endpoint
# but those modifications will be cleared by the endpoint before it is
# run.
if self.has_inode():
columns = self.columns
if isinstance(columns, Store):
# df modifications expects a list, not a Store[list]
columns = columns.value
# Add a modification if it's on the graph
mod = DataFrameModification(id=self.id, scope=columns)
mod.add_to_queue()
def __call__(self):
return self._get(slice(0, len(self), 1), materialize=True)
[docs] def set(self, value: DataFrame):
"""Set the data of this DataFrame to the data of another DataFrame.
This is used inside endpoints to tell Meerkat when a DataFrame
has been modified. Calling this method outside of an endpoint
will not have any effect on the graph.
"""
self._set_data(value._data)
self._set_state(value._get_state())
if self.has_inode():
columns = self.columns
if isinstance(columns, Store):
columns = columns.value
# Add a modification if it's on the graph
# TODO: think about what the scope should be.
# How does `scope` relate to the the skip_fn mechanism
# in Operation?
mod = DataFrameModification(id=self.inode.id, scope=columns)
mod.add_to_queue()
def consolidate(self):
self.data.consolidate()
[docs] @classmethod
# @capture_provenance()
def from_batch(
cls,
batch: Batch,
) -> DataFrame:
"""Convert a batch to a Dataset."""
return cls(batch)
[docs] @classmethod
# @capture_provenance()
def from_batches(
cls,
batches: Sequence[Batch],
) -> DataFrame:
"""Convert a list of batches to a dataset."""
return cls.from_batch(
tz.merge_with(
tz.compose(list, tz.concat),
*batches,
),
)
[docs] @classmethod
# @capture_provenance()
def from_pandas(
cls,
df: pd.DataFrame,
index: bool = True,
primary_key: Optional[str] = None,
) -> DataFrame:
"""Create a Meerkat DataFrame from a Pandas DataFrame.
.. warning::
In Meerkat, column names must be strings, so non-string column
names in the Pandas DataFrame will be converted.
Args:
df: The Pandas DataFrame to convert.
index: Whether to include the index of the Pandas DataFrame as a
column in the Meerkat DataFrame.
primary_key: The name of the column to use as the primary key.
If `index` is True and `primary_key` is None, the index will
be used as the primary key. If `index` is False, then no
primary key will be set. Optional default is None.
Returns:
DataFrame: The Meerkat DataFrame.
"""
# column names must be str in meerkat
df = df.rename(mapper=str, axis="columns")
pandas_df = df
df = cls.from_batch(
pandas_df.to_dict("series"),
)
if index:
df["index"] = pandas_df.index.to_series()
if primary_key is None:
df.set_primary_key("index", inplace=True)
if primary_key is not None:
df.set_primary_key(primary_key, inplace=True)
return df
[docs] @classmethod
# @capture_provenance()
def from_arrow(
cls,
table: pa.Table,
):
"""Create a Dataset from a pandas DataFrame."""
from meerkat.block.arrow_block import ArrowBlock
from meerkat.columns.scalar.arrow import ArrowScalarColumn
block_views = ArrowBlock.from_block_data(table)
return cls.from_batch(
{view.block_index: ArrowScalarColumn(view) for view in block_views}
)
[docs] @classmethod
def from_huggingface(cls, *args, **kwargs):
"""Load a Huggingface dataset as a DataFrame.
Use this to replace `datasets.load_dataset`, so
>>> dict_of_datasets = datasets.load_dataset('boolq')
becomes
>>> dict_of_dataframes = DataFrame.from_huggingface('boolq')
"""
import datasets
datasets.logging.set_verbosity_error()
import pyarrow.compute as pc
# Load the dataset
dataset = datasets.load_dataset(*args, **kwargs)
def _convert_columns(dataset: datasets.Dataset):
df = cls.from_arrow(dataset._data.table)
for name, feature in dataset.features.items():
if isinstance(feature, datasets.Audio):
column = df[name]
bytes = ArrowScalarColumn(pc.struct_field(column._data, "bytes"))
path = ArrowScalarColumn(pc.struct_field(column._data, "path"))
if (~bytes.isnull()).all():
from meerkat.interactive.formatter import AudioFormatterGroup
df[name] = (
df[name]
.defer(lambda x: x["bytes"])
.format(AudioFormatterGroup())
)
elif (~path.isnull()).all():
from meerkat.columns.deferred.file import FileColumn
df[name] = FileColumn(path, type="audio")
else:
raise ValueError(
"Huggingface column must either provide bytes or path for "
"every row."
)
elif isinstance(feature, datasets.Image):
column = df[name]
bytes = ArrowScalarColumn(pc.struct_field(column._data, "bytes"))
path = ArrowScalarColumn(pc.struct_field(column._data, "path"))
if (~ArrowScalarColumn(bytes).isnull()).all():
import io
from PIL import Image
from meerkat.interactive.formatter import ImageFormatterGroup
df[name] = bytes.defer(
lambda x: Image.open(io.BytesIO(x))
).format(ImageFormatterGroup().defer())
elif (~path.isnull()).all():
from meerkat.columns.deferred.file import FileColumn
df[name] = FileColumn(path, type="image")
else:
raise ValueError(
"Huggingface column must either provide bytes or path for "
"every row."
)
return df
if isinstance(dataset, dict):
return dict(
map(
lambda t: (
t[0],
_convert_columns(t[1]),
),
dataset.items(),
)
)
else:
return _convert_columns(dataset)
[docs] @classmethod
# @capture_provenance(capture_args=["filepath"])
def from_csv(
cls,
filepath: str,
primary_key: str = None,
backend: str = "pandas",
*args,
**kwargs,
) -> DataFrame:
"""Create a DataFrame from a csv file. All of the columns will be
:class:`meerkat.ScalarColumn` with backend Pandas.
Args:
filepath (str): The file path or buffer to load from.
Same as :func:`pandas.read_csv`.
*args: Argument list for :func:`pandas.read_csv`.
**kwargs: Keyword arguments forwarded to :func:`pandas.read_csv`.
Returns:
DataFrame: The constructed dataframe.
"""
if backend == "pandas":
df = cls.from_pandas(pd.read_csv(filepath, *args, **kwargs), index=False)
elif backend == "arrow":
df = cls.from_arrow(pa.csv.read_csv(filepath, *args, **kwargs))
if primary_key is not None:
df.set_primary_key(primary_key, inplace=True)
return df
[docs] @classmethod
# @capture_provenance()
def from_feather(
cls,
filepath: str,
primary_key: str = None,
columns: Optional[Sequence[str]] = None,
use_threads: bool = True,
**kwargs,
) -> DataFrame:
"""Create a DataFrame from a feather file. All of the columns will be
:class:`meerkat.ScalarColumn` with backend Pandas.
Args:
filepath (str): The file path or buffer to load from.
Same as :func:`pandas.read_feather`.
columns (Optional[Sequence[str]]): The columns to load.
Same as :func:`pandas.read_feather`.
use_threads (bool): Whether to use threads to read the file.
Same as :func:`pandas.read_feather`.
**kwargs: Keyword arguments forwarded to :func:`pandas.read_feather`.
Returns:
DataFrame: The constructed dataframe.
"""
df = cls.from_pandas(
pd.read_feather(
filepath, columns=columns, use_threads=use_threads, **kwargs
),
index=False,
)
if primary_key is not None:
df.set_primary_key(primary_key, inplace=True)
return df
[docs] @classmethod
# @capture_provenance()
def from_parquet(
cls,
filepath: str,
primary_key: str = None,
engine: str = "auto",
columns: Optional[Sequence[str]] = None,
**kwargs,
) -> DataFrame:
"""Create a DataFrame from a parquet file. All of the columns will be
:class:`meerkat.ScalarColumn` with backend Pandas.
Args:
filepath (str): The file path or buffer to load from.
Same as :func:`pandas.read_parquet`.
engine (str): The parquet engine to use.
Same as :func:`pandas.read_parquet`.
columns (Optional[Sequence[str]]): The columns to load.
Same as :func:`pandas.read_parquet`.
**kwargs: Keyword arguments forwarded to :func:`pandas.read_parquet`.
Returns:
DataFrame: The constructed dataframe.
"""
df = cls.from_pandas(
pd.read_parquet(filepath, engine=engine, columns=columns, **kwargs),
index=False,
)
if primary_key is not None:
df.set_primary_key(primary_key, inplace=True)
return df
[docs] @classmethod
def from_json(
cls,
filepath: str,
primary_key: str = None,
orient: str = "records",
lines: bool = False,
backend: str = "pandas",
**kwargs,
) -> DataFrame:
"""Load a DataFrame from a json file.
By default, data in the JSON file should be a list of dictionaries, each with
an entry for each column. This is the ``orient="records"`` format. If the data
is in a different format in the JSON, you can specify the ``orient`` parameter.
See :func:`pandas.read_json` for more details.
Args:
filepath (str): The file path or buffer to load from.
Same as :func:`pandas.read_json`.
orient (str): The expected JSON string format. Options are:
"split", "records", "index", "columns", "values".
Same as :func:`pandas.read_json`.
lines (bool): Whether the json file is a jsonl file.
Same as :func:`pandas.read_json`.
backend (str): The backend to use for the loading and reuslting columns.
**kwargs: Keyword arguments forwarded to :func:`pandas.read_json`.
Returns:
DataFrame: The constructed dataframe.
"""
from pyarrow import json
if backend == "arrow":
if lines is False:
raise ValueError("Arrow backend only supports lines=True.")
df = cls.from_arrow(json.read_json(filepath, **kwargs))
elif backend == "pandas":
df = cls.from_pandas(
pd.read_json(filepath, orient=orient, lines=lines, **kwargs),
index=False,
)
if primary_key is not None:
df.set_primary_key(primary_key, inplace=True)
return df
def _to_base(self, base: str = "pandas", **kwargs) -> Dict[str, Any]:
"""Helper method to convert a Meerkat DataFrame to a base format."""
data = {}
for name, column in self.items():
try:
converted_column = getattr(column, f"to_{base}")(**kwargs)
except ConversionError:
warnings.warn(
f"Could not convert column {name} of type {type(column)}, it will "
"be dropped from the output."
)
continue
data[name] = converted_column
return data
[docs] def to_pandas(
self, index: bool = False, allow_objects: bool = False
) -> pd.DataFrame:
"""Convert a Meerkat DataFrame to a Pandas DataFrame.
Args:
index (bool): Use the primary key as the index of the Pandas
DataFrame. Defaults to False.
Returns:
pd.DataFrame: The constructed dataframe.
"""
df = pd.DataFrame(self._to_base(base="pandas", allow_objects=allow_objects))
if index:
df.set_index(self.primary_key, inplace=True)
return df
[docs] def to_arrow(self) -> pd.DataFrame:
"""Convert a Meerkat DataFrame to an Arrow Table.
Returns:
pa.Table: The constructed table.
"""
return pa.table(self._to_base(base="arrow"))
def _choose_engine(self):
"""Choose which library to use for export.
If any of the columns are ArrowScalarColumn, then use Arrow.
"""
for column in self.values():
if isinstance(column, ArrowScalarColumn):
return "arrow"
return "pandas"
[docs] def to_csv(self, filepath: str, engine: str = "auto"):
"""Save a DataFrame to a csv file.
The engine used to write the csv to disk.
Args:
filepath (str): The file path to save to.
engine (str): The library to use to write the csv. One of ["pandas",
"arrow", "auto"]. If "auto", then the library will be chosen based on
the column types.
"""
if engine == "auto":
engine = self._choose_engine()
if engine == "arrow":
from pyarrow.csv import write_csv
write_csv(self.to_arrow(), filepath)
else:
self.to_pandas().to_csv(filepath, index=False)
[docs] def to_feather(self, filepath: str, engine: str = "auto"):
"""Save a DataFrame to a feather file.
The engine used to write the feather to disk.
Args:
filepath (str): The file path to save to.
engine (str): The library to use to write the feather. One of ["pandas",
"arrow", "auto"]. If "auto", then the library will be chosen based on
the column types.
"""
if engine == "auto":
engine = self._choose_engine()
if engine == "arrow":
from pyarrow.feather import write_feather
write_feather(self.to_arrow(), filepath)
else:
self.to_pandas().to_feather(filepath)
[docs] def to_parquet(self, filepath: str, engine: str = "auto"):
"""Save a DataFrame to a parquet file.
The engine used to write the parquet to disk.
Args:
filepath (str): The file path to save to.
engine (str): The library to use to write the parquet. One of ["pandas",
"arrow", "auto"]. If "auto", then the library will be chosen based on
the column types.
"""
if engine == "auto":
engine = self._choose_engine()
if engine == "arrow":
from pyarrow.parquet import write_table
write_table(self.to_arrow(), filepath)
else:
self.to_pandas().to_parquet(filepath)
[docs] def to_json(
self, filepath: str, lines: bool = False, orient: str = "records"
) -> None:
"""Save a Dataset to a json file.
Args:
filepath (str): The file path to save to.
lines (bool): Whether to write the json file as a jsonl file.
orient (str): The orientation of the json file.
Same as :func:`pandas.DataFrame.to_json`.
"""
self.to_pandas().to_json(filepath, lines=lines, orient=orient)
def _get_collate_fns(self, columns: Iterable[str] = None):
columns = self.data.keys() if columns is None else columns
return {name: self.data[name].collate for name in columns}
def _collate(self, batch: List):
batch = tz.merge_with(list, *batch)
column_to_collate = self._get_collate_fns(batch.keys())
new_batch = {}
for name, values in batch.items():
new_batch[name] = column_to_collate[name](values)
df = self._clone(data=new_batch)
return df
@staticmethod
def _convert_to_batch_fn(
function: Callable, with_indices: bool, materialize: bool = True, **kwargs
) -> callable:
return convert_to_batch_fn(
function=function,
with_indices=with_indices,
materialize=materialize,
**kwargs,
)
[docs] def batch(
self,
batch_size: int = 1,
drop_last_batch: bool = False,
num_workers: int = 0,
materialize: bool = True,
shuffle: bool = False,
*args,
**kwargs,
):
"""Batch the dataset.
TODO:
Args:
batch_size: integer batch size
drop_last_batch: drop the last batch if its smaller than batch_size
Returns:
batches of data
"""
cell_columns, batch_columns = [], []
from meerkat.columns.deferred.base import DeferredColumn
for name, column in self.items():
if isinstance(column, DeferredColumn) and materialize:
cell_columns.append(name)
else:
batch_columns.append(name)
indices = np.arange(len(self))
if shuffle:
indices = np.random.permutation(indices)
if batch_columns:
batch_indices = []
for i in range(0, len(self), batch_size):
if drop_last_batch and i + batch_size > len(self):
continue
batch_indices.append(indices[i : i + batch_size])
batch_dl = torch.utils.data.DataLoader(
self[batch_columns],
sampler=batch_indices,
batch_size=None,
batch_sampler=None,
drop_last=drop_last_batch,
num_workers=num_workers,
*args,
**kwargs,
)
if cell_columns:
df = self[cell_columns] if not shuffle else self[cell_columns][indices]
cell_dl = torch.utils.data.DataLoader(
df,
batch_size=batch_size,
collate_fn=self._collate,
drop_last=drop_last_batch,
num_workers=num_workers,
*args,
**kwargs,
)
if batch_columns and cell_columns:
for cell_batch, batch_batch in zip(cell_dl, batch_dl):
out = self._clone(data={**cell_batch.data, **batch_batch.data})
yield out() if materialize else out
elif batch_columns:
for batch_batch in batch_dl:
yield batch_batch() if materialize else batch_batch
elif cell_columns:
for cell_batch in cell_dl:
yield cell_batch() if materialize else cell_batch
# @capture_provenance(capture_args=["with_indices"])
[docs] def update(
self,
function: Optional[Callable] = None,
with_indices: bool = False,
input_columns: Optional[Union[str, List[str]]] = None,
is_batched_fn: bool = False,
batch_size: Optional[int] = 1,
remove_columns: Optional[List[str]] = None,
num_workers: int = 0,
output_type: Union[type, Dict[str, type]] = None,
mmap: bool = False,
mmap_path: str = None,
materialize: bool = True,
pbar: bool = False,
**kwargs,
) -> DataFrame:
"""Update the columns of the dataset."""
# TODO(karan): make this fn go faster
# most of the time is spent on the merge, speed it up further
# Return if `self` has no examples
if not len(self):
logger.info("Dataset empty, returning None.")
return self
# Get some information about the function
df = self[input_columns] if input_columns is not None else self
function_properties = df._inspect_function(
function, with_indices, is_batched_fn, materialize=materialize, **kwargs
)
assert (
function_properties.dict_output
), f"`function` {function} must return dict."
if not is_batched_fn:
# Convert to a batch function
function = convert_to_batch_fn(
function, with_indices=with_indices, materialize=materialize, **kwargs
)
logger.info(f"Converting `function` {function} to batched function.")
# Update always returns a new dataset
logger.info("Running update, a new dataset will be returned.")
# Copy the ._data dict with a reference to the actual columns
new_df = self.view()
# Calculate the values for the new columns using a .map()
output = new_df.map(
function=function,
with_indices=with_indices,
is_batched_fn=True,
batch_size=batch_size,
num_workers=num_workers,
output_type=output_type,
input_columns=input_columns,
mmap=mmap,
mmap_path=mmap_path,
materialize=materialize,
pbar=pbar,
**kwargs,
)
# Add new columns for the update
for col, vals in output.data.items():
new_df.add_column(col, vals, overwrite=True)
# Remove columns
if remove_columns:
for col in remove_columns:
new_df.remove_column(col)
logger.info(f"Removed columns {remove_columns}.")
return new_df
def map(
self,
function: Callable,
is_batched_fn: bool = False,
batch_size: int = 1,
inputs: Union[Mapping[str, str], Sequence[str]] = None,
outputs: Union[Mapping[any, str], Sequence[str]] = None,
output_type: Union[Mapping[str, Type["Column"]], Type["Column"]] = None,
materialize: bool = True,
**kwargs,
) -> Optional[Union[Dict, List, Column]]:
from meerkat.ops.map import map
return map(
data=self,
function=function,
is_batched_fn=is_batched_fn,
batch_size=batch_size,
inputs=inputs,
outputs=outputs,
output_type=output_type,
materialize=materialize,
**kwargs,
)
# @capture_provenance(capture_args=["function"])
[docs] @reactive()
def filter(
self,
function: Optional[Callable] = None,
with_indices=False,
input_columns: Optional[Union[str, List[str]]] = None,
is_batched_fn: bool = False,
batch_size: Optional[int] = 1,
drop_last_batch: bool = False,
num_workers: int = 0,
materialize: bool = True,
pbar: bool = False,
**kwargs,
) -> Optional[DataFrame]:
"""Filter operation on the DataFrame."""
# Return if `self` has no examples
if not len(self):
logger.info("DataFrame empty, returning None.")
return None
# Get some information about the function
df = self[input_columns] if input_columns is not None else self
function_properties = df._inspect_function(
function,
with_indices,
is_batched_fn=is_batched_fn,
materialize=materialize,
**kwargs,
)
assert function_properties.bool_output, "function must return boolean."
# Map to get the boolean outputs and indices
logger.info("Running `filter`, a new DataFrame will be returned.")
outputs = self.map(
function=function,
with_indices=with_indices,
input_columns=input_columns,
is_batched_fn=is_batched_fn,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_workers=num_workers,
materialize=materialize,
pbar=pbar,
**kwargs,
)
indices = np.where(outputs)[0]
# filter returns a new dataframe
return self[indices]
@reactive()
def merge(
self,
right: meerkat.DataFrame,
how: str = "inner",
on: Union[str, List[str]] = None,
left_on: Union[str, List[str]] = None,
right_on: Union[str, List[str]] = None,
sort: bool = False,
suffixes: Sequence[str] = ("_x", "_y"),
validate=None,
):
from meerkat import merge
return merge(
self,
right,
how=how,
on=on,
left_on=left_on,
right_on=right_on,
sort=sort,
suffixes=suffixes,
validate=validate,
)
[docs] @reactive()
def sort(
self,
by: Union[str, List[str]],
ascending: Union[bool, List[bool]] = True,
kind: str = "quicksort",
) -> DataFrame:
"""Sort the DataFrame by the values in the specified columns. Similar
to ``sort_values`` in pandas.
Args:
by (Union[str, List[str]]): The columns to sort by.
ascending (Union[bool, List[bool]]): Whether to sort in ascending or
descending order. If a list, must be the same length as `by`.Defaults
to True.
kind (str): The kind of sort to use. Defaults to 'quicksort'. Options
include 'quicksort', 'mergesort', 'heapsort', 'stable'.
Return:
DataFrame: A sorted view of DataFrame.
"""
from meerkat import sort
return sort(data=self, by=by, ascending=ascending, kind=kind)
[docs] @reactive()
def sample(
self,
n: int = None,
frac: float = None,
replace: bool = False,
weights: Union[str, np.ndarray] = None,
random_state: Union[int, np.random.RandomState] = None,
) -> DataFrame:
"""Select a random sample of rows from DataFrame. Roughly equivalent to
``sample`` in Pandas https://pandas.pydata.org/docs/reference/api/panda
s.DataFrame.sample.html.
Args:
n (int): Number of samples to draw. If `frac` is specified, this parameter
should not be passed. Defaults to 1 if `frac` is not passed.
frac (float): Fraction of rows to sample. If `n` is specified, this
parameter should not be passed.
replace (bool): Sample with or without replacement. Defaults to False.
weights (Union[str, np.ndarray]): Weights to use for sampling. If `None`
(default), the rows will be sampled uniformly. If a numpy array, the
sample will be weighted accordingly. If a string, the weights will be
applied to the rows based on the column with the name specified. If
weights do not sum to 1 they will be normalized to sum to 1.
random_state (Union[int, np.random.RandomState]): Random state or seed to
use for sampling.
Return:
DataFrame: A random sample of rows from the DataFrame.
"""
from meerkat import sample
return sample(
data=self,
n=n,
frac=frac,
replace=replace,
weights=weights,
random_state=random_state,
)
[docs] @reactive()
def shuffle(self, seed: int = None) -> DataFrame:
"""Shuffle the rows of the DataFrame out-of-place.
Args:
seed (int): Random seed to use for shuffling.
Returns:
DataFrame: A shuffled view of the DataFrame.
"""
from meerkat import shuffle
return shuffle(data=self, seed=seed)
[docs] @reactive()
def rename(
self,
mapper: Union[Dict, Callable] = None,
errors: Literal["ignore", "raise"] = "ignore",
) -> DataFrame:
"""Return a new DataFrame with the specified column labels renamed.
Dictionary values must be unique (1-to-1). Labels not specified will be
left unchanged. Extra labels will not throw an error.
Args:
mapper (Union[Dict, Callable], optional): Dict-like of function
transformations to apply to the values of the columns. Defaults
to None.
errors (Literal['ignore', 'raise'], optional): If 'raise', raise a
KeyError when the Dict contains labels that do not exist in the
DataFrame. If 'ignore', extra keys will be ignored. Defaults to
'ignore'.
Raises:
ValueError: _description_
Returns:
DataFrame: A new DataFrame with the specified column labels renamed.
"""
# Copy the ._data dict with a reference to the actual columns
new_df = self.view()
new_primary_key = None
if isinstance(mapper, Dict):
assert len(set(mapper.values())) == len(
mapper.values()
), "Dictionary values must be unique (1-to-1)."
names = self.columns # used to preserve order of columns
for i, (old_name, new_name) in enumerate(mapper.items()):
if old_name not in self.keys():
if errors == "raise":
raise KeyError(
f"Cannot rename nonexistent column `{old_name}`."
)
continue
if old_name == new_name:
continue
if self.primary_key_name == old_name:
new_primary_key = new_name
# Copy old column into new column
new_df[new_name] = new_df[old_name]
new_df.remove_column(old_name)
names[names.index(old_name)] = new_name
new_df = new_df[names]
elif isinstance(mapper, Callable):
for old_name in new_df.columns:
new_name = mapper(old_name)
if old_name == new_name:
continue
if self.primary_key_name == old_name:
new_primary_key = new_name
# Copy old column into new column
new_df[new_name] = new_df[old_name]
new_df.remove_column(old_name)
else:
logger.info(
f"Mapper type is not one of Dict or Callable: {type(mapper)}. Returning"
)
if new_primary_key is not None:
new_df = new_df.set_primary_key(new_primary_key)
return new_df
[docs] @reactive()
def drop(
self, columns: Union[str, Collection[str]], check_exists=True
) -> DataFrame:
"""Return a new DataFrame with the specified columns dropped.
Args:
columns (Union[str, Collection[str]]): The columns to drop.
Return:
DataFrame: A new DataFrame with the specified columns dropped.
"""
if isinstance(columns, str):
columns = [columns]
for c in columns:
if c not in self.columns and check_exists:
raise ValueError(
f"Cannot drop nonexistent column '{c}' from DataFrame."
)
return self[[c for c in self.columns if c not in columns]]
def items(self):
# TODO: Add support for decorating iterators with reactive.
for name in self.columns:
yield name, self.data[name]
@reactive()
def keys(self):
return self.columns
def iterrows(self):
for i in range(len(self)):
yield self[i]
def values(self):
# TODO: Add support for decorating iterators with reactive.
for name in self.columns:
yield self.data[name]
[docs] @classmethod
def read(
cls,
path: str,
overwrite: bool = False,
*args,
**kwargs,
) -> DataFrame:
"""Load a DataFrame stored on disk."""
from meerkat.datasets.utils import download_df, extract_tar_file
# URL
if path.startswith("http://") or path.startswith("https://"):
return download_df(path, overwrite=overwrite)
path = os.path.abspath(os.path.expanduser(path))
if not os.path.exists(path):
raise ValueError(f"Path does not exist: {path}")
# Compressed file.
# TODO: Add support for other compressed formats.
if path.endswith(".tar.gz"):
download_dir = os.path.join(os.path.dirname(path), ".mktemp")
cache_dir = os.path.join(
download_dir, os.path.basename(path).split(".tar.gz")[0]
)
os.makedirs(download_dir, exist_ok=True)
extract_tar_file(path, download_dir=download_dir)
df = cls.read(cache_dir, overwrite=overwrite, *args, **kwargs)
shutil.rmtree(download_dir)
return df
# Load the metadata
metadata = load_yaml(os.path.join(path, "meta.yaml"))
if "state" in metadata:
state = metadata["state"]
else:
# backwards compatability to dill dataframes
state = dill.load(open(os.path.join(path, "state.dill"), "rb"))
df = cls.__new__(cls)
df._set_id() # TODO: consider if we want to persist this id
df._set_inode() # FIXME: this seems hacky
df._set_state(state)
# Load the the manager
mgr_dir = os.path.join(path, "mgr")
if os.path.exists(mgr_dir):
data = BlockManager.read(mgr_dir, **kwargs)
else:
# backwards compatability to pre-manager dataframes
data = {
name: dtype.read(os.path.join(path, "columns", name), *args, **kwargs)
for name, dtype in metadata["column_dtypes"].items()
}
df._set_data(data)
return df
[docs] def write(
self,
path: str,
) -> None:
"""Save a DataFrame to disk."""
# Make all the directories to the path
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
# Get the DataFrame state
state = self._get_state()
# Get the metadata
metadata = {
"dtype": type(self),
"column_dtypes": {name: type(col) for name, col in self.data.items()},
"len": len(self),
"state": state,
}
# write the block manager
mgr_dir = os.path.join(path, "mgr")
self.data.write(mgr_dir)
# Save the metadata as a yaml file
metadata_path = os.path.join(path, "meta.yaml")
dump_yaml(metadata, metadata_path)
[docs] def to_huggingface(self, repository, commit_message: str = None):
"""Upload a DataFrame to a HuggingFace repository.
This method will dump the dataframe into the ``repository.local_dir``.
If ``commit_message`` is specified, the repository will be pushed to the hub.
The dataframe can then be accessed with:
>>> repo = huggingface_hub.snapshot_download(repository)
>>> # or repo = huggingface_hub.Repository(clone_from=repository)
>>> df = mk.read(repo)
Args:
repository: The huggingface_hub.Repository object to upload to.
commit_message: The commit message to use when pushing to the huggingface.
Note:
This will overwrite the existing DataFrame in the repository.
Example:
>>> repo = huggingface_hub.Repository(
... local_dir="my-dataset",
... clone_from="user/my-dataset",
... repo_type="dataset")
>>> df.to_huggingface(repo, commit_message="uploading dataframe")
"""
repository.git_pull()
self.write(repository.local_dir)
repository.push_to_hub(commit_message=commit_message)
@classmethod
def _state_keys(cls) -> Set[str]:
"""List of attributes that describe the state of the object."""
return {"_primary_key"}
def _set_state(self, state: dict):
self.__dict__.update(state)
# backwards compatibility for old dataframes
if "_primary_key" not in state:
self._primary_key = None
def _view_data(self) -> object:
return self.data.view()
def _copy_data(self) -> object:
return self.data.copy()
def __finalize__(self, *args, **kwargs):
return self
@reactive()
def groupby(self, *args, **kwargs):
from meerkat.ops.sliceby.groupby import groupby
return groupby(self, *args, **kwargs)
@reactive()
def sliceby(self, *args, **kwargs):
from meerkat.ops.sliceby.sliceby import sliceby
return sliceby(self, *args, **kwargs)
@reactive()
def clusterby(self, *args, **kwargs):
from meerkat.ops.sliceby.clusterby import clusterby
return clusterby(self, *args, **kwargs)
@reactive()
def explainby(self, *args, **kwargs):
from meerkat.ops.sliceby.explainby import explainby
return explainby(self, *args, **kwargs)
@reactive()
def aggregate(
self, function: Union[str, Callable], nuisance: str = "drop", *args, **kwargs
) -> Dict[str, Any]:
from meerkat.ops.aggregate.aggregate import aggregate
return aggregate(self, function, *args, **kwargs)
@reactive()
def mean(self, *args, nuisance: str = "drop", **kwargs):
from meerkat.ops.aggregate.aggregate import aggregate
return aggregate(self, function="mean", nuisance=nuisance, *args, **kwargs)
def is_listlike(obj) -> bool:
"""Check if the object is listlike.
Args:
obj (object): The object to check.
Return:
bool: True if the object is listlike, False otherwise.
"""
is_column = isinstance(obj, Column)
is_sequential = (
hasattr(obj, "__len__")
and hasattr(obj, "__getitem__")
and not isinstance(obj, str)
)
return is_column or is_sequential
DataFrame.mark = DataFrame._react
DataFrame.unmark = DataFrame._no_react