Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,6 @@ Usage of /home/user/go/bin/dumbproxy:
email used for ACME registration
-autocert-http string
listen address for HTTP-01 challenges handler of ACME
-autocert-local-cache-timeout duration
timeout for cert cache queries (default 10s)
-autocert-local-cache-ttl duration
enables in-memory cache for certificates
-autocert-whitelist value
Expand Down
75 changes: 37 additions & 38 deletions certcache/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,61 @@ package certcache
import (
"context"
"io"
"sync"
"time"

"github.com/jellydator/ttlcache/v3"
"github.com/Snawoot/secache"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/sync/singleflight"
)

type certCacheKey = string
type certCacheValue struct {
ts time.Time
res []byte
err error
}

type LocalCertCache struct {
cache *ttlcache.Cache[certCacheKey, certCacheValue]
next autocert.Cache
stopOnce sync.Once
cache secache.Cache[certCacheKey, *certCacheValue]
sf singleflight.Group
next autocert.Cache
}

func NewLocalCertCache(next autocert.Cache, ttl, timeout time.Duration) *LocalCertCache {
cache := ttlcache.New[certCacheKey, certCacheValue](
ttlcache.WithTTL[certCacheKey, certCacheValue](ttl),
ttlcache.WithLoader(
ttlcache.NewSuppressedLoader(
ttlcache.LoaderFunc[certCacheKey, certCacheValue](
func(c *ttlcache.Cache[certCacheKey, certCacheValue], key certCacheKey) *ttlcache.Item[certCacheKey, certCacheValue] {
ctx, cl := context.WithTimeout(context.Background(), timeout)
defer cl()
res, err := next.Get(ctx, key)
if err != nil {
return c.Set(key, certCacheValue{res, err}, -100)
}
return c.Set(key, certCacheValue{res, err}, 0)
},
),
nil),
),
)
go cache.Start()
func NewLocalCertCache(next autocert.Cache, ttl time.Duration) *LocalCertCache {
return &LocalCertCache{
cache: cache,
next: next,
cache: *(secache.New[certCacheKey, *certCacheValue](3, func(key certCacheKey, item *certCacheValue) bool {
return time.Now().Before(item.ts.Add(ttl))
})),
next: next,
}
}

func (cc *LocalCertCache) Get(_ context.Context, key string) ([]byte, error) {
resItem := cc.cache.Get(key).Value()
func (cc *LocalCertCache) Get(ctx context.Context, key string) ([]byte, error) {
resItem, ok := cc.cache.GetValidOrDelete(key)
if !ok {
v, _, _ := cc.sf.Do(key, func() (any, error) {
res, err := cc.next.Get(ctx, key)
item := &certCacheValue{
ts: time.Now(),
res: res,
err: err,
}
if ctx.Err() == nil {
cc.cache.Set(key, item)
}
return item, err
})
resItem = v.(*certCacheValue)
}
return resItem.res, resItem.err
}

func (cc *LocalCertCache) Put(ctx context.Context, key string, data []byte) error {
cc.cache.Set(key, certCacheValue{data, nil}, 0)
cc.cache.Set(key, &certCacheValue{
ts: time.Now(),
res: data,
err: nil,
})
return cc.next.Put(ctx, key, data)
}

Expand All @@ -64,14 +67,10 @@ func (cc *LocalCertCache) Delete(ctx context.Context, key string) error {
}

func (cc *LocalCertCache) Close() error {
var err error
cc.stopOnce.Do(func() {
cc.cache.Stop()
if cacheCloser, ok := cc.next.(io.Closer); ok {
err = cacheCloser.Close()
}
})
return err
if cacheCloser, ok := cc.next.(io.Closer); ok {
return cacheCloser.Close()
}
return nil
}

var _ autocert.Cache = new(LocalCertCache)
53 changes: 18 additions & 35 deletions dialer/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ import (
"net/url"
"time"

"github.com/jellydator/ttlcache/v3"
"github.com/Snawoot/secache"
xproxy "golang.org/x/net/proxy"
"golang.org/x/sync/singleflight"
)

type dialerCacheKey struct {
Expand All @@ -17,20 +16,14 @@ type dialerCacheKey struct {
}

type dialerCacheValue struct {
dialer xproxy.Dialer
err error
expires time.Time
dialer xproxy.Dialer
err error
}

var (
dialerCache = ttlcache.New[dialerCacheKey, dialerCacheValue](
ttlcache.WithDisableTouchOnHit[dialerCacheKey, dialerCacheValue](),
)
dialerCacheSingleFlight = new(singleflight.Group)
)

func init() {
go dialerCache.Start()
}
var dialerCache = secache.New[dialerCacheKey, *dialerCacheValue](3, func(key dialerCacheKey, val *dialerCacheValue) bool {
return time.Now().Before(val.expires)
})

func GetCachedDialer(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
params, err := url.ParseQuery(u.RawQuery)
Expand All @@ -51,29 +44,19 @@ func GetCachedDialer(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
if err != nil {
return nil, fmt.Errorf("cached dialer: unable to parse TTL duration %q: %w", params.Get("ttl"), err)
}
cacheRes := dialerCache.Get(
item := dialerCache.GetOrCreate(
dialerCacheKey{
url: params.Get("url"),
next: next,
},
ttlcache.WithLoader[dialerCacheKey, dialerCacheValue](
ttlcache.NewSuppressedLoader[dialerCacheKey, dialerCacheValue](
ttlcache.LoaderFunc[dialerCacheKey, dialerCacheValue](
func(c *ttlcache.Cache[dialerCacheKey, dialerCacheValue], key dialerCacheKey) *ttlcache.Item[dialerCacheKey, dialerCacheValue] {
dialer, err := xproxy.FromURL(parsedURL, next)
return c.Set(
key,
dialerCacheValue{
dialer: dialer,
err: err,
},
ttl,
)
},
),
dialerCacheSingleFlight,
),
),
).Value()
return cacheRes.dialer, cacheRes.err
func() *dialerCacheValue {
dialer, err := xproxy.FromURL(parsedURL, next)
return &dialerCacheValue{
expires: time.Now().Add(ttl),
dialer: dialer,
err: err,
}
},
)
return item.dialer, item.err
}
104 changes: 52 additions & 52 deletions dialer/rescache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
"net"
"net/netip"
"strings"
"sync"
"time"

"github.com/SenseUnit/dumbproxy/dialer/dto"
"github.com/Snawoot/secache"
"github.com/hashicorp/go-multierror"
"github.com/jellydator/ttlcache/v3"
"golang.org/x/sync/singleflight"

"github.com/SenseUnit/dumbproxy/dialer/dto"
)

type resolverCacheKey struct {
Expand All @@ -20,46 +21,36 @@ type resolverCacheKey struct {
}

type resolverCacheValue struct {
addrs []netip.Addr
err error
expires time.Time
addrs []netip.Addr
err error
}

type NameResolveCachingDialer struct {
cache *ttlcache.Cache[resolverCacheKey, resolverCacheValue]
next Dialer
startOnce sync.Once
stopOnce sync.Once
resolver Resolver
cache secache.Cache[resolverCacheKey, *resolverCacheValue]
sf singleflight.Group
posTTL time.Duration
negTTL time.Duration
timeout time.Duration
next Dialer
}

func NewNameResolveCachingDialer(next Dialer, resolver Resolver, posTTL, negTTL, timeout time.Duration) *NameResolveCachingDialer {
cache := ttlcache.New[resolverCacheKey, resolverCacheValue](
ttlcache.WithDisableTouchOnHit[resolverCacheKey, resolverCacheValue](),
ttlcache.WithLoader(
ttlcache.NewSuppressedLoader(
ttlcache.LoaderFunc[resolverCacheKey, resolverCacheValue](
func(c *ttlcache.Cache[resolverCacheKey, resolverCacheValue], key resolverCacheKey) *ttlcache.Item[resolverCacheKey, resolverCacheValue] {
ctx, cl := context.WithTimeout(context.Background(), timeout)
defer cl()
res, err := resolver.LookupNetIP(ctx, key.network, key.host)
for i := range res {
res[i] = res[i].Unmap()
}
setTTL := negTTL
if err == nil {
setTTL = posTTL
}
return c.Set(key, resolverCacheValue{
addrs: res,
err: err,
}, setTTL)
},
),
nil),
),
)
// func(c *ttlcache.Cache[resolverCacheKey, resolverCacheValue], key resolverCacheKey) *ttlcache.Item[resolverCacheKey, resolverCacheValue] {
// },
return &NameResolveCachingDialer{
cache: cache,
next: next,
resolver: resolver,
cache: *(secache.New[resolverCacheKey, *resolverCacheValue](
3,
func(key resolverCacheKey, item *resolverCacheValue) bool {
return time.Now().Before(item.expires)
},
)),
posTTL: posTTL,
negTTL: negTTL,
timeout: timeout,
next: next,
}
}

Expand Down Expand Up @@ -91,16 +82,35 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network,
}

host = strings.ToLower(host)

resItem := nrcd.cache.Get(resolverCacheKey{
key := resolverCacheKey{
network: resolveNetwork,
host: host,
})
if resItem == nil {
return nil, fmt.Errorf("cache lookup failed for pair <%q, %q>", resolveNetwork, host)
}

res := resItem.Value()
res, ok := nrcd.cache.GetValidOrDelete(key)
if !ok {
v, _, _ := nrcd.sf.Do(key.network+":"+key.host, func() (any, error) {
ctx, cl := context.WithTimeout(context.Background(), nrcd.timeout)
defer cl()
res, err := nrcd.resolver.LookupNetIP(ctx, key.network, key.host)
for i := range res {
res[i] = res[i].Unmap()
}
setTTL := nrcd.negTTL
if err == nil {
setTTL = nrcd.posTTL
}
item := &resolverCacheValue{
expires: time.Now().Add(setTTL),
addrs: res,
err: err,
}
nrcd.cache.Set(key, item)
return item, nil
})
res = v.(*resolverCacheValue)
}

if res.err != nil {
return nil, res.err
}
Expand Down Expand Up @@ -129,15 +139,5 @@ func (nrcd *NameResolveCachingDialer) WantsHostname(ctx context.Context, net, ad
return WantsHostname(ctx, net, address, nrcd.next)
}

func (nrcd *NameResolveCachingDialer) Start() {
nrcd.startOnce.Do(func() {
go nrcd.cache.Start()
})
}

func (nrcd *NameResolveCachingDialer) Stop() {
nrcd.stopOnce.Do(nrcd.cache.Stop)
}

var _ Dialer = new(NameResolveCachingDialer)
var _ HostnameWanter = new(NameResolveCachingDialer)
Loading