Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/ray/data/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,20 @@ py_test(
],
)

py_test(
name = "test_iceberg_checkpoint",
size = "medium",
srcs = ["tests/test_iceberg_checkpoint.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_kafka",
size = "medium",
Expand Down
11 changes: 5 additions & 6 deletions python/ray/data/_internal/datasource/iceberg_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,7 @@ def __init__(
f"Removed '{invalid_param}' from overwrite_kwargs: {reason}"
)

if "name" in self._catalog_kwargs:
self._catalog_name = self._catalog_kwargs.pop("name")
else:
self._catalog_name = "default"
self._catalog_name = self._catalog_kwargs.get("name", "default")

self._table: "Table" = None
self._io: "FileIO" = None
Expand Down Expand Up @@ -176,10 +173,12 @@ def _with_retry(self, func: Callable, description: str) -> Any:
)

def _get_catalog(self) -> "Catalog":
from pyiceberg import catalog
from ray.data._internal.datasource.iceberg_datasource import (
_get_iceberg_catalog,
)

return self._with_retry(
lambda: catalog.load_catalog(self._catalog_name, **self._catalog_kwargs),
lambda: _get_iceberg_catalog(self._catalog_kwargs),
description=f"load Iceberg catalog '{self._catalog_name}'",
)

Expand Down
23 changes: 12 additions & 11 deletions python/ray/data/_internal/datasource/iceberg_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def _generate_tables() -> Iterable[pa.Table]:
)


def _get_iceberg_catalog(catalog_kwargs: Optional[Dict[str, Any]] = None) -> "Catalog":
from pyiceberg import catalog

catalog_kwargs = (catalog_kwargs or {}).copy()
catalog_name = catalog_kwargs.pop("name", "default")
return catalog.load_catalog(catalog_name, **catalog_kwargs)


@DeveloperAPI
class IcebergDatasource(Datasource):
"""
Expand Down Expand Up @@ -286,17 +294,12 @@ def __init__(
_check_import(self, module="pyiceberg", package="pyiceberg")
from pyiceberg.expressions import AlwaysTrue

self._scan_kwargs = scan_kwargs if scan_kwargs is not None else {}
self._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {}

if "name" in self._catalog_kwargs:
self._catalog_name = self._catalog_kwargs.pop("name")
else:
self._catalog_name = "default"
self._scan_kwargs = (scan_kwargs or {}).copy()
self._catalog_kwargs = (catalog_kwargs or {}).copy()

self.table_identifier = table_identifier

self._row_filter = row_filter if row_filter is not None else AlwaysTrue()

# Convert selected_fields to projection_map (identity mapping if specified)
# Note: Empty tuple () means no columns, None/"*" means all columns
if selected_fields is None or selected_fields == ("*",):
Expand All @@ -311,9 +314,7 @@ def __init__(
self._table = None

def _get_catalog(self) -> "Catalog":
from pyiceberg import catalog

return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs)
return _get_iceberg_catalog(self._catalog_kwargs)

@property
def table(self) -> "Table":
Expand Down
7 changes: 3 additions & 4 deletions python/ray/data/_internal/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,8 @@ def __init__(self):
def plan(self, logical_plan: LogicalPlan) -> PhysicalPlan:
"""Convert logical to physical operators recursively in post-order."""
checkpoint_config = logical_plan.context.checkpoint_config
if checkpoint_config is not None and self._check_supports_checkpointing(
logical_plan
):
supports_ckpt = self._check_supports_checkpointing(logical_plan)
if checkpoint_config is not None and supports_ckpt:
self._supports_checkpointing = True

checkpoint_callback = self._create_checkpoint_callback(checkpoint_config)
Expand All @@ -194,7 +193,7 @@ def plan(self, logical_plan: LogicalPlan) -> PhysicalPlan:
)

elif checkpoint_config is not None:
assert not self._check_supports_checkpointing(logical_plan)
assert not supports_ckpt
warnings.warn(
"You've enabled checkpointing, but the logical plan doesn't support "
"checkpointing. Checkpointing will be disabled."
Expand Down
80 changes: 56 additions & 24 deletions python/ray/data/checkpoint/checkpoint_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import List, Optional

import numpy
import pandas as pd
import pyarrow

import ray
from ray.data._internal.arrow_ops import transform_pyarrow
from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
from ray.data.block import Block, BlockAccessor, BlockMetadata, DataBatch, Schema
from ray.data.checkpoint import CheckpointConfig
from ray.data.checkpoint.interfaces import CheckpointBackend, CheckpointConfig
from ray.data.datasource import PathPartitionFilter
from ray.data.datasource.path_util import _unwrap_protocol
from ray.types import ObjectRef
Expand Down Expand Up @@ -62,23 +63,20 @@ class CheckpointLoader:

def __init__(
self,
checkpoint_path: str,
filesystem: pyarrow.fs.FileSystem,
id_column: str,
config: CheckpointConfig,
checkpoint_path_partition_filter: Optional[PathPartitionFilter] = None,
):
"""Initialize the CheckpointLoader.

