From 478b3fdb185b5fb734d547cfb3eb327079369a04 Mon Sep 17 00:00:00 2001 From: William Bergeron-Drouin Date: Tue, 16 Apr 2024 13:51:45 -0400 Subject: [PATCH] Fix group.Set not propagating hot caches of n>2 peers --- groupcache.go | 78 ++++++++++++++++++++++++++++++++++++++++----- http.go | 7 +++- http_test.go | 49 ++++++++++++++++++++++++++-- integration_test.go | 9 ++++-- 4 files changed, 130 insertions(+), 13 deletions(-) diff --git a/groupcache.go b/groupcache.go index 541e89a..ca90d32 100644 --- a/groupcache.go +++ b/groupcache.go @@ -272,21 +272,78 @@ func (g *Group) Set(ctx context.Context, key string, value []byte, expire time.T } _, err := g.setGroup.Do(key, func() (interface{}, error) { + wg := sync.WaitGroup{} + errs := make(chan error) + // If remote peer owns this key owner, ok := g.peers.PickPeer(key) if ok { - if err := g.setFromPeer(ctx, owner, key, value, expire); err != nil { + if err := g.setFromPeer(ctx, owner, key, value, expire, false); err != nil { return nil, err } // TODO(thrawn01): Not sure if this is useful outside of tests... // maybe we should ALWAYS update the local cache? - if hotCache { - g.localSet(key, value, expire, &g.hotCache) + if !hotCache { + return nil, nil } - return nil, nil + + g.localSet(key, value, expire, &g.hotCache) + + for _, peer := range g.peers.GetAll() { + if peer == owner { + // Avoid setting to owner a second time + continue + } + wg.Add(1) + go func(peer ProtoGetter) { + errs <- g.setFromPeer(ctx, peer, key, value, expire, true) + wg.Done() + }(peer) + } + + go func() { + wg.Wait() + close(errs) + }() + + var err error + for e := range errs { + if e != nil { + err = errors.Join(err, e) + } + } + + return nil, err } // We own this key g.localSet(key, value, expire, &g.mainCache) + + if hotCache { + // Also set to the hot cache of all peers + + for _, peer := range g.peers.GetAll() { + wg.Add(1) + go func(peer ProtoGetter) { + errs <- g.setFromPeer(ctx, peer, key, value, expire, true) + wg.Done() + }(peer) + } + + go func() { + wg.Wait() + close(errs) + }() + + var err error + for e := range errs { + if e != nil { + err = errors.Join(err, e) + } + } + + return nil, err + } + return nil, nil }) return err @@ -329,11 +386,11 @@ func (g *Group) Remove(ctx context.Context, key string) error { close(errs) }() - // TODO(thrawn01): Should we report all errors? Reporting context - // cancelled error for each peer doesn't make much sense. var err error for e := range errs { - err = e + if e != nil { + err = errors.Join(err, e) + } } return nil, err @@ -473,7 +530,7 @@ func (g *Group) getFromPeer(ctx context.Context, peer ProtoGetter, key string) ( return value, nil } -func (g *Group) setFromPeer(ctx context.Context, peer ProtoGetter, k string, v []byte, e time.Time) error { +func (g *Group) setFromPeer(ctx context.Context, peer ProtoGetter, k string, v []byte, e time.Time, hotCache bool) error { var expire int64 if !e.IsZero() { expire = e.UnixNano() @@ -484,6 +541,11 @@ func (g *Group) setFromPeer(ctx context.Context, peer ProtoGetter, k string, v [ Key: &k, Value: v, } + + if hotCache { + req.HotCache = &hotCache + } + return peer.Set(ctx, req) } diff --git a/http.go b/http.go index 32e1f90..b4f6738 100644 --- a/http.go +++ b/http.go @@ -217,7 +217,12 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { expire = time.Unix(*out.Expire/int64(time.Second), *out.Expire%int64(time.Second)) } - group.localSet(*out.Key, out.Value, expire, &group.mainCache) + c := &group.mainCache + if out.HotCache != nil && *out.HotCache { + c = &group.hotCache + } + + group.localSet(*out.Key, out.Value, expire, c) return } diff --git a/http_test.go b/http_test.go index 23f9a7b..4026ef0 100644 --- a/http_test.go +++ b/http_test.go @@ -95,7 +95,9 @@ func TestHTTPPool(t *testing.T) { wg.Wait() // Use a dummy self address so that we don't handle gets in-process. - p := NewHTTPPool("should-be-ignored") + p, mux := newTestHTTPPool("should-be-ignored") + defer mux.Close() + p.Set(addrToURL(childAddr)...) // Dummy getter function. Gets should go to children only. @@ -219,7 +221,8 @@ func testKeys(n int) (keys []string) { func beChildForTestHTTPPool(t *testing.T) { addrs := strings.Split(*peerAddrs, ",") - p := NewHTTPPool("http://" + addrs[*peerIndex]) + p, mux := newTestHTTPPool("http://" + addrs[*peerIndex]) + defer mux.Close() p.Set(addrToURL(addrs)...) getter := GetterFunc(func(ctx context.Context, key string, dest Sink) error { @@ -286,3 +289,45 @@ func awaitAddrReady(t *testing.T, addr string, wg *sync.WaitGroup) { time.Sleep(delay) } } + +type serveMux struct { + mux *http.ServeMux + handlers map[string]http.Handler +} + +func newTestHTTPPool(self string) (*HTTPPool, *serveMux) { + httpPoolMade, portPicker = false, nil // Testing only + + p := NewHTTPPoolOpts(self, nil) + sm := &serveMux{ + mux: http.NewServeMux(), + handlers: make(map[string]http.Handler), + } + + sm.handlers[p.opts.BasePath] = p + + return p, sm +} + +func (s *serveMux) Handle(pattern string, handler http.Handler) { + s.handlers[pattern] = handler + s.mux.Handle(pattern, handler) +} + +func (s *serveMux) Close() { + for pattern := range s.handlers { + delete(s.handlers, pattern) + } +} + +func (s *serveMux) RemoveHandle(pattern string) { + delete(s.handlers, pattern) +} + +func (s *serveMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if _, ok := s.handlers[r.URL.Path]; ok { + s.mux.ServeHTTP(w, r) + } else { + http.NotFound(w, r) + } +} diff --git a/integration_test.go b/integration_test.go index 6a09d73..f7f55a4 100644 --- a/integration_test.go +++ b/integration_test.go @@ -71,7 +71,9 @@ func TestManualSet(t *testing.T) { wg.Wait() // Use a dummy self address so that we don't handle gets in-process. - p := NewHTTPPool("should-be-ignored") + p, mux := newTestHTTPPool("should-be-ignored") + defer mux.Close() + p.Set(addrToURL(childAddr)...) // Dummy getter function. Gets should go to children only. @@ -146,8 +148,11 @@ var _ http.Handler = (*overwriteHttpPool)(nil) func beChildForIntegrationTest(t *testing.T) { addrs := strings.Split(*peerAddrs, ",") + p, mux := newTestHTTPPool("http://" + addrs[*peerIndex]) + defer mux.Close() + hp := overwriteHttpPool{ - p: NewHTTPPool("http://" + addrs[*peerIndex]), + p: p, } hp.p.Set(addrToURL(addrs)...)