diff --git a/Cargo.lock b/Cargo.lock index 5617eff..0060eb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -200,6 +200,16 @@ version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae44d1a3d5a19df61dd0c8beb138458ac2a53a7ac09eba97d55592540004306b" +[[package]] +name = "bytes" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" +dependencies = [ + "byteorder", + "iovec", +] + [[package]] name = "bytes" version = "1.0.1" @@ -227,6 +237,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" +dependencies = [ + "libc", + "num-integer", + "num-traits", + "serde", + "time", + "winapi", +] + [[package]] name = "cloudabi" version = "0.0.3" @@ -253,7 +277,7 @@ version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc4369b5e4c0cddf64ad8981c0111e7df4f7078f4d6ba98fb31f2e17c4c57b7e" dependencies = [ - "bytes", + "bytes 1.0.1", "futures-util", "memchr", "pin-project-lite", @@ -295,7 +319,55 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775" dependencies = [ "cfg-if 1.0.0", - "crossbeam-utils", + "crossbeam-utils 0.8.1", +] + +[[package]] +name = "crossbeam-deque" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils 0.7.2", + "maybe-uninit", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace" +dependencies = [ + "autocfg 1.0.1", + "cfg-if 0.1.10", + "crossbeam-utils 0.7.2", + "lazy_static", + "maybe-uninit", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-queue" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "774ba60a54c213d409d5353bda12d49cd68d14e45036a285234c8d6f91f92570" +dependencies = [ + "cfg-if 0.1.10", + "crossbeam-utils 0.7.2", + "maybe-uninit", +] + +[[package]] +name = "crossbeam-utils" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8" +dependencies = [ + "autocfg 1.0.1", + "cfg-if 0.1.10", + "lazy_static", ] [[package]] @@ -428,6 +500,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed34cd105917e91daa4da6b3728c47b068749d6a62c59811f06ed2ac71d9da7" +[[package]] +name = "futures" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7e4c2612746b0df8fed4ce0c69156021b704c9aefa360311c04e6e9e002eed" + [[package]] name = "futures" version = "0.3.12" @@ -557,7 +635,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b67e66362108efccd8ac053abafc8b7a8d86a37e6e48fc4f6f7485eb5e9e6a5" dependencies = [ - "bytes", + "bytes 1.0.1", "fnv", "futures-core", "futures-sink", @@ -585,7 +663,7 @@ checksum = "62689dc57c7456e69712607ffcbd0aa1dfcccf9af73727e9b25bc1825375cac3" dependencies = [ "base64", "bitflags", - "bytes", + "bytes 1.0.1", "headers-core", "http", "mime", @@ -604,9 +682,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aca5565f760fb5b220e499d72710ed156fdb74e631659e99377d9ebfbd13ae8" +checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c" dependencies = [ "libc", ] @@ -617,7 +695,7 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7245cd7449cc792608c3c8a9eaf69bd4eabbabf802713748fd739c98b82f0747" dependencies = [ - "bytes", + "bytes 1.0.1", "fnv", "itoa", ] @@ -628,7 +706,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2861bd27ee074e5ee891e8b539837a9430012e249d7f0ca2d795650f579c1994" dependencies = [ - "bytes", + "bytes 1.0.1", "http", ] @@ -675,7 +753,7 @@ version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12219dc884514cb4a6a03737f4413c0e01c23a1b059b0156004b23f1e19dccbe" dependencies = [ - "bytes", + "bytes 1.0.1", "futures-channel", "futures-core", "futures-util", @@ -735,7 +813,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413" dependencies = [ - "bytes", + "bytes 1.0.1", ] [[package]] @@ -747,6 +825,15 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "iovec" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e" +dependencies = [ + "libc", +] + [[package]] name = "ipnet" version = "2.3.0" @@ -834,12 +921,27 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" +[[package]] +name = "maybe-uninit" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00" + [[package]] name = "memchr" version = "2.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" +[[package]] +name = "memoffset" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "043175f069eda7b85febe4a74abbaeff828d9f8b448515d3151a14a3542811aa" +dependencies = [ + "autocfg 1.0.1", +] + [[package]] name = "mime" version = "0.3.16" @@ -963,6 +1065,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-integer" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" +dependencies = [ + "autocfg 1.0.1", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.14" @@ -1440,7 +1552,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a6ddfecac9391fed21cce10e83c65fa4abafd77df05c98b1c647c65374ce9b3" dependencies = [ "async-trait", - "bytes", + "bytes 1.0.1", "combine", "dtoa", "futures-util", @@ -1502,7 +1614,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd281b1030aa675fb90aa994d07187645bb3c8fc756ca766e7c3070b439de9de" dependencies = [ "base64", - "bytes", + "bytes 1.0.1", "encoding_rs", "futures-core", "futures-util", @@ -1579,6 +1691,19 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "rustacles-model" +version = "0.1.0" +source = "git+https://github.com/spec-tacles/rustacles#7957424348ac9f830ebdd14bd66759f73f33985a" +dependencies = [ + "chrono", + "serde", + "serde_derive", + "serde_json", + "serde_repr", + "tokio-fs", +] + [[package]] name = "rustls" version = "0.19.0" @@ -1793,7 +1918,8 @@ name = "spectacles-proxy" version = "0.1.0" dependencies = [ "anyhow", - "bytes", + "async-trait", + "bytes 1.0.1", "env_logger", "http", "humantime 2.1.0", @@ -1806,7 +1932,9 @@ dependencies = [ "reqwest", "rmp-serde", "rustacles-brokers", + "rustacles-model", "serde", + "serde_json", "serde_repr", "tokio", "tokio-stream", @@ -1940,7 +2068,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ca04cec6ff2474c638057b65798f60ac183e5e79d3448bb7163d36a39cff6ec" dependencies = [ "autocfg 1.0.1", - "bytes", + "bytes 1.0.1", "libc", "memchr", "mio", @@ -1952,6 +2080,38 @@ dependencies = [ "winapi", ] +[[package]] +name = "tokio-executor" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb2d1b8f4548dbf5e1f7818512e9c406860678f29c300cdf0ebac72d1a3a1671" +dependencies = [ + "crossbeam-utils 0.7.2", + "futures 0.1.30", +] + +[[package]] +name = "tokio-fs" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297a1206e0ca6302a0eed35b700d292b275256f596e2f3fea7729d5e629b6ff4" +dependencies = [ + "futures 0.1.30", + "tokio-io", + "tokio-threadpool", +] + +[[package]] +name = "tokio-io" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57fc868aae093479e3131e3d165c93b1c7474109d13c90ec0dda2a1bbfff0674" +dependencies = [ + "bytes 0.4.12", + "futures 0.1.30", + "log", +] + [[package]] name = "tokio-macros" version = "1.0.0" @@ -1985,6 +2145,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-threadpool" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df720b6581784c118f0eb4310796b12b1d242a7eb95f716a8367855325c25f89" +dependencies = [ + "crossbeam-deque", + "crossbeam-queue", + "crossbeam-utils 0.7.2", + "futures 0.1.30", + "lazy_static", + "log", + "num_cpus", + "slab", + "tokio-executor", +] + [[package]] name = "tokio-tungstenite" version = "0.13.0" @@ -2004,7 +2181,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12ae4751faa60b9f96dd8344d74592e5a17c0c9a220413dbc6942d14139bbfcc" dependencies = [ - "bytes", + "bytes 1.0.1", "futures-core", "futures-sink", "log", @@ -2073,7 +2250,7 @@ checksum = "8ada8297e8d70872fa9a551d93250a9f407beb9f37ef86494eb20012a2ff7c24" dependencies = [ "base64", "byteorder", - "bytes", + "bytes 1.0.1", "http", "httparse", "input_buffer", @@ -2193,8 +2370,8 @@ name = "warp" version = "0.2.5" source = "git+https://github.com/seanmonstar/warp.git#ffefea08050ffe8022d3391f4bd5e5ab4e95d7c9" dependencies = [ - "bytes", - "futures", + "bytes 1.0.1", + "futures 0.3.12", "headers", "http", "hyper", diff --git a/Cargo.toml b/Cargo.toml index b379fa7..18651f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ path = "src/main.rs" [dependencies] anyhow = "1.0" +async-trait = "0.1" bytes = { version = "1.0", features = ["serde"] } env_logger = "0.7" http = "0.2" @@ -22,7 +23,9 @@ log = "0.4" prometheus = { version = "0.11", optional = true } rmp-serde = "0.14" rustacles-brokers = { git = "https://github.com/spec-tacles/rustacles" } +rustacles-model = { git = "https://github.com/spec-tacles/rustacles" } serde = "1.0" +serde_json = "1.0" serde_repr = "0.1" tokio-stream = "0.1" toml = "0.5" diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..2024f69 --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,20 @@ +use anyhow::Result; +use async_trait::async_trait; +use rustacles_model::{ + channel::Channel, guild::Guild, message::Message, presence::Presence, voice::VoiceState, + Snowflake, +}; + +pub mod redis; + +#[async_trait] +pub trait Cache { + async fn get(&self, id: Snowflake) -> Result>; + async fn save(&self, item: T) -> Result<()>; + async fn delete(&self, id: Snowflake) -> Result<()>; +} + +pub trait DiscordCache: + Cache + Cache + Cache + Cache + Cache +{ +} diff --git a/src/cache/redis.rs b/src/cache/redis.rs new file mode 100644 index 0000000..a75e863 --- /dev/null +++ b/src/cache/redis.rs @@ -0,0 +1,77 @@ +use super::{Cache, DiscordCache}; +use anyhow::Result; +use async_trait::async_trait; +use lazy_static::lazy_static; +use redis::{Client, Script}; +use rustacles_model::{channel::Channel, guild::Guild, Snowflake}; +use serde_json::{from_str, to_vec}; + +lazy_static! { + static ref SAVE_GUILD: Script = Script::new(include_str!("scripts/save_guild.lua")); + static ref DELETE_GUILD: Script = Script::new(include_str!("scripts/delete_guild.lua")); +} + +#[async_trait] +impl Cache for Client { + async fn get(&self, id: Snowflake) -> Result> { + let redis = self.clone(); + let guild_str: Option = redis::cmd("JSON.GET") + .arg(format!("guilds.{}", id)) + .arg(".") + .query_async(&mut redis.get_async_connection().await?) + .await?; + + Ok(guild_str.map(|s| from_str(&s)).transpose()?) + } + + async fn save(&self, item: Guild) -> Result<()> { + let redis = self.clone(); + let guild_vec = to_vec(&item)?; + let mut cmd = SAVE_GUILD.key(format!("guilds.{}", item.id)); + cmd.arg(guild_vec); + + for channel in item.channels { + let channel_vec = to_vec(&channel)?; + cmd.key(format!("channels.{}", channel.id)).arg(channel_vec); + } + + cmd.invoke_async::<_, redis::Value>(&mut redis.get_async_connection().await?) + .await?; + Ok(()) + } + + async fn delete(&self, id: Snowflake) -> Result<()> { + let redis = self.clone(); + + let maybe_guild: Option = Cache::::get(&redis, id).await?; + + if let Some(guild) = maybe_guild { + let mut cmd = DELETE_GUILD.key(format!("guilds.{}", guild.id)); + for channel in guild.channels { + cmd.key(format!("channels.{}", channel.id)); + } + + cmd.invoke_async::<_, redis::Value>(&mut redis.get_async_connection().await?) + .await?; + } + + Ok(()) + } +} + +#[async_trait] +impl Cache for Client { + async fn get(&self, id: Snowflake) -> Result> { + todo!() + } + + async fn save(&self, item: Channel) -> Result<()> { + todo!() + } + + async fn delete(&self, id: Snowflake) -> Result<()> { + todo!() + } +} + +// impl DiscordCache for RedisCache {} diff --git a/src/cache/scripts/delete_guild.lua b/src/cache/scripts/delete_guild.lua new file mode 100644 index 0000000..3e4b4be --- /dev/null +++ b/src/cache/scripts/delete_guild.lua @@ -0,0 +1,7 @@ +local guild_key = KEYS[1] + +redis.call("DEL", guild_key) + +for i = 1, table.maxn(ARGS) do + redis.call("DEL", KEYS[i]) +end diff --git a/src/cache/scripts/save_guild.lua b/src/cache/scripts/save_guild.lua new file mode 100644 index 0000000..2dad1df --- /dev/null +++ b/src/cache/scripts/save_guild.lua @@ -0,0 +1,8 @@ +local guild_key = KEYS[1] +local guild = ARGS[1] + +redis.call("JSON.SET", guild_key, ".", guild) + +for i = 2, table.maxn(ARGS) do + redis.call("JSON.SET", KEYS[i], ".", ARGS[i]) +end diff --git a/src/lib.rs b/src/lib.rs index 2ba0715..01e8d8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![feature(iterator_fold_self)] +pub mod cache; #[cfg(feature = "metrics")] pub mod metrics; pub mod models; diff --git a/src/ratelimiter.rs b/src/ratelimiter.rs index bcf6629..49e4649 100644 --- a/src/ratelimiter.rs +++ b/src/ratelimiter.rs @@ -1,28 +1,29 @@ use anyhow::Result; +use async_trait::async_trait; use reqwest::{header::HeaderMap, Response}; -use std::{future::Future, ops::Deref, pin::Pin, str::FromStr}; +use std::{ops::Deref, str::FromStr}; pub mod local; #[cfg(feature = "redis-ratelimiter")] pub mod redis; -pub type FutureResult = Pin> + Send>>; - +#[async_trait] pub trait Ratelimiter { - fn claim(&self, bucket: String) -> FutureResult<()>; - fn release(&self, bucket: String, info: RatelimitInfo) -> FutureResult<()>; + async fn claim(&self, bucket: String) -> Result<()>; + async fn release(&self, bucket: String, info: RatelimitInfo) -> Result<()>; } +#[async_trait] impl Ratelimiter for T where - T: Deref, + T: Deref + Send + Sync + 'static, { - fn claim(&self, bucket: String) -> FutureResult<()> { - Ratelimiter::claim(self.deref(), bucket) + async fn claim(&self, bucket: String) -> Result<()> { + Ratelimiter::claim(self.deref(), bucket).await } - fn release(&self, bucket: String, info: RatelimitInfo) -> FutureResult<()> { - Ratelimiter::release(self.deref(), bucket, info) + async fn release(&self, bucket: String, info: RatelimitInfo) -> Result<()> { + Ratelimiter::release(self.deref(), bucket, info).await } } diff --git a/src/ratelimiter/local.rs b/src/ratelimiter/local.rs index 89dcb56..c39cfc0 100644 --- a/src/ratelimiter/local.rs +++ b/src/ratelimiter/local.rs @@ -1,5 +1,6 @@ -use super::{FutureResult, RatelimitInfo, Ratelimiter}; -use anyhow::anyhow; +use super::{RatelimitInfo, Ratelimiter}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; use log::debug; use std::{ collections::HashMap, @@ -40,96 +41,93 @@ pub struct LocalRatelimiter { buckets: Arc>>>, } +#[async_trait] impl Ratelimiter for LocalRatelimiter { - fn claim(&self, bucket_name: String) -> FutureResult<()> { + async fn claim(&self, bucket_name: String) -> Result<()> { let buckets = Arc::clone(&self.buckets); - Box::pin(async move { - let mut claim = buckets.write().await; - let bucket = Arc::clone(claim.entry(bucket_name.clone()).or_default()); - drop(claim); + let mut claim = buckets.write().await; + let bucket = Arc::clone(claim.entry(bucket_name.clone()).or_default()); + drop(claim); - bucket.ready.acquire().await?.forget(); + bucket.ready.acquire().await?.forget(); - debug!("Acquired lock for \"{}\"", &bucket_name); - Ok(()) - }) + debug!("Acquired lock for \"{}\"", &bucket_name); + Ok(()) } - fn release(&self, bucket_name: String, info: RatelimitInfo) -> FutureResult<()> { + async fn release(&self, bucket_name: String, info: RatelimitInfo) -> Result<()> { let buckets = Arc::clone(&self.buckets); let now = Instant::now(); - Box::pin(async move { - debug!("Releasing \"{}\"", &bucket_name); + debug!("Releasing \"{}\"", &bucket_name); - let bucket = Arc::clone( - buckets - .read() - .await - .get(&bucket_name) - .ok_or(anyhow!("Attempted to release before claim"))?, - ); + let bucket = Arc::clone( + buckets + .read() + .await + .get(&bucket_name) + .ok_or(anyhow!("Attempted to release before claim"))?, + ); - let mut maybe_sender = bucket.new_timeout.lock().await; + let mut maybe_sender = bucket.new_timeout.lock().await; - if let None = &*maybe_sender { - debug!("No timeout: releasing \"{}\" immediately", &bucket_name); - bucket.ready.add_permits(1); - } + if let None = &*maybe_sender { + debug!("No timeout: releasing \"{}\" immediately", &bucket_name); + bucket.ready.add_permits(1); + } + + if let Some(resets_in) = info.resets_in { + let duration = Duration::from_millis(resets_in); + + debug!( + "Received timeout of {:?} for \"{}\"", + duration, &bucket_name + ); - if let Some(resets_in) = info.resets_in { - let duration = Duration::from_millis(resets_in); - - debug!( - "Received timeout of {:?} for \"{}\"", - duration, &bucket_name - ); - - match &mut *maybe_sender { - Some(sender) => { - debug!("Resetting expiration for \"{}\"", &bucket_name); - sender.send(now + duration).await?; - } - None => { - debug!("Creating new expiration for \"{}\"", &bucket_name); - let mut delay = sleep(duration); - let (sender, mut receiver) = mpsc::channel(1); - let timeout_bucket = Arc::clone(&bucket); - let bucket_name = bucket_name.clone(); - spawn(async move { - loop { - select! { - Some(new_instant) = receiver.recv() => { - debug!("Updating timeout for \"{}\" to {:?}", &bucket_name, new_instant); - delay = sleep_until(new_instant); - }, - _ = delay => { - debug!("Releasing \"{}\" after timeout", &bucket_name); - let size = timeout_bucket.size.load(Ordering::SeqCst); - timeout_bucket.ready.add_permits(size); - *timeout_bucket.new_timeout.lock().await = None; - break; - } + match &mut *maybe_sender { + Some(sender) => { + debug!("Resetting expiration for \"{}\"", &bucket_name); + sender.send(now + duration).await?; + } + None => { + debug!("Creating new expiration for \"{}\"", &bucket_name); + let mut delay = sleep(duration); + let (sender, mut receiver) = mpsc::channel(1); + let timeout_bucket = Arc::clone(&bucket); + let bucket_name = bucket_name.clone(); + spawn(async move { + loop { + select! { + Some(new_instant) = receiver.recv() => { + debug!("Updating timeout for \"{}\" to {:?}", &bucket_name, new_instant); + delay = sleep_until(new_instant); + }, + _ = delay => { + debug!("Releasing \"{}\" after timeout", &bucket_name); + let size = timeout_bucket.size.load(Ordering::SeqCst); + timeout_bucket.ready.add_permits(size); + *timeout_bucket.new_timeout.lock().await = None; + break; } } - }); - *maybe_sender = Some(sender); - } + } + }); + *maybe_sender = Some(sender); } } + } - if let Some(size) = info.limit { - let old_size = bucket.size.swap(size, Ordering::SeqCst); - let diff = size - old_size; - debug!( - "New bucket size for \"{}\": {} (changing permits by {})", - &bucket_name, size, diff - ); - bucket.ready.add_permits(diff); - } + if let Some(size) = info.limit { + let old_size = bucket.size.swap(size, Ordering::SeqCst); + let diff = size - old_size; + debug!( + "New bucket size for \"{}\": {} (changing permits by {})", + &bucket_name, size, diff + ); + bucket.ready.add_permits(diff); + } - Ok(()) - }) + Ok(()) } } diff --git a/src/ratelimiter/redis.rs b/src/ratelimiter/redis.rs index 9591c2c..3050ee0 100644 --- a/src/ratelimiter/redis.rs +++ b/src/ratelimiter/redis.rs @@ -1,5 +1,6 @@ -use super::{FutureResult, RatelimitInfo, Ratelimiter}; +use super::{RatelimitInfo, Ratelimiter}; use anyhow::Result; +use async_trait::async_trait; use lazy_static::lazy_static; use log::debug; use redis::{ @@ -50,55 +51,52 @@ impl RedisRatelimiter { } } +#[async_trait] impl Ratelimiter for RedisRatelimiter { - fn claim(&self, bucket: String) -> FutureResult<()> { + async fn claim(&self, bucket: String) -> Result<()> { let mut conn = self.redis.clone(); let mut rcv = self.ready_publisher.subscribe(); - Box::pin(async move { - 'outer: loop { - let expiration: isize = CLAIM_SCRIPT - .key(&bucket) - .key(bucket.to_string() + "_size") - .invoke_async(&mut conn) - .await?; - - debug!("Received expiration of {}ms for \"{}\"", expiration, bucket); - - if expiration.is_positive() { - tokio::time::sleep(Duration::from_millis(expiration as u64)).await; - continue; - } + 'outer: loop { + let expiration: isize = CLAIM_SCRIPT + .key(&bucket) + .key(bucket.to_string() + "_size") + .invoke_async(&mut conn) + .await?; - if expiration == 0 { - break; - } + debug!("Received expiration of {}ms for \"{}\"", expiration, bucket); + + if expiration.is_positive() { + tokio::time::sleep(Duration::from_millis(expiration as u64)).await; + continue; + } - loop { - let opened_bucket = rcv.recv().await?; - if opened_bucket == bucket { - continue 'outer; - } + if expiration == 0 { + break; + } + + loop { + let opened_bucket = rcv.recv().await?; + if opened_bucket == bucket { + continue 'outer; } } + } - Ok(()) - }) + Ok(()) } - fn release(&self, bucket: String, info: RatelimitInfo) -> FutureResult<()> { + async fn release(&self, bucket: String, info: RatelimitInfo) -> Result<()> { let mut conn = self.redis.clone(); - Box::pin(async move { - RELEASE_SCRIPT - .key(&bucket) - .key(bucket.to_string() + "_size") - .key(NOTIFY_KEY) - .arg(info.limit.unwrap_or(0)) - .arg(info.resets_in.unwrap_or(0)) - .invoke_async(&mut conn) - .await?; - - Ok(()) - }) + RELEASE_SCRIPT + .key(&bucket) + .key(bucket.to_string() + "_size") + .key(NOTIFY_KEY) + .arg(info.limit.unwrap_or(0)) + .arg(info.resets_in.unwrap_or(0)) + .invoke_async(&mut conn) + .await?; + + Ok(()) } }