diff --git a/auth/auth.go b/auth/auth.go index 29b6c4a..eb3fc53 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -3,6 +3,7 @@ package auth import ( "context" "errors" + "io" "net/http" "net/url" "strings" @@ -12,7 +13,7 @@ import ( type Auth interface { Validate(ctx context.Context, wr http.ResponseWriter, req *http.Request) (string, bool) - Stop() + io.Closer } func NewAuth(paramstr string, logger *clog.CondLogger) (Auth, error) { diff --git a/auth/basic.go b/auth/basic.go index 2f8cb35..d21877d 100644 --- a/auth/basic.go +++ b/auth/basic.go @@ -173,11 +173,13 @@ func (auth *BasicAuth) Validate(ctx context.Context, wr http.ResponseWriter, req return requireBasicAuth(ctx, wr, req, auth.hiddenDomain, auth.next) } -func (auth *BasicAuth) Stop() { +func (auth *BasicAuth) Close() error { + var err error auth.stopOnce.Do(func() { if auth.next != nil { - auth.next.Stop() + err = auth.next.Close() } close(auth.stopChan) }) + return err } diff --git a/auth/cert.go b/auth/cert.go index 2577837..b69c38b 100644 --- a/auth/cert.go +++ b/auth/cert.go @@ -16,6 +16,7 @@ import ( "time" clog "github.com/SenseUnit/dumbproxy/log" + "github.com/hashicorp/go-multierror" us "github.com/Snawoot/uniqueslice" ) @@ -108,16 +109,22 @@ func (auth *CertAuth) Validate(ctx context.Context, wr http.ResponseWriter, req ), true } -func (auth *CertAuth) Stop() { +func (auth *CertAuth) Close() error { + var err error auth.stopOnce.Do(func() { if auth.next != nil { - auth.next.Stop() + if closeErr := auth.next.Close(); closeErr != nil { + err = multierror.Append(err, closeErr) + } } if auth.reject != nil { - auth.reject.Stop() + if closeErr := auth.reject.Close(); closeErr != nil { + err = multierror.Append(err, closeErr) + } } close(auth.stopChan) }) + return err } func (auth *CertAuth) reload() error { diff --git a/auth/hmac.go b/auth/hmac.go index 4cbecc7..e48ee63 100644 --- a/auth/hmac.go +++ b/auth/hmac.go @@ -143,12 +143,14 @@ func (auth *HMACAuth) Validate(ctx context.Context, wr http.ResponseWriter, req return requireBasicAuth(ctx, wr, req, auth.hiddenDomain, auth.next) } -func (auth *HMACAuth) Stop() { +func (auth *HMACAuth) Close() error { + var err error auth.stopOnce.Do(func() { if auth.next != nil { - auth.next.Stop() + err = auth.next.Close() } }) + return err } func CalculateHMACSignature(secret []byte, username string, expire int64) []byte { diff --git a/auth/noauth.go b/auth/noauth.go index 4abf56a..168654f 100644 --- a/auth/noauth.go +++ b/auth/noauth.go @@ -11,4 +11,6 @@ func (_ NoAuth) Validate(_ context.Context, _ http.ResponseWriter, _ *http.Reque return "", true } -func (_ NoAuth) Stop() {} +func (_ NoAuth) Close() error { + return nil +} diff --git a/auth/redis.go b/auth/redis.go index c6759e7..264319c 100644 --- a/auth/redis.go +++ b/auth/redis.go @@ -13,6 +13,7 @@ import ( "time" clog "github.com/SenseUnit/dumbproxy/log" + "github.com/hashicorp/go-multierror" "github.com/redis/go-redis/v9" ) @@ -134,11 +135,17 @@ func (auth *RedisAuth) Validate(ctx context.Context, wr http.ResponseWriter, req return requireBasicAuth(ctx, wr, req, auth.hiddenDomain, auth.next) } -func (auth *RedisAuth) Stop() { +func (auth *RedisAuth) Close() error { + var err error auth.stopOnce.Do(func() { if auth.next != nil { - auth.next.Stop() + if closeErr := auth.next.Close(); closeErr != nil { + err = multierror.Append(err, closeErr) + } + } + if closeErr := auth.r.Close(); closeErr != nil { + err = multierror.Append(err, closeErr) } - auth.r.Close() }) + return err } diff --git a/auth/rejecthttp.go b/auth/rejecthttp.go index 699fff3..b7611d0 100644 --- a/auth/rejecthttp.go +++ b/auth/rejecthttp.go @@ -56,4 +56,6 @@ func (_ *RejectHTTPAuth) Valid(_, _, _ string) bool { return false } -func (_ *RejectHTTPAuth) Stop() {} +func (_ *RejectHTTPAuth) Close() error { + return nil +} diff --git a/auth/rejectstatic.go b/auth/rejectstatic.go index ea430aa..d052bb5 100644 --- a/auth/rejectstatic.go +++ b/auth/rejectstatic.go @@ -105,4 +105,6 @@ func (_ *StaticRejectAuth) Valid(_, _, _ string) bool { return false } -func (_ *StaticRejectAuth) Stop() {} +func (_ *StaticRejectAuth) Close() error { + return nil +} diff --git a/certcache/cryptobox.go b/certcache/cryptobox.go index fd23c5b..cf1a417 100644 --- a/certcache/cryptobox.go +++ b/certcache/cryptobox.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" cryptorand "crypto/rand" "errors" + "io" "golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/chacha20poly1305" @@ -64,3 +65,10 @@ func (c *EncryptedCache) Put(ctx context.Context, key string, data []byte) error func (c *EncryptedCache) Delete(ctx context.Context, key string) error { return c.next.Delete(ctx, key) } + +func (c *EncryptedCache) Close() error { + if cacheCloser, ok := c.next.(io.Closer); ok { + return cacheCloser.Close() + } + return nil +} diff --git a/certcache/local.go b/certcache/local.go index 44ec318..840c842 100644 --- a/certcache/local.go +++ b/certcache/local.go @@ -2,6 +2,7 @@ package certcache import ( "context" + "io" "sync" "time" @@ -16,10 +17,9 @@ type certCacheValue struct { } type LocalCertCache struct { - cache *ttlcache.Cache[certCacheKey, certCacheValue] - next autocert.Cache - startOnce sync.Once - stopOnce sync.Once + cache *ttlcache.Cache[certCacheKey, certCacheValue] + next autocert.Cache + stopOnce sync.Once } func NewLocalCertCache(next autocert.Cache, ttl, timeout time.Duration) *LocalCertCache { @@ -41,6 +41,7 @@ func NewLocalCertCache(next autocert.Cache, ttl, timeout time.Duration) *LocalCe nil), ), ) + go cache.Start() return &LocalCertCache{ cache: cache, next: next, @@ -62,14 +63,15 @@ func (cc *LocalCertCache) Delete(ctx context.Context, key string) error { return cc.next.Delete(ctx, key) } -func (cc *LocalCertCache) Start() { - cc.startOnce.Do(func() { - go cc.cache.Start() +func (cc *LocalCertCache) Close() error { + var err error + cc.stopOnce.Do(func() { + cc.cache.Stop() + if cacheCloser, ok := cc.next.(io.Closer); ok { + err = cacheCloser.Close() + } }) -} - -func (cc *LocalCertCache) Stop() { - cc.stopOnce.Do(cc.cache.Stop) + return err } var _ autocert.Cache = new(LocalCertCache) diff --git a/certcache/redis.go b/certcache/redis.go index 807e133..639d63d 100644 --- a/certcache/redis.go +++ b/certcache/redis.go @@ -2,17 +2,25 @@ package certcache import ( "context" + "io" + "sync" "github.com/redis/go-redis/v9" "golang.org/x/crypto/acme/autocert" ) +type CmdableCloser interface { + redis.Cmdable + io.Closer +} + type RedisCache struct { - r redis.Cmdable - pfx string + r CmdableCloser + pfx string + stopOnce sync.Once } -func NewRedisCache(r redis.Cmdable, prefix string) *RedisCache { +func NewRedisCache(r CmdableCloser, prefix string) *RedisCache { return &RedisCache{ r: r, pfx: prefix, @@ -38,6 +46,14 @@ func (r *RedisCache) Delete(ctx context.Context, key string) error { return r.r.Del(ctx, r.pfx+key).Err() } +func (r *RedisCache) Close() error { + var err error + r.stopOnce.Do(func() { + err = r.r.Close() + }) + return err +} + func RedisCacheFromURL(url string, prefix string) (*RedisCache, error) { opts, err := redis.ParseURL(url) if err != nil { diff --git a/main.go b/main.go index fbe62d8..ec347b4 100644 --- a/main.go +++ b/main.go @@ -561,7 +561,7 @@ func run() int { mainLogger.Critical("Failed to instantiate auth provider: %v", err) return 3 } - defer authProvider.Stop() + defer authProvider.Close() // setup access filters var filterRoot access.Filter = access.AlwaysAllow{} @@ -744,6 +744,9 @@ func run() int { mainLogger.Critical("redis cluster cache construction failed: %v", err) return 3 } + default: + mainLogger.Critical("unknown cert cache type %#v", args.autocertCache.kind) + return 3 } if len(args.autocertCacheEncKey.Value()) > 0 { certCache, err = certcache.NewEncryptedCache(args.autocertCacheEncKey.Value(), certCache) @@ -758,10 +761,11 @@ func run() int { args.autocertLocalCacheTTL, args.autocertLocalCacheTimeout, ) - lcc.Start() - defer lcc.Stop() certCache = lcc } + if cacheCloser, ok := certCache.(io.Closer); ok { + defer cacheCloser.Close() + } m := &autocert.Manager{ Cache: certCache,