roy_cli/
lib.rs

1// Copyright 2025 Massimiliano Pippi
2// SPDX-License-Identifier: MIT
3
4use axum::{
5    extract::State,
6    http::{Request, Uri},
7    middleware::{self, Next},
8    response::Response,
9    routing::post,
10    Router,
11};
12use clap::Parser;
13use clap_verbosity_flag::Verbosity;
14use colored::Colorize;
15use std::net::{IpAddr, SocketAddr};
16use std::time::Duration;
17use tower_http::timeout::TimeoutLayer;
18
19pub mod chat_completions;
20pub mod responses;
21pub mod server_state;
22use crate::server_state::ServerState;
23
24#[derive(Parser, Clone)]
25#[command(name = "roy")]
26#[command(version = env!("CARGO_PKG_VERSION"))]
27#[command(
28    about = "A HTTP server compatible with the OpenAI platform format that simulates errors and rate limit data"
29)]
30pub struct Args {
31    #[command(flatten)]
32    pub verbosity: Verbosity,
33
34    #[arg(long, help = "Port to listen on", default_value = "8000")]
35    pub port: u16,
36
37    #[arg(long, help = "Address to listen on", default_value = "127.0.0.1")]
38    pub address: IpAddr,
39
40    #[arg(
41        long,
42        help = "Length of response (fixed number or range like '10:100')",
43        default_value = "250"
44    )]
45    pub response_length: Option<String>,
46
47    #[arg(long, help = "HTTP error code to return")]
48    pub error_code: Option<u16>,
49
50    #[arg(long, help = "Error rate percentage (0-100)")]
51    pub error_rate: Option<u32>,
52
53    #[arg(
54        long,
55        help = "Maximum number of requests per minute",
56        default_value = "500"
57    )]
58    pub rpm: u32,
59
60    #[arg(
61        long,
62        help = "Maximum number of tokens per minute",
63        default_value = "30000"
64    )]
65    pub tpm: u32,
66
67    #[arg(
68        long,
69        help = "Slowdown in milliseconds (fixed number or range like '10:100')"
70    )]
71    pub slowdown: Option<String>,
72
73    #[arg(long, help = "Timeout in milliseconds")]
74    pub timeout: Option<u64>,
75}
76
77pub async fn not_found(uri: Uri) -> (axum::http::StatusCode, String) {
78    log::warn!("Path not found: {}", uri.path());
79    (axum::http::StatusCode::NOT_FOUND, "Not Found".to_string())
80}
81
82pub async fn run(args: Args) -> anyhow::Result<()> {
83    let state = ServerState::new(args.clone());
84
85    async fn slowdown(
86        State(state): State<ServerState>,
87        req: Request<axum::body::Body>,
88        next: Next,
89    ) -> Response {
90        let slowdown = state.get_slodown_ms();
91        log::debug!("Slowing down request by {}ms", slowdown);
92        tokio::time::sleep(std::time::Duration::from_millis(slowdown)).await;
93        next.run(req).await
94    }
95
96    let mut app = Router::new()
97        .route(
98            "/v1/chat/completions",
99            post(chat_completions::chat_completions),
100        )
101        .route("/v1/responses", post(responses::responses))
102        .route_layer(middleware::from_fn_with_state(state.clone(), slowdown))
103        .fallback(not_found)
104        .with_state(state);
105
106    if let Some(timeout) = args.timeout {
107        app = app.layer(TimeoutLayer::new(Duration::from_millis(timeout)));
108    }
109
110    let addr = SocketAddr::new(args.address, args.port);
111    let listener = tokio::net::TcpListener::bind(addr).await?;
112
113    println!(
114        "Roy server running on {}",
115        format!("http://{}", addr).blue()
116    );
117
118    axum::serve(listener, app).await?;
119
120    Ok(())
121}