Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 1 addition & 67 deletions python/ray/air/_internal/tensorflow_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Union

import numpy as np
import pyarrow
import tensorflow as tf

from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed
from ray.air.util.tensor_extensions.arrow import get_arrow_extension_tensor_types

if TYPE_CHECKING:
from ray.data._internal.pandas_block import PandasBlockSchema


def convert_ndarray_to_tf_tensor(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is now duplicated across ray.air and ray.data, but its not dead code.

In a follow-up, I think we should remove this and update the reference in tensorflow_predictor to use the one defined in ray.data.

Expand Down Expand Up @@ -74,64 +69,3 @@ def convert_ndarray_batch_to_tf_tensor_batch(
}

return batch


def get_type_spec(
schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"],
columns: Union[str, List[str]],
) -> Union[tf.TypeSpec, Dict[str, tf.TypeSpec]]:
import pyarrow as pa

from ray.data.extensions import TensorDtype

tensor_extension_types = get_arrow_extension_tensor_types()

assert not isinstance(schema, type)

dtypes: Dict[str, Union[np.dtype, pa.DataType]] = dict(
zip(schema.names, schema.types)
)

def get_dtype(dtype: Union[np.dtype, pa.DataType]) -> tf.dtypes.DType:
if isinstance(dtype, pa.ListType):
dtype = dtype.value_type
if isinstance(dtype, pa.DataType):
dtype = dtype.to_pandas_dtype()
if isinstance(dtype, TensorDtype):
dtype = dtype.element_dtype
res = tf.dtypes.as_dtype(dtype)
return res

def get_shape(dtype: Union[np.dtype, pa.DataType]) -> Tuple[int, ...]:
shape = (None,)
if isinstance(dtype, tensor_extension_types):
dtype = dtype.to_pandas_dtype()
if isinstance(dtype, pa.ListType):
shape += (None,)
elif isinstance(dtype, TensorDtype):
shape += dtype.element_shape
return shape

def get_tensor_spec(
dtype: Union[np.dtype, pa.DataType], *, name: str
) -> tf.TypeSpec:

shape, dtype = get_shape(dtype), get_dtype(dtype)
# Batch dimension is always `None`. So, if there's more than one `None`-valued
# dimension, then the tensor is ragged.
is_ragged = sum(dim is None for dim in shape) > 1
if is_ragged:
type_spec = tf.RaggedTensorSpec(shape, dtype=dtype)
else:
type_spec = tf.TensorSpec(shape, dtype=dtype, name=name)
return type_spec

if isinstance(columns, str):
name, dtype = columns, dtypes[columns]
return get_tensor_spec(dtype, name=name)

return {
name: get_tensor_spec(dtype, name=name)
for name, dtype in dtypes.items()
if name in columns
}
8 changes: 4 additions & 4 deletions python/ray/air/util/data_batch_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def _convert_batch_type_to_numpy(
)
return data
elif pyarrow is not None and isinstance(data, pyarrow.Table):
from ray.air.util.tensor_extensions.arrow import (
from ray.data._internal.arrow_ops import transform_pyarrow
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this still importing from data? Should this util also be moved to data?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you already created a duplicate in python/ray/data, should we just delete this file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, the reason this is here is because this utility is still dependent from ray.train.predictors, so this is a duplicate. How about let's just address this in a later PR when we remove predictors?

from ray.data._internal.tensor_extensions.arrow import (
get_arrow_extension_fixed_shape_tensor_types,
)
from ray.data._internal.arrow_ops import transform_pyarrow

column_values_ndarrays = []

Expand Down Expand Up @@ -292,7 +292,7 @@ def _cast_ndarray_columns_to_tensor_extension(df: "pd.DataFrame") -> "pd.DataFra
# SettingWithCopyWarning was moved to pd.errors in Pandas 1.5.0.
SettingWithCopyWarning = pd.errors.SettingWithCopyWarning

from ray.air.util.tensor_extensions.pandas import (
from ray.data._internal.tensor_extensions.pandas import (
TensorArray,
column_needs_tensor_extension,
)
Expand Down Expand Up @@ -334,7 +334,7 @@ def _cast_tensor_columns_to_ndarrays(df: "pd.DataFrame") -> "pd.DataFrame":
except AttributeError:
# SettingWithCopyWarning was moved to pd.errors in Pandas 1.5.0.
SettingWithCopyWarning = pd.errors.SettingWithCopyWarning
from ray.air.util.tensor_extensions.pandas import TensorDtype
from ray.data._internal.tensor_extensions.pandas import TensorDtype

# Try to convert any tensor extension columns to ndarray columns.
# TODO(Clark): Optimize this with propagated DataFrame metadata containing a list of
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
import pyarrow as pa

# Import these arrow extension types to ensure that they are registered.
from ray.air.util.tensor_extensions.arrow import ( # noqa
from ray.data._internal.tensor_extensions.arrow import ( # noqa
ArrowTensorType,
ArrowVariableShapedTensorType,
)
Expand Down
12 changes: 6 additions & 6 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@

from ray._private.arrow_utils import get_pyarrow_version
from ray._private.ray_constants import env_integer
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.air.util.tensor_extensions.arrow import (
convert_to_pyarrow_array,
pyarrow_table_from_pydict,
)
from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow
from ray.data._internal.arrow_ops.transform_pyarrow import shuffle
from ray.data._internal.row import row_repr, row_repr_pretty, row_str
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
from ray.data._internal.tensor_extensions.arrow import (
convert_to_pyarrow_array,
pyarrow_table_from_pydict,
)
from ray.data.block import (
Block,
BlockAccessor,
Expand All @@ -38,6 +37,7 @@
BlockType,
U,
)
from ray.data.constants import TENSOR_COLUMN_NAME
from ray.data.context import DEFAULT_TARGET_MAX_BLOCK_SIZE, DataContext
from ray.data.expressions import Expr

Expand Down Expand Up @@ -272,7 +272,7 @@ def schema(self) -> "pyarrow.lib.Schema":
return self._table.schema

def to_pandas(self) -> "pandas.DataFrame":
from ray.air.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays
from ray.data.util.data_batch_conversion import _cast_tensor_columns_to_ndarrays

# We specify ignore_metadata=True because pyarrow will use the metadata
# to build the Table. This is handled incorrectly for older pyarrow versions
Expand Down
18 changes: 9 additions & 9 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ray._private.arrow_utils import get_pyarrow_version
from ray._private.ray_constants import env_integer
from ray._private.utils import INT32_MAX
from ray.air.util.tensor_extensions.arrow import (
from ray.data._internal.tensor_extensions.arrow import (
MIN_PYARROW_VERSION_CHUNKED_ARRAY_TO_NUMPY_ZERO_COPY_ONLY,
PYARROW_VERSION,
get_arrow_extension_fixed_shape_tensor_types,
Expand Down Expand Up @@ -143,7 +143,7 @@ def take_table(
extension arrays. This is exposed as a static method for easier use on
intermediate tables, not underlying an ArrowBlockAccessor.
"""
from ray.air.util.transform_pyarrow import (
from ray.data._internal.utils.transform_pyarrow import (
_concatenate_extension_column,
_is_pa_extension_type,
)
Expand Down Expand Up @@ -176,7 +176,7 @@ def _reconcile_diverging_fields(
Returns:
A dictionary of diverging fields with their reconciled types.
"""
from ray.air.util.object_extensions.arrow import ArrowPythonObjectType
from ray.data._internal.object_extensions.arrow import ArrowPythonObjectType

reconciled_fields = {}
field_types = defaultdict(list) # field_name -> list of types seen so far
Expand Down Expand Up @@ -232,8 +232,8 @@ def _reconcile_field(

Returns reconciled type or None if default PyArrow handling is sufficient.
"""
from ray.air.util.object_extensions.arrow import ArrowPythonObjectType
from ray.air.util.tensor_extensions.arrow import (
from ray.data._internal.object_extensions.arrow import ArrowPythonObjectType
from ray.data._internal.tensor_extensions.arrow import (
get_arrow_extension_tensor_types,
)

Expand Down Expand Up @@ -431,7 +431,7 @@ def _backfill_missing_fields(
"""
import pyarrow as pa

from ray.air.util.tensor_extensions.arrow import (
from ray.data._internal.tensor_extensions.arrow import (
ArrowVariableShapedTensorType,
)

Expand Down Expand Up @@ -690,7 +690,7 @@ def concat(
"""
import pyarrow as pa

from ray.air.util.tensor_extensions.arrow import ArrowConversionError
from ray.data._internal.tensor_extensions.arrow import ArrowConversionError
from ray.data.extensions import (
ArrowPythonObjectType,
get_arrow_extension_tensor_types,
Expand Down Expand Up @@ -910,7 +910,7 @@ def combine_chunked_array(

import pyarrow as pa

from ray.air.util.transform_pyarrow import (
from ray.data._internal.utils.transform_pyarrow import (
_concatenate_extension_column,
_is_pa_extension_type,
)
Expand Down Expand Up @@ -993,7 +993,7 @@ def _try_combine_chunks_safe(

import pyarrow as pa

from ray.air.util.transform_pyarrow import _is_pa_extension_type
from ray.data._internal.utils.transform_pyarrow import _is_pa_extension_type

assert not _is_pa_extension_type(
array.type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from typing import TYPE_CHECKING, Iterable, List, Optional, Union

from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data._internal.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data._internal.util import _check_pyarrow_version
from ray.data.block import Block, BlockAccessor, BlockMetadata
from ray.data.dataset import Dataset
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/json_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pandas as pd

from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data._internal.pandas_block import PandasBlockAccessor
from ray.data._internal.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data.context import DataContext
from ray.data.datasource.file_based_datasource import FileBasedDatasource

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pyarrow

from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data._internal.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/execution/operators/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type

from ray._private.arrow_utils import get_pyarrow_version
from ray.air.util.transform_pyarrow import _is_pa_extension_type
from ray.data._internal.arrow_block import ArrowBlockAccessor, ArrowBlockBuilder
from ray.data._internal.arrow_ops.transform_pyarrow import (
MIN_PYARROW_VERSION_RUN_END_ENCODED_TYPES,
Expand All @@ -17,6 +16,7 @@
)
from ray.data._internal.logical.operators.join_operator import JoinType
from ray.data._internal.util import GiB, MiB
from ray.data._internal.utils.transform_pyarrow import _is_pa_extension_type
from ray.data.block import Block
from ray.data.context import DataContext

Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from ray.air.util.tensor_extensions.utils import (
from ray.data._internal.tensor_extensions.utils import (
create_ragged_ndarray,
is_ndarray_like,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pyarrow as pa
from packaging.version import parse as parse_version

import ray.air.util.object_extensions.pandas
import ray.data._internal.object_extensions.pandas
from ray._common.serialization import pickle_dumps
from ray._private.arrow_utils import _check_pyarrow_version, get_pyarrow_version
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -67,7 +67,7 @@ def to_pandas_dtype(self):
to the Arrow type. See https://pandas.pydata.org/docs/development/extending.html
for more information.
"""
return ray.air.util.object_extensions.pandas.PythonObjectDtype()
return ray.data._internal.object_extensions.pandas.PythonObjectDtype()

def __reduce__(self):
# Earlier PyArrow versions require custom pickling behavior.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pandas._libs import lib
from pandas._typing import ArrayLike, Dtype, PositionalIndexer, TakeIndexer, npt

import ray.air.util.object_extensions.arrow
import ray.data._internal.object_extensions.arrow
from ray.util.annotations import PublicAPI


Expand Down Expand Up @@ -76,7 +76,7 @@ def nbytes(self) -> int:
return self.values.nbytes

def __arrow_array__(self, type=None):
return ray.air.util.object_extensions.arrow.ArrowPythonObjectArray.from_objects(
return ray.data._internal.object_extensions.arrow.ArrowPythonObjectArray.from_objects(
self.values
)

Expand Down
Loading