Args:
checkpoint_path: The path to the checkpoint
filesystem: The filesystem to use
id_column: The name of the ID column
config: The checkpoint configuration
checkpoint_path_partition_filter: Filter for checkpoint files to load during
restoration when reading from `checkpoint_path`.
"""
self.checkpoint_path = checkpoint_path
self.filesystem = filesystem
self.id_column = id_column
self.ckpt_config = config
self.checkpoint_path = config.checkpoint_path
self.filesystem = config.filesystem
self.id_column = config.id_column
self.checkpoint_path_partition_filter = checkpoint_path_partition_filter

def load_checkpoint(self) -> ObjectRef[Block]:
Expand All @@ -89,12 +87,28 @@ def load_checkpoint(self) -> ObjectRef[Block]:
"""
start_t = time.time()

# Load the checkpoint data
checkpoint_ds: ray.data.Dataset = ray.data.read_parquet(
self.checkpoint_path,
filesystem=self.filesystem,
partition_filter=self.checkpoint_path_partition_filter,
)
if self.ckpt_config.backend == CheckpointBackend.ICEBERG:
from ray.data._internal.datasource.iceberg_datasource import (
_get_iceberg_catalog,
)

catalog = _get_iceberg_catalog(self.ckpt_config.catalog_kwargs)
if catalog.table_exists(self.checkpoint_path):
checkpoint_ds = ray.data.read_iceberg(
table_identifier=self.checkpoint_path,
selected_fields=(self.id_column,),
catalog_kwargs=self.ckpt_config.catalog_kwargs,
)
else:
arrow_tbl = pyarrow.Table.from_pydict({self.id_column: []})
checkpoint_ds = ray.data.from_arrow(arrow_tbl)
else:
# Load the checkpoint data
checkpoint_ds: ray.data.Dataset = ray.data.read_parquet(
self.checkpoint_path,
filesystem=self.filesystem,
partition_filter=self.checkpoint_path_partition_filter,
)

# Manually disable checkpointing for loading the checkpoint metadata
# to avoid recursively restoring checkpoints.
Expand Down Expand Up @@ -174,22 +188,32 @@ def _preprocess_data_pipeline(
class BatchBasedCheckpointFilter(CheckpointFilter):
"""CheckpointFilter for batch-based backends."""

def __init__(self, config: CheckpointConfig):
super().__init__(config)

self._loader = IdColumnCheckpointLoader(
config=config,
checkpoint_path_partition_filter=config.checkpoint_path_partition_filter,
)

def load_checkpoint(self) -> ObjectRef[Block]:
"""Load checkpointed ids as a sorted block.

Returns:
ObjectRef[Block]: ObjectRef to the checkpointed IDs block.
"""
loader = IdColumnCheckpointLoader(
checkpoint_path=self.checkpoint_path,
filesystem=self.filesystem,
id_column=self.id_column,
checkpoint_path_partition_filter=self.ckpt_config.checkpoint_path_partition_filter,
)
return loader.load_checkpoint()
return self._loader.load_checkpoint()

def delete_checkpoint(self) -> None:
self.filesystem.delete_dir(self.checkpoint_path_unwrapped)
if self.ckpt_config.backend == CheckpointBackend.ICEBERG:
from ray.data._internal.datasource.iceberg_datasource import (
_get_iceberg_catalog,
)

catalog = _get_iceberg_catalog(self.ckpt_config.catalog_kwargs)
catalog.drop_table(self.checkpoint_path)
else:
self.filesystem.delete_dir(self.checkpoint_path_unwrapped)

def filter_rows_for_block(
self,
Expand All @@ -210,6 +234,10 @@ def filter_rows_for_block(
if len(checkpointed_ids) == 0 or len(block) == 0:
return block

is_pandas_block = isinstance(block, pd.DataFrame)
if is_pandas_block:
block = pyarrow.Table.from_pandas(block)

assert isinstance(block, pyarrow.Table)
assert isinstance(checkpointed_ids, pyarrow.Table)

Expand Down Expand Up @@ -255,6 +283,10 @@ def filter_with_ckpt_chunk(ckpt_chunk: pyarrow.ChunkedArray) -> numpy.ndarray:
# Convert the final mask to a PyArrow array and filter the block.
mask_array = pyarrow.array(final_mask)
filtered_block = block.filter(mask_array)

if is_pandas_block:
return filtered_block.to_pandas()

return filtered_block

def filter_rows_for_batch(
Expand Down
Loading