diff --git a/README.md b/README.md index 512f3f9..813eec9 100644 --- a/README.md +++ b/README.md @@ -510,8 +510,6 @@ Usage of /home/user/go/bin/dumbproxy: Unix domain socket to listen to, overrides bind-address if set. -bw-limit uint per-user bandwidth limit in bytes per second - -bw-limit-buckets uint - number of buckets of bandwidth limit (default 1048576) -bw-limit-burst int allowed burst size for bandwidth limit, how many "tokens" can fit into leaky bucket -bw-limit-separate diff --git a/forward/bwlimit.go b/forward/bwlimit.go index e3ec15b..81567af 100644 --- a/forward/bwlimit.go +++ b/forward/bwlimit.go @@ -4,44 +4,54 @@ import ( "context" "errors" "io" + "math/rand/v2" + "sync" "time" - "github.com/zeebo/xxh3" + "github.com/ajwerner/orderstat" "github.com/SenseUnit/dumbproxy/rate" ) const copyChunkSize = 128 * 1024 +type treeItem struct { + key string + mux sync.RWMutex + ul *rate.Limiter + dl *rate.Limiter +} + +func (i *treeItem) Less(other orderstat.Item) bool { + return other.(*treeItem).key > i.key +} + +func (i *treeItem) rLock() { + i.mux.RLock() +} + +func (i *treeItem) rUnlock() { + i.mux.RUnlock() +} + +func (i *treeItem) tryLock() bool { + return i.mux.TryLock() +} + type BWLimit struct { - limit rate.Limit - burst int64 - d []rate.Limiter - u []rate.Limiter + mux sync.Mutex + m *orderstat.Tree + bps float64 + burst int64 + separate bool } -func NewBWLimit(bytesPerSecond float64, burst int64, buckets uint, separate bool) *BWLimit { - if buckets == 0 { - buckets = 1 - } - burst = max(copyChunkSize, burst) - lim := *(rate.NewLimiter(burst)) - d := make([]rate.Limiter, buckets) - for i := range d { - d[i] = lim - } - u := d - if separate { - u = make([]rate.Limiter, buckets) - for i := range u { - u[i] = lim - } - } +func NewBWLimit(bytesPerSecond float64, burst int64, separate bool) *BWLimit { return &BWLimit{ - limit: rate.Limit(bytesPerSecond), - burst: burst, - d: d, - u: u, + m: orderstat.NewTree(), + bps: bytesPerSecond, + burst: burst, + separate: separate, } } @@ -53,7 +63,7 @@ func (l *BWLimit) copy(ctx context.Context, rl *rate.Limiter, dst io.Writer, src var n int64 for { t := time.Now() - r := rl.ReserveN(l.limit, l.burst, t, copyChunkSize) + r := rl.ReserveN(t, copyChunkSize) if !r.OK() { err = errors.New("can't get rate limit reservation") return @@ -70,9 +80,9 @@ func (l *BWLimit) copy(ctx context.Context, rl *rate.Limiter, dst io.Writer, src n, err = io.Copy(dst, lim) written += n if n < copyChunkSize { - r.CancelAt(l.limit, l.burst, t) + r.CancelAt(t) if n > 0 { - rl.ReserveN(l.limit, l.burst, t, n) + rl.ReserveN(t, n) } } if err != nil { @@ -104,21 +114,70 @@ func (l *BWLimit) futureCopyAndCloseWrite(ctx context.Context, c chan<- error, r close(c) } -func (l *BWLimit) getRatelimiters(username string) (*rate.Limiter, *rate.Limiter) { - idx := int(hashUsername(username, uint64(len(l.d)))) - return &(l.d[idx]), &(l.u[idx]) +func (l *BWLimit) newTreeItem(username string) *treeItem { + ul := rate.NewLimiter(rate.Limit(l.bps), max(copyChunkSize, l.burst)) + dl := ul + if l.separate { + dl = rate.NewLimiter(rate.Limit(l.bps), max(copyChunkSize, l.burst)) + } + return &treeItem{ + key: username, + ul: ul, + dl: dl, + } +} + +const randomEvictions = 2 + +func (l *BWLimit) evictRandom() { + for _ = range randomEvictions { + n := l.m.Len() + if n == 0 { + return + } + item := l.m.Select(rand.IntN(n)) + if item == nil { + panic("random tree sampling failed") + } + ti := item.(*treeItem) + if ti.tryLock() { + if ti.ul.Tokens() >= float64(ti.ul.Burst()) && ti.dl.Tokens() >= float64(ti.dl.Burst()) { + // RL is full and nobody touches it. Removing... + l.m.Delete(item) + } + } + } +} + +func (l *BWLimit) getRatelimiters(username string) *treeItem { + l.mux.Lock() + defer l.mux.Unlock() + item := l.m.Get(&treeItem{ + key: username, + }) + if item == nil { + ti := l.newTreeItem(username) + ti.rLock() + l.m.ReplaceOrInsert(ti) + l.evictRandom() + return ti + } + ti := item.(*treeItem) + ti.rLock() + return ti } func (l *BWLimit) PairConnections(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error { - dl, ul := l.getRatelimiters(username) + ti := l.getRatelimiters(username) + defer ti.rUnlock() var err error i2oErr := make(chan error, 1) o2iErr := make(chan error, 1) ctxErr := ctx.Done() - go l.futureCopyAndCloseWrite(ctx, i2oErr, ul, outgoing, incoming) - go l.futureCopyAndCloseWrite(ctx, o2iErr, dl, incoming, outgoing) + go l.futureCopyAndCloseWrite(ctx, i2oErr, ti.ul, outgoing, incoming) + go l.futureCopyAndCloseWrite(ctx, o2iErr, ti.dl, incoming, outgoing) // do while we're listening to children channels for i2oErr != nil || o2iErr != nil { @@ -145,33 +204,3 @@ func (l *BWLimit) PairConnections(ctx context.Context, username string, incoming return err } - -func hashUsername(s string, nslots uint64) uint64 { - if nslots == 0 { - panic("number of slots can't be zero") - } - - hash := xxh3.New() - iv := []byte{0} - - if nslots&(nslots-1) == 0 { - hash.Write(iv) - hash.Write([]byte(s)) - return hash.Sum64() & (nslots - 1) - } - - minBiased := -((-nslots) % nslots) // == 2**64 - (2**64%nslots) - - var hv uint64 - for { - hash.Write(iv) - hash.Write([]byte(s)) - hv = hash.Sum64() - if hv < minBiased { - break - } - iv[0]++ - hash.Reset() - } - return hv % nslots -} diff --git a/go.mod b/go.mod index 9441c9d..a968ec6 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.24.6 require ( github.com/Snawoot/uniqueslice v0.1.1 + github.com/ajwerner/orderstat v0.0.0-20200914031159-0ebfd67afbea github.com/coreos/go-systemd/v22 v22.6.0 github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 github.com/hashicorp/go-multierror v1.1.1 @@ -17,7 +18,6 @@ require ( github.com/refraction-networking/utls v1.8.1 github.com/tg123/go-htpasswd v1.2.4 github.com/things-go/go-socks5 v0.1.0 - github.com/zeebo/xxh3 v1.0.2 golang.org/x/crypto v0.44.0 golang.org/x/crypto/x509roots/fallback v0.0.0-20251112184832-bcf6a849efcf golang.org/x/net v0.47.0 @@ -34,7 +34,6 @@ require ( github.com/google/pprof v0.0.0-20251114195745-4902fdda35c8 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/klauspost/compress v1.18.1 // indirect - github.com/klauspost/cpuid/v2 v2.3.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect diff --git a/go.sum b/go.sum index 2412b2c..ba0875b 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0 github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Snawoot/uniqueslice v0.1.1 h1:KEfv3FtAXiNEoxvcc79pFQDhnqwYXQyZIkxOM4e/qpw= github.com/Snawoot/uniqueslice v0.1.1/go.mod h1:K9zIaHO43FGLHbqm6WCDFeY6+CN/du5eiio/vxvDVC8= +github.com/ajwerner/orderstat v0.0.0-20200914031159-0ebfd67afbea h1:eCQW3axgFSgzerNCCUc9E3W8sHIsol3D84SoxqLtRd0= +github.com/ajwerner/orderstat v0.0.0-20200914031159-0ebfd67afbea/go.mod h1:lBcZZIGVJZdjAenFzeOISQY+Gr2lTavurpq4QuJCdog= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -14,6 +16,7 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= @@ -24,6 +27,7 @@ github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/ github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/go-sourcemap/sourcemap v2.1.4+incompatible h1:a+iTbH5auLKxaNwQFg0B+TCYl6lbukKPc7b5x0n1s6Q= github.com/go-sourcemap/sourcemap v2.1.4+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/pprof v0.0.0-20251114195745-4902fdda35c8 h1:3DsUAV+VNEQa2CUVLxCY3f87278uWfIDhJnbdvDjvmE= github.com/google/pprof v0.0.0-20251114195745-4902fdda35c8/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -35,12 +39,11 @@ github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= -github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= -github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= github.com/ncruces/go-dns v1.2.7 h1:NMA7vFqXUl+nBhGFlleLyo2ni3Lqv3v+qFWZidzRemI= github.com/ncruces/go-dns v1.2.7/go.mod h1:SqmhVMBd8Wr7hsu3q6yTt6/Jno/xLMrbse/JLOMBo1Y= +github.com/petar/GoLLRB v0.0.0-20130427215148-53be0d36a84c/go.mod h1:HUpKUBZnpzkdx0kD/+Yfuft+uD3zHGtXF/XJB14TUr4= github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -49,6 +52,8 @@ github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERS github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.11.0 h1:ib4sjIrwZKxE5u/Japgo/7SJV3PvgjGiRNAvTVGqQl8= github.com/stretchr/testify v1.11.0/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tg123/go-htpasswd v1.2.4 h1:HgH8KKCjdmo7jjXWN9k1nefPBd7Be3tFCTjc2jPraPU= @@ -57,10 +62,6 @@ github.com/things-go/go-socks5 v0.1.0 h1:4f5dz0iMQ6cA4wseFmyLmCHmg3SWJTW92ndrKS6 github.com/things-go/go-socks5 v0.1.0/go.mod h1:Riabiyu52kLsla0YmJqunt1c1JEl6iXSr4bRd7swFEA= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= -github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= -github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= -github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= diff --git a/main.go b/main.go index ec347b4..b58d731 100644 --- a/main.go +++ b/main.go @@ -321,7 +321,6 @@ type CLIArgs struct { tlsALPNEnabled bool bwLimit uint64 bwBurst int64 - bwBuckets uint bwSeparate bool dnsServers []string dnsPreferAddress dnsPreferenceArg @@ -447,7 +446,6 @@ func parse_args() *CLIArgs { flag.BoolVar(&args.tlsALPNEnabled, "tls-alpn-enabled", true, "enable application protocol negotiation with TLS ALPN extension") flag.Uint64Var(&args.bwLimit, "bw-limit", 0, "per-user bandwidth limit in bytes per second") flag.Int64Var(&args.bwBurst, "bw-limit-burst", 0, "allowed burst size for bandwidth limit, how many \"tokens\" can fit into leaky bucket") - flag.UintVar(&args.bwBuckets, "bw-limit-buckets", 1024*1024, "number of buckets of bandwidth limit") flag.BoolVar(&args.bwSeparate, "bw-limit-separate", false, "separate upload and download bandwidth limits") flag.Func("dns-server", "nameserver specification (udp://..., tcp://..., https://..., tls://..., doh://..., dot://..., default://). Option can be used multiple times for parallel use of multiple nameservers. Empty string resets the list", func(p string) error { if p == "" { @@ -646,7 +644,6 @@ func run() int { forwarder = forward.NewBWLimit( float64(args.bwLimit), args.bwBurst, - args.bwBuckets, args.bwSeparate, ).PairConnections } diff --git a/rate/rate.go b/rate/rate.go index 71bdb43..c70f502 100644 --- a/rate/rate.go +++ b/rate/rate.go @@ -56,6 +56,8 @@ func Every(interval time.Duration) Limit { // Limiter is safe for simultaneous use by multiple goroutines. type Limiter struct { mu sync.Mutex + limit Limit + burst int64 tokens float64 // last is the last time the limiter's tokens field was updated last int64 @@ -63,37 +65,56 @@ type Limiter struct { lastEvent int64 } +// Limit returns the maximum overall event rate. +func (lim *Limiter) Limit() Limit { + lim.mu.Lock() + defer lim.mu.Unlock() + return lim.limit +} + +// Burst returns the maximum burst size. Burst is the maximum number of tokens +// that can be consumed in a single call to Allow, Reserve, or Wait, so higher +// Burst values allow more events to happen at once. +// A zero Burst allows no events, unless limit == Inf. +func (lim *Limiter) Burst() int64 { + lim.mu.Lock() + defer lim.mu.Unlock() + return lim.burst +} + // TokensAt returns the number of tokens available at time t. -func (lim *Limiter) TokensAt(limit Limit, burst int64, t time.Time) float64 { +func (lim *Limiter) TokensAt(t time.Time) float64 { lim.mu.Lock() - tokens := lim.advance(limit, burst, t) // does not mutate lim + tokens := lim.advance(t) // does not mutate lim lim.mu.Unlock() return tokens } // Tokens returns the number of tokens available now. -func (lim *Limiter) Tokens(limit Limit, burst int64) float64 { - return lim.TokensAt(limit, burst, time.Now()) +func (lim *Limiter) Tokens() float64 { + return lim.TokensAt(time.Now()) } // NewLimiter returns a new Limiter that allows events up to rate r and permits // bursts of at most b tokens. -func NewLimiter(b int64) *Limiter { +func NewLimiter(r Limit, b int64) *Limiter { return &Limiter{ + limit: r, + burst: b, tokens: float64(b), } } // Allow reports whether an event may happen now. -func (lim *Limiter) Allow(limit Limit, burst int64) bool { - return lim.AllowN(limit, burst, time.Now(), 1) +func (lim *Limiter) Allow() bool { + return lim.AllowN(time.Now(), 1) } // AllowN reports whether n events may happen at time t. // Use this method if you intend to drop / skip events that exceed the rate limit. // Otherwise use Reserve or Wait. -func (lim *Limiter) AllowN(limit Limit, burst int64, t time.Time, n int64) bool { - return lim.reserveN(limit, burst, t, n, 0).ok +func (lim *Limiter) AllowN(t time.Time, n int64) bool { + return lim.reserveN(t, n, 0).ok } // A Reservation holds information about events that are permitted by a Limiter to happen after a delay. @@ -138,14 +159,14 @@ func (r *Reservation) DelayFrom(t time.Time) time.Duration { } // Cancel is shorthand for CancelAt(time.Now()). -func (r *Reservation) Cancel(limit Limit, burst int64) { - r.CancelAt(limit, burst, time.Now()) +func (r *Reservation) Cancel() { + r.CancelAt(time.Now()) } // CancelAt indicates that the reservation holder will not perform the reserved action // and reverses the effects of this Reservation on the rate limit as much as possible, // considering that other reservations may have already been made. -func (r *Reservation) CancelAt(limit Limit, burst int64, t time.Time) { +func (r *Reservation) CancelAt(t time.Time) { if !r.ok { return } @@ -153,7 +174,7 @@ func (r *Reservation) CancelAt(limit Limit, burst int64, t time.Time) { r.lim.mu.Lock() defer r.lim.mu.Unlock() - if limit == Inf || r.tokens == 0 || time.Unix(0, r.timeToAct).Before(t) { + if r.lim.limit == Inf || r.tokens == 0 || time.Unix(0, r.timeToAct).Before(t) { return } @@ -165,10 +186,10 @@ func (r *Reservation) CancelAt(limit Limit, burst int64, t time.Time) { return } // advance time to now - tokens := r.lim.advance(limit, burst, t) + tokens := r.lim.advance(t) // calculate new number of tokens tokens += restoreTokens - if burst := float64(burst); tokens > burst { + if burst := float64(r.lim.burst); tokens > burst { tokens = burst } // update state @@ -183,8 +204,8 @@ func (r *Reservation) CancelAt(limit Limit, burst int64, t time.Time) { } // Reserve is shorthand for ReserveN(time.Now(), 1). -func (lim *Limiter) Reserve(limit Limit, burst int64) *Reservation { - return lim.ReserveN(limit, burst, time.Now(), 1) +func (lim *Limiter) Reserve() *Reservation { + return lim.ReserveN(time.Now(), 1) } // ReserveN returns a Reservation that indicates how long the caller must wait before n events happen. @@ -203,21 +224,21 @@ func (lim *Limiter) Reserve(limit Limit, burst int64) *Reservation { // Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events. // If you need to respect a deadline or cancel the delay, use Wait instead. // To drop or skip events exceeding rate limit, use Allow instead. -func (lim *Limiter) ReserveN(limit Limit, burst int64, t time.Time, n int64) *Reservation { - r := lim.reserveN(limit, burst, t, n, InfDuration) +func (lim *Limiter) ReserveN(t time.Time, n int64) *Reservation { + r := lim.reserveN(t, n, InfDuration) return &r } // Wait is shorthand for WaitN(ctx, 1). -func (lim *Limiter) Wait(ctx context.Context, limit Limit, burst int64) (err error) { - return lim.WaitN(ctx, limit, burst, 1) +func (lim *Limiter) Wait(ctx context.Context) (err error) { + return lim.WaitN(ctx, 1) } // WaitN blocks until lim permits n events to happen. // It returns an error if n exceeds the Limiter's burst size, the Context is // canceled, or the expected wait time exceeds the Context's Deadline. // The burst limit is ignored if the rate limit is Inf. -func (lim *Limiter) WaitN(ctx context.Context, limit Limit, burst int64, n int64) (err error) { +func (lim *Limiter) WaitN(ctx context.Context, n int64) (err error) { // The test code calls lim.wait with a fake timer generator. // This is the real timer generator. newTimer := func(d time.Duration) (<-chan time.Time, func() bool, func()) { @@ -225,11 +246,16 @@ func (lim *Limiter) WaitN(ctx context.Context, limit Limit, burst int64, n int64 return timer.C, timer.Stop, func() {} } - return lim.wait(ctx, limit, burst, n, time.Now(), newTimer) + return lim.wait(ctx, n, time.Now(), newTimer) } // wait is the internal implementation of WaitN. -func (lim *Limiter) wait(ctx context.Context, limit Limit, burst int64, n int64, t time.Time, newTimer func(d time.Duration) (<-chan time.Time, func() bool, func())) error { +func (lim *Limiter) wait(ctx context.Context, n int64, t time.Time, newTimer func(d time.Duration) (<-chan time.Time, func() bool, func())) error { + lim.mu.Lock() + burst := lim.burst + limit := lim.limit + lim.mu.Unlock() + if n > burst && limit != Inf { return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, burst) } @@ -245,7 +271,7 @@ func (lim *Limiter) wait(ctx context.Context, limit Limit, burst int64, n int64, waitLimit = deadline.Sub(t) } // Reserve - r := lim.reserveN(limit, burst, t, n, waitLimit) + r := lim.reserveN(t, n, waitLimit) if !r.ok { return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n) } @@ -264,19 +290,55 @@ func (lim *Limiter) wait(ctx context.Context, limit Limit, burst int64, n int64, case <-ctx.Done(): // Context was canceled before we could proceed. Cancel the // reservation, which may permit other events to proceed sooner. - r.Cancel(limit, burst) + r.Cancel() return ctx.Err() } } +// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit). +func (lim *Limiter) SetLimit(newLimit Limit) { + lim.SetLimitAt(time.Now(), newLimit) +} + +// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated +// or underutilized by those which reserved (using Reserve or Wait) but did not yet act +// before SetLimitAt was called. +func (lim *Limiter) SetLimitAt(t time.Time, newLimit Limit) { + lim.mu.Lock() + defer lim.mu.Unlock() + + tokens := lim.advance(t) + + lim.last = t.UnixNano() + lim.tokens = tokens + lim.limit = newLimit +} + +// SetBurst is shorthand for SetBurstAt(time.Now(), newBurst). +func (lim *Limiter) SetBurst(newBurst int64) { + lim.SetBurstAt(time.Now(), newBurst) +} + +// SetBurstAt sets a new burst size for the limiter. +func (lim *Limiter) SetBurstAt(t time.Time, newBurst int64) { + lim.mu.Lock() + defer lim.mu.Unlock() + + tokens := lim.advance(t) + + lim.last = t.UnixNano() + lim.tokens = tokens + lim.burst = newBurst +} + // reserveN is a helper method for AllowN, ReserveN, and WaitN. // maxFutureReserve specifies the maximum reservation wait duration allowed. // reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN. -func (lim *Limiter) reserveN(limit Limit, burst int64, t time.Time, n int64, maxFutureReserve time.Duration) Reservation { +func (lim *Limiter) reserveN(t time.Time, n int64, maxFutureReserve time.Duration) Reservation { lim.mu.Lock() defer lim.mu.Unlock() - if limit == Inf { + if lim.limit == Inf { return Reservation{ ok: true, lim: lim, @@ -285,7 +347,7 @@ func (lim *Limiter) reserveN(limit Limit, burst int64, t time.Time, n int64, max } } - tokens := lim.advance(limit, burst, t) + tokens := lim.advance(t) // Calculate the remaining number of tokens resulting from the request. tokens -= float64(n) @@ -293,17 +355,17 @@ func (lim *Limiter) reserveN(limit Limit, burst int64, t time.Time, n int64, max // Calculate the wait duration var waitDuration time.Duration if tokens < 0 { - waitDuration = limit.durationFromTokens(-tokens) + waitDuration = lim.limit.durationFromTokens(-tokens) } // Decide result - ok := n <= burst && waitDuration <= maxFutureReserve + ok := n <= lim.burst && waitDuration <= maxFutureReserve // Prepare reservation r := Reservation{ ok: ok, lim: lim, - limit: limit, + limit: lim.limit, } if ok { r.tokens = n @@ -322,7 +384,7 @@ func (lim *Limiter) reserveN(limit Limit, burst int64, t time.Time, n int64, max // resulting from the passage of time. // lim is not changed. // advance requires that lim.mu is held. -func (lim *Limiter) advance(limit Limit, burst int64, t time.Time) (newTokens float64) { +func (lim *Limiter) advance(t time.Time) (newTokens float64) { last := time.Unix(0, lim.last) if t.Before(last) { last = t @@ -330,9 +392,9 @@ func (lim *Limiter) advance(limit Limit, burst int64, t time.Time) (newTokens fl // Calculate the new number of tokens, due to time that passed. elapsed := t.Sub(last) - delta := limit.tokensFromDuration(elapsed) + delta := lim.limit.tokensFromDuration(elapsed) tokens := lim.tokens + delta - if burst := float64(burst); tokens > burst { + if burst := float64(lim.burst); tokens > burst { tokens = burst } return tokens diff --git a/rate/rate_test.go b/rate/rate_test.go index c08c3c6..b4c50f6 100644 --- a/rate/rate_test.go +++ b/rate/rate_test.go @@ -71,14 +71,14 @@ type allow struct { ok bool } -func run(t *testing.T, limit Limit, burst int64, lim *Limiter, allows []allow) { +func run(t *testing.T, lim *Limiter, allows []allow) { t.Helper() for i, allow := range allows { - if toks := lim.TokensAt(limit, burst, allow.t); toks != allow.toks { + if toks := lim.TokensAt(allow.t); toks != allow.toks { t.Errorf("step %d: lim.TokensAt(%v) = %v want %v", i, allow.t, toks, allow.toks) } - ok := lim.AllowN(limit, burst, allow.t, allow.n) + ok := lim.AllowN(allow.t, allow.n) if ok != allow.ok { t.Errorf("step %d: lim.AllowN(%v, %v) = %v want %v", i, allow.t, allow.n, ok, allow.ok) @@ -87,7 +87,7 @@ func run(t *testing.T, limit Limit, burst int64, lim *Limiter, allows []allow) { } func TestLimiterBurst1(t *testing.T) { - run(t, Limit(10), 1, NewLimiter(1), []allow{ + run(t, NewLimiter(10, 1), []allow{ {t0, 1, 1, true}, {t0, 0, 1, false}, {t0, 0, 1, false}, @@ -101,7 +101,7 @@ func TestLimiterBurst1(t *testing.T) { } func TestLimiterBurst3(t *testing.T) { - run(t, Limit(10), 3, NewLimiter(3), []allow{ + run(t, NewLimiter(10, 3), []allow{ {t0, 3, 2, true}, {t0, 1, 2, false}, {t0, 1, 1, true}, @@ -119,7 +119,7 @@ func TestLimiterBurst3(t *testing.T) { } func TestLimiterJumpBackwards(t *testing.T) { - run(t, Limit(10), 3, NewLimiter(3), []allow{ + run(t, NewLimiter(10, 3), []allow{ {t1, 3, 1, true}, // start at t1 {t0, 2, 1, true}, // jump back to t0, two tokens remain {t0, 1, 1, true}, @@ -138,7 +138,7 @@ func TestLimiterJumpBackwards(t *testing.T) { // rounding errors by truncating nanoseconds. // See golang.org/issues/34861. func TestLimiter_noTruncationErrors(t *testing.T) { - if !NewLimiter(1).Allow(0.7692307692307693, 1) { + if !NewLimiter(0.7692307692307693, 1).Allow() { t.Fatal("expected true") } } @@ -242,13 +242,13 @@ func TestSimultaneousRequests(t *testing.T) { ) // Very slow replenishing bucket. - lim := NewLimiter(burst) + lim := NewLimiter(limit, burst) // Tries to take a token, atomically updates the counter and decreases the wait // group counter. f := func() { defer wg.Done() - if ok := lim.Allow(limit, burst); ok { + if ok := lim.Allow(); ok { atomic.AddUint32(&numOK, 1) } } @@ -275,12 +275,12 @@ func TestLongRunningQPS(t *testing.T) { tt = makeTestTime(t) ) - lim := NewLimiter(burst) + lim := NewLimiter(limit, burst) start := tt.now() end := start.Add(5 * time.Second) for tt.now().Before(end) { - if ok := lim.AllowN(limit, burst, tt.now(), 1); ok { + if ok := lim.AllowN(tt.now(), 1); ok { numOK++ } @@ -322,17 +322,17 @@ func dSince(t time.Time) int { return dFromDuration(t.Sub(t0)) } -func runReserve(t *testing.T, limit Limit, burst int64, lim *Limiter, req request) *Reservation { +func runReserve(t *testing.T, lim *Limiter, req request) *Reservation { t.Helper() - return runReserveMax(t, limit, burst, lim, req, InfDuration) + return runReserveMax(t, lim, req, InfDuration) } // runReserveMax attempts to reserve req.n tokens at time req.t, limiting the delay until action to // maxReserve. It checks whether the response matches req.act and req.ok. If not, it reports a test // error including the difference from expected durations in multiples of d (global constant). -func runReserveMax(t *testing.T, limit Limit, burst int64, lim *Limiter, req request, maxReserve time.Duration) *Reservation { +func runReserveMax(t *testing.T, lim *Limiter, req request, maxReserve time.Duration) *Reservation { t.Helper() - r := lim.reserveN(limit, burst, req.t, req.n, maxReserve) + r := lim.reserveN(req.t, req.n, maxReserve) if r.ok && (dSince(time.Unix(0, r.timeToAct)) != dSince(req.act)) || r.ok != req.ok { t.Errorf("lim.reserveN(t%d, %v, %v) = (t%d, %v) want (t%d, %v)", dSince(req.t), req.n, maxReserve, dSince(time.Unix(0, r.timeToAct)), r.ok, dSince(req.act), req.ok) @@ -341,114 +341,142 @@ func runReserveMax(t *testing.T, limit Limit, burst int64, lim *Limiter, req req } func TestSimpleReserve(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t0, 2, t0, true}) - runReserve(t, 10, 2, lim, request{t0, 2, t2, true}) - runReserve(t, 10, 2, lim, request{t3, 2, t4, true}) + runReserve(t, lim, request{t0, 2, t0, true}) + runReserve(t, lim, request{t0, 2, t2, true}) + runReserve(t, lim, request{t3, 2, t4, true}) } func TestMix(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t0, 3, t1, false}) // should return false because n > Burst - runReserve(t, 10, 2, lim, request{t0, 2, t0, true}) - run(t, 10, 2, lim, []allow{{t1, 1, 2, false}}) // not enough tokens - don't allow - runReserve(t, 10, 2, lim, request{t1, 2, t2, true}) - run(t, 10, 2, lim, []allow{{t1, -1, 1, false}}) // negative tokens - don't allow - run(t, 10, 2, lim, []allow{{t3, 1, 1, true}}) + runReserve(t, lim, request{t0, 3, t1, false}) // should return false because n > Burst + runReserve(t, lim, request{t0, 2, t0, true}) + run(t, lim, []allow{{t1, 1, 2, false}}) // not enough tokens - don't allow + runReserve(t, lim, request{t1, 2, t2, true}) + run(t, lim, []allow{{t1, -1, 1, false}}) // negative tokens - don't allow + run(t, lim, []allow{{t3, 1, 1, true}}) } func TestCancelInvalid(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t0, 2, t0, true}) - r := runReserve(t, 10, 2, lim, request{t0, 3, t3, false}) - r.CancelAt(10, 2, t0) // should have no effect - runReserve(t, 10, 2, lim, request{t0, 2, t2, true}) // did not get extra tokens + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 3, t3, false}) + r.CancelAt(t0) // should have no effect + runReserve(t, lim, request{t0, 2, t2, true}) // did not get extra tokens } func TestCancelLast(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t0, 2, t0, true}) - r := runReserve(t, 10, 2, lim, request{t0, 2, t2, true}) - r.CancelAt(10, 2, t1) // got 2 tokens back - runReserve(t, 10, 2, lim, request{t1, 2, t2, true}) + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t2, true}) + r.CancelAt(t1) // got 2 tokens back + runReserve(t, lim, request{t1, 2, t2, true}) } func TestCancelTooLate(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t0, 2, t0, true}) - r := runReserve(t, 10, 2, lim, request{t0, 2, t2, true}) - r.CancelAt(10, 2, t3) // too late to cancel - should have no effect - runReserve(t, 10, 2, lim, request{t3, 2, t4, true}) + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t2, true}) + r.CancelAt(t3) // too late to cancel - should have no effect + runReserve(t, lim, request{t3, 2, t4, true}) } func TestCancel0Tokens(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t0, 2, t0, true}) - r := runReserve(t, 10, 2, lim, request{t0, 1, t1, true}) - runReserve(t, 10, 2, lim, request{t0, 1, t2, true}) - r.CancelAt(10, 2, t0) // got 0 tokens back - runReserve(t, 10, 2, lim, request{t0, 1, t3, true}) + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 1, t1, true}) + runReserve(t, lim, request{t0, 1, t2, true}) + r.CancelAt(t0) // got 0 tokens back + runReserve(t, lim, request{t0, 1, t3, true}) } func TestCancel1Token(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t0, 2, t0, true}) - r := runReserve(t, 10, 2, lim, request{t0, 2, t2, true}) - runReserve(t, 10, 2, lim, request{t0, 1, t3, true}) - r.CancelAt(10, 2, t2) // got 1 token back - runReserve(t, 10, 2, lim, request{t2, 2, t4, true}) + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t2, true}) + runReserve(t, lim, request{t0, 1, t3, true}) + r.CancelAt(t2) // got 1 token back + runReserve(t, lim, request{t2, 2, t4, true}) } func TestCancelMulti(t *testing.T) { - lim := NewLimiter(4) + lim := NewLimiter(10, 4) - runReserve(t, 10, 4, lim, request{t0, 4, t0, true}) - rA := runReserve(t, 10, 4, lim, request{t0, 3, t3, true}) - runReserve(t, 10, 4, lim, request{t0, 1, t4, true}) - rC := runReserve(t, 10, 4, lim, request{t0, 1, t5, true}) - rC.CancelAt(10, 4, t1) // get 1 token back - rA.CancelAt(10, 4, t1) // get 2 tokens back, as if C was never reserved - runReserve(t, 10, 4, lim, request{t1, 3, t5, true}) + runReserve(t, lim, request{t0, 4, t0, true}) + rA := runReserve(t, lim, request{t0, 3, t3, true}) + runReserve(t, lim, request{t0, 1, t4, true}) + rC := runReserve(t, lim, request{t0, 1, t5, true}) + rC.CancelAt(t1) // get 1 token back + rA.CancelAt(t1) // get 2 tokens back, as if C was never reserved + runReserve(t, lim, request{t1, 3, t5, true}) } func TestReserveJumpBack(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t1, 2, t1, true}) // start at t1 - runReserve(t, 10, 2, lim, request{t0, 1, t1, true}) // should violate Limit,Burst - runReserve(t, 10, 2, lim, request{t2, 2, t3, true}) + runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 + runReserve(t, lim, request{t0, 1, t1, true}) // should violate Limit,Burst + runReserve(t, lim, request{t2, 2, t3, true}) // burst size is 2, so n=3 always fails, and the state of lim should not be changed - runReserve(t, 10, 2, lim, request{t0, 3, time.Time{}, false}) - runReserve(t, 10, 2, lim, request{t2, 1, t4, true}) + runReserve(t, lim, request{t0, 3, time.Time{}, false}) + runReserve(t, lim, request{t2, 1, t4, true}) // the maxReserve is not enough so it fails, and the state of lim should not be changed - runReserveMax(t, 10, 2, lim, request{t0, 2, time.Time{}, false}, d) - runReserve(t, 10, 2, lim, request{t2, 1, t5, true}) + runReserveMax(t, lim, request{t0, 2, time.Time{}, false}, d) + runReserve(t, lim, request{t2, 1, t5, true}) } func TestReserveJumpBackCancel(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) - runReserve(t, 10, 2, lim, request{t1, 2, t1, true}) // start at t1 - r := runReserve(t, 10, 2, lim, request{t1, 2, t3, true}) - runReserve(t, 10, 2, lim, request{t1, 1, t4, true}) - r.CancelAt(10, 2, t0) // cancel at t0, get 1 token back - runReserve(t, 10, 2, lim, request{t1, 2, t4, true}) // should violate Limit,Burst + runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 + r := runReserve(t, lim, request{t1, 2, t3, true}) + runReserve(t, lim, request{t1, 1, t4, true}) + r.CancelAt(t0) // cancel at t0, get 1 token back + runReserve(t, lim, request{t1, 2, t4, true}) // should violate Limit,Burst +} + +func TestReserveSetLimit(t *testing.T) { + lim := NewLimiter(5, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + runReserve(t, lim, request{t0, 2, t4, true}) + lim.SetLimitAt(t2, 10) + runReserve(t, lim, request{t2, 1, t4, true}) // violates Limit and Burst +} + +func TestReserveSetBurst(t *testing.T) { + lim := NewLimiter(5, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + runReserve(t, lim, request{t0, 2, t4, true}) + lim.SetBurstAt(t3, 4) + runReserve(t, lim, request{t0, 4, t9, true}) // violates Limit and Burst +} + +func TestReserveSetLimitCancel(t *testing.T) { + lim := NewLimiter(5, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t4, true}) + lim.SetLimitAt(t2, 10) + r.CancelAt(t2) // 2 tokens back + runReserve(t, lim, request{t2, 2, t3, true}) } func TestReserveMax(t *testing.T) { - lim := NewLimiter(2) + lim := NewLimiter(10, 2) maxT := d - runReserveMax(t, 10, 2, lim, request{t0, 2, t0, true}, maxT) - runReserveMax(t, 10, 2, lim, request{t0, 1, t1, true}, maxT) // reserve for close future - runReserveMax(t, 10, 2, lim, request{t0, 1, t2, false}, maxT) // time to act too far in the future + runReserveMax(t, lim, request{t0, 2, t0, true}, maxT) + runReserveMax(t, lim, request{t0, 1, t1, true}, maxT) // reserve for close future + runReserveMax(t, lim, request{t0, 1, t2, false}, maxT) // time to act too far in the future } type wait struct { @@ -459,10 +487,10 @@ type wait struct { nilErr bool } -func runWait(t *testing.T, tt *testTime, limit Limit, burst int64, lim *Limiter, w wait) { +func runWait(t *testing.T, tt *testTime, lim *Limiter, w wait) { t.Helper() start := tt.now() - err := lim.wait(w.ctx, limit, burst, w.n, start, tt.newTimer) + err := lim.wait(w.ctx, w.n, start, tt.newTimer) delay := tt.since(start) if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || !waitDelayOk(w.delay, delay) { @@ -510,100 +538,120 @@ func waitDelayOk(wantD int, got time.Duration) bool { func TestWaitSimple(t *testing.T) { tt := makeTestTime(t) - lim := NewLimiter(3) + lim := NewLimiter(10, 3) ctx, cancel := context.WithCancel(context.Background()) cancel() - runWait(t, tt, 10, 3, lim, wait{"already-cancelled", ctx, 1, 0, false}) + runWait(t, tt, lim, wait{"already-cancelled", ctx, 1, 0, false}) - runWait(t, tt, 10, 3, lim, wait{"exceed-burst-error", context.Background(), 4, 0, false}) + runWait(t, tt, lim, wait{"exceed-burst-error", context.Background(), 4, 0, false}) - runWait(t, tt, 10, 3, lim, wait{"act-now", context.Background(), 2, 0, true}) - runWait(t, tt, 10, 3, lim, wait{"act-later", context.Background(), 3, 2, true}) + runWait(t, tt, lim, wait{"act-now", context.Background(), 2, 0, true}) + runWait(t, tt, lim, wait{"act-later", context.Background(), 3, 2, true}) } func TestWaitCancel(t *testing.T) { tt := makeTestTime(t) - lim := NewLimiter(3) + lim := NewLimiter(10, 3) ctx, cancel := context.WithCancel(context.Background()) - runWait(t, tt, 10, 3, lim, wait{"act-now", ctx, 2, 0, true}) // after this lim.tokens = 1 + runWait(t, tt, lim, wait{"act-now", ctx, 2, 0, true}) // after this lim.tokens = 1 ch, _, _ := tt.newTimer(d) go func() { <-ch cancel() }() - runWait(t, tt, 10, 3, lim, wait{"will-cancel", ctx, 3, 1, false}) + runWait(t, tt, lim, wait{"will-cancel", ctx, 3, 1, false}) // should get 3 tokens back, and have lim.tokens = 2 t.Logf("tokens:%v last:%v lastEvent:%v", lim.tokens, lim.last, lim.lastEvent) - runWait(t, tt, 10, 3, lim, wait{"act-now-after-cancel", context.Background(), 2, 0, true}) + runWait(t, tt, lim, wait{"act-now-after-cancel", context.Background(), 2, 0, true}) } func TestWaitTimeout(t *testing.T) { tt := makeTestTime(t) - lim := NewLimiter(3) + lim := NewLimiter(10, 3) ctx, cancel := context.WithTimeout(context.Background(), d) defer cancel() - runWait(t, tt, 10, 3, lim, wait{"act-now", ctx, 2, 0, true}) - runWait(t, tt, 10, 3, lim, wait{"w-timeout-err", ctx, 3, 0, false}) + runWait(t, tt, lim, wait{"act-now", ctx, 2, 0, true}) + runWait(t, tt, lim, wait{"w-timeout-err", ctx, 3, 0, false}) } func TestWaitInf(t *testing.T) { tt := makeTestTime(t) - lim := NewLimiter(0) + lim := NewLimiter(Inf, 0) - runWait(t, tt, Inf, 0, lim, wait{"exceed-burst-no-error", context.Background(), 3, 0, true}) + runWait(t, tt, lim, wait{"exceed-burst-no-error", context.Background(), 3, 0, true}) } func BenchmarkAllowN(b *testing.B) { - lim := NewLimiter(1) + lim := NewLimiter(Every(1*time.Second), 1) now := time.Now() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - lim.AllowN(Every(1*time.Second), 1, now, 1) + lim.AllowN(now, 1) } }) } func BenchmarkWaitNNoDelay(b *testing.B) { - lim := NewLimiter(int64(b.N)) + lim := NewLimiter(Limit(b.N), int64(b.N)) ctx := context.Background() b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - lim.WaitN(ctx, Limit(b.N), int64(b.N), 1) + lim.WaitN(ctx, 1) } } func TestZeroLimit(t *testing.T) { - r := NewLimiter(1) - if !r.Allow(0, 1) { + r := NewLimiter(0, 1) + if !r.Allow() { t.Errorf("Limit(0, 1) want true when first used") } - if r.Allow(0, 1) { + if r.Allow() { t.Errorf("Limit(0, 1) want false when already used") } } +func TestSetAfterZeroLimit(t *testing.T) { + lim := NewLimiter(0, 1) + // The limiter should start off full, so even though our rate limit is 0, our first request + // should be allowed… + if !lim.Allow() { + t.Errorf("Limit(0, 1) want true when first used") + } + // …the token bucket is not being replenished though, so the second request should not succeed + if lim.Allow() { + t.Errorf("Limit(0, 1) want false when already used") + } + + lim.SetLimit(10) + + tt := makeTestTime(t) + + // We set the limit to 10/s so expect to get another token in 100ms + runWait(t, tt, lim, wait{"wait-after-set-nonzero-after-zero", context.Background(), 1, 1, true}) +} + // TestTinyLimit tests that a limiter does not allow more than burst, when the rate is tiny. // Prior to resolution of issue 71154, this test // would fail on amd64 due to overflow in durationFromTokens. func TestTinyLimit(t *testing.T) { - lim := NewLimiter(1) + lim := NewLimiter(1e-10, 1) // The limiter starts with 1 burst token, so the first request should succeed - if !lim.Allow(10e-10, 1) { + if !lim.Allow() { t.Errorf("Limit(1e-10, 1) want true when first used") } // The limiter should not have replenished the token bucket yet, so the second request should fail - if lim.Allow(10e-10, 1) { + if lim.Allow() { t.Errorf("Limit(1e-10, 1) want false when already used") } }