Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/ray-core/examples/lm/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fairseq_cli.train import main

import ray
from ray._common.network_utils import build_address

_original_save_checkpoint = fairseq.checkpoint_utils.save_checkpoint

Expand Down Expand Up @@ -112,7 +113,7 @@ def run_fault_tolerant_loop():
# fairseq distributed training.
ip = ray.get(workers[0].get_node_ip.remote())
port = ray.get(workers[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
address = f"tcp://{build_address(ip, port)}"

# Start the remote processes, and check whether their are any process
# fails. If so, restart all the processes.
Expand Down
44 changes: 44 additions & 0 deletions python/ray/_common/network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Optional, Tuple, Union


def parse_address(address: str) -> Optional[Tuple[str, str]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some tests? In the follow-up, we can explore combining this with the c++ version via cython.

"""Parse a network address string into host and port.

Args:
address: The address string to parse (e.g., "localhost:8000", "[::1]:8000").

Returns:
Tuple with (host, port) if port found, None if no colon separator.
"""
pos = address.rfind(":")
if pos == -1:
return None

host = address[:pos]
port = address[pos + 1 :]

if ":" in host:
if host.startswith("[") and host.endswith("]"):
host = host[1:-1]
else:
# Invalid IPv6 (missing brackets) or colon is part of the address, not a host:port split.
return None

return (host, port)


def build_address(host: str, port: Union[int, str]) -> str:
"""Build a network address string from host and port.

Args:
host: The hostname or IP address.
port: The port number (int or string).

Returns:
Formatted address string (e.g., "localhost:8000" or "[::1]:8000").
"""
if host is not None and ":" in host:
# IPv6 address
return f"[{host}]:{port}"
# IPv4 address or hostname
return f"{host}:{port}"
3 changes: 2 additions & 1 deletion python/ray/_common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


import ray
from ray._common.network_utils import build_address
import ray._private.utils
import ray._common.usage.usage_lib as ray_usage_lib

Expand Down Expand Up @@ -174,7 +175,7 @@ def simulate_s3_bucket(
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"

s3_server = f"http://localhost:{port}"
s3_server = f"http://{build_address('localhost', port)}"
server = ThreadedMotoServer(port=port)
server.start()
url = f"s3://{uuid.uuid4().hex}?region={region}&endpoint_override={s3_server}"
Expand Down
1 change: 1 addition & 0 deletions python/ray/_common/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ py_library(
py_test_module_list(
size = "small",
files = [
"test_network_utils.py",
"test_ray_option_utils.py",
"test_signal_semaphore_utils.py",
"test_signature.py",
Expand Down
71 changes: 71 additions & 0 deletions python/ray/_common/tests/test_network_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
import sys

from ray._common.network_utils import parse_address, build_address


class TestBuildAddress:
"""Test cases for build_address function, matching C++ tests exactly."""

@pytest.mark.parametrize(
"host,port,expected",
[
# IPv4
("192.168.1.1", 8080, "192.168.1.1:8080"),
("192.168.1.1", "8080", "192.168.1.1:8080"),
# IPv6
("::1", 8080, "[::1]:8080"),
("::1", "8080", "[::1]:8080"),
("2001:db8::1", 8080, "[2001:db8::1]:8080"),
("2001:db8::1", "8080", "[2001:db8::1]:8080"),
# Hostname
("localhost", 9000, "localhost:9000"),
("localhost", "9000", "localhost:9000"),
],
)
def test_build_address(self, host, port, expected):
"""Test building address strings from host and port."""
result = build_address(host, port)
assert result == expected


class TestParseAddress:
"""Test cases for parse_address function, matching C++ tests exactly."""

@pytest.mark.parametrize(
"address,expected",
[
# IPv4
("192.168.1.1:8080", ("192.168.1.1", "8080")),
# IPv6:loopback address
("[::1]:8080", ("::1", "8080")),
# IPv6
("[2001:db8::1]:8080", ("2001:db8::1", "8080")),
# Hostname:Port
("localhost:9000", ("localhost", "9000")),
],
)
def test_parse_valid_addresses(self, address, expected):
"""Test parsing valid addresses."""
result = parse_address(address)
assert result == expected

@pytest.mark.parametrize(
"address",
[
# bare IP or hostname
# should return None when no port is found
"::1",
"2001:db8::1",
"192.168.1.1",
"localhost",
],
)
def test_parse_bare_addresses(self, address):
"""Test parsing bare addresses returns None."""
result = parse_address(address)
assert result is None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need main function



if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
7 changes: 4 additions & 3 deletions python/ray/_private/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ray._private.services as services
import ray._private.utils as utils
import ray._private.worker
from ray._common.network_utils import build_address
from ray._private.state import GlobalState
from ray._raylet import GcsClientOptions
from ray.core.generated import common_pb2
Expand Down Expand Up @@ -68,11 +69,11 @@ def get_memory_info_reply(state, node_manager_address=None, node_manager_port=No
raylet = node
break
assert raylet is not None, "Every raylet is dead"
raylet_address = "{}:{}".format(
raylet_address = build_address(
raylet["NodeManagerAddress"], raylet["NodeManagerPort"]
)
else:
raylet_address = "{}:{}".format(node_manager_address, node_manager_port)
raylet_address = build_address(node_manager_address, node_manager_port)

channel = utils.init_grpc_channel(
raylet_address,
Expand All @@ -99,7 +100,7 @@ def node_stats(

# We can ask any Raylet for the global memory info.
assert node_manager_address is not None and node_manager_port is not None
raylet_address = "{}:{}".format(node_manager_address, node_manager_port)
raylet_address = build_address(node_manager_address, node_manager_port)
channel = utils.init_grpc_channel(
raylet_address,
options=[
Expand Down
3 changes: 2 additions & 1 deletion python/ray/_private/metrics_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)

import ray
from ray._common.network_utils import build_address
from ray._private.ray_constants import env_bool
from ray._private.telemetry.metric_cardinality import (
WORKER_ID_TAG_KEY,
Expand Down Expand Up @@ -780,7 +781,7 @@ def get_file_discovery_content(self):
"""Return the content for Prometheus service discovery."""
nodes = ray.nodes()
metrics_export_addresses = [
"{}:{}".format(node["NodeManagerAddress"], node["MetricsExportPort"])
build_address(node["NodeManagerAddress"], node["MetricsExportPort"])
for node in nodes
if node["alive"] is True
]
Expand Down
22 changes: 12 additions & 10 deletions python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ray
import ray._private.ray_constants as ray_constants
import ray._private.services
from ray._common.network_utils import build_address, parse_address
from ray._common.ray_constants import LOGGING_ROTATE_BACKUP_COUNT, LOGGING_ROTATE_BYTES
from ray._common.utils import try_to_create_directory
from ray._private.resource_and_label_spec import ResourceAndLabelSpec
Expand Down Expand Up @@ -109,7 +110,6 @@ def __init__(
# instance provided.
if len(external_redis) == 1:
external_redis.append(external_redis[0])
[primary_redis_ip, port] = external_redis[0].rsplit(":", 1)
ray_params.external_addresses = external_redis
ray_params.num_redis_shards = len(external_redis) - 1

Expand Down Expand Up @@ -196,8 +196,8 @@ def __init__(
assert not self._default_worker
self._webui_url = ray._private.services.get_webui_url_from_internal_kv()
else:
self._webui_url = (
f"{ray_params.dashboard_host}:{ray_params.dashboard_port}"
self._webui_url = build_address(
ray_params.dashboard_host, ray_params.dashboard_port
)

# It creates a session_dir.
Expand Down Expand Up @@ -421,14 +421,14 @@ def check_persisted_session_name(self):
@staticmethod
def validate_ip_port(ip_port):
"""Validates the address is in the ip:port format"""
_, _, port = ip_port.rpartition(":")
if port == ip_port:
parts = parse_address(ip_port)
if parts is None:
raise ValueError(f"Port is not specified for address {ip_port}")
try:
_ = int(port)
_ = int(parts[1])
except ValueError:
raise ValueError(
f"Unable to parse port number from {port} (full address = {ip_port})"
f"Unable to parse port number from {parts[1]} (full address = {ip_port})"
)

def check_version_info(self):
Expand Down Expand Up @@ -633,7 +633,7 @@ def runtime_env_agent_port(self):
@property
def runtime_env_agent_address(self):
"""Get the address that exposes runtime env agent as http"""
return f"http://{self._raylet_ip_address}:{self._runtime_env_agent_port}"
return f"http://{build_address(self._raylet_ip_address, self._runtime_env_agent_port)}"

@property
def dashboard_agent_listen_port(self):
Expand Down Expand Up @@ -941,7 +941,9 @@ def _prepare_socket_file(self, socket_path: str, default_prefix: str):
result = socket_path
if sys.platform == "win32":
if socket_path is None:
result = f"tcp://{self._localhost}:{self._get_unused_port()}"
result = (
f"tcp://{build_address(self._localhost, self._get_unused_port())}"
)
else:
if socket_path is None:
result = self._make_inc_temp(
Expand Down Expand Up @@ -1161,7 +1163,7 @@ def start_gcs_server(self):
# e.g. https://github.com/ray-project/ray/issues/15780
# TODO(mwtian): figure out a way to use 127.0.0.1 for local connection
# when possible.
self._gcs_address = f"{self._node_ip_address}:" f"{gcs_server_port}"
self._gcs_address = build_address(self._node_ip_address, gcs_server_port)

def start_raylet(
self,
Expand Down
Loading