diff --git a/examples/endpoints/streaming.py b/examples/endpoints/streaming.py index 363727e0..538a2e6a 100644 --- a/examples/endpoints/streaming.py +++ b/examples/endpoints/streaming.py @@ -3,7 +3,7 @@ import runpod # Set your global API key with `runpod config` or uncomment the line below: -# runpod.api_key = "YOUR_RUNPOD_API_KEY" +runpod.api_key = "YOUR_RUNPOD_API_KEY" endpoint = runpod.Endpoint("gwp4kx5yd3nur1") diff --git a/examples/endpoints/sxwl_test.py b/examples/endpoints/sxwl_test.py new file mode 100644 index 00000000..d740da48 --- /dev/null +++ b/examples/endpoints/sxwl_test.py @@ -0,0 +1,53 @@ +import runpod + +# 设置runpod的api key +runpod.api_key = "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI4NzVhYzI0ZTQ0NGM0YzE4OGI4OWM2YTNiYjU3ZTJkOSIsInN1YiI6InBsYXlncm91bmRAc3h3bC5haSIsInVzZXJfaWQiOiJ1c2VyLTdlNjg3ZWEwLTg0NGItNDJjMy05MDA1LWVjOWRkZjRhZTg2MyIsInVzZXJpZCI6MjU0LCJ1c2VybmFtZSI6InBsYXlncm91bmRAc3h3bC5haSJ9.Y9XGk2zshcxxy6VZFCeZBhbce9KACbz3U8q7cze-subIfCNaEIgTP_R_GWXBINuWPWDxmfQfVlHujli0Am35LQ" + +endpoint = runpod.Endpoint("INFERENCE") + +# 异步方法 +# run_request = endpoint.run( +# { +# "input": { +# "gpu_model": "NVIDIA-GeForce-RTX-3090", +# "model_category": "chat", +# "gpu_count": 1, +# "model_id": "model-storage-0ce92f029254ff34", +# "model_name":"google/gemma-2b-it", +# "model_size": 15065904829, +# "model_is_public": True, +# "model_template": "gemma", +# "min_instances": 1, +# "model_meta": "{\"template\":\"gemma\",\"category\":\"chat\", \"can_finetune\":true,\"can_inference\":true}", +# "max_instances": 1 +# } +# } +# ) + +# 打印推理的状态 +# print(run_request.status()) + +# 打印推理的结果,结果返回推理接口的地址 +# print(run_request.output()) + + + +run_request = endpoint.run_sync( + { + "input": { + "gpu_model": "NVIDIA-GeForce-RTX-3090", + "model_category": "chat", + "gpu_count": 1, + "model_id": "model-storage-0ce92f029254ff34", + "model_name":"google/gemma-2b-it", + "model_size": 15065904829, + "model_is_public": True, + "model_template": "gemma", + "min_instances": 1, + "model_meta": "{\"template\":\"gemma\",\"category\":\"chat\", \"can_finetune\":true,\"can_inference\":true}", + "max_instances": 1 + } + } +) + +print(run_request.output()) \ No newline at end of file diff --git a/runpod/endpoint/__init__.py b/runpod/endpoint/__init__.py index 8002ed17..56ec9275 100644 --- a/runpod/endpoint/__init__.py +++ b/runpod/endpoint/__init__.py @@ -2,4 +2,5 @@ from .asyncio.asyncio_runner import Endpoint as AsyncioEndpoint from .asyncio.asyncio_runner import Job as AsyncioJob -from .runner import Endpoint, Job +# from .runner import Endpoint, Job +from .sxwl import Endpoint diff --git a/runpod/endpoint/runner.py b/runpod/endpoint/runner.py index e4d93384..a3066ef7 100644 --- a/runpod/endpoint/runner.py +++ b/runpod/endpoint/runner.py @@ -1,254 +1,254 @@ -""" -RunPod | Python | Endpoint Runner -""" - -import time -from typing import Any, Dict, Optional - -import requests -from requests.adapters import HTTPAdapter, Retry - -from runpod.endpoint.helpers import ( - API_KEY_NOT_SET_MSG, - FINAL_STATES, - UNAUTHORIZED_MSG, - is_completed, -) - - -# ---------------------------------------------------------------------------- # -# Client # -# ---------------------------------------------------------------------------- # -class RunPodClient: - """A client for running endpoint calls.""" - - def __init__(self): - """ - Initialize a RunPodClient instance. - - Raises: - RuntimeError: If the API key has not been initialized. - """ - from runpod import ( # pylint: disable=import-outside-toplevel, cyclic-import - api_key, - endpoint_url_base, - ) - - if api_key is None: - raise RuntimeError(API_KEY_NOT_SET_MSG) - - self.rp_session = requests.Session() - retries = Retry(total=5, backoff_factor=1, status_forcelist=[408, 429]) - self.rp_session.mount("http://", HTTPAdapter(max_retries=retries)) - - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } - - self.endpoint_url_base = endpoint_url_base - - def _request( - self, method: str, endpoint: str, data: Optional[dict] = None, timeout: int = 10 - ): - """ - Make a request to the specified endpoint using the given HTTP method. - - Args: - method: The HTTP method to use ('GET' or 'POST'). - endpoint: The endpoint path to which the request will be made. - data: The JSON payload to send with the request. - timeout: The number of seconds to wait for the server to send data before giving up. - - Returns: - The JSON response from the server. - - Raises: - RuntimeError: If the response returns a 401 Unauthorized status. - requests.HTTPError: If the response contains an unsuccessful status code. - """ - url = f"{self.endpoint_url_base}/{endpoint}" - response = self.rp_session.request( - method, url, headers=self.headers, json=data, timeout=timeout - ) - - if response.status_code == 401: - raise RuntimeError(UNAUTHORIZED_MSG) - - response.raise_for_status() - return response.json() - - def post(self, endpoint: str, data: dict, timeout: int = 10): - """Post to the endpoint.""" - return self._request("POST", endpoint, data, timeout) - - def get(self, endpoint: str, timeout: int = 10): - """Get from the endpoint.""" - return self._request("GET", endpoint, timeout=timeout) - - -# ---------------------------------------------------------------------------- # -# Job # -# ---------------------------------------------------------------------------- # -class Job: - """Represents a job to be run on the RunPod service.""" - - def __init__(self, endpoint_id: str, job_id: str, client: RunPodClient): - """ - Initialize a Job instance with the given endpoint ID and job ID. - - Args: - endpoint_id: The identifier for the endpoint. - job_id: The identifier for the job. - client: An instance of the RunPodClient to make requests with. - """ - self.endpoint_id = endpoint_id - self.job_id = job_id - self.rp_client = client - - self.job_status = None - self.job_output = None - - def _fetch_job(self, source: str = "status") -> Dict[str, Any]: - """Returns the raw json of the status, raises an exception if invalid""" - status_url = f"{self.endpoint_id}/{source}/{self.job_id}" - job_state = self.rp_client.get(endpoint=status_url) - - if is_completed(job_state["status"]): - self.job_status = job_state["status"] - self.job_output = job_state.get("output", None) - - return job_state - - def status(self): - """Returns the status of the job request.""" - if self.job_status is not None: - return self.job_status - - return self._fetch_job()["status"] - - def output(self, timeout: int = 0) -> Any: - """ - Returns the output of the job request. - - Args: - timeout: The number of seconds to wait for the server to send data before giving up. - """ - if timeout > 0: - while not is_completed(self.status()): - time.sleep(1) - timeout -= 1 - if timeout <= 0: - raise TimeoutError("Job timed out.") - - if self.job_output is not None: - return self.job_output - - return self._fetch_job().get("output", None) - - def stream(self) -> Any: - """Returns a generator that yields the output of the job request.""" - while True: - time.sleep(1) - stream_partial = self._fetch_job(source="stream") - if ( - stream_partial["status"] not in FINAL_STATES - or len(stream_partial["stream"]) > 0 - ): - for chunk in stream_partial.get("stream", []): - yield chunk["output"] - elif stream_partial["status"] in FINAL_STATES: - break - - def cancel(self, timeout: int = 3) -> Any: - """ - Cancels the job and returns the result of the cancellation request. - - Args: - timeout: The number of seconds to wait for the server to respond before giving up. - """ - return self.rp_client.post( - f"{self.endpoint_id}/cancel/{self.job_id}", data=None, timeout=timeout - ) - - -# ---------------------------------------------------------------------------- # -# Endpoint # -# ---------------------------------------------------------------------------- # -class Endpoint: - """Manages an endpoint to run jobs on the RunPod service.""" - - def __init__(self, endpoint_id: str): - """ - Initialize an Endpoint instance with the given endpoint ID. - - Args: - endpoint_id: The identifier for the endpoint. - - Example: - >>> endpoint = runpod.Endpoint("ENDPOINT_ID") - >>> run_request = endpoint.run({"your_model_input_key": "your_model_input_value"}) - >>> print(run_request.status()) - >>> print(run_request.output()) - """ - self.endpoint_id = endpoint_id - self.rp_client = RunPodClient() - - def run(self, request_input: Dict[str, Any]) -> Job: - """ - Run the endpoint with the given input. - - Args: - request_input: The input to pass into the endpoint. - - Returns: - A Job instance for the run request. - """ - if not request_input.get("input"): - request_input = {"input": request_input} - - job_request = self.rp_client.post(f"{self.endpoint_id}/run", request_input) - return Job(self.endpoint_id, job_request["id"], self.rp_client) - - def run_sync( - self, request_input: Dict[str, Any], timeout: int = 86400 - ) -> Dict[str, Any]: - """ - Run the endpoint with the given input synchronously. - - Args: - request_input: The input to pass into the endpoint. - """ - if not request_input.get("input"): - request_input = {"input": request_input} - - job_request = self.rp_client.post( - f"{self.endpoint_id}/runsync", request_input, timeout=timeout - ) - - if job_request["status"] in FINAL_STATES: - return job_request.get("output", None) - - return Job(self.endpoint_id, job_request["id"], self.rp_client).output( - timeout=timeout - ) - - def health(self, timeout: int = 3) -> Dict[str, Any]: - """ - Check the health of the endpoint (number/state of workers, number/state of requests). - - Args: - timeout: The number of seconds to wait for the server to respond before giving up. - """ - return self.rp_client.get(f"{self.endpoint_id}/health", timeout=timeout) - - def purge_queue(self, timeout: int = 3) -> Dict[str, Any]: - """ - Purges the endpoint's job queue and returns the result of the purge request. - - Args: - timeout: The number of seconds to wait for the server to respond before giving up. - """ - return self.rp_client.post( - f"{self.endpoint_id}/purge-queue", data=None, timeout=timeout - ) +# """ +# RunPod | Python | Endpoint Runner +# """ + +# import time +# from typing import Any, Dict, Optional + +# import requests +# from requests.adapters import HTTPAdapter, Retry + +# from runpod.endpoint.helpers import ( +# API_KEY_NOT_SET_MSG, +# FINAL_STATES, +# UNAUTHORIZED_MSG, +# is_completed, +# ) + + +# # ---------------------------------------------------------------------------- # +# # Client # +# # ---------------------------------------------------------------------------- # +# class RunPodClient: +# """A client for running endpoint calls.""" + +# def __init__(self): +# """ +# Initialize a RunPodClient instance. + +# Raises: +# RuntimeError: If the API key has not been initialized. +# """ +# from runpod import ( # pylint: disable=import-outside-toplevel, cyclic-import +# api_key, +# endpoint_url_base, +# ) + +# if api_key is None: +# raise RuntimeError(API_KEY_NOT_SET_MSG) + +# self.rp_session = requests.Session() +# retries = Retry(total=5, backoff_factor=1, status_forcelist=[408, 429]) +# self.rp_session.mount("http://", HTTPAdapter(max_retries=retries)) + +# self.headers = { +# "Content-Type": "application/json", +# "Authorization": f"Bearer {api_key}", +# } + +# self.endpoint_url_base = endpoint_url_base + +# def _request( +# self, method: str, endpoint: str, data: Optional[dict] = None, timeout: int = 10 +# ): +# """ +# Make a request to the specified endpoint using the given HTTP method. + +# Args: +# method: The HTTP method to use ('GET' or 'POST'). +# endpoint: The endpoint path to which the request will be made. +# data: The JSON payload to send with the request. +# timeout: The number of seconds to wait for the server to send data before giving up. + +# Returns: +# The JSON response from the server. + +# Raises: +# RuntimeError: If the response returns a 401 Unauthorized status. +# requests.HTTPError: If the response contains an unsuccessful status code. +# """ +# url = f"{self.endpoint_url_base}/{endpoint}" +# response = self.rp_session.request( +# method, url, headers=self.headers, json=data, timeout=timeout +# ) + +# if response.status_code == 401: +# raise RuntimeError(UNAUTHORIZED_MSG) + +# response.raise_for_status() +# return response.json() + +# def post(self, endpoint: str, data: dict, timeout: int = 10): +# """Post to the endpoint.""" +# return self._request("POST", endpoint, data, timeout) + +# def get(self, endpoint: str, timeout: int = 10): +# """Get from the endpoint.""" +# return self._request("GET", endpoint, timeout=timeout) + + +# # ---------------------------------------------------------------------------- # +# # Job # +# # ---------------------------------------------------------------------------- # +# class Job: +# """Represents a job to be run on the RunPod service.""" + +# def __init__(self, endpoint_id: str, job_id: str, client: RunPodClient): +# """ +# Initialize a Job instance with the given endpoint ID and job ID. + +# Args: +# endpoint_id: The identifier for the endpoint. +# job_id: The identifier for the job. +# client: An instance of the RunPodClient to make requests with. +# """ +# self.endpoint_id = endpoint_id +# self.job_id = job_id +# self.rp_client = client + +# self.job_status = None +# self.job_output = None + +# def _fetch_job(self, source: str = "status") -> Dict[str, Any]: +# """Returns the raw json of the status, raises an exception if invalid""" +# status_url = f"{self.endpoint_id}/{source}/{self.job_id}" +# job_state = self.rp_client.get(endpoint=status_url) + +# if is_completed(job_state["status"]): +# self.job_status = job_state["status"] +# self.job_output = job_state.get("output", None) + +# return job_state + +# def status(self): +# """Returns the status of the job request.""" +# if self.job_status is not None: +# return self.job_status + +# return self._fetch_job()["status"] + +# def output(self, timeout: int = 0) -> Any: +# """ +# Returns the output of the job request. + +# Args: +# timeout: The number of seconds to wait for the server to send data before giving up. +# """ +# if timeout > 0: +# while not is_completed(self.status()): +# time.sleep(1) +# timeout -= 1 +# if timeout <= 0: +# raise TimeoutError("Job timed out.") + +# if self.job_output is not None: +# return self.job_output + +# return self._fetch_job().get("output", None) + +# def stream(self) -> Any: +# """Returns a generator that yields the output of the job request.""" +# while True: +# time.sleep(1) +# stream_partial = self._fetch_job(source="stream") +# if ( +# stream_partial["status"] not in FINAL_STATES +# or len(stream_partial["stream"]) > 0 +# ): +# for chunk in stream_partial.get("stream", []): +# yield chunk["output"] +# elif stream_partial["status"] in FINAL_STATES: +# break + +# def cancel(self, timeout: int = 3) -> Any: +# """ +# Cancels the job and returns the result of the cancellation request. + +# Args: +# timeout: The number of seconds to wait for the server to respond before giving up. +# """ +# return self.rp_client.post( +# f"{self.endpoint_id}/cancel/{self.job_id}", data=None, timeout=timeout +# ) + + +# # ---------------------------------------------------------------------------- # +# # Endpoint # +# # ---------------------------------------------------------------------------- # +# class Endpoint: +# """Manages an endpoint to run jobs on the RunPod service.""" + +# def __init__(self, endpoint_id: str): +# """ +# Initialize an Endpoint instance with the given endpoint ID. + +# Args: +# endpoint_id: The identifier for the endpoint. + +# Example: +# >>> endpoint = runpod.Endpoint("ENDPOINT_ID") +# >>> run_request = endpoint.run({"your_model_input_key": "your_model_input_value"}) +# >>> print(run_request.status()) +# >>> print(run_request.output()) +# """ +# self.endpoint_id = endpoint_id +# self.rp_client = RunPodClient() + +# def run(self, request_input: Dict[str, Any]) -> Job: +# """ +# Run the endpoint with the given input. + +# Args: +# request_input: The input to pass into the endpoint. + +# Returns: +# A Job instance for the run request. +# """ +# if not request_input.get("input"): +# request_input = {"input": request_input} + +# job_request = self.rp_client.post(f"{self.endpoint_id}/run", request_input) +# return Job(self.endpoint_id, job_request["id"], self.rp_client) + +# def run_sync( +# self, request_input: Dict[str, Any], timeout: int = 86400 +# ) -> Dict[str, Any]: +# """ +# Run the endpoint with the given input synchronously. + +# Args: +# request_input: The input to pass into the endpoint. +# """ +# if not request_input.get("input"): +# request_input = {"input": request_input} + +# job_request = self.rp_client.post( +# f"{self.endpoint_id}/runsync", request_input, timeout=timeout +# ) + +# if job_request["status"] in FINAL_STATES: +# return job_request.get("output", None) + +# return Job(self.endpoint_id, job_request["id"], self.rp_client).output( +# timeout=timeout +# ) + +# def health(self, timeout: int = 3) -> Dict[str, Any]: +# """ +# Check the health of the endpoint (number/state of workers, number/state of requests). + +# Args: +# timeout: The number of seconds to wait for the server to respond before giving up. +# """ +# return self.rp_client.get(f"{self.endpoint_id}/health", timeout=timeout) + +# def purge_queue(self, timeout: int = 3) -> Dict[str, Any]: +# """ +# Purges the endpoint's job queue and returns the result of the purge request. + +# Args: +# timeout: The number of seconds to wait for the server to respond before giving up. +# """ +# return self.rp_client.post( +# f"{self.endpoint_id}/purge-queue", data=None, timeout=timeout +# ) diff --git a/runpod/endpoint/sxwl.py b/runpod/endpoint/sxwl.py new file mode 100644 index 00000000..a04341b3 --- /dev/null +++ b/runpod/endpoint/sxwl.py @@ -0,0 +1,302 @@ +import json +import time +import requests +from typing import Optional, Dict, Any, List +from dataclasses import dataclass + + +# 全局变量用于存储需要清理的资源 +resources_to_cleanup = { + 'inference_services': set(), + 'finetune_jobs': set() +} + + +@dataclass +class APIConfig: + base_url: str + token: str + headers: Dict[str, str] + + @classmethod + def create_default(cls) -> 'APIConfig': + from runpod import ( # pylint: disable=import-outside-toplevel, cyclic-import + api_key, + endpoint_url_base, + ) + token = api_key + if token and isinstance(token, bytes): + # 如果是 bytes,解码为字符串 + token = token.decode('utf-8') + headers = { + 'Accept': 'application/json, text/plain, */*', + 'Authorization': f'Bearer {token}', + 'Content-Type': 'application/json', + 'Origin': endpoint_url_base, + 'Connection': 'keep-alive', + 'User-Agent': 'github-actions' + } + return cls(endpoint_url_base, token, headers) + +class SXWLClient: + def __init__(self, config: APIConfig): + self.config = config + + def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: + url = f"{self.config.base_url}/api{endpoint}" + response = requests.request(method, url, headers=self.config.headers, **kwargs) + response.raise_for_status() + return response + + def get_models(self) -> List[Dict[str, Any]]: + """获取可用的模型列表""" + try: + response = self._make_request('GET', '/resource/models') + data = response.json() + models = data.get('public_list', []) + data.get('user_list', []) + print(f"获取到 {len(models)} 个模型", flush=True) + return models + except Exception as e: + print(f"获取模型列表失败: {str(e)}", flush=True) + return [] + + def delete_inference_service(self, service_name: str) -> None: + try: + self._make_request('DELETE', '/job/inference', params={'service_name': service_name}) + print("推理服务删除成功", flush=True) + resources_to_cleanup['inference_services'].discard(service_name) + except Exception as e: + print(f"删除推理服务失败: {str(e)}", flush=True) + + def delete_finetune_job(self, finetune_id: str) -> None: + try: + self._make_request('POST', '/userJob/job_del', json={'job_id': finetune_id}) + print("微调任务删除成功", flush=True) + resources_to_cleanup['finetune_jobs'].discard(finetune_id) + except Exception as e: + print(f"删除微调任务失败: {str(e)}", flush=True) + +class InferenceService: + def __init__(self, client: SXWLClient): + self.client = client + self.service_name: Optional[str] = None + self.api_endpoint: Optional[str] = None + + def deploy(self, model_config: Dict[str, Any]) -> 'InferenceService': + response = self.client._make_request('POST', '/job/inference', json=model_config) + self.service_name = response.json()['service_name'] + print(f"服务名称: {self.service_name}", flush=True) + # 添加到需要清理的资源列表 + resources_to_cleanup['inference_services'].add(self.service_name) + return self + + def wait_until_complete(self) -> Dict[str, Any]: + """等待服务部署完成并返回结果""" + self._wait_for_ready() + return { + "service_name": self.service_name, + "api_endpoint": self.api_endpoint, + "status": "running" + } + + def status(self) -> str: + response = self.client._make_request('GET', '/job/inference') + status_json = response.json() + for item in status_json.get('data', []): + if item['service_name'] == self.service_name: + return item['status'] + return 'unknown' + + def output(self) -> Dict[str, Any]: + response = self.client._make_request('GET', '/job/inference') + result = {} + status_json = response.json() + for item in status_json.get('data', []): + if item['service_name'] == self.service_name: + if item['status'] == 'running': + result['chat_url'] = item['api'] + return result + return {'status': item['status']} + return {'status': 'unknown'} + + def _wait_for_ready(self, max_retries: int = 60, retry_interval: int = 30) -> None: + for attempt in range(max_retries): + response = self.client._make_request('GET', '/job/inference') + status_json = response.json() + + for item in status_json.get('data', []): + if item['service_name'] == self.service_name: + if item['status'] == 'running': + self.api_endpoint = item['api'] + print(f"服务已就绪: {item}", flush=True) + return + break + + print(f"服务启动中... ({attempt + 1}/{max_retries})", flush=True) + time.sleep(retry_interval) + + raise TimeoutError("服务启动超时") + + def chat(self, messages: list) -> Dict[str, Any]: + if not self.api_endpoint: + raise RuntimeError("服务尚未就绪") + + chat_url = self.api_endpoint + headers = {'accept': 'application/json', 'Content-Type': 'application/json'} + data = {"model": "/mnt/models", "messages": messages} + + response = requests.post(chat_url, headers=headers, json=data) + response.raise_for_status() + return response.json() + +class FinetuneJob: + def __init__(self, client: SXWLClient): + self.client = client + self.job_id: Optional[str] = None + self.adapter_id: Optional[str] = None + + def start(self, finetune_config: Dict[str, Any]) -> 'FinetuneJob': + response = self.client._make_request('POST', '/job/finetune', json=finetune_config) + self.job_id = response.json()['job_id'] + print(f"微调任务ID: {self.job_id}", flush=True) + # 添加到需要清理的资源列表 + resources_to_cleanup['finetune_jobs'].add(self.job_id) + return self + + def wait_until_complete(self) -> Dict[str, Any]: + """等待任务完成并返回结果""" + self._wait_for_completion() + self._get_adapter_id() + return { + "job_id": self.job_id, + "adapter_id": self.adapter_id, + "status": "succeeded" + } + + def _wait_for_completion(self, max_retries: int = 60, retry_interval: int = 30) -> None: + for _ in range(max_retries): + print(f"正在检查微调任务状态... (第 {_ + 1}/{max_retries} 次尝试)", flush=True) + response = self.client._make_request('GET', '/job/training', + params={'current': 1, 'size': 1000}) + + print(f"API响应: {response.json()}", flush=True) + for job in response.json().get('content', []): + if job['jobName'] == self.job_id: + status = job['status'] + print(f"微调状态: {status}", flush=True) + + if status == 'succeeded': + return + elif status in ['failed', 'error']: + raise RuntimeError("微调任务失败") + break + + time.sleep(retry_interval) + raise TimeoutError("微调任务超时") + + def _get_adapter_id(self) -> None: + response = self.client._make_request('GET', '/resource/adapters') + + for adapter in response.json().get('user_list', []): + try: + meta = json.loads(adapter.get('meta', '{}')) + if meta.get('finetune_id') == self.job_id: + self.adapter_id = adapter['id'] + print(f"适配器ID: {self.adapter_id}", flush=True) + return + except json.JSONDecodeError: + continue + + raise ValueError(f"未找到对应的适配器") + +# ---------------------------------------------------------------------------- # +# Endpoint # +# ---------------------------------------------------------------------------- # +class Endpoint: + """Manages an endpoint to run jobs on the sxwl.""" + + def __init__(self, endpoint_id: str): + """ + Initialize an Endpoint instance with the given endpoint ID. + + Args: + endpoint_id: The identifier for the endpoint. + + Example: + >>> endpoint = runpod.Endpoint("INFERENCE") + >>> run_request = endpoint.run({"your_model_input_key": "your_model_input_value"}) + >>> print(run_request.status()) + >>> print(run_request.output()) + """ + self.endpoint_id = endpoint_id + self.rp_client = SXWLClient(APIConfig.create_default()) + + def run(self, request_input: Dict[str, Any]) -> InferenceService: + """ + Run the endpoint with the given input. + + Args: + request_input: The input to pass into the endpoint. + + Returns: + An InferenceService instance for the run request. + """ + # if not request_input.get("input"): + # request_input = {"input": request_input} + + + inference = InferenceService(self.rp_client) + return inference.deploy(request_input.get("input")) + + def run_sync( + self, request_input: Dict[str, Any], timeout: int = 86400 + ) -> Dict[str, Any]: + """ + Run the endpoint with the given input synchronously. + + Args: + request_input: The input to pass into the endpoint. + timeout: The maximum time to wait for the result in seconds. + + Returns: + The output of the completed job. + """ + + config = APIConfig.create_default() + client = SXWLClient(config) + + inference = InferenceService(client) + # 启动任务 + inference = inference.deploy(request_input.get("input")) + # 等待任务完成 + return inference.wait_until_complete() + + def health(self, timeout: int = 3) -> Dict[str, Any]: + """ + Check the health of the endpoint (number/state of workers, number/state of requests). + + Args: + timeout: The number of seconds to wait for the server to respond before giving up. + """ + config = APIConfig.create_default() + client = SXWLClient(config) + try: + response = client._make_request('GET', '/job/health', timeout=timeout) + return response.json() + except Exception as e: + return {"status": "error", "message": str(e)} + + def purge_queue(self, timeout: int = 3) -> Dict[str, Any]: + """ + Purges the endpoint's job queue and returns the result of the purge request. + + Args: + timeout: The number of seconds to wait for the server to respond before giving up. + """ + config = APIConfig.create_default() + client = SXWLClient(config) + try: + response = client._make_request('POST', '/job/purge-queue', timeout=timeout) + return response.json() + except Exception as e: + return {"status": "error", "message": str(e)} \ No newline at end of file