Skip to content
Merged
4 changes: 0 additions & 4 deletions ci/lint/pydoclint-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1484,10 +1484,6 @@ python/ray/llm/_internal/batch/processor/base.py
DOC101: Method `ProcessorBuilder.build`: Docstring contains fewer arguments than in function signature.
DOC103: Method `ProcessorBuilder.build`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ].
--------------------
python/ray/llm/_internal/batch/processor/vllm_engine_proc.py
DOC101: Function `build_vllm_engine_processor`: Docstring contains fewer arguments than in function signature.
DOC103: Function `build_vllm_engine_processor`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [telemetry_agent: Optional[TelemetryAgent]].
--------------------
python/ray/llm/_internal/batch/stages/base.py
DOC405: Method `StatefulStageUDF.__call__` has both "return" and "yield" statements. Please use Generator[YieldType, SendType, ReturnType] as the return type annotation, and put your yield type in YieldType and return type in ReturnType. More details in https://jsh9.github.io/pydoclint/notes_generator_vs_iterator.html
--------------------
Expand Down
23 changes: 19 additions & 4 deletions python/ray/data/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class ProcessorConfig(_ProcessorConfig):
accelerator_type: The accelerator type used by the LLM stage in a processor.
Default to None, meaning that only the CPU will be used.
concurrency: The number of workers for data parallelism. Default to 1.
If ``concurrency`` is a ``tuple`` ``(m, n)``, Ray creates an autoscaling
actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
If ``concurrency`` is an ``int`` ``n``, Ray uses either a fixed pool of ``n``
workers or an autoscaling pool from ``1`` to ``n`` workers, depending on
the processor and stage.
"""

pass
Expand All @@ -41,7 +46,9 @@ class HttpRequestProcessorConfig(_HttpRequestProcessorConfig):
batch_size: The batch size to send to the HTTP request.
url: The URL to send the HTTP request to.
headers: The headers to send with the HTTP request.
concurrency: The number of concurrent requests to send.
concurrency: The number of concurrent requests to send. Default to 1.
If ``concurrency`` is a ``tuple`` ``(m, n)``,
autoscaling strategy is used (``1 <= m <= n``).

Examples:
.. testcode::
Expand Down Expand Up @@ -116,6 +123,10 @@ class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
accelerator_type: The accelerator type used by the LLM stage in a processor.
Default to None, meaning that only the CPU will be used.
concurrency: The number of workers for data parallelism. Default to 1.
If ``concurrency`` is a tuple ``(m, n)``, Ray creates an autoscaling
actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
If ``concurrency`` is an ``int`` ``n``, CPU stages use an autoscaling
pool from ``(1, n)``, while GPU stages use a fixed pool of ``n`` workers.

Examples:

Expand Down Expand Up @@ -177,7 +188,7 @@ class SGLangEngineProcessorConfig(_SGLangEngineProcessorConfig):

Args:
model_source: The model source to use for the SGLang engine.
batch_size: The batch size to send to the vLLM engine. Large batch sizes are
batch_size: The batch size to send to the SGLang engine. Large batch sizes are
likely to saturate the compute resources and could achieve higher throughput.
On the other hand, small batch sizes are more fault-tolerant and could
reduce bubbles in the data pipeline. You can tune the batch size to balance
Expand All @@ -197,12 +208,16 @@ class SGLangEngineProcessorConfig(_SGLangEngineProcessorConfig):
apply_chat_template: Whether to apply chat template.
chat_template: The chat template to use. This is usually not needed if the
model checkpoint already contains the chat template.
tokenize: Whether to tokenize the input before passing it to the vLLM engine.
If not, vLLM will tokenize the prompt in the engine.
tokenize: Whether to tokenize the input before passing it to the SGLang engine.
If not, SGLang will tokenize the prompt in the engine.
detokenize: Whether to detokenize the output.
accelerator_type: The accelerator type used by the LLM stage in a processor.
Default to None, meaning that only the CPU will be used.
concurrency: The number of workers for data parallelism. Default to 1.
If ``concurrency`` is a tuple ``(m, n)``, Ray creates an autoscaling
actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
If ``concurrency`` is an ``int`` ``n``, CPU stages use an autoscaling
pool from ``(1, n)``, while GPU stages use a fixed pool of ``n`` workers.

Examples:
.. testcode::
Expand Down
80 changes: 75 additions & 5 deletions python/ray/llm/_internal/batch/processor/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from pydantic import Field
from pydantic import Field, field_validator

import ray
from ray.data import Dataset
Expand Down Expand Up @@ -45,9 +45,14 @@ class ProcessorConfig(BaseModelExtended):
description="The accelerator type used by the LLM stage in a processor. "
"Default to None, meaning that only the CPU will be used.",
)
concurrency: Optional[int] = Field(
concurrency: Union[int, Tuple[int, int]] = Field(
default=1,
description="The number of workers for data parallelism. Default to 1.",
description="The number of workers for data parallelism. Default to 1. "
"If ``concurrency`` is a ``tuple`` ``(m, n)``, Ray creates an autoscaling "
"actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``). "
"If ``concurrency`` is an ``int`` ``n``, Ray uses either a fixed pool of ``n`` "
"workers or an autoscaling pool from ``1`` to ``n`` workers, depending on "
"the processor and stage.",
)

experimental: Dict[str, Any] = Field(
Expand All @@ -57,6 +62,71 @@ class ProcessorConfig(BaseModelExtended):
"`max_tasks_in_flight_per_actor`: The maximum number of tasks in flight per actor. Default to 4.",
)

@field_validator("concurrency")
def validate_concurrency(
cls, concurrency: Union[int, Tuple[int, int]]
) -> Union[int, Tuple[int, int]]:
"""Validate that `concurrency` is either:
- a positive int, or
- a 2-tuple `(min, max)` of positive ints with `min <= max`.
"""

def require(condition: bool, message: str) -> None:
if not condition:
raise ValueError(message)

if isinstance(concurrency, int):
require(
concurrency > 0,
f"A positive integer for `concurrency` is expected! Got: `{concurrency}`.",
)
elif isinstance(concurrency, tuple):
require(
all(c > 0 for c in concurrency),
f"`concurrency` tuple items must be positive integers! Got: `{concurrency}`.",
)

min_concurrency, max_concurrency = concurrency
require(
min_concurrency <= max_concurrency,
f"min > max in the concurrency tuple `{concurrency}`!",
)
return concurrency

def get_concurrency(self, autoscaling_enabled: bool = True) -> Tuple[int, int]:
"""Return a normalized `(min, max)` worker range from `self.concurrency`.

Behavior:
- If `concurrency` is an int `n`:
- `autoscaling_enabled` is True -> return `(1, n)` (autoscaling).
- `autoscaling_enabled` is False -> return `(n, n)` (fixed-size pool).
- If `concurrency` is a 2-tuple `(m, n)`, return it unchanged
(the `autoscaling_enabled` flag is ignored).

Args:
autoscaling_enabled: When False, treat an integer `concurrency` as fixed `(n, n)`;
otherwise treat it as a range `(1, n)`. Defaults to True.

Returns:
tuple[int, int]: The allowed worker range `(min, max)`.

Examples:
>>> self.concurrency = (2, 4)
>>> self.get_concurrency()
(2, 4)
>>> self.concurrency = 4
>>> self.get_concurrency()
(1, 4)
>>> self.get_concurrency(autoscaling_enabled=False)
(4, 4)
"""
if isinstance(self.concurrency, int):
if autoscaling_enabled:
return 1, self.concurrency
else:
return self.concurrency, self.concurrency
return self.concurrency

class Config:
validate_assignment = True
arbitrary_types_allowed = True
Expand Down Expand Up @@ -263,7 +333,7 @@ class ProcessorBuilder:

@classmethod
def register(cls, config_type: Type[ProcessorConfig], builder: Callable) -> None:
"""A decorator to assoicate a particular pipeline config
"""A decorator to associate a particular pipeline config
with its build function.
"""
type_name = config_type.__name__
Expand Down
10 changes: 5 additions & 5 deletions python/ray/llm/_internal/batch/processor/sglang_engine_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def build_sglang_engine_processor(
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=(1, config.concurrency),
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
),
Expand All @@ -100,7 +100,7 @@ def build_sglang_engine_processor(
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=(1, config.concurrency),
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
),
Expand All @@ -123,8 +123,8 @@ def build_sglang_engine_processor(
# which initiates enough many overlapping UDF calls per actor, to
# saturate `max_concurrency`.
compute=ray.data.ActorPoolStrategy(
min_size=config.concurrency,
max_size=config.concurrency,
min_size=config.get_concurrency(autoscaling_enabled=False)[0],
max_size=config.get_concurrency(autoscaling_enabled=False)[1],
max_tasks_in_flight_per_actor=config.experimental.get(
"max_tasks_in_flight_per_actor", DEFAULT_MAX_TASKS_IN_FLIGHT
),
Expand All @@ -148,7 +148,7 @@ def build_sglang_engine_processor(
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=(1, config.concurrency),
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
),
Expand Down
23 changes: 7 additions & 16 deletions python/ray/llm/_internal/batch/processor/vllm_engine_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,21 @@ def build_vllm_engine_processor(
required fields for the following processing stages.
postprocess: An optional lambda function that takes a row (dict) as input
and returns a postprocessed row (dict).
telemetry_agent: An optional telemetry agent for collecting usage telemetry.

Returns:
The constructed processor.
"""
ray.init(runtime_env=config.runtime_env, ignore_reinit_error=True)

stages = []
if isinstance(config.concurrency, int):
# For CPU-only stages, we leverage auto-scaling to recycle resources.
processor_concurrency = (1, config.concurrency)
else:
raise ValueError(
"``concurrency`` is expected to be set as an integer,"
f" but got: {config.concurrency}."
)

if config.has_image:
stages.append(
PrepareImageStage(
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=processor_concurrency,
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
),
)
Expand All @@ -115,7 +108,7 @@ def build_vllm_engine_processor(
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=processor_concurrency,
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
),
Expand All @@ -130,7 +123,7 @@ def build_vllm_engine_processor(
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=processor_concurrency,
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
),
Expand All @@ -157,10 +150,8 @@ def build_vllm_engine_processor(
# which initiates enough many overlapping UDF calls per actor, to
# saturate `max_concurrency`.
compute=ray.data.ActorPoolStrategy(
# vLLM start up time is significant, so if user give fixed
# concurrency, start all instances without auto-scaling.
min_size=config.concurrency,
max_size=config.concurrency,
min_size=config.get_concurrency(autoscaling_enabled=False)[0],
max_size=config.get_concurrency(autoscaling_enabled=False)[1],
max_tasks_in_flight_per_actor=config.experimental.get(
"max_tasks_in_flight_per_actor", DEFAULT_MAX_TASKS_IN_FLIGHT
),
Expand All @@ -184,7 +175,7 @@ def build_vllm_engine_processor(
),
map_batches_kwargs=dict(
zero_copy_batch=True,
concurrency=processor_concurrency,
concurrency=config.get_concurrency(),
batch_size=config.batch_size,
runtime_env=config.runtime_env,
),
Expand Down
80 changes: 75 additions & 5 deletions python/ray/llm/tests/batch/cpu/processor/test_processor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,87 @@ def overrider(name: str, stage: StatefulStage):

class TestProcessorConfig:
def test_valid_concurrency(self):
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.2-1B-Instruct",
concurrency=(1, 2),
)
assert config.concurrency == (1, 2)

with pytest.raises(pydantic.ValidationError, match="should be a valid integer"):
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.2-1B-Instruct",
concurrency=(1, 2),
)
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.2-1B-Instruct",
)
assert config.concurrency == 1

def test_invalid_concurrency(self):
with pytest.raises(pydantic.ValidationError):
vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.2-1B-Instruct",
concurrency=1.1,
)

with pytest.raises(pydantic.ValidationError):
vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.2-1B-Instruct",
concurrency=[1, 2, 3],
)

@pytest.mark.parametrize("n", [1, 2, 10])
def test_positive_int_not_fail(self, n):
conf = ProcessorConfig(concurrency=n)
assert conf.concurrency == n

def test_positive_int_unusual_not_fail(self):
assert ProcessorConfig(concurrency="1").concurrency == 1
assert ProcessorConfig(concurrency=1.0).concurrency == 1
assert ProcessorConfig(concurrency="1.0").concurrency == 1

@pytest.mark.parametrize("pair", [(1, 1), (1, 2), (2, 8)])
def test_valid_tuple_not_fail(self, pair):
conf = ProcessorConfig(concurrency=pair)
assert conf.concurrency == pair

def test_valid_tuple_unusual_not_fail(self):
assert ProcessorConfig(concurrency=("1", 2)).concurrency == (1, 2)
assert ProcessorConfig(concurrency=(1, "2")).concurrency == (1, 2)
assert ProcessorConfig(concurrency=[1, "2"]).concurrency == (1, 2)

@pytest.mark.parametrize(
"bad,msg_part",
[
(0, "positive integer"),
(-5, "positive integer"),
((1, 2, 3), "at most 2 items"),
((0, 1), "positive integers"),
((1, 0), "positive integers"),
((-1, 2), "positive integers"),
((1, -2), "positive integers"),
((1, 2.5), "a number with a fractional part"),
("2.1", "unable to parse string"),
((5, 2), "min > max"),
],
)
def test_invalid_inputs_raise(self, bad, msg_part):
with pytest.raises(pydantic.ValidationError) as e:
ProcessorConfig(concurrency=bad)
assert msg_part in str(e.value)

@pytest.mark.parametrize(
"n,expected", [(1, (1, 1)), (4, (1, 4)), (10, (1, 10)), ("10", (1, 10))]
)
def test_with_int_concurrency_scaling(self, n, expected):
conf = ProcessorConfig(concurrency=n)
assert conf.get_concurrency() == expected

@pytest.mark.parametrize("n,expected", [(1, (1, 1)), (4, (4, 4)), (10, (10, 10))])
def test_with_int_concurrency_fixed(self, n, expected):
conf = ProcessorConfig(concurrency=n)
assert conf.get_concurrency(autoscaling_enabled=False) == expected

@pytest.mark.parametrize("pair", [(1, 1), (1, 3), (2, 8)])
def test_with_tuple_concurrency(self, pair):
conf = ProcessorConfig(concurrency=pair)
assert conf.get_concurrency() == pair


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
7 changes: 5 additions & 2 deletions release/llm_tests/batch/test_batch_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ def add_buffer_time_between_tests():
"""Add buffer time after each test to avoid resource conflicts, which cause
flakiness.
"""
yield # Test runs here
# yield # test runs
# time.sleep(10)
import gc

time.sleep(10)
gc.collect()
time.sleep(15)


def test_chat_template_with_vllm():
Expand Down