@@ -303,7 +303,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]:
303
303
elif self .storage == "pyarrow" and self ._na_value is libmissing .NA :
304
304
return ArrowStringArray
305
305
elif self .storage == "python" :
306
- return StringArrayNumpySemantics
306
+ return StringArray
307
307
else :
308
308
return ArrowStringArray
309
309
@@ -490,8 +490,10 @@ def _str_map_str_or_object(
490
490
)
491
491
# error: "BaseStringArray" has no attribute "_from_pyarrow_array"
492
492
return self ._from_pyarrow_array (result ) # type: ignore[attr-defined]
493
- # error: Too many arguments for "BaseStringArray"
494
- return type (self )(result ) # type: ignore[call-arg]
493
+ else :
494
+ # StringArray
495
+ # error: Too many arguments for "BaseStringArray"
496
+ return type (self )(result , dtype = self .dtype ) # type: ignore[call-arg]
495
497
496
498
else :
497
499
# This is when the result type is object. We reach this when
@@ -581,6 +583,8 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
581
583
nan-likes(``None``, ``np.nan``) for the ``values`` parameter
582
584
in addition to strings and :attr:`pandas.NA`
583
585
586
+ dtype : StringDtype
587
+ Dtype for the array.
584
588
copy : bool, default False
585
589
Whether to copy the array of data.
586
590
@@ -635,36 +639,56 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
635
639
636
640
# undo the NumpyExtensionArray hack
637
641
_typ = "extension"
638
- _storage = "python"
639
- _na_value : libmissing .NAType | float = libmissing .NA
640
642
641
- def __init__ (self , values , copy : bool = False ) -> None :
643
+ def __init__ (
644
+ self , values , * , dtype : StringDtype | None = None , copy : bool = False
645
+ ) -> None :
646
+ if dtype is None :
647
+ dtype = StringDtype ()
642
648
values = extract_array (values )
643
649
644
650
super ().__init__ (values , copy = copy )
645
651
if not isinstance (values , type (self )):
646
- self ._validate ()
652
+ self ._validate (dtype )
647
653
NDArrayBacked .__init__ (
648
654
self ,
649
655
self ._ndarray ,
650
- StringDtype ( storage = self . _storage , na_value = self . _na_value ) ,
656
+ dtype ,
651
657
)
652
658
653
- def _validate (self ) -> None :
659
+ def _validate (self , dtype : StringDtype ) -> None :
654
660
"""Validate that we only store NA or strings."""
655
- if len (self ._ndarray ) and not lib .is_string_array (self ._ndarray , skipna = True ):
656
- raise ValueError ("StringArray requires a sequence of strings or pandas.NA" )
657
- if self ._ndarray .dtype != "object" :
658
- raise ValueError (
659
- "StringArray requires a sequence of strings or pandas.NA. Got "
660
- f"'{ self ._ndarray .dtype } ' dtype instead."
661
- )
662
- # Check to see if need to convert Na values to pd.NA
663
- if self ._ndarray .ndim > 2 :
664
- # Ravel if ndims > 2 b/c no cythonized version available
665
- lib .convert_nans_to_NA (self ._ndarray .ravel ("K" ))
661
+
662
+ if dtype ._na_value is libmissing .NA :
663
+ if len (self ._ndarray ) and not lib .is_string_array (
664
+ self ._ndarray , skipna = True
665
+ ):
666
+ raise ValueError (
667
+ "StringArray requires a sequence of strings or pandas.NA"
668
+ )
669
+ if self ._ndarray .dtype != "object" :
670
+ raise ValueError (
671
+ "StringArray requires a sequence of strings or pandas.NA. Got "
672
+ f"'{ self ._ndarray .dtype } ' dtype instead."
673
+ )
674
+ # Check to see if need to convert Na values to pd.NA
675
+ if self ._ndarray .ndim > 2 :
676
+ # Ravel if ndims > 2 b/c no cythonized version available
677
+ lib .convert_nans_to_NA (self ._ndarray .ravel ("K" ))
678
+ else :
679
+ lib .convert_nans_to_NA (self ._ndarray )
666
680
else :
667
- lib .convert_nans_to_NA (self ._ndarray )
681
+ # Validate that we only store NaN or strings.
682
+ if len (self ._ndarray ) and not lib .is_string_array (
683
+ self ._ndarray , skipna = True
684
+ ):
685
+ raise ValueError ("StringArray requires a sequence of strings or NaN" )
686
+ if self ._ndarray .dtype != "object" :
687
+ raise ValueError (
688
+ "StringArray requires a sequence of strings "
689
+ "or NaN. Got '{self._ndarray.dtype}' dtype instead."
690
+ )
691
+ # TODO validate or force NA/None to NaN
668
692
669
693
def _validate_scalar (self , value ):
670
694
# used by NDArrayBackedExtensionIndex.insert
@@ -732,8 +756,8 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
732
756
@classmethod
733
757
def _empty (cls , shape , dtype ) -> StringArray :
734
758
values = np .empty (shape , dtype = object )
735
- values [:] = libmissing . NA
736
- return cls (values ).astype (dtype , copy = False )
759
+ values [:] = dtype . na_value
760
+ return cls (values , dtype = dtype ).astype (dtype , copy = False )
737
761
738
762
def __arrow_array__ (self , type = None ):
739
763
"""
@@ -933,7 +957,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
933
957
if self ._hasna :
934
958
na_mask = cast ("npt.NDArray[np.bool_]" , isna (ndarray ))
935
959
if np .all (na_mask ):
936
- return type (self )(ndarray )
960
+ return type (self )(ndarray , dtype = self . dtype )
937
961
if skipna :
938
962
if name == "cumsum" :
939
963
ndarray = np .where (na_mask , "" , ndarray )
@@ -967,7 +991,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
967
991
# Argument 2 to "where" has incompatible type "NAType | float"
968
992
np_result = np .where (na_mask , self .dtype .na_value , np_result ) # type: ignore[arg-type]
969
993
970
- result = type (self )(np_result )
994
+ result = type (self )(np_result , dtype = self . dtype )
971
995
return result
972
996
973
997
def _wrap_reduction_result (self , axis : AxisInt | None , result ) -> Any :
@@ -1046,7 +1070,7 @@ def _cmp_method(self, other, op):
1046
1070
and other .dtype .na_value is libmissing .NA
1047
1071
):
1048
1072
# NA has priority of NaN semantics
1049
- return NotImplemented
1073
+ return op ( self . astype ( other . dtype , copy = False ), other )
1050
1074
1051
1075
if isinstance (other , ArrowExtensionArray ):
1052
1076
if isinstance (other , BaseStringArray ):
@@ -1099,29 +1123,3 @@ def _cmp_method(self, other, op):
1099
1123
return res_arr
1100
1124
1101
1125
_arith_method = _cmp_method
1102
-
1103
-
1104
- class StringArrayNumpySemantics (StringArray ):
1105
- _storage = "python"
1106
- _na_value = np .nan
1107
-
1108
- def _validate (self ) -> None :
1109
- """Validate that we only store NaN or strings."""
1110
- if len (self ._ndarray ) and not lib .is_string_array (self ._ndarray , skipna = True ):
1111
- raise ValueError (
1112
- "StringArrayNumpySemantics requires a sequence of strings or NaN"
1113
- )
1114
- if self ._ndarray .dtype != "object" :
1115
- raise ValueError (
1116
- "StringArrayNumpySemantics requires a sequence of strings or NaN. Got "
1117
- f"'{ self ._ndarray .dtype } ' dtype instead."
1118
- )
1119
- # TODO validate or force NA/None to NaN
1120
-
1121
- @classmethod
1122
- def _from_sequence (
1123
- cls , scalars , * , dtype : Dtype | None = None , copy : bool = False
1124
- ) -> Self :
1125
- if dtype is None :
1126
- dtype = StringDtype (storage = "python" , na_value = np .nan )
1127
- return super ()._from_sequence (scalars , dtype = dtype , copy = copy )
0 commit comments