diff --git a/doc/source/cluster/kubernetes/k8s-ecosystem.md b/doc/source/cluster/kubernetes/k8s-ecosystem.md index cc3b8fe2f390..8d8b2fb759fc 100644 --- a/doc/source/cluster/kubernetes/k8s-ecosystem.md +++ b/doc/source/cluster/kubernetes/k8s-ecosystem.md @@ -13,6 +13,7 @@ k8s-ecosystem/volcano k8s-ecosystem/yunikorn k8s-ecosystem/kueue k8s-ecosystem/istio +k8s-ecosystem/scheduler-plugins ``` * {ref}`kuberay-ingress` @@ -23,3 +24,4 @@ k8s-ecosystem/istio * {ref}`kuberay-yunikorn` * {ref}`kuberay-kueue` * {ref}`kuberay-istio` +* {ref}`kuberay-scheduler-plugins` diff --git a/doc/source/cluster/kubernetes/k8s-ecosystem/scheduler-plugins.md b/doc/source/cluster/kubernetes/k8s-ecosystem/scheduler-plugins.md new file mode 100644 index 000000000000..ed59aa8c52de --- /dev/null +++ b/doc/source/cluster/kubernetes/k8s-ecosystem/scheduler-plugins.md @@ -0,0 +1,60 @@ +(kuberay-scheduler-plugins)= +# KubeRay integration with scheduler plugins + +The [kubernetes-sigs/scheduler-plugins](https://github.com/kubernetes-sigs/scheduler-plugins) repository provides out-of-tree scheduler plugins based on the scheduler framework. + +Starting with KubeRay v1.4.0, KubeRay integrates with the [PodGroup API](https://github.com/kubernetes-sigs/scheduler-plugins/blob/93126eabdf526010bf697d5963d849eab7e8e898/site/content/en/docs/plugins/coscheduling.md) provided by scheduler plugins to support gang scheduling for RayCluster custom resources. + +## Step 1: Create a Kubernetes cluster with Kind + +```sh +kind create cluster --image=kindest/node:v1.26.0 +``` + +## Step 2: Install scheduler plugins + +Follow the [installation guide](https://scheduler-plugins.sigs.k8s.io/docs/user-guide/installation/) in the scheduler-plugins repository to install the scheduler plugins. + +:::{note} + +There are two modes for installing the scheduler plugins: *single scheduler mode* and *second scheduler mode*. + +KubeRay v1.4.0 only supports the *single scheduler mode*. +You need to have the access to configure Kubernetes control plane to replace the default scheduler with the scheduler plugins. + +::: + +## Step 3: Install KubeRay operator with scheduler plugins enabled + +KubeRay v1.4.0 and later versions support scheduler plugins. + +```sh +helm install kuberay-operator kuberay/kuberay-operator --version 1.4.0 --set batchScheduler.name=scheduler-plugins +``` + +## Step 4: Deploy a RayCluster with gang scheduling + +```sh +# Configure the RayCluster with label `ray.io/gang-scheduling-enabled: "true"` +# to enable gang scheduling. +kubectl apply -f https://raw.githubusercontent.com/ray-project/kuberay/release-1.4/ray-operator/config/samples/ray-cluster.scheduler-plugins.yaml +``` + +## Step 5: Verify Ray Pods and PodGroup + +Note that if you use "second scheduler mode," which KubeRay currently doesn't support, the following commands still show similar results. +However, the Ray Pods don't get scheduled in a gang scheduling manner. +Make sure to use "single scheduler mode" to enable gang scheduling. + +```sh +kubectl get podgroups.scheduling.x-k8s.io +# NAME PHASE MINMEMBER RUNNING SUCCEEDED FAILED AGE +# test-podgroup-0 Running 3 3 2m25s + +# All Ray Pods (1 head and 2 workers) belong to the same PodGroup. +kubectl get pods -L scheduling.x-k8s.io/pod-group +# NAME READY STATUS RESTARTS AGE POD-GROUP +# test-podgroup-0-head 1/1 Running 0 3m30s test-podgroup-0 +# test-podgroup-0-worker-worker-4vc6j 1/1 Running 0 3m30s test-podgroup-0 +# test-podgroup-0-worker-worker-ntm9f 1/1 Running 0 3m30s test-podgroup-0 +``` \ No newline at end of file diff --git a/doc/source/serve/production-guide/fault-tolerance.md b/doc/source/serve/production-guide/fault-tolerance.md index c4d19f856a9a..e5afbb295406 100644 --- a/doc/source/serve/production-guide/fault-tolerance.md +++ b/doc/source/serve/production-guide/fault-tolerance.md @@ -37,6 +37,16 @@ You can also use the deployment options to customize how frequently Serve runs t :language: python ``` +In this example, `check_health` raises an error if the connection to an external database is lost. The Serve controller periodically calls this method on each replica of the deployment. If the method raises an exception for a replica, Serve marks that replica as unhealthy and restarts it. Health checks are configured and performed on a per-replica basis. + +:::{note} +You shouldn't call ``check_health`` directly through a deployment handle (e.g., ``await deployment_handle.check_health.remote()``). This would invoke the health check on a single, arbitrary replica. The ``check_health`` method is designed as an interface for the Serve controller, not for direct user calls. +::: + +:::{note} +In a composable deployment graph, each deployment is responsible for its own health, independent of the other deployments it's bound to. For example, in an application defined by ``app = ParentDeployment.bind(ChildDeployment.bind())``, ``ParentDeployment`` doesn't restart if ``ChildDeployment`` replicas fail their health checks. When the ``ChildDeployment`` replicas recover, the handle in ``ParentDeployment`` updates automatically to route requests to the healthy replicas. +::: + ### Worker node recovery :::{admonition} KubeRay Required diff --git a/doc/source/train/examples/pytorch/distributing-pytorch/README.ipynb b/doc/source/train/examples/pytorch/distributing-pytorch/README.ipynb index 0d7fb52fbd67..4c914515c9a4 100644 --- a/doc/source/train/examples/pytorch/distributing-pytorch/README.ipynb +++ b/doc/source/train/examples/pytorch/distributing-pytorch/README.ipynb @@ -37,6 +37,7 @@ }, "outputs": [], "source": [ + "%%bash\n", "pip install torch torchvision" ] }, diff --git a/doc/source/train/examples/pytorch/distributing-pytorch/README.md b/doc/source/train/examples/pytorch/distributing-pytorch/README.md index 089c9ca8e5c8..a29e85c32485 100644 --- a/doc/source/train/examples/pytorch/distributing-pytorch/README.md +++ b/doc/source/train/examples/pytorch/distributing-pytorch/README.md @@ -16,7 +16,8 @@ In this step you train a PyTorch VisionTransformer model to recognize objects us First, install and import the required Python modules. -```python +```bash +%%bash pip install torch torchvision ``` diff --git a/python/ray/_private/telemetry/open_telemetry_metric_recorder.py b/python/ray/_private/telemetry/open_telemetry_metric_recorder.py index 04442af9a4e5..78c531650034 100644 --- a/python/ray/_private/telemetry/open_telemetry_metric_recorder.py +++ b/python/ray/_private/telemetry/open_telemetry_metric_recorder.py @@ -57,14 +57,46 @@ def callback(options): callbacks=[callback], ) self._registered_instruments[name] = instrument + self._observations_by_name[name] = {} + + def register_counter_metric(self, name: str, description: str) -> None: + """ + Register a counter metric with the given name and description. + """ + with self._lock: + if name in self._registered_instruments: + # Counter with the same name is already registered. + return + + instrument = self.meter.create_counter( + name=f"{NAMESPACE}_{name}", + description=description, + unit="1", + ) + self._registered_instruments[name] = instrument def set_metric_value(self, name: str, tags: dict, value: float): """ - Set the value of a metric with the given name and tags. - This will create a gauge if it does not exist. + Set the value of a metric with the given name and tags. If the metric is not + registered, it lazily records the value for observable metrics or is a no-op for + synchronous metrics. """ with self._lock: - self._observations_by_name[name][frozenset(tags.items())] = value + if self._observations_by_name.get(name) is not None: + # Set the value of an observable metric with the given name and tags. It + # lazily records the metric value by storing it in a dictionary until + # the value actually gets exported by OpenTelemetry. + self._observations_by_name[name][frozenset(tags.items())] = value + else: + # Set the value of a synchronous metric with the given name and tags. + # It is a no-op if the metric is not registered. + instrument = self._registered_instruments.get(name) + if isinstance(instrument, metrics.Counter): + instrument.add(value, attributes=tags) + else: + logger.warning( + f"Unsupported synchronous instrument type for metric: {name}." + ) def record_and_export(self, records: List[Record], global_tags=None): """ @@ -84,7 +116,7 @@ def record_and_export(self, records: List[Record], global_tags=None): f"Failed to record metric {gauge.name} with value {value} with tags {tags!r} and global tags {global_tags!r} due to: {e!r}" ) - def _get_metric_value(self, name: str, tags: dict) -> Optional[float]: + def _get_observable_metric_value(self, name: str, tags: dict) -> Optional[float]: """ Get the value of a metric with the given name and tags. This method is mainly used for testing purposes. diff --git a/python/ray/dashboard/modules/reporter/reporter_agent.py b/python/ray/dashboard/modules/reporter/reporter_agent.py index 652cff520995..65a098ec4913 100644 --- a/python/ray/dashboard/modules/reporter/reporter_agent.py +++ b/python/ray/dashboard/modules/reporter/reporter_agent.py @@ -570,10 +570,20 @@ async def Export( for resource_metrics in request.resource_metrics: for scope_metrics in resource_metrics.scope_metrics: for metric in scope_metrics.metrics: - self._open_telemetry_metric_recorder.register_gauge_metric( - metric.name, metric.description or "" - ) - for data_point in metric.gauge.data_points: + data_points = [] + # gauge metrics + if metric.WhichOneof("data") == "gauge": + self._open_telemetry_metric_recorder.register_gauge_metric( + metric.name, metric.description or "" + ) + data_points = metric.gauge.data_points + # counter metrics + if metric.WhichOneof("data") == "sum" and metric.sum.is_monotonic: + self._open_telemetry_metric_recorder.register_counter_metric( + metric.name, metric.description or "" + ) + data_points = metric.sum.data_points + for data_point in data_points: self._open_telemetry_metric_recorder.set_metric_value( metric.name, { diff --git a/python/ray/data/_internal/execution/operators/hash_shuffle.py b/python/ray/data/_internal/execution/operators/hash_shuffle.py index 9f1fdf376a02..fbc293d14872 100644 --- a/python/ray/data/_internal/execution/operators/hash_shuffle.py +++ b/python/ray/data/_internal/execution/operators/hash_shuffle.py @@ -4,6 +4,7 @@ import logging import math import threading +import time from collections import defaultdict, deque from dataclasses import dataclass from typing import ( @@ -404,6 +405,7 @@ def __init__( partition_size_hint=partition_size_hint, ) ), + data_context=data_context, ) self._input_block_transformer = input_block_transformer @@ -452,6 +454,7 @@ def start(self, options: ExecutionOptions) -> None: self._aggregator_pool.start() def _add_input_inner(self, input_bundle: RefBundle, input_index: int) -> None: + # TODO move to base class self._metrics.on_input_queued(input_bundle) try: @@ -961,11 +964,13 @@ def __init__( num_aggregators: int, aggregation_factory: StatefulShuffleAggregationFactory, aggregator_ray_remote_args: Dict[str, Any], + data_context: DataContext, ): assert ( num_partitions >= 1 ), f"Number of partitions has to be >= 1 (got {num_partitions})" + self._data_context = data_context self._num_partitions = num_partitions self._num_aggregators: int = num_aggregators self._aggregator_partition_map: Dict[ @@ -987,7 +992,24 @@ def __init__( self._aggregator_partition_map, ) + # Resource monitoring state + self._started_at: Optional[float] = None + + # Add last warning timestamp for health checks + self._last_health_warning_time: Optional[float] = None + self._health_warning_interval_s: float = ( + self._data_context.hash_shuffle_aggregator_health_warning_interval_s + ) + # Track readiness refs for non-blocking health checks + self._pending_aggregators_refs: Optional[List[ObjectRef]] = None + def start(self): + # Record start time for monitoring + self._started_at = time.time() + + # Check cluster resources before starting aggregators + self._check_cluster_resources() + for aggregator_id in range(self._num_aggregators): target_partition_ids = self._aggregator_partition_map[aggregator_id] @@ -999,6 +1021,147 @@ def start(self): self._aggregators.append(aggregator) + def _check_cluster_resources(self) -> None: + """Check if cluster has enough resources to schedule all aggregators. + Raises: + ValueError: If cluster doesn't have sufficient resources. + """ + try: + cluster_resources = ray.cluster_resources() + available_resources = ray.available_resources() + except Exception as e: + logger.warning(f"Failed to get cluster resources: {e}") + return + + # Calculate required resources for all aggregators + required_cpus = ( + self._aggregator_ray_remote_args.get("num_cpus", 1) * self._num_aggregators + ) + required_memory = ( + self._aggregator_ray_remote_args.get("memory", 0) * self._num_aggregators + ) + + # Check CPU resources + total_cpus = cluster_resources.get("CPU", 0) + available_cpus = available_resources.get("CPU", 0) + + if required_cpus > total_cpus: + logger.warning( + f"Insufficient CPU resources in cluster for hash shuffle operation. " + f"Required: {required_cpus} CPUs for {self._num_aggregators} aggregators, " + f"but cluster only has {total_cpus} total CPUs. " + f"Consider either increasing the cluster size or reducing the number of aggregators via `DataContext.max_hash_shuffle_aggregators`." + ) + + if required_cpus > available_cpus: + logger.warning( + f"Limited available CPU resources for hash shuffle operation. " + f"Required: {required_cpus} CPUs, available: {available_cpus} CPUs. " + f"Aggregators may take longer to start due to contention for resources." + ) + + # Check memory resources if specified + if required_memory > 0: + total_memory = cluster_resources.get("memory", 0) + available_memory = available_resources.get("memory", 0) + + if required_memory > total_memory: + logger.warning( + f"Insufficient memory resources in cluster for hash shuffle operation. " + f"Required: {required_memory / GiB:.2f} GiB for {self._num_aggregators} aggregators, " + f"but cluster only has {total_memory / GiB:.2f} GiB total memory. " + f"Consider reducing the number of partitions or increasing cluster size." + ) + + if required_memory > available_memory: + logger.warning( + f"Limited available memory resources for hash shuffle operation. " + f"Required: {required_memory / GiB:.2f} GiB, available: {available_memory / GiB:.2f} GiB. " + f"Aggregators may take longer to start due to resource contention." + ) + + logger.debug( + f"Resource check passed for hash shuffle operation: " + f"required CPUs={required_cpus}, available CPUs={available_cpus}, " + f"required memory={required_memory / GiB:.2f} GiB, available memory={available_memory / GiB:.2f} GiB" + ) + + def _check_aggregator_health(self) -> None: + """Check if all aggregators are up and running after a timeout period. + Uses non-blocking ray.wait to check actor readiness. + Will warn every 10 seconds (configurable via `DataContext.hash_shuffle_aggregator_health_warning_interval_s`) if aggregators remain unhealthy. + """ + min_wait_time = self._data_context.min_hash_shuffle_aggregator_wait_time_in_s + if self._started_at is None or time.time() - self._started_at < min_wait_time: + return + + try: + # Initialize readiness refs the first time. + if self._pending_aggregators_refs is None: + self._pending_aggregators_refs = [ + aggregator.__ray_ready__.remote() + for aggregator in self._aggregators + ] + + if len(self._pending_aggregators_refs) == 0: + self._last_health_warning_time = None + logger.debug( + f"All {self._num_aggregators} hash shuffle aggregators " + f"are now healthy" + ) + return + + # Use ray.wait to check readiness in non-blocking fashion + _, unready_refs = ray.wait( + self._pending_aggregators_refs, + num_returns=len(self._pending_aggregators_refs), + timeout=0, # Short timeout to avoid blocking + ) + + # Update readiness refs to only track the unready ones + self._pending_aggregators_refs = unready_refs + + current_time = time.time() + should_warn = unready_refs and ( # If any refs are not ready + self._last_health_warning_time is None + or current_time - self._last_health_warning_time + >= self._health_warning_interval_s + ) + + if should_warn: + # Get cluster resource information for better diagnostics + available_resources = ray.available_resources() + available_cpus = available_resources.get("CPU", 0) + cluster_resources = ray.cluster_resources() + total_memory = cluster_resources.get("memory", 0) + available_memory = available_resources.get("memory", 0) + + required_cpus = ( + self._aggregator_ray_remote_args.get("num_cpus", 1) + * self._num_aggregators + ) + + ready_aggregators = self._num_aggregators - len(unready_refs) + + logger.warning( + f"Only {ready_aggregators} out of {self._num_aggregators} hash-shuffle aggregators are ready after {min_wait_time:.1f} secs. " + f"This might indicate resource contention for cluster resources (available CPUs: {available_cpus}, required CPUs: {required_cpus}). " + f"Cluster only has {available_memory / GiB:.2f} GiB available memory, {total_memory / GiB:.2f} GiB total memory. " + f"Consider increasing cluster size or reducing the number of aggregators via `DataContext.max_hash_shuffle_aggregators`. " + f"Will continue checking every {self._health_warning_interval_s}s." + ) + self._last_health_warning_time = current_time + elif not unready_refs and self._last_health_warning_time is not None: + # All aggregators are ready + self._last_health_warning_time = None + logger.debug( + f"All {self._num_aggregators} hash shuffle aggregators " + f"are now healthy" + ) + + except Exception as e: + logger.warning(f"Failed to check aggregator health: {e}") + @property def num_partitions(self): return self._num_partitions @@ -1008,6 +1171,7 @@ def num_aggregators(self): return self._num_aggregators def get_aggregator_for_partition(self, partition_id: int) -> ActorHandle: + self._check_aggregator_health() return self._aggregators[self._get_aggregator_id_for_partition(partition_id)] def _allocate_partitions(self, *, num_partitions: int): diff --git a/python/ray/data/_internal/logical/optimizers.py b/python/ray/data/_internal/logical/optimizers.py index 103c3bf9ea40..071e1edde07d 100644 --- a/python/ray/data/_internal/logical/optimizers.py +++ b/python/ray/data/_internal/logical/optimizers.py @@ -75,9 +75,9 @@ def get_execution_plan(logical_plan: LogicalPlan) -> PhysicalPlan: (2) planning: convert logical to physical operators. (3) physical optimization: optimize physical operators. """ - from ray.data._internal.planner.planner import Planner + from ray.data._internal.planner import create_planner optimized_logical_plan = LogicalOptimizer().optimize(logical_plan) logical_plan._dag = optimized_logical_plan.dag - physical_plan = Planner().plan(optimized_logical_plan) + physical_plan = create_planner().plan(optimized_logical_plan) return PhysicalOptimizer().optimize(physical_plan) diff --git a/python/ray/data/_internal/planner/__init__.py b/python/ray/data/_internal/planner/__init__.py index e69de29bb2d1..f25267b9cf82 100644 --- a/python/ray/data/_internal/planner/__init__.py +++ b/python/ray/data/_internal/planner/__init__.py @@ -0,0 +1,14 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ray.data._internal.planner.planner import Planner + + +def create_planner() -> "Planner": + # Import here to avoid circular import. + from ray.data._internal.planner.planner import Planner + + return Planner() + + +__all__ = ["create_planner"] diff --git a/python/ray/data/_internal/planner/planner.py b/python/ray/data/_internal/planner/planner.py index 89425e066f5c..de55604a83e7 100644 --- a/python/ray/data/_internal/planner/planner.py +++ b/python/ray/data/_internal/planner/planner.py @@ -18,7 +18,7 @@ ] # A list of registered plan functions for logical operators. -PLAN_LOGICAL_OP_FNS: List[Tuple[Type[LogicalOperator], PlanLogicalOpFn]] = [] +_PLAN_LOGICAL_OP_FNS: Dict[Type[LogicalOperator], PlanLogicalOpFn] = {} @DeveloperAPI @@ -27,7 +27,12 @@ def register_plan_logical_op_fn( plan_fn: PlanLogicalOpFn, ): """Register a plan function for a logical operator type.""" - PLAN_LOGICAL_OP_FNS.append((logical_op_type, plan_fn)) + _PLAN_LOGICAL_OP_FNS[logical_op_type] = plan_fn + + +@DeveloperAPI +def get_plan_logical_op_fns(): + return _PLAN_LOGICAL_OP_FNS.copy() def _register_default_plan_logical_op_fns(): @@ -163,52 +168,68 @@ class Planner: done by physical optimizer. """ - def __init__(self): - self._physical_op_to_logical_op: Dict[PhysicalOperator, LogicalOperator] = {} - def plan(self, logical_plan: LogicalPlan) -> PhysicalPlan: """Convert logical to physical operators recursively in post-order.""" - physical_dag = self._plan(logical_plan.dag, logical_plan.context) - physical_plan = PhysicalPlan( - physical_dag, - self._physical_op_to_logical_op, - logical_plan.context, + plan_fns = get_plan_logical_op_fns() + physical_dag, op_map = plan_recursively( + logical_plan.dag, plan_fns, logical_plan.context ) + physical_plan = PhysicalPlan(physical_dag, op_map, logical_plan.context) return physical_plan - def _plan( - self, logical_op: LogicalOperator, data_context: DataContext - ) -> PhysicalOperator: - # Plan the input dependencies first. - physical_children = [] - for child in logical_op.input_dependencies: - physical_children.append(self._plan(child, data_context)) - - physical_op = None - for op_type, plan_fn in PLAN_LOGICAL_OP_FNS: - if isinstance(logical_op, op_type): - # We will call `set_logical_operators()` in the following for-loop, - # no need to do it here. - physical_op = plan_fn(logical_op, physical_children, data_context) - break - - if physical_op is None: - raise ValueError( - f"Found unknown logical operator during planning: {logical_op}" - ) - - # Traverse up the DAG, and set the mapping from physical to logical operators. - # At this point, all physical operators without logical operators set - # must have been created by the current logical operator. - queue = [physical_op] - while queue: - curr_physical_op = queue.pop() - # Once we find an operator with a logical operator set, we can stop. - if curr_physical_op._logical_operators: - break - - curr_physical_op.set_logical_operators(logical_op) - queue.extend(physical_op.input_dependencies) - - self._physical_op_to_logical_op[physical_op] = logical_op - return physical_op + +@DeveloperAPI +def plan_recursively( + logical_op: LogicalOperator, + plan_fns: Dict[Type[LogicalOperator], PlanLogicalOpFn], + data_context: DataContext, +) -> Tuple[PhysicalOperator, Dict[PhysicalOperator, LogicalOperator]]: + """Plan a logical operator and its input dependencies recursively. + + Args: + logical_op: The logical operator to plan. + plan_fns: A dictionary of planning functions for different logical operator + types. + data_context: The data context. + + Returns: + A tuple of the physical operator corresponding to the logical operator, and + a mapping from physical to logical operators. + """ + op_map: Dict[PhysicalOperator, LogicalOperator] = {} + + # Plan the input dependencies first. + physical_children = [] + for child in logical_op.input_dependencies: + physical_child, child_op_map = plan_recursively(child, plan_fns, data_context) + physical_children.append(physical_child) + op_map.update(child_op_map) + + physical_op = None + for op_type, plan_fn in plan_fns.items(): + if isinstance(logical_op, op_type): + # We will call `set_logical_operators()` in the following for-loop, + # no need to do it here. + physical_op = plan_fn(logical_op, physical_children, data_context) + break + + if physical_op is None: + raise ValueError( + f"Found unknown logical operator during planning: {logical_op}" + ) + + # Traverse up the DAG, and set the mapping from physical to logical operators. + # At this point, all physical operators without logical operators set + # must have been created by the current logical operator. + queue = [physical_op] + while queue: + curr_physical_op = queue.pop() + # Once we find an operator with a logical operator set, we can stop. + if curr_physical_op._logical_operators: + break + + curr_physical_op.set_logical_operators(logical_op) + queue.extend(physical_op.input_dependencies) + + op_map[physical_op] = logical_op + return physical_op, op_map diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 03c0406e91cc..a2979db84faf 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -202,6 +202,14 @@ class ShuffleStrategy(str, enum.Enum): int(os.environ.get("RAY_DATA_PER_NODE_METRICS", "0")) ) +DEFAULT_MIN_HASH_SHUFFLE_AGGREGATOR_WAIT_TIME_IN_S = env_integer( + "RAY_DATA_MIN_HASH_SHUFFLE_AGGREGATOR_WAIT_TIME_IN_S", 300 +) + +DEFAULT_HASH_SHUFFLE_AGGREGATOR_HEALTH_WARNING_INTERVAL_S = env_integer( + "RAY_DATA_HASH_SHUFFLE_AGGREGATOR_HEALTH_WARNING_INTERVAL_S", 30 +) + def _execution_options_factory() -> "ExecutionOptions": # Lazily import to avoid circular dependencies. @@ -369,6 +377,15 @@ class DataContext: # # When unset defaults to `DataContext.min_parallelism` max_hash_shuffle_aggregators: Optional[int] = DEFAULT_MAX_HASH_SHUFFLE_AGGREGATORS + + min_hash_shuffle_aggregator_wait_time_in_s: int = ( + DEFAULT_MIN_HASH_SHUFFLE_AGGREGATOR_WAIT_TIME_IN_S + ) + + hash_shuffle_aggregator_health_warning_interval_s: int = ( + DEFAULT_HASH_SHUFFLE_AGGREGATOR_HEALTH_WARNING_INTERVAL_S + ) + # Max number of *concurrent* hash-shuffle finalization tasks running # at the same time. This config is helpful to control concurrency of # finalization tasks to prevent single aggregator running multiple tasks diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 3c685c8e9414..c224889dea20 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -47,8 +47,8 @@ from ray.data._internal.logical.rules.configure_map_task_memory import ( ConfigureMapTaskMemoryUsingOutputSize, ) +from ray.data._internal.planner import create_planner from ray.data._internal.planner.exchange.sort_task_spec import SortKey -from ray.data._internal.planner.planner import Planner from ray.data._internal.stats import DatasetStats from ray.data.aggregate import Count from ray.data.block import BlockMetadata @@ -77,7 +77,7 @@ def _check_valid_plan_and_result( def test_read_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() op = get_parquet_read_logical_op() plan = LogicalPlan(op, ctx) physical_op = planner.plan(plan).dag @@ -113,7 +113,7 @@ def read_fn(): def test_split_blocks_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() op = get_parquet_read_logical_op(parallelism=10) logical_plan = LogicalPlan(op, ctx) physical_plan = planner.plan(logical_plan) @@ -156,7 +156,7 @@ def test_from_operators(ray_start_regular_shared_2_cpus): FromPandas, ] for op_cls in op_classes: - planner = Planner() + planner = create_planner() op = op_cls([], []) plan = LogicalPlan(op, ctx) physical_op = planner.plan(plan).dag @@ -227,7 +227,7 @@ def method(self, x): def test_map_batches_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = MapBatches( read_op, @@ -255,7 +255,7 @@ def test_map_batches_e2e(ray_start_regular_shared_2_cpus): def test_map_rows_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = MapRows( read_op, @@ -282,7 +282,7 @@ def test_map_rows_e2e(ray_start_regular_shared_2_cpus): def test_filter_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = Filter( read_op, @@ -324,7 +324,7 @@ def test_project_operator_select(ray_start_regular_shared_2_cpus): assert isinstance(op, Project), op.name assert op.cols == cols - physical_plan = Planner().plan(logical_plan) + physical_plan = create_planner().plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) physical_op = physical_plan.dag assert isinstance(physical_op, TaskPoolMapOperator) @@ -348,7 +348,7 @@ def test_project_operator_rename(ray_start_regular_shared_2_cpus): assert not op.cols assert op.cols_rename == cols_rename - physical_plan = Planner().plan(logical_plan) + physical_plan = create_planner().plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) physical_op = physical_plan.dag assert isinstance(physical_op, TaskPoolMapOperator) @@ -358,7 +358,7 @@ def test_project_operator_rename(ray_start_regular_shared_2_cpus): def test_flat_map(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = FlatMap( read_op, @@ -423,7 +423,7 @@ def ensure_sample_size_close(dataset, sample_percent=0.5): def test_random_shuffle_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = RandomShuffle( read_op, @@ -462,7 +462,7 @@ def test_random_shuffle_e2e(ray_start_regular_shared_2_cpus, configure_shuffle_m def test_repartition_operator(ray_start_regular_shared_2_cpus, shuffle): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = Repartition(read_op, num_outputs=5, shuffle=shuffle) plan = LogicalPlan(op, ctx) @@ -543,7 +543,7 @@ def test_write_operator(ray_start_regular_shared_2_cpus, tmp_path): ctx = DataContext.get_current() concurrency = 2 - planner = Planner() + planner = create_planner() datasink = ParquetDatasink(tmp_path) read_op = get_parquet_read_logical_op() op = Write( @@ -569,7 +569,7 @@ def test_sort_operator( ): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = Sort( read_op, @@ -708,7 +708,7 @@ def test_batch_format_on_aggregate(ray_start_regular_shared_2_cpus): def test_aggregate_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = Aggregate( read_op, @@ -778,7 +778,7 @@ def test_aggregate_validate_keys(ray_start_regular_shared_2_cpus): def test_zip_operator(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op1 = get_parquet_read_logical_op() read_op2 = get_parquet_read_logical_op() op = Zip(read_op1, read_op2) diff --git a/python/ray/data/tests/test_operator_fusion.py b/python/ray/data/tests/test_operator_fusion.py index d6dd69b092a6..389ea3990fea 100644 --- a/python/ray/data/tests/test_operator_fusion.py +++ b/python/ray/data/tests/test_operator_fusion.py @@ -25,7 +25,7 @@ from ray.data._internal.logical.operators.read_operator import Read from ray.data._internal.logical.optimizers import PhysicalOptimizer, get_execution_plan from ray.data._internal.plan import ExecutionPlan -from ray.data._internal.planner.planner import Planner +from ray.data._internal.planner import create_planner from ray.data._internal.stats import DatasetStats from ray.data.context import DataContext from ray.data.tests.conftest import * # noqa @@ -38,7 +38,7 @@ def test_read_map_batches_operator_fusion(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() # Test that Read is fused with MapBatches. - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op(parallelism=1) op = MapBatches( read_op, @@ -67,7 +67,7 @@ def test_read_map_chain_operator_fusion(ray_start_regular_shared_2_cpus): ctx = DataContext.get_current() # Test that a chain of different map operators are fused. - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op(parallelism=1) map1 = MapRows(read_op, lambda x: x) map2 = MapBatches(map1, lambda x: x) @@ -117,7 +117,7 @@ def test_read_map_batches_operator_fusion_compatible_remote_args( ({"scheduling_strategy": "SPREAD"}, {}), ] for up_remote_args, down_remote_args in compatiple_remote_args_pairs: - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op( ray_remote_args={"resources": {"non-existent": 1}}, parallelism=1, @@ -164,7 +164,7 @@ def test_read_map_batches_operator_fusion_incompatible_remote_args( ({"scheduling_strategy": "SPREAD"}, {"scheduling_strategy": "PACK"}), ] for up_remote_args, down_remote_args in incompatible_remote_args_pairs: - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op( ray_remote_args={"resources": {"non-existent": 1}} ) @@ -198,7 +198,7 @@ def test_read_map_batches_operator_fusion_compute_tasks_to_actors( # Test that a task-based map operator is fused into an actor-based map operator when # the former comes before the latter. - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op(parallelism=1) op = MapBatches(read_op, lambda x: x) op = MapBatches(op, lambda x: x, compute=ray.data.ActorPoolStrategy()) @@ -220,7 +220,7 @@ def test_read_map_batches_operator_fusion_compute_read_to_actors( ctx = DataContext.get_current() # Test that reads fuse into an actor-based map operator. - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op(parallelism=1) op = MapBatches(read_op, lambda x: x, compute=ray.data.ActorPoolStrategy()) logical_plan = LogicalPlan(op, ctx) @@ -241,7 +241,7 @@ def test_read_map_batches_operator_fusion_incompatible_compute( ctx = DataContext.get_current() # Test that map operators are not fused when compute strategies are incompatible. - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op(parallelism=1) op = MapBatches(read_op, lambda x: x, compute=ray.data.ActorPoolStrategy()) op = MapBatches(op, lambda x: x) @@ -715,7 +715,7 @@ def test_zero_copy_fusion_eliminate_build_output_blocks( ctx = DataContext.get_current() # Test the EliminateBuildOutputBlocks optimization rule. - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = MapBatches(read_op, lambda x: x) logical_plan = LogicalPlan(op, ctx) diff --git a/python/ray/data/tests/test_randomize_block_order.py b/python/ray/data/tests/test_randomize_block_order.py index c03d47a125a1..ef099af92503 100644 --- a/python/ray/data/tests/test_randomize_block_order.py +++ b/python/ray/data/tests/test_randomize_block_order.py @@ -14,7 +14,7 @@ from ray.data._internal.logical.operators.read_operator import Read from ray.data._internal.logical.optimizers import LogicalOptimizer from ray.data._internal.logical.rules.randomize_blocks import ReorderRandomizeBlocksRule -from ray.data._internal.planner.planner import Planner +from ray.data._internal.planner import create_planner from ray.data.context import DataContext from ray.data.tests.test_util import get_parquet_read_logical_op from ray.data.tests.util import extract_values @@ -23,7 +23,7 @@ def test_randomize_blocks_operator(ray_start_regular_shared): ctx = DataContext.get_current() - planner = Planner() + planner = create_planner() read_op = get_parquet_read_logical_op() op = RandomizeBlocks( read_op, diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index 65825fcbd16c..c09d6e122796 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -305,7 +305,16 @@ async def fetch_until_disconnect(self): pickled_messages = await self._receive_asgi_messages( self._request_metadata ) - for message in pickle.loads(pickled_messages): + if isinstance(pickled_messages, bytes): + messages = pickle.loads(pickled_messages) + else: + messages = ( + pickled_messages + if isinstance(pickled_messages, list) + else [pickled_messages] + ) + + for message in messages: self._queue.put_nowait(message) if message["type"] in {"http.disconnect", "websocket.disconnect"}: diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index b4b519475bce..dc2cdedb6fea 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -88,6 +88,7 @@ ) from ray.serve._private.version import DeploymentVersion from ray.serve.config import AutoscalingConfig +from ray.serve.context import _get_in_flight_requests from ray.serve.deployment import Deployment from ray.serve.exceptions import ( BackPressureError, @@ -891,17 +892,22 @@ async def _on_initialized(self): self._initialization_latency = time.time() - self._initialization_start_time def _on_request_cancelled( - self, request_metadata: RequestMetadata, e: asyncio.CancelledError + self, metadata: RequestMetadata, e: asyncio.CancelledError ): """Recursively cancels child requests.""" requests_pending_assignment = ( ray.serve.context._get_requests_pending_assignment( - request_metadata.internal_request_id + metadata.internal_request_id ) ) for task in requests_pending_assignment.values(): task.cancel() + # Cancel child requests that have already been assigned. + in_flight_requests = _get_in_flight_requests(metadata.internal_request_id) + for replica_result in in_flight_requests.values(): + replica_result.cancel() + def _on_request_failed(self, request_metadata: RequestMetadata, e: Exception): if ray.util.pdb._is_ray_debugger_post_mortem_enabled(): ray.util.pdb._post_mortem() diff --git a/python/ray/serve/_private/replica_result.py b/python/ray/serve/_private/replica_result.py index b3429f6fa5bc..780b4a44bd45 100644 --- a/python/ray/serve/_private/replica_result.py +++ b/python/ray/serve/_private/replica_result.py @@ -8,7 +8,7 @@ import ray from ray.serve._private.common import RequestMetadata -from ray.serve._private.utils import calculate_remaining_timeout +from ray.serve._private.utils import calculate_remaining_timeout, generate_request_id from ray.serve.exceptions import RequestCancelledError @@ -75,6 +75,19 @@ def __init__( self._obj_ref_gen is not None ), "An ObjectRefGenerator must be passed for streaming requests." + request_context = ray.serve.context._get_serve_request_context() + if request_context.cancel_on_parent_request_cancel: + # Keep track of in-flight requests. + self._response_id = generate_request_id() + ray.serve.context._add_in_flight_request( + request_context._internal_request_id, self._response_id, self + ) + self.add_done_callback( + lambda _: ray.serve.context._remove_in_flight_request( + request_context._internal_request_id, self._response_id + ) + ) + @property def _object_ref_or_gen_asyncio_lock(self) -> asyncio.Lock: """Lazy `asyncio.Lock` object.""" @@ -96,7 +109,7 @@ async def async_wrapper(self, *args, **kwargs): try: return await f(self, *args, **kwargs) except ray.exceptions.TaskCancelledError: - raise RequestCancelledError(self._request_id) + raise asyncio.CancelledError() if inspect.iscoroutinefunction(f): return async_wrapper diff --git a/python/ray/serve/_private/test_utils.py b/python/ray/serve/_private/test_utils.py index d717ab1dfdf8..9b42db4e2661 100644 --- a/python/ray/serve/_private/test_utils.py +++ b/python/ray/serve/_private/test_utils.py @@ -739,11 +739,13 @@ def get_application_urls( for target_group in target_groups: for target in target_group.targets: if protocol == RequestProtocol.HTTP: - urls.append(f"http://{target.ip}:{target.port}{route_prefix}") + url = f"http://{target.ip}:{target.port}{route_prefix}" elif protocol == RequestProtocol.GRPC: - urls.append(f"{target.ip}:{target.port}") + url = f"{target.ip}:{target.port}" else: raise ValueError(f"Unsupported protocol: {protocol}") + url = url.rstrip("/") + urls.append(url) return urls diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index 736e80e7bb8d..ecb412aab37b 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -20,6 +20,7 @@ SERVE_LOGGER_NAME, SERVE_NAMESPACE, ) +from ray.serve._private.replica_result import ReplicaResult from ray.serve.exceptions import RayServeException from ray.serve.grpc_util import RayServegRPCContext from ray.util.annotations import DeveloperAPI @@ -179,6 +180,7 @@ class _RequestContext: multiplexed_model_id: str = "" grpc_context: Optional[RayServegRPCContext] = None is_http_request: bool = False + cancel_on_parent_request_cancel: bool = False _serve_request_context = contextvars.ContextVar( @@ -260,3 +262,37 @@ def _remove_request_pending_assignment(parent_request_id: str, response_id: str) if len(_requests_pending_assignment[parent_request_id]) == 0: del _requests_pending_assignment[parent_request_id] + + +# `_in_flight_requests` is a map from request ID to a dictionary of replica results. +# The request ID points to an ongoing Serve request, and the replica results are +# in-flight child requests that have been assigned to a downstream replica. + +# A dictionary is used over a set to track the replica results for more +# efficient addition and deletion time complexity. A uniquely generated +# `response_id` is used to identify each replica result. + +_in_flight_requests: Dict[str, Dict[str, ReplicaResult]] = defaultdict(dict) + +# Note that the functions below that manipulate `_in_flight_requests` +# are NOT thread-safe. They are only expected to be called from the +# same thread/asyncio event-loop. + + +def _get_in_flight_requests(parent_request_id): + if parent_request_id in _in_flight_requests: + return _in_flight_requests[parent_request_id] + + return {} + + +def _add_in_flight_request(parent_request_id, response_id, replica_result): + _in_flight_requests[parent_request_id][response_id] = replica_result + + +def _remove_in_flight_request(parent_request_id, response_id): + if response_id in _in_flight_requests[parent_request_id]: + del _in_flight_requests[parent_request_id][response_id] + + if len(_in_flight_requests[parent_request_id]) == 0: + del _in_flight_requests[parent_request_id] diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 6f140c3e93fb..a3c686553f99 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -279,15 +279,9 @@ async def _fetch_future_result_async(self) -> ReplicaResult: if self._replica_result is None: # Use `asyncio.wrap_future` so `self._replica_result_future` can be awaited # safely from any asyncio loop. - try: - self._replica_result = await asyncio.wrap_future( - self._replica_result_future - ) - except asyncio.CancelledError: - if self._cancelled: - raise RequestCancelledError(self.request_id) from None - else: - raise asyncio.CancelledError from None + self._replica_result = await asyncio.wrap_future( + self._replica_result_future + ) return self._replica_result @@ -408,9 +402,15 @@ async def __call__(self, start: int) -> int: def __await__(self): """Yields the final result of the deployment handle call.""" - replica_result = yield from self._fetch_future_result_async().__await__() - result = yield from replica_result.get_async().__await__() - return result + try: + replica_result = yield from self._fetch_future_result_async().__await__() + result = yield from replica_result.get_async().__await__() + return result + except asyncio.CancelledError: + if self._cancelled: + raise RequestCancelledError(self.request_id) from None + else: + raise asyncio.CancelledError from None def __reduce__(self): raise RayServeException( @@ -571,8 +571,14 @@ def __aiter__(self) -> AsyncIterator[Any]: return self async def __anext__(self) -> Any: - replica_result = await self._fetch_future_result_async() - return await replica_result.__anext__() + try: + replica_result = await self._fetch_future_result_async() + return await replica_result.__anext__() + except asyncio.CancelledError: + if self._cancelled: + raise RequestCancelledError(self.request_id) from None + else: + raise asyncio.CancelledError from None def __iter__(self) -> Iterator[Any]: return self diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 2415894d9617..1d12ad801140 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -1119,9 +1119,9 @@ def f(): serve.run(f.bind()) controller_details = ray.get(serve_instance._controller.get_actor_details.remote()) node_ip = controller_details.node_ip - assert get_application_urls() == [f"http://{node_ip}:8000/"] + assert get_application_urls() == [f"http://{node_ip}:8000"] assert get_application_urls("gRPC") == [f"{node_ip}:9000"] - assert get_application_urls(RequestProtocol.HTTP) == [f"http://{node_ip}:8000/"] + assert get_application_urls(RequestProtocol.HTTP) == [f"http://{node_ip}:8000"] assert get_application_urls(RequestProtocol.GRPC) == [f"{node_ip}:9000"] @@ -1133,7 +1133,7 @@ def f(): serve.run(f.bind(), name="app1", route_prefix="/") controller_details = ray.get(serve_instance._controller.get_actor_details.remote()) node_ip = controller_details.node_ip - assert get_application_urls("HTTP", app_name="app1") == [f"http://{node_ip}:8000/"] + assert get_application_urls("HTTP", app_name="app1") == [f"http://{node_ip}:8000"] assert get_application_urls("gRPC", app_name="app1") == [f"{node_ip}:9000"] diff --git a/python/ray/serve/tests/test_fastapi.py b/python/ray/serve/tests/test_fastapi.py index d784247214da..f8b749124ca2 100644 --- a/python/ray/serve/tests/test_fastapi.py +++ b/python/ray/serve/tests/test_fastapi.py @@ -30,6 +30,7 @@ from ray.serve._private.client import ServeControllerClient from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME from ray.serve._private.http_util import make_fastapi_class_based_view +from ray.serve._private.test_utils import get_application_url from ray.serve.exceptions import RayServeException from ray.serve.handle import DeploymentHandle @@ -48,10 +49,12 @@ class FastAPIApp: serve.run(FastAPIApp.bind()) - resp = httpx.get("http://localhost:8000/100") + url = get_application_url("HTTP") + + resp = httpx.get(f"{url}/100") assert resp.json() == {"result": 100} - resp = httpx.get("http://localhost:8000/not-number") + resp = httpx.get(f"{url}/not-number") assert resp.status_code == 422 # Unprocessable Entity # Pydantic 1.X returns `type_error.integer`, 2.X returns `int_parsing`. assert resp.json()["detail"][0]["type"] in {"type_error.integer", "int_parsing"} @@ -71,7 +74,8 @@ class App: serve.run(App.bind(), route_prefix="/api") - resp = httpx.get("http://localhost:8000/api/100") + url = get_application_url("HTTP") + resp = httpx.get(f"{url}/100") assert resp.json() == {"result": 100} @@ -102,11 +106,12 @@ def other(self, msg: str): serve.run(A.bind()) # Test HTTP calls. - resp = httpx.get("http://localhost:8000/calc/41") + url = get_application_url("HTTP") + resp = httpx.get(f"{url}/calc/41") assert resp.json() == 42 - resp = httpx.post("http://localhost:8000/calc/41") + resp = httpx.post(f"{url}/calc/41") assert resp.json() == 40 - resp = httpx.get("http://localhost:8000/other") + resp = httpx.get(f"{url}/other") assert resp.json() == "hello" # Test handle calls. @@ -257,7 +262,7 @@ class Worker: serve.run(Worker.bind()) - url = "http://localhost:8000" + url = get_application_url("HTTP") resp = httpx.get(f"{url}/") assert resp.status_code == 404 assert "x-process-time" in resp.headers @@ -351,7 +356,8 @@ class A: serve.run(A.bind(), route_prefix="/api") - assert httpx.get("http://localhost:8000/api/mounted/hi").json() == "world" + url = get_application_url("HTTP") + assert httpx.get(f"{url}/mounted/hi").json() == "world" def test_fastapi_init_lifespan_should_not_shutdown(serve_instance): @@ -413,15 +419,17 @@ def ignored(): serve.run(App1.bind(), name="app1", route_prefix="/api/v1") serve.run(App2.bind(), name="app2", route_prefix="/api/v2") + app1_url = get_application_url("HTTP", app_name="app1") + app2_url = get_application_url("HTTP", app_name="app2") - resp = httpx.get("http://localhost:8000/api/v1", follow_redirects=True) + resp = httpx.get(app1_url, follow_redirects=True) assert resp.json() == "first" - resp = httpx.get("http://localhost:8000/api/v2", follow_redirects=True) + resp = httpx.get(app2_url, follow_redirects=True) assert resp.json() == "second" - for version in ["v1", "v2"]: - resp = httpx.get(f"http://localhost:8000/api/{version}/ignored") + for version in [app1_url, app2_url]: + resp = httpx.get(f"{version}/ignored") assert resp.status_code == 404 @@ -438,7 +446,8 @@ class MyApp: serve.run(MyApp.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() == {"hello": "world"} @@ -458,14 +467,16 @@ def func1(self, arg: str): serve.run(App.bind(), route_prefix=input_route_prefix) - r = httpx.get(f"http://localhost:8000{expected_route_prefix}openapi.json") + url = get_application_url("HTTP") + assert expected_route_prefix.rstrip("/") in url + r = httpx.get(f"{url}/openapi.json") assert r.status_code == 200 assert len(r.json()["paths"]) == 1 assert "/" in r.json()["paths"] assert len(r.json()["paths"]["/"]) == 1 assert "get" in r.json()["paths"]["/"] - r = httpx.get(f"http://localhost:8000{expected_route_prefix}docs") + r = httpx.get(f"{url}/docs") assert r.status_code == 200 @serve.deployment @@ -481,7 +492,9 @@ def func2(self, arg: int): serve.run(App.bind(), route_prefix=input_route_prefix) - r = httpx.get(f"http://localhost:8000{expected_route_prefix}openapi.json") + url = get_application_url("HTTP") + assert expected_route_prefix.rstrip("/") in url + r = httpx.get(f"{url}/openapi.json") assert r.status_code == 200 assert len(r.json()["paths"]) == 2 assert "/" in r.json()["paths"] @@ -491,7 +504,7 @@ def func2(self, arg: int): assert len(r.json()["paths"]["/hello"]) == 1 assert "post" in r.json()["paths"]["/hello"] - r = httpx.get(f"http://localhost:8000{expected_route_prefix}docs") + r = httpx.get(f"{url}/docs") assert r.status_code == 200 @@ -512,7 +525,8 @@ class FastAPIApp: serve.run(FastAPIApp.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert dict(resp.cookies) == {"a": "b", "c": "d"} @@ -548,13 +562,14 @@ def test_endpoint_3(self): serve.run(TestDeployment.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() == {"a": "a", "b": ["b"]} - resp = httpx.get("http://localhost:8000/inner") + resp = httpx.get(f"{url}/inner") assert resp.json() == {"a": "a", "b": ["b"]} - resp = httpx.get("http://localhost:8000/inner2") + resp = httpx.get(f"{url}/inner2") assert resp.json() == [{"a": "a", "b": ["b"]}] @@ -589,7 +604,8 @@ def root(self): return self.test_passed serve.run(TestDeployment.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() @@ -654,8 +670,9 @@ def method(self): # noqa: F811 method redefinition return "hi post" serve.run(A.bind(), route_prefix="/a") - assert httpx.get("http://localhost:8000/a/").json() == "hi get" - assert httpx.post("http://localhost:8000/a/").json() == "hi post" + url = get_application_url("HTTP") + assert httpx.get(f"{url}/").json() == "hi get" + assert httpx.post(f"{url}/").json() == "hi post" def test_fastapi_same_app_multiple_deployments(serve_instance): @@ -687,23 +704,26 @@ def decr2(self): serve.run(CounterDeployment1.bind(), name="app1", route_prefix="/app1") serve.run(CounterDeployment2.bind(), name="app2", route_prefix="/app2") + app1_url = get_application_url("HTTP", app_name="app1") + app2_url = get_application_url("HTTP", app_name="app2") + should_work = [ - ("/app1/incr", "incr"), - ("/app1/decr", "decr"), - ("/app2/incr2", "incr2"), - ("/app2/decr2", "decr2"), + (app1_url, "/incr", "incr"), + (app1_url, "/decr", "decr"), + (app2_url, "/incr2", "incr2"), + (app2_url, "/decr2", "decr2"), ] - for path, resp in should_work: - assert httpx.get("http://localhost:8000" + path).json() == resp, (path, resp) + for url, path, resp in should_work: + assert httpx.get(f"{url}{path}").json() == resp, (path, resp) should_404 = [ - "/app2/incr", - "/app2/decr", - "/app1/incr2", - "/app1/decr2", + (app1_url, "/incr2", 404), + (app1_url, "/decr2", 404), + (app2_url, "/incr", 404), + (app2_url, "/decr", 404), ] - for path in should_404: - assert httpx.get("http://localhost:8000" + path).status_code == 404, path + for url, path, status_code in should_404: + assert httpx.get(f"{url}{path}").status_code == status_code, (path, status_code) @pytest.mark.parametrize("two_fastapi", [True, False]) @@ -818,11 +838,12 @@ def class_route(self): return "hello class route" serve.run(ASGIIngress.bind()) - assert httpx.get("http://localhost:8000/").json() == "hello" - assert httpx.get("http://localhost:8000/f2").json() == "hello f2" - assert httpx.get("http://localhost:8000/class_route").json() == "hello class route" - assert httpx.get("http://localhost:8000/error").status_code == 500 - assert httpx.get("http://localhost:8000/error").json() == {"error": "fake-error"} + url = get_application_url("HTTP") + assert httpx.get(url).json() == "hello" + assert httpx.get(f"{url}/f2").json() == "hello f2" + assert httpx.get(f"{url}/class_route").json() == "hello class route" + assert httpx.get(f"{url}/error").status_code == 500 + assert httpx.get(f"{url}/error").json() == {"error": "fake-error"} # get the docs path from the controller docs_path = ray.get(serve_instance._controller.get_docs_path.remote("default")) @@ -835,10 +856,11 @@ def test_ingress_with_fastapi_with_no_deployment_class(serve_instance): ingress_deployment = serve.deployment(serve.ingress(app)()) assert ingress_deployment.name == "ASGIIngressDeployment" serve.run(ingress_deployment.bind()) - assert httpx.get("http://localhost:8000/").json() == "hello" - assert httpx.get("http://localhost:8000/f2").json() == "hello f2" - assert httpx.get("http://localhost:8000/error").status_code == 500 - assert httpx.get("http://localhost:8000/error").json() == {"error": "fake-error"} + url = get_application_url("HTTP") + assert httpx.get(url).json() == "hello" + assert httpx.get(f"{url}/f2").json() == "hello f2" + assert httpx.get(f"{url}/error").status_code == 500 + assert httpx.get(f"{url}/error").json() == {"error": "fake-error"} # get the docs path from the controller docs_path = ray.get(serve_instance._controller.get_docs_path.remote("default")) @@ -849,15 +871,16 @@ def test_ingress_with_fastapi_builder_function(serve_instance): ingress_deployment = serve.deployment(serve.ingress(fastapi_builder)()) serve.run(ingress_deployment.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() == "hello" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/f2") + resp = httpx.get(f"{url}/f2") assert resp.json() == "hello f2" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/error") + resp = httpx.get(f"{url}/error") assert resp.status_code == 500 assert resp.json() == {"error": "fake-error"} @@ -874,13 +897,14 @@ def __init__(self): serve.run(ASGIIngress.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() == "hello" - resp = httpx.get("http://localhost:8000/f2") + resp = httpx.get(f"{url}/f2") assert resp.json() == "hello f2" - resp = httpx.get("http://localhost:8000/error") + resp = httpx.get(f"{url}/error") assert resp.status_code == 500 assert resp.json() == {"error": "fake-error"} @@ -941,7 +965,8 @@ def __init__(self, sub_deployment: DeploymentHandle): serve.run(ASGIIngress.bind(sub_deployment().bind())) - resp = httpx.get("http://localhost:8000/sub_deployment?a=2") + url = get_application_url("HTTP") + resp = httpx.get(f"{url}/sub_deployment?a=2") assert resp.json() == {"a": 3} @@ -952,7 +977,8 @@ def test_deployment_composition_with_builder_function_without_decorator(serve_in # and passes them to the deployment constructor serve.run(app.bind(sub_deployment().bind())) - resp = httpx.get("http://localhost:8000/sub_deployment?a=2") + url = get_application_url("HTTP") + resp = httpx.get(f"{url}/sub_deployment?a=2") assert resp.json() == {"a": 3} @@ -1016,15 +1042,16 @@ def test_ingress_with_starlette_app_with_no_deployment_class(serve_instance): ingress_deployment = serve.deployment(serve.ingress(starlette_builder())()) serve.run(ingress_deployment.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() == "hello" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/f2") + resp = httpx.get(f"{url}/f2") assert resp.json() == "hello f2" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/error") + resp = httpx.get(f"{url}/error") assert resp.status_code == 500 assert resp.json() == {"error": "fake-error"} @@ -1036,15 +1063,16 @@ def test_ingress_with_starlette_builder_with_no_deployment_class(serve_instance) ingress_deployment = serve.deployment(serve.ingress(starlette_builder)()) serve.run(ingress_deployment.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() == "hello" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/f2") + resp = httpx.get(f"{url}/f2") assert resp.json() == "hello f2" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/error") + resp = httpx.get(f"{url}/error") assert resp.status_code == 500 assert resp.json() == {"error": "fake-error"} @@ -1061,15 +1089,16 @@ def __init__(self): serve.run(ASGIIngress.bind()) - resp = httpx.get("http://localhost:8000/") + url = get_application_url("HTTP") + resp = httpx.get(url) assert resp.json() == "hello" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/f2") + resp = httpx.get(f"{url}/f2") assert resp.json() == "hello f2" assert resp.headers["X-Custom-Middleware"] == "fake-middleware" - resp = httpx.get("http://localhost:8000/error") + resp = httpx.get(f"{url}/error") assert resp.status_code == 500 assert resp.json() == {"error": "fake-error"} diff --git a/python/ray/serve/tests/test_handle_cancellation.py b/python/ray/serve/tests/test_handle_cancellation.py index 4fc4c5684198..6900a2044f13 100644 --- a/python/ray/serve/tests/test_handle_cancellation.py +++ b/python/ray/serve/tests/test_handle_cancellation.py @@ -214,7 +214,7 @@ async def __call__(self, *args): g.cancel() with pytest.raises(RequestCancelledError): - assert await g.__anext__() == "hi" + await g.__anext__() await signal_actor.wait.remote() diff --git a/python/ray/serve/tests/test_http_cancellation.py b/python/ray/serve/tests/test_http_cancellation.py index c4150476d9fb..a255f30be75a 100644 --- a/python/ray/serve/tests/test_http_cancellation.py +++ b/python/ray/serve/tests/test_http_cancellation.py @@ -10,7 +10,10 @@ from ray import serve from ray._common.test_utils import SignalActor, wait_for_condition from ray._private.test_utils import Collector -from ray.serve._private.test_utils import send_signal_on_cancellation +from ray.serve._private.test_utils import ( + get_application_url, + send_signal_on_cancellation, +) from ray.serve.exceptions import RequestCancelledError @@ -58,7 +61,7 @@ async def __call__(self, request: Request): # Intentionally time out on the client, causing it to disconnect. with pytest.raises(httpx.ReadTimeout): - httpx.get("http://localhost:8000", timeout=0.5) + httpx.get(get_application_url("HTTP"), timeout=0.5) # Both the HTTP handler and the inner deployment handle call should be cancelled. ray.get(inner_signal_actor.wait.remote(), timeout=10) @@ -89,7 +92,7 @@ async def __call__(self, *args): # Intentionally time out on the client, causing it to disconnect. with pytest.raises(httpx.ReadTimeout): - httpx.get("http://localhost:8000", timeout=0.5) + httpx.get(get_application_url("HTTP"), timeout=0.5) # Now signal the initial request to finish and check that the request sent via HTTP # never reaches the replica. @@ -125,7 +128,7 @@ async def __call__(self): try: await self.child.remote() except asyncio.CancelledError: - await collector.add.remote("Parent_CancelledError") + await collector.add.remote("Parent_AsyncioCancelledError") raise except RequestCancelledError: await collector.add.remote("Parent_RequestCancelledError") @@ -135,13 +138,13 @@ async def __call__(self): # Make a request with short timeout that will cause disconnection try: - await httpx.AsyncClient(timeout=0.5).get("http://localhost:8000/") + await httpx.AsyncClient(timeout=0.5).get(get_application_url("HTTP")) except httpx.ReadTimeout: pass wait_for_condition( lambda: set(ray.get(collector.get.remote())) - == {"Child_CancelledError", "Parent_CancelledError"} + == {"Child_CancelledError", "Parent_AsyncioCancelledError"} ) @@ -171,7 +174,7 @@ async def __call__(self): try: await self.child.remote() except asyncio.CancelledError: - await collector.add.remote("Parent_CancelledError") + await collector.add.remote("Parent_AsyncioCancelledError") raise except RequestCancelledError: await collector.add.remote("Parent_RequestCancelledError") @@ -185,12 +188,12 @@ async def __call__(self): # Make a second request with short timeout that will cause disconnection try: - await httpx.AsyncClient(timeout=0.5).get("http://localhost:8000/") + await httpx.AsyncClient(timeout=0.5).get(get_application_url("HTTP")) except httpx.ReadTimeout: pass wait_for_condition( - lambda: ray.get(collector.get.remote()) == ["Parent_CancelledError"] + lambda: ray.get(collector.get.remote()) == ["Parent_AsyncioCancelledError"] ) # Clean up first request diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index 2cdba1a8986e..1b1bc693d3d6 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional import httpx import pytest @@ -11,13 +11,17 @@ import ray from ray import serve from ray._common.test_utils import SignalActor +from ray.serve._private.test_utils import get_application_url, get_application_urls from ray.serve.handle import DeploymentHandle @ray.remote class StreamingRequester: - async def make_request(self) -> AsyncGenerator[str, None]: - with httpx.stream("GET", "http://localhost:8000") as r: + async def make_request( + self, url: Optional[str] = None + ) -> AsyncGenerator[str, None]: + url = url or get_application_url("HTTP") + with httpx.stream("GET", url) as r: r.raise_for_status() for chunk in r.iter_text(): yield chunk @@ -56,7 +60,8 @@ def __call__(self, request: Request) -> StreamingResponse: serve.run(SimpleGenerator.bind()) - with httpx.stream("GET", "http://localhost:8000") as r: + url = get_application_url("HTTP") + with httpx.stream("GET", url) as r: r.raise_for_status() for i, chunk in enumerate(r.iter_text()): assert chunk == f"hi_{i}" @@ -111,9 +116,15 @@ def __call__(self, request: Request) -> StreamingResponse: ).bind() ) + urls = get_application_urls("HTTP") + requester = StreamingRequester.remote() - gen1 = requester.make_request.options(num_returns="streaming").remote() - gen2 = requester.make_request.options(num_returns="streaming").remote() + if len(urls) == 2: + gen1 = requester.make_request.options(num_returns="streaming").remote(urls[0]) + gen2 = requester.make_request.options(num_returns="streaming").remote(urls[1]) + else: + gen1 = requester.make_request.options(num_returns="streaming").remote() + gen2 = requester.make_request.options(num_returns="streaming").remote() # Check that we get the first responses before the signal is sent # (so the generator is still hanging after the first yield). @@ -186,7 +197,8 @@ def __call__(self, request: Request) -> StreamingResponse: serve.run(SimpleGenerator.bind()) - with httpx.stream("GET", "http://localhost:8000") as r: + url = get_application_url("HTTP") + with httpx.stream("GET", url) as r: assert r.status_code == 301 assert r.headers["hello"] == "world" assert r.headers["content-type"] == "foo/bar" @@ -226,7 +238,8 @@ def __call__(self, request: Request) -> StreamingResponse: serve.run(SimpleGenerator.bind()) - with httpx.stream("GET", "http://localhost:8000") as r: + url = get_application_url("HTTP") + with httpx.stream("GET", url) as r: r.raise_for_status() stream_iter = r.iter_text() assert next(stream_iter) == "first result" @@ -284,7 +297,8 @@ def __call__(self, request: Request) -> StreamingResponse: serve.run(SimpleGenerator.bind(Streamer.bind())) - with httpx.stream("GET", "http://localhost:8000") as r: + url = get_application_url("HTTP") + with httpx.stream("GET", url) as r: r.raise_for_status() for i, chunk in enumerate(r.iter_text()): assert chunk == f"hi_{i}" @@ -309,7 +323,8 @@ async def wait_for_disconnect(): serve.run(SimpleGenerator.bind()) - with httpx.stream("GET", "http://localhost:8000"): + url = get_application_url("HTTP") + with httpx.stream("GET", url): with pytest.raises(TimeoutError): _ = ray.get(signal_actor.wait.remote(), timeout=1) diff --git a/python/ray/tests/test_open_telemetry_metric_recorder.py b/python/ray/tests/test_open_telemetry_metric_recorder.py index b386b7de210a..59fb45da0e5d 100644 --- a/python/ray/tests/test_open_telemetry_metric_recorder.py +++ b/python/ray/tests/test_open_telemetry_metric_recorder.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest +from opentelemetry.metrics import NoOpCounter from ray._private.telemetry.open_telemetry_metric_recorder import ( OpenTelemetryMetricRecorder, @@ -28,7 +29,7 @@ def test_register_gauge_metric(mock_get_meter, mock_set_meter_provider): value=42.0, ) assert ( - recorder._get_metric_value( + recorder._get_observable_metric_value( name="test_gauge", tags={"label_key": "label_value"}, ) @@ -36,6 +37,31 @@ def test_register_gauge_metric(mock_get_meter, mock_set_meter_provider): ) +@patch("ray._private.telemetry.open_telemetry_metric_recorder.logger.warning") +@patch("opentelemetry.metrics.set_meter_provider") +@patch("opentelemetry.metrics.get_meter") +def test_register_counter_metric( + mock_get_meter, mock_set_meter_provider, mock_logger_warning +): + """ + Test the register_counter_metric method of OpenTelemetryMetricRecorder. + - Test that it registers a counter metric with the correct name and description. + - Test that a value can be set for the counter metric successfully without warnings. + """ + mock_meter = MagicMock() + mock_meter.create_counter.return_value = NoOpCounter(name="test_counter") + mock_get_meter.return_value = mock_meter + recorder = OpenTelemetryMetricRecorder() + recorder.register_counter_metric(name="test_counter", description="Test Counter") + assert "test_counter" in recorder._registered_instruments + recorder.set_metric_value( + name="test_counter", + tags={"label_key": "label_value"}, + value=10.0, + ) + mock_logger_warning.assert_not_called() + + @patch("opentelemetry.metrics.set_meter_provider") @patch("opentelemetry.metrics.get_meter") def test_record_and_export(mock_get_meter, mock_set_meter_provider): diff --git a/release/nightly_tests/dataset/join_benchmark.py b/release/nightly_tests/dataset/join_benchmark.py new file mode 100644 index 000000000000..87b2f797e953 --- /dev/null +++ b/release/nightly_tests/dataset/join_benchmark.py @@ -0,0 +1,72 @@ +import ray +import argparse + +from benchmark import Benchmark + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--left_dataset", required=True, type=str, help="Path to the left dataset" + ) + parser.add_argument( + "--right_dataset", required=True, type=str, help="Path to the right dataset" + ) + parser.add_argument( + "--num_partitions", + required=True, + type=int, + help="Number of partitions to use for the join", + ) + parser.add_argument( + "--left_join_keys", + required=True, + nargs="+", + type=str, + help="Join keys for the left dataset", + ) + parser.add_argument( + "--right_join_keys", + required=True, + nargs="+", + type=str, + help="Join keys for the right dataset", + ) + parser.add_argument( + "--join_type", + required=True, + choices=["inner", "left_outer", "right_outer", "full_outer"], + help="Type of join operation", + ) + return parser.parse_args() + + +def main(args): + benchmark = Benchmark() + + def benchmark_fn(): + left_ds = ray.data.read_parquet(args.left_dataset) + right_ds = ray.data.read_parquet(args.right_dataset) + # Check if join keys match; if not, rename right join keys + if len(args.left_join_keys) != len(args.right_join_keys): + raise ValueError("Number of left and right join keys must match.") + + # Perform join + joined_ds = left_ds.join( + right_ds, + num_partitions=args.num_partitions, + on=args.left_join_keys, + right_on=args.right_join_keys, + join_type=args.join_type, + ) + + # Process joined_ds if needed + print(f"Join completed with {joined_ds.count()} records.") + + benchmark.run_fn(str(vars(args)), benchmark_fn) + benchmark.write_result() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/release/release_data_tests.yaml b/release/release_data_tests.yaml index fb3aacd96174..5f15d3d3b256 100644 --- a/release/release_data_tests.yaml +++ b/release/release_data_tests.yaml @@ -180,6 +180,49 @@ python groupby_benchmark.py --sf 10 --map-groups --group-by {{columns}} --shuffle-strategy {{shuffle_strategy}} +############### +# Join tests +############### + +# NOTE: +# Joining on Benchmark TPCH parquet datasets +# Left dataset 'LINEITEM' = SF*6M rows +# Right dataset 'ORDERS' = SF*1.5M rows +# Join key = 'l_orderkey', 'o_orderkey' respectively from 'LINEITEM', 'ORDERS' dataset. In the generated dataset, +# * For 'LINEITEM' dataset, 'column_00' corresponds to l_orderkey +# * For 'ORDERS' dataset, 'column_0' corresponds to o_orderkey. +# Join type = inner, left_join, right_join and full_join +# +# Dataset TPCH Scale Factor (SF) for CSV files. Note that parquet files will be low smaller with column compression. +# SF1 = 1GB +# SF10 = 10GB +# SF100 = 100GB +# SF1000 = 1TB +# SF10000 = 10TB +# +# Do adjust timeout below based on SF above. +# + +- name: joins_{{dataset}}_{{join_type}} + + cluster: + cluster_compute: fixed_size_100_cpu_compute.yaml + + matrix: + setup: + dataset: [sf100] + join_type: [inner, left_outer, right_outer, full_outer] + + run: + timeout: 3600 + script: > + python join_benchmark.py + --left_dataset s3://ray-benchmark-data/tpch/parquet/{{dataset}}/lineitem + --right_dataset s3://ray-benchmark-data/tpch/parquet/{{dataset}}/orders + --left_join_keys column00 + --right_join_keys column0 + --join_type {{join_type}} + --num_partitions 50 ####################### # Streaming split tests diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 984c35d994f2..f484493fa7dc 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -475,6 +475,10 @@ WorkerID TaskSpecification::CallerWorkerId() const { return WorkerID::FromBinary(message_->caller_address().worker_id()); } +std::string TaskSpecification::CallerWorkerIdBinary() const { + return message_->caller_address().worker_id(); +} + NodeID TaskSpecification::CallerNodeId() const { return NodeID::FromBinary(message_->caller_address().raylet_id()); } diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index a2097edb3a7e..4132ecfabb6b 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -486,6 +486,8 @@ class TaskSpecification : public MessageWrapper { WorkerID CallerWorkerId() const; + std::string CallerWorkerIdBinary() const; + NodeID CallerNodeId() const; uint64_t SequenceNumber() const; diff --git a/src/ray/core_worker/transport/actor_task_submitter.cc b/src/ray/core_worker/transport/actor_task_submitter.cc index b3322580c238..b6fda9515d76 100644 --- a/src/ray/core_worker/transport/actor_task_submitter.cc +++ b/src/ray/core_worker/transport/actor_task_submitter.cc @@ -926,7 +926,7 @@ Status ActorTaskSubmitter::CancelTask(TaskSpecification task_spec, bool recursiv request.set_intended_task_id(task_spec.TaskIdBinary()); request.set_force_kill(force_kill); request.set_recursive(recursive); - request.set_caller_worker_id(task_spec.CallerWorkerId().Binary()); + request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); client->CancelTask(request, [this, task_spec = std::move(task_spec), recursive, task_id]( const Status &status, const rpc::CancelTaskReply &reply) { diff --git a/src/ray/core_worker/transport/normal_task_submitter.cc b/src/ray/core_worker/transport/normal_task_submitter.cc index 24d76227b32d..381582c630d9 100644 --- a/src/ray/core_worker/transport/normal_task_submitter.cc +++ b/src/ray/core_worker/transport/normal_task_submitter.cc @@ -760,7 +760,7 @@ Status NormalTaskSubmitter::CancelTask(TaskSpecification task_spec, request.set_intended_task_id(task_spec.TaskIdBinary()); request.set_force_kill(force_kill); request.set_recursive(recursive); - request.set_caller_worker_id(task_spec.CallerWorkerId().Binary()); + request.set_caller_worker_id(task_spec.CallerWorkerIdBinary()); client->CancelTask( request, [this, diff --git a/src/ray/stats/metric.h b/src/ray/stats/metric.h index 6426ca308fdc..630ce4e8800b 100644 --- a/src/ray/stats/metric.h +++ b/src/ray/stats/metric.h @@ -266,9 +266,15 @@ void RegisterView(const std::string &name, .set_description(description) .set_measure(name) .set_aggregation(I::Aggregation(buckets)); - if (T == GAUGE && - ::RayConfig::instance().experimental_enable_open_telemetry_on_core()) { - OpenTelemetryMetricRecorder::GetInstance().RegisterGaugeMetric(name, description); + + if (::RayConfig::instance().experimental_enable_open_telemetry_on_core()) { + if (T == GAUGE) { + OpenTelemetryMetricRecorder::GetInstance().RegisterGaugeMetric(name, description); + } else if (T == COUNT) { + OpenTelemetryMetricRecorder::GetInstance().RegisterCounterMetric(name, description); + } else { + internal::RegisterAsView(view_descriptor, tag_keys); + } } else { internal::RegisterAsView(view_descriptor, tag_keys); } diff --git a/src/ray/stats/tests/metric_with_open_telemetry_test.cc b/src/ray/stats/tests/metric_with_open_telemetry_test.cc index 15732058dae4..1e625e75cc78 100644 --- a/src/ray/stats/tests/metric_with_open_telemetry_test.cc +++ b/src/ray/stats/tests/metric_with_open_telemetry_test.cc @@ -23,8 +23,16 @@ namespace telemetry { using OpenTelemetryMetricRecorder = ray::telemetry::OpenTelemetryMetricRecorder; using StatsConfig = ray::stats::StatsConfig; -DECLARE_stats(metric_test); -DEFINE_stats(metric_test, "A test gauge metric", ("Tag1", "Tag2"), (), ray::stats::GAUGE); +DECLARE_stats(metric_gauge_test); +DEFINE_stats( + metric_gauge_test, "A test gauge metric", ("Tag1", "Tag2"), (), ray::stats::GAUGE); + +DECLARE_stats(metric_counter_test); +DEFINE_stats(metric_counter_test, + "A test counter metric", + ("Tag1", "Tag2"), + (), + ray::stats::COUNT); class MetricTest : public ::testing::Test { public: @@ -41,12 +49,21 @@ class MetricTest : public ::testing::Test { TEST_F(MetricTest, TestGaugeMetric) { ASSERT_TRUE( - OpenTelemetryMetricRecorder::GetInstance().IsMetricRegistered("metric_test")); - STATS_metric_test.Record(42.0, {{"Tag1", "Value1"}, {"Tag2", "Value2"}}); + OpenTelemetryMetricRecorder::GetInstance().IsMetricRegistered("metric_gauge_test")); + STATS_metric_gauge_test.Record(42.0, {{"Tag1", "Value1"}, {"Tag2", "Value2"}}); ASSERT_EQ(OpenTelemetryMetricRecorder::GetInstance().GetObservableMetricValue( - "metric_test", {{"Tag1", "Value1"}, {"Tag2", "Value2"}}), + "metric_gauge_test", {{"Tag1", "Value1"}, {"Tag2", "Value2"}}), 42.0); } +TEST_F(MetricTest, TestCounterMetric) { + ASSERT_TRUE(OpenTelemetryMetricRecorder::GetInstance().IsMetricRegistered( + "metric_counter_test")); + // We only test that recording is not crashing. The actual value is not checked + // because open telemetry does not provide a way to retrieve the value of a counter. + // Checking value is performed via e2e tests instead (e.g., in test_metrics_agent.py). + STATS_metric_counter_test.Record(100.0, {{"Tag1", "Value1"}, {"Tag2", "Value2"}}); +} + } // namespace telemetry } // namespace ray diff --git a/src/ray/telemetry/open_telemetry_metric_recorder.cc b/src/ray/telemetry/open_telemetry_metric_recorder.cc index 67472ff0e47d..29a1ae2ffb7c 100644 --- a/src/ray/telemetry/open_telemetry_metric_recorder.cc +++ b/src/ray/telemetry/open_telemetry_metric_recorder.cc @@ -58,6 +58,12 @@ void OpenTelemetryMetricRecorder::RegisterGrpcExporter( // Create an OTLP exporter opentelemetry::exporter::otlp::OtlpGrpcMetricExporterOptions exporter_options; exporter_options.endpoint = endpoint; + // This line ensures that only the delta values for count and sum are exported during + // each collection interval. This is necessary because the dashboard agent already + // accumulates these metrics—re-accumulating them during export would lead to double + // counting. + exporter_options.aggregation_temporality = + opentelemetry::exporter::otlp::PreferredAggregationTemporality::kDelta; auto exporter = std::make_unique( exporter_options); diff --git a/src/ray/telemetry/open_telemetry_metric_recorder.h b/src/ray/telemetry/open_telemetry_metric_recorder.h index 01a7aadd6f0c..f82199976669 100644 --- a/src/ray/telemetry/open_telemetry_metric_recorder.h +++ b/src/ray/telemetry/open_telemetry_metric_recorder.h @@ -66,10 +66,6 @@ class OpenTelemetryMetricRecorder { absl::flat_hash_map &&tags, double value); - // Get the value of a metric given the tags. - std::optional GetObservableMetricValue( - const std::string &name, const absl::flat_hash_map &tags); - // Helper function to collect gauge metric values. This function is called only once // per interval for each metric. It collects the values from the observations_by_name_ // map and passes them to the observer. @@ -120,9 +116,16 @@ class OpenTelemetryMetricRecorder { absl::flat_hash_map &&tags, double value); + std::optional GetObservableMetricValue( + const std::string &name, const absl::flat_hash_map &tags); + opentelemetry::nostd::shared_ptr GetMeter() { return meter_provider_->GetMeter("ray"); } + + // Declare the test class as a friend to allow access to private members for testing. + friend class MetricTest_TestGaugeMetric_Test; + friend class OpenTelemetryMetricRecorderTest_TestGaugeMetric_Test; }; } // namespace telemetry } // namespace ray