from__future__importannotationsimportosimportreimportwarningsfromtypingimportTYPE_CHECKING,Any,List,Sequence,Set,Unionimportnumpyasnpimportpyarrowaspaimportpyarrow.computeaspcfrompandas.core.accessorimportCachedAccessorfrommeerkat.block.abstractimportBlockViewfrommeerkat.block.arrow_blockimportArrowBlockfrommeerkat.errorsimportImmutableErrorfrommeerkat.tools.lazy_loaderimportLazyLoaderfrom..abstractimportColumnfrom.abstractimportScalarColumn,StringMethodsifTYPE_CHECKING:frommeerkatimportDataFramefrommeerkat.interactive.formatter.baseimportBaseFormattertorch=LazyLoader("torch")classArrowStringMethods(StringMethods):defcenter(self,width:int,fillchar:str=" ",**kwargs)->ScalarColumn:returnself.column._dispatch_unary_function("utf8_center",width=width,padding=fillchar,**kwargs)defextract(self,pat:str,**kwargs)->"DataFrame":frommeerkatimportDataFrame# Pandas raises a value error if the pattern does not include a group# but pyarrow does not. We check for this case and raise a value error.ifnotre.search(r"\(\?P<\w+>",pat):raiseValueError("Pattern does not contain capture group. Use '(?P<name>...)' instead")struct_array=pc.extract_regex(self.column.data,pattern=pat,**kwargs)result={}forfield_indexinrange(struct_array.type.num_fields):field=struct_array.type.field(field_index)result[field.name]=self.column._clone(pc.struct_field(struct_array,field.name))returnDataFrame(result)def_split(self,pat=None,n=-1,reverse:bool=False,regex:bool=False,**kwargs)->"DataFrame":frommeerkatimportDataFramefn=pc.split_pattern_regexifregexelsepc.split_patternlist_array=fn(self.column.data,pattern=pat,max_splits=nifn!=-1elseNone,reverse=reverse,**kwargs,)# need to find the max length of the list arrayifn==-1:n=pc.max(pc.list_value_length(list_array)).as_py()-1returnDataFrame({str(i):self.column._clone(data=pc.list_flatten(pc.list_slice(list_array,start=i,stop=i+1,return_fixed_size_list=True)))foriinrange(n+1)})defsplit(self,pat:str=None,n:int=-1,regex:bool=False,**kwargs)->"DataFrame":returnself._split(pat=pat,n=n,reverse=False,regex=regex,**kwargs)defrsplit(self,pat:str=None,n:int=-1,regex:bool=False,**kwargs)->"DataFrame":returnself._split(pat=pat,n=n,reverse=True,regex=regex,**kwargs)defstartswith(self,pat:str,**kwargs)->ScalarColumn:returnself.column._dispatch_unary_function("starts_with",pattern=pat,**kwargs)defstrip(self,to_strip:str=None,**kwargs)->ScalarColumn:ifto_stripisNone:returnself.column._dispatch_unary_function("utf8_trim_whitespace",**kwargs)else:returnself.column._dispatch_unary_function("utf8_strip",characters=to_strip,**kwargs)deflstrip(self,to_strip:str=None,**kwargs)->ScalarColumn:ifto_stripisNone:returnself.column._dispatch_unary_function("utf8_ltrim_whitespace",**kwargs)else:returnself.column._dispatch_unary_function("utf8_lstrip",characters=to_strip,**kwargs)defrstrip(self,to_strip:str=None,**kwargs)->ScalarColumn:ifto_stripisNone:returnself.column._dispatch_unary_function("utf8_rtrim_whitespace",**kwargs)else:returnself.column._dispatch_unary_function("utf8_rstrip",characters=to_strip,**kwargs)defreplace(self,pat:str,repl:str,n:int=-1,regex:bool=False,**kwargs)->ScalarColumn:fn=pc.replace_substring_regexifregexelsepc.replace_substringreturnself.column._clone(fn(self.column.data,pattern=pat,replacement=repl,max_replacements=nifn!=-1elseNone,**kwargs,))defcontains(self,pat:str,case:bool=True,regex:bool=True)->ScalarColumn:fn=pc.match_substring_regexifregexelsepc.match_substringreturnself.column._clone(fn(self.column.data,pattern=pat,ignore_case=notcase,))
[docs]classArrowScalarColumn(ScalarColumn):block_class:type=ArrowBlockstr=CachedAccessor("str",ArrowStringMethods)def__init__(self,data:Sequence,*args,**kwargs,):ifisinstance(data,BlockView):ifnotisinstance(data.block,ArrowBlock):raiseValueError("ArrowArrayColumn can only be initialized with ArrowBlock.")elifnotisinstance(data,(pa.Array,pa.ChunkedArray)):# Arrow cannot construct an array from a torch.Tensor.ifisinstance(data,torch.Tensor):data=data.numpy()data=pa.array(data)super(ArrowScalarColumn,self).__init__(data=data,*args,**kwargs)def_get(self,index,materialize:bool=True):index=ArrowBlock._convert_index(index)ifisinstance(index,slice)orisinstance(index,int):data=self._data[index]elifindex.dtype==bool:data=self._data.filter(pa.array(index))else:data=self._data.take(index)ifself._is_batch_index(index):returnself._clone(data=data)else:# Convert to Python object for consistency with other ScalarColumn# implementations.returndata.as_py()def_set(self,index,value):raiseImmutableError("ArrowArrayColumn is immutable.")def_is_valid_primary_key(self):try:returnlen(self.unique())==len(self)exceptExceptionase:warnings.warn(f"Unable to check if column is a valid primary key: {e}")returnFalsedef_keyidx_to_posidx(self,keyidx:Any)->int:"""Get the posidx of the first occurrence of the given keyidx. Raise a key error if the keyidx is not found. Args: keyidx: The keyidx to search for. Returns: The posidx of the first occurrence of the given keyidx. """posidx=pc.index(self.data,keyidx)ifposidx==-1:raiseKeyError(f"keyidx {keyidx} not found in column.")returnposidx.as_py()def_keyidxs_to_posidxs(self,keyidxs:Sequence[Any])->np.ndarray:# FIXME: this implementation is very slow. This should be done with indicesreturnnp.array([self._keyidx_to_posidx(keyidx)forkeyidxinkeyidxs])def_repr_cell(self,index)->object:returnself.data[index]def_get_default_formatters(self)->BaseFormatter:# can't implement this as a class level property because then it will treat# the formatter as a methodfrommeerkat.interactive.formatterimport(NumberFormatterGroup,TextFormatterGroup,)iflen(self)==0:returnsuper()._get_default_formatters()ifself.data.type==pa.string():returnTextFormatterGroup()cell=self[0]returnNumberFormatterGroup(dtype=type(cell).__name__)
@classmethoddef_state_keys(cls)->Set:returnsuper()._state_keys()def_write_data(self,path):table=pa.Table.from_arrays([self.data],names=["0"])ArrowBlock._write_table(os.path.join(path,"data.arrow"),table)@staticmethoddef_read_data(path,mmap=False):table=ArrowBlock._read_table(os.path.join(path,"data.arrow"),mmap=mmap)returntable["0"]@classmethoddefconcat(cls,columns:Sequence[ArrowScalarColumn]):arrays=[]forcincolumns:ifisinstance(c.data,pa.Array):arrays.append(c.data)elifisinstance(c.data,pa.ChunkedArray):arrays.extend(c.data.chunks)else:raiseValueError(f"Unexpected type {type(c.data)}")data=pa.concat_arrays(arrays)returncolumns[0]._clone(data=data)
defequals(self,other:Column)->bool:ifother.__class__!=self.__class__:returnFalsereturnpc.all(pc.equal(self.data,other.data)).as_py()@propertydefdtype(self)->pa.DataType:returnself.data.typeKWARG_MAPPING={"skipna":"skip_nulls"}COMPUTE_FN_MAPPING={"var":"variance","std":"stddev","sub":"subtract","mul":"multiply","truediv":"divide","pow":"power","eq":"equal","ne":"not_equal","lt":"less","gt":"greater","le":"less_equal","ge":"greater_equal","isna":"is_nan","capitalize":"utf8_capitalize","center":"utf8_center","isalnum":"utf8_is_alnum","isalpha":"utf8_is_alpha","isdecimal":"utf8_is_decimal","isdigit":"utf8_is_digit","islower":"utf8_is_lower","isnumeric":"utf8_is_numeric","isspace":"utf8_is_space","istitle":"utf8_is_title","isupper":"utf8_is_upper","lower":"utf8_lower","upper":"utf8_upper","len":"utf8_length","lstrip":"utf8_ltrim","rstrip":"utf8_rtrim","strip":"utf8_trim","swapcase":"utf8_swapcase","title":"utf8_title",}def_dispatch_aggregation_function(self,compute_fn:str,**kwargs):kwargs={self.KWARG_MAPPING.get(k,k):vfork,vinkwargs.items()}out=getattr(pc,self.COMPUTE_FN_MAPPING.get(compute_fn,compute_fn))(self.data,**kwargs)returnout.as_py()defmode(self,**kwargs)->ScalarColumn:if"n"in"kwargs":raiseValueError("Meerkat does not support passing `n` to `mode` when ""backend is Arrow.")# matching behavior of Pandas, get all counts, but only return top modesstruct_array=pc.mode(self.data,n=len(self),**kwargs)modes=[]count=struct_array[0]["count"]formodeinstruct_array:ifcount!=mode["count"]:breakmodes.append(mode["mode"].as_py())returnArrowScalarColumn(modes)defmedian(self,skipna:bool=True,**kwargs)->any:warnings.warn("Arrow backend computes an approximate median.")returnpc.approximate_median(self.data,skip_nulls=skipna).as_py()def_dispatch_arithmetic_function(self,other:ScalarColumn,compute_fn:str,right:bool,*args,**kwargs):ifisinstance(other,Column):assertisinstance(other,ArrowScalarColumn)other=other.datacompute_fn=self.COMPUTE_FN_MAPPING.get(compute_fn,compute_fn)ifright:out=self._clone(data=getattr(pc,compute_fn)(other,self.data,*args,**kwargs))returnoutelse:returnself._clone(data=getattr(pc,compute_fn)(self.data,other,*args,**kwargs))def_true_div(self,other,right:bool=False,**kwargs)->ScalarColumn:ifisinstance(other,Column):assertisinstance(other,ArrowScalarColumn)other=other.data# convert other to float if it is an integerifisinstance(other,pa.ChunkedArray)orisinstance(other,pa.Array):ifother.type==pa.int64():other=other.cast(pa.float64())else:other=pa.scalar(other,type=pa.float64())ifright:returnself._clone(pc.divide(other,self.data),**kwargs)else:returnself._clone(pc.divide(self.data,other),**kwargs)def__add__(self,other:ScalarColumn):ifself.dtype==pa.string():# pyarrow expects a final str used as the spearatorreturnself._dispatch_arithmetic_function(other,"binary_join_element_wise",False,"")returnself._dispatch_arithmetic_function(other,"add",right=False)def__radd__(self,other:ScalarColumn):ifself.dtype==pa.string():returnself._dispatch_arithmetic_function(other,"binary_join_element_wise",True,"")returnself._dispatch_arithmetic_function(other,"add",right=False)def__truediv__(self,other:ScalarColumn):returnself._true_div(other,right=False)def__rtruediv__(self,other:ScalarColumn):returnself._true_div(other,right=True)def_floor_div(self,other,right:bool=False,**kwargs)->ScalarColumn:_true_div=self._true_div(other,right=right,**kwargs)return_true_div._clone(data=pc.floor(_true_div.data))def__floordiv__(self,other:ScalarColumn):returnself._floor_div(other,right=False)def__rfloordiv__(self,other:ScalarColumn):returnself._floor_div(other,right=True)def__mod__(self,other:ScalarColumn):raiseNotImplementedError("Modulo is not supported by Arrow backend.")def__rmod__(self,other:ScalarColumn):raiseNotImplementedError("Modulo is not supported by Arrow backend.")def_dispatch_comparison_function(self,other:ScalarColumn,compute_fn:str,**kwargs):ifisinstance(other,Column):assertisinstance(other,ArrowScalarColumn)other=other.datacompute_fn=self.COMPUTE_FN_MAPPING.get(compute_fn,compute_fn)returnself._clone(data=getattr(pc,compute_fn)(self.data,other,**kwargs))def_dispatch_logical_function(self,other:ScalarColumn,compute_fn:str,**kwargs):ifisinstance(other,Column):assertisinstance(other,ArrowScalarColumn)other=other.datacompute_fn=self.COMPUTE_FN_MAPPING.get(compute_fn,compute_fn)ifotherisNone:returnself._clone(data=getattr(pc,compute_fn)(self.data,**kwargs))returnself._clone(data=getattr(pc,compute_fn)(self.data,other,**kwargs))defisin(self,values:Union[List,Set],**kwargs)->ScalarColumn:returnself._clone(data=pc.is_in(self.data,pa.array(values),**kwargs))def_dispatch_unary_function(self,compute_fn:str,_namespace:str=None,**kwargs):compute_fn=self.COMPUTE_FN_MAPPING.get(compute_fn,compute_fn)returnself._clone(data=getattr(pc,compute_fn)(self.data,**kwargs))defisnull(self,**kwargs)->ScalarColumn:returnself._clone(data=pc.is_null(self.data,nan_is_null=True,**kwargs))