Skip to content

Commit aabbbc5

Browse files
authored
REF: get rid of StringArrayNumpySemantics (#62149)
1 parent 5b16660 commit aabbbc5

File tree

7 files changed

+86
-88
lines changed

7 files changed

+86
-88
lines changed

asv_bench/benchmarks/strings.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DataFrame,
99
Index,
1010
Series,
11+
StringDtype,
1112
)
1213
from pandas.arrays import StringArray
1314

@@ -290,10 +291,10 @@ def setup(self):
290291
self.series_arr_nan = np.concatenate([self.series_arr, np.array([NA] * 1000)])
291292

292293
def time_string_array_construction(self):
293-
StringArray(self.series_arr)
294+
StringArray(self.series_arr, dtype=StringDtype())
294295

295296
def time_string_array_with_nan_construction(self):
296-
StringArray(self.series_arr_nan)
297+
StringArray(self.series_arr_nan, dtype=StringDtype())
297298

298299
def peakmem_stringarray_construction(self):
299-
StringArray(self.series_arr)
300+
StringArray(self.series_arr, dtype=StringDtype())

pandas/core/arrays/string_.py

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]:
303303
elif self.storage == "pyarrow" and self._na_value is libmissing.NA:
304304
return ArrowStringArray
305305
elif self.storage == "python":
306-
return StringArrayNumpySemantics
306+
return StringArray
307307
else:
308308
return ArrowStringArray
309309

@@ -490,8 +490,10 @@ def _str_map_str_or_object(
490490
)
491491
# error: "BaseStringArray" has no attribute "_from_pyarrow_array"
492492
return self._from_pyarrow_array(result) # type: ignore[attr-defined]
493-
# error: Too many arguments for "BaseStringArray"
494-
return type(self)(result) # type: ignore[call-arg]
493+
else:
494+
# StringArray
495+
# error: Too many arguments for "BaseStringArray"
496+
return type(self)(result, dtype=self.dtype) # type: ignore[call-arg]
495497

496498
else:
497499
# This is when the result type is object. We reach this when
@@ -581,6 +583,8 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
581583
nan-likes(``None``, ``np.nan``) for the ``values`` parameter
582584
in addition to strings and :attr:`pandas.NA`
583585
586+
dtype : StringDtype
587+
Dtype for the array.
584588
copy : bool, default False
585589
Whether to copy the array of data.
586590
@@ -635,36 +639,56 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
635639

636640
# undo the NumpyExtensionArray hack
637641
_typ = "extension"
638-
_storage = "python"
639-
_na_value: libmissing.NAType | float = libmissing.NA
640642

641-
def __init__(self, values, copy: bool = False) -> None:
643+
def __init__(
644+
self, values, *, dtype: StringDtype | None = None, copy: bool = False
645+
) -> None:
646+
if dtype is None:
647+
dtype = StringDtype()
642648
values = extract_array(values)
643649

644650
super().__init__(values, copy=copy)
645651
if not isinstance(values, type(self)):
646-
self._validate()
652+
self._validate(dtype)
647653
NDArrayBacked.__init__(
648654
self,
649655
self._ndarray,
650-
StringDtype(storage=self._storage, na_value=self._na_value),
656+
dtype,
651657
)
652658

653-
def _validate(self) -> None:
659+
def _validate(self, dtype: StringDtype) -> None:
654660
"""Validate that we only store NA or strings."""
655-
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
656-
raise ValueError("StringArray requires a sequence of strings or pandas.NA")
657-
if self._ndarray.dtype != "object":
658-
raise ValueError(
659-
"StringArray requires a sequence of strings or pandas.NA. Got "
660-
f"'{self._ndarray.dtype}' dtype instead."
661-
)
662-
# Check to see if need to convert Na values to pd.NA
663-
if self._ndarray.ndim > 2:
664-
# Ravel if ndims > 2 b/c no cythonized version available
665-
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
661+
662+
if dtype._na_value is libmissing.NA:
663+
if len(self._ndarray) and not lib.is_string_array(
664+
self._ndarray, skipna=True
665+
):
666+
raise ValueError(
667+
"StringArray requires a sequence of strings or pandas.NA"
668+
)
669+
if self._ndarray.dtype != "object":
670+
raise ValueError(
671+
"StringArray requires a sequence of strings or pandas.NA. Got "
672+
f"'{self._ndarray.dtype}' dtype instead."
673+
)
674+
# Check to see if need to convert Na values to pd.NA
675+
if self._ndarray.ndim > 2:
676+
# Ravel if ndims > 2 b/c no cythonized version available
677+
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
678+
else:
679+
lib.convert_nans_to_NA(self._ndarray)
666680
else:
667-
lib.convert_nans_to_NA(self._ndarray)
681+
# Validate that we only store NaN or strings.
682+
if len(self._ndarray) and not lib.is_string_array(
683+
self._ndarray, skipna=True
684+
):
685+
raise ValueError("StringArray requires a sequence of strings or NaN")
686+
if self._ndarray.dtype != "object":
687+
raise ValueError(
688+
"StringArray requires a sequence of strings "
689+
"or NaN. Got '{self._ndarray.dtype}' dtype instead."
690+
)
691+
# TODO validate or force NA/None to NaN
668692

669693
def _validate_scalar(self, value):
670694
# used by NDArrayBackedExtensionIndex.insert
@@ -732,8 +756,8 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
732756
@classmethod
733757
def _empty(cls, shape, dtype) -> StringArray:
734758
values = np.empty(shape, dtype=object)
735-
values[:] = libmissing.NA
736-
return cls(values).astype(dtype, copy=False)
759+
values[:] = dtype.na_value
760+
return cls(values, dtype=dtype).astype(dtype, copy=False)
737761

738762
def __arrow_array__(self, type=None):
739763
"""
@@ -933,7 +957,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
933957
if self._hasna:
934958
na_mask = cast("npt.NDArray[np.bool_]", isna(ndarray))
935959
if np.all(na_mask):
936-
return type(self)(ndarray)
960+
return type(self)(ndarray, dtype=self.dtype)
937961
if skipna:
938962
if name == "cumsum":
939963
ndarray = np.where(na_mask, "", ndarray)
@@ -967,7 +991,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
967991
# Argument 2 to "where" has incompatible type "NAType | float"
968992
np_result = np.where(na_mask, self.dtype.na_value, np_result) # type: ignore[arg-type]
969993

970-
result = type(self)(np_result)
994+
result = type(self)(np_result, dtype=self.dtype)
971995
return result
972996

973997
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
@@ -1046,7 +1070,7 @@ def _cmp_method(self, other, op):
10461070
and other.dtype.na_value is libmissing.NA
10471071
):
10481072
# NA has priority of NaN semantics
1049-
return NotImplemented
1073+
return op(self.astype(other.dtype, copy=False), other)
10501074

10511075
if isinstance(other, ArrowExtensionArray):
10521076
if isinstance(other, BaseStringArray):
@@ -1099,29 +1123,3 @@ def _cmp_method(self, other, op):
10991123
return res_arr
11001124

11011125
_arith_method = _cmp_method
1102-
1103-
1104-
class StringArrayNumpySemantics(StringArray):
1105-
_storage = "python"
1106-
_na_value = np.nan
1107-
1108-
def _validate(self) -> None:
1109-
"""Validate that we only store NaN or strings."""
1110-
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
1111-
raise ValueError(
1112-
"StringArrayNumpySemantics requires a sequence of strings or NaN"
1113-
)
1114-
if self._ndarray.dtype != "object":
1115-
raise ValueError(
1116-
"StringArrayNumpySemantics requires a sequence of strings or NaN. Got "
1117-
f"'{self._ndarray.dtype}' dtype instead."
1118-
)
1119-
# TODO validate or force NA/None to NaN
1120-
1121-
@classmethod
1122-
def _from_sequence(
1123-
cls, scalars, *, dtype: Dtype | None = None, copy: bool = False
1124-
) -> Self:
1125-
if dtype is None:
1126-
dtype = StringDtype(storage="python", na_value=np.nan)
1127-
return super()._from_sequence(scalars, dtype=dtype, copy=copy)

pandas/tests/arrays/string_/test_string.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import pandas as pd
2323
import pandas._testing as tm
24-
from pandas.core.arrays.string_ import StringArrayNumpySemantics
2524
from pandas.core.arrays.string_arrow import (
2625
ArrowStringArray,
2726
)
@@ -115,7 +114,7 @@ def test_repr(dtype):
115114
arr_name = "ArrowStringArray"
116115
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
117116
elif dtype.storage == "python" and dtype.na_value is np.nan:
118-
arr_name = "StringArrayNumpySemantics"
117+
arr_name = "StringArray"
119118
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
120119
else:
121120
arr_name = "StringArray"
@@ -433,44 +432,45 @@ def test_comparison_methods_list(comparison_op, dtype):
433432
def test_constructor_raises(cls):
434433
if cls is pd.arrays.StringArray:
435434
msg = "StringArray requires a sequence of strings or pandas.NA"
436-
elif cls is StringArrayNumpySemantics:
437-
msg = "StringArrayNumpySemantics requires a sequence of strings or NaN"
435+
kwargs = {"dtype": pd.StringDtype()}
438436
else:
439437
msg = "Unsupported type '<class 'numpy.ndarray'>' for ArrowExtensionArray"
438+
kwargs = {}
440439

441440
with pytest.raises(ValueError, match=msg):
442-
cls(np.array(["a", "b"], dtype="S1"))
441+
cls(np.array(["a", "b"], dtype="S1"), **kwargs)
443442

444443
with pytest.raises(ValueError, match=msg):
445-
cls(np.array([]))
444+
cls(np.array([]), **kwargs)
446445

447-
if cls is pd.arrays.StringArray or cls is StringArrayNumpySemantics:
446+
if cls is pd.arrays.StringArray:
448447
# GH#45057 np.nan and None do NOT raise, as they are considered valid NAs
449448
# for string dtype
450-
cls(np.array(["a", np.nan], dtype=object))
451-
cls(np.array(["a", None], dtype=object))
449+
cls(np.array(["a", np.nan], dtype=object), **kwargs)
450+
cls(np.array(["a", None], dtype=object), **kwargs)
452451
else:
453452
with pytest.raises(ValueError, match=msg):
454-
cls(np.array(["a", np.nan], dtype=object))
453+
cls(np.array(["a", np.nan], dtype=object), **kwargs)
455454
with pytest.raises(ValueError, match=msg):
456-
cls(np.array(["a", None], dtype=object))
455+
cls(np.array(["a", None], dtype=object), **kwargs)
457456

458457
with pytest.raises(ValueError, match=msg):
459-
cls(np.array(["a", pd.NaT], dtype=object))
458+
cls(np.array(["a", pd.NaT], dtype=object), **kwargs)
460459

461460
with pytest.raises(ValueError, match=msg):
462-
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object))
461+
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object), **kwargs)
463462

464463
with pytest.raises(ValueError, match=msg):
465-
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object))
464+
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object), **kwargs)
466465

467466

468467
@pytest.mark.parametrize("na", [np.nan, np.float64("nan"), float("nan"), None, pd.NA])
469468
def test_constructor_nan_like(na):
470-
expected = pd.arrays.StringArray(np.array(["a", pd.NA]))
471-
tm.assert_extension_array_equal(
472-
pd.arrays.StringArray(np.array(["a", na], dtype="object")), expected
469+
expected = pd.arrays.StringArray(np.array(["a", pd.NA]), dtype=pd.StringDtype())
470+
result = pd.arrays.StringArray(
471+
np.array(["a", na], dtype="object"), dtype=pd.StringDtype()
473472
)
473+
tm.assert_extension_array_equal(result, expected)
474474

475475

476476
@pytest.mark.parametrize("copy", [True, False])
@@ -487,10 +487,10 @@ def test_from_sequence_no_mutate(copy, cls, dtype):
487487
expected = cls(
488488
pa.array(na_arr, type=pa.string(), from_pandas=True), dtype=dtype
489489
)
490-
elif cls is StringArrayNumpySemantics:
491-
expected = cls(nan_arr)
490+
elif dtype.na_value is np.nan:
491+
expected = cls(nan_arr, dtype=dtype)
492492
else:
493-
expected = cls(na_arr)
493+
expected = cls(na_arr, dtype=dtype)
494494

495495
tm.assert_extension_array_equal(result, expected)
496496
tm.assert_numpy_array_equal(nan_arr, expected_input)

pandas/tests/base/test_conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
NumpyExtensionArray,
2222
PeriodArray,
2323
SparseArray,
24+
StringArray,
2425
TimedeltaArray,
2526
)
26-
from pandas.core.arrays.string_ import StringArrayNumpySemantics
2727
from pandas.core.arrays.string_arrow import ArrowStringArray
2828

2929

@@ -222,7 +222,7 @@ def test_iter_box_period(self):
222222
)
223223
def test_values_consistent(arr, expected_type, dtype, using_infer_string):
224224
if using_infer_string and dtype == "object":
225-
expected_type = ArrowStringArray if HAS_PYARROW else StringArrayNumpySemantics
225+
expected_type = ArrowStringArray if HAS_PYARROW else StringArray
226226
l_values = Series(arr)._values
227227
r_values = pd.Index(arr)._values
228228
assert type(l_values) is expected_type

pandas/tests/extension/test_common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,13 @@ def __getitem__(self, item):
9393
def test_ellipsis_index():
9494
# GH#42430 1D slices over extension types turn into N-dimensional slices
9595
# over ExtensionArrays
96+
dtype = pd.StringDtype()
9697
df = pd.DataFrame(
97-
{"col1": CapturingStringArray(np.array(["hello", "world"], dtype=object))}
98+
{
99+
"col1": CapturingStringArray(
100+
np.array(["hello", "world"], dtype=object), dtype=dtype
101+
)
102+
}
98103
)
99104
_ = df.iloc[:1]
100105

pandas/tests/io/parser/test_upcast.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
BooleanArray,
1515
FloatingArray,
1616
IntegerArray,
17-
StringArray,
1817
)
1918

2019

@@ -95,7 +94,7 @@ def test_maybe_upcast_object(val, string_storage):
9594

9695
if string_storage == "python":
9796
exp_val = "c" if val == "c" else NA
98-
expected = StringArray(np.array(["a", "b", exp_val], dtype=np.object_))
97+
expected = pd.array(["a", "b", exp_val], dtype=pd.StringDtype())
9998
else:
10099
exp_val = "c" if val == "c" else None
101100
expected = ArrowStringArray(pa.array(["a", "b", exp_val]))

pandas/tests/io/test_orc.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import pandas as pd
1313
from pandas import read_orc
1414
import pandas._testing as tm
15-
from pandas.core.arrays import StringArray
1615

1716
pytest.importorskip("pyarrow.orc")
1817

@@ -368,13 +367,9 @@ def test_orc_dtype_backend_numpy_nullable():
368367

369368
expected = pd.DataFrame(
370369
{
371-
"string": StringArray(np.array(["a", "b", "c"], dtype=np.object_)),
372-
"string_with_nan": StringArray(
373-
np.array(["a", pd.NA, "c"], dtype=np.object_)
374-
),
375-
"string_with_none": StringArray(
376-
np.array(["a", pd.NA, "c"], dtype=np.object_)
377-
),
370+
"string": pd.array(["a", "b", "c"], dtype=pd.StringDtype()),
371+
"string_with_nan": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()),
372+
"string_with_none": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()),
378373
"int": pd.Series([1, 2, 3], dtype="Int64"),
379374
"int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"),
380375
"na_only": pd.Series([pd.NA, pd.NA, pd.NA], dtype="Int64"),

0 commit comments

Comments
 (0)