1use axum::{
5 extract::State,
6 http::{Request, Uri},
7 middleware::{self, Next},
8 response::Response,
9 routing::post,
10 Router,
11};
12use clap::Parser;
13use clap_verbosity_flag::Verbosity;
14use colored::Colorize;
15use std::net::{IpAddr, SocketAddr};
16use std::time::Duration;
17use tower_http::timeout::TimeoutLayer;
18
19pub mod chat_completions;
20pub mod responses;
21pub mod server_state;
22use crate::server_state::ServerState;
23
24#[derive(Parser, Clone)]
25#[command(name = "roy")]
26#[command(version = env!("CARGO_PKG_VERSION"))]
27#[command(
28 about = "A HTTP server compatible with the OpenAI platform format that simulates errors and rate limit data"
29)]
30pub struct Args {
31 #[command(flatten)]
32 pub verbosity: Verbosity,
33
34 #[arg(long, help = "Port to listen on", default_value = "8000")]
35 pub port: u16,
36
37 #[arg(long, help = "Address to listen on", default_value = "127.0.0.1")]
38 pub address: IpAddr,
39
40 #[arg(
41 long,
42 help = "Length of response (fixed number or range like '10:100')",
43 default_value = "250"
44 )]
45 pub response_length: Option<String>,
46
47 #[arg(long, help = "HTTP error code to return")]
48 pub error_code: Option<u16>,
49
50 #[arg(long, help = "Error rate percentage (0-100)")]
51 pub error_rate: Option<u32>,
52
53 #[arg(
54 long,
55 help = "Maximum number of requests per minute",
56 default_value = "500"
57 )]
58 pub rpm: u32,
59
60 #[arg(
61 long,
62 help = "Maximum number of tokens per minute",
63 default_value = "30000"
64 )]
65 pub tpm: u32,
66
67 #[arg(
68 long,
69 help = "Slowdown in milliseconds (fixed number or range like '10:100')"
70 )]
71 pub slowdown: Option<String>,
72
73 #[arg(long, help = "Timeout in milliseconds")]
74 pub timeout: Option<u64>,
75}
76
77pub async fn not_found(uri: Uri) -> (axum::http::StatusCode, String) {
78 log::warn!("Path not found: {}", uri.path());
79 (axum::http::StatusCode::NOT_FOUND, "Not Found".to_string())
80}
81
82pub async fn run(args: Args) -> anyhow::Result<()> {
83 let state = ServerState::new(args.clone());
84
85 async fn slowdown(
86 State(state): State<ServerState>,
87 req: Request<axum::body::Body>,
88 next: Next,
89 ) -> Response {
90 let slowdown = state.get_slodown_ms();
91 log::debug!("Slowing down request by {}ms", slowdown);
92 tokio::time::sleep(std::time::Duration::from_millis(slowdown)).await;
93 next.run(req).await
94 }
95
96 let mut app = Router::new()
97 .route(
98 "/v1/chat/completions",
99 post(chat_completions::chat_completions),
100 )
101 .route("/v1/responses", post(responses::responses))
102 .route_layer(middleware::from_fn_with_state(state.clone(), slowdown))
103 .fallback(not_found)
104 .with_state(state);
105
106 if let Some(timeout) = args.timeout {
107 app = app.layer(TimeoutLayer::new(Duration::from_millis(timeout)));
108 }
109
110 let addr = SocketAddr::new(args.address, args.port);
111 let listener = tokio::net::TcpListener::bind(addr).await?;
112
113 println!(
114 "Roy server running on {}",
115 format!("http://{}", addr).blue()
116 );
117
118 axum::serve(listener, app).await?;
119
120 Ok(())
121}