diff --git a/groupcache.go b/groupcache.go index b68a4eb..586fb79 100644 --- a/groupcache.go +++ b/groupcache.go @@ -97,11 +97,12 @@ func newGroup(name string, cacheBytes int64, getter Getter, peers PeerPicker) *G panic("duplicate registration of group " + name) } g := &Group{ - name: name, - getter: getter, - peers: peers, - cacheBytes: cacheBytes, - loadGroup: &singleflight.Group{}, + name: name, + getter: getter, + peers: peers, + cacheBytes: cacheBytes, + loadGroup: &singleflight.Group{}, + removeGroup: &singleflight.Group{}, } if fn := newGroupHook; fn != nil { fn(g) @@ -167,6 +168,10 @@ type Group struct { // concurrent callers. loadGroup flightGroup + // removeGroup ensures that each removed key is only removed + // remotely once regardless of the number of concurrent callers. + removeGroup flightGroup + _ int32 // force Stats to be 8-byte aligned on 32-bit platforms // Stats are statistics on the group. @@ -177,8 +182,8 @@ type Group struct { // satisfies. We define this so that we may test with an alternate // implementation. type flightGroup interface { - // Done is called when Do is done. Do(key string, fn func() (interface{}, error)) (interface{}, error) + Lock(fn func()) } // Stats are per-group statistics. @@ -233,6 +238,53 @@ func (g *Group) Get(ctx Context, key string, dest Sink) error { return setSinkView(dest, value) } +// Remove clears the key from our cache then forwards the remove +// request to all peers. +func (g *Group) Remove(ctx Context, key string) error { + _, err := g.removeGroup.Do(key, func() (interface{}, error) { + + // Remove from key owner first + owner, ok := g.peers.PickPeer(key) + if ok { + if err := g.removeFromPeer(ctx, owner, key); err != nil { + return nil, err + } + } + // Remove from our cache first in case we are owner + g.localRemove(key) + wg := sync.WaitGroup{} + errs := make(chan error) + + // Asynchronously clear the key from all hot and main caches of peers + for _, peer := range g.peers.GetAll() { + // avoid deleting from owner a second time + if peer == owner { + continue + } + + wg.Add(1) + go func() { + errs <- g.removeFromPeer(ctx, peer, key) + wg.Done() + }() + } + go func() { + wg.Wait() + close(errs) + }() + + // TODO(thrawn01): Should we report all errors? Reporting context + // cancelled error for each peer doesn't make much sense. + var err error + for e := range errs { + err = e + } + + return nil, err + }) + return err +} + // load loads key either by invoking the getter locally or by sending it to another machine. func (g *Group) load(ctx Context, key string, dest Sink) (value ByteView, destPopulated bool, err error) { g.Stats.Loads.Add(1) @@ -330,6 +382,14 @@ func (g *Group) getFromPeer(ctx Context, peer ProtoGetter, key string) (ByteView return value, nil } +func (g *Group) removeFromPeer(ctx Context, peer ProtoGetter, key string) error { + req := &pb.GetRequest{ + Group: &g.name, + Key: &key, + } + return peer.Remove(ctx, req) +} + func (g *Group) lookupCache(key string) (value ByteView, ok bool) { if g.cacheBytes <= 0 { return @@ -342,6 +402,19 @@ func (g *Group) lookupCache(key string) (value ByteView, ok bool) { return } +func (g *Group) localRemove(key string) { + // Clear key from our local cache + if g.cacheBytes <= 0 { + return + } + + // Ensure no requests are in flight + g.loadGroup.Lock(func() { + g.hotCache.remove(key) + g.mainCache.remove(key) + }) +} + func (g *Group) populateCache(key string, value ByteView, cache *cache) { if g.cacheBytes <= 0 { return @@ -447,6 +520,15 @@ func (c *cache) get(key string) (value ByteView, ok bool) { return vi.(ByteView), true } +func (c *cache) remove(key string) { + c.mu.Lock() + defer c.mu.Unlock() + if c.lru == nil { + return + } + c.lru.Remove(key) +} + func (c *cache) removeOldest() { c.mu.Lock() defer c.mu.Unlock() diff --git a/groupcache_test.go b/groupcache_test.go index f76d270..da24f28 100644 --- a/groupcache_test.go +++ b/groupcache_test.go @@ -263,6 +263,14 @@ func (p *fakePeer) Get(_ Context, in *pb.GetRequest, out *pb.GetResponse) error return nil } +func (p *fakePeer) Remove(_ Context, in *pb.GetRequest) error { + p.hits++ + if p.fail { + return errors.New("simulated error from peer") + } + return nil +} + type fakePeers []ProtoGetter func (p fakePeers) PickPeer(key string) (peer ProtoGetter, ok bool) { @@ -273,6 +281,10 @@ func (p fakePeers) PickPeer(key string) (peer ProtoGetter, ok bool) { return p[n], p[n] != nil } +func (p fakePeers) GetAll() []ProtoGetter { + return p +} + // tests that peers (virtual, in-process) are hit, and how much. func TestPeers(t *testing.T) { once.Do(testSetup) @@ -406,6 +418,10 @@ func (g *orderedFlightGroup) Do(key string, fn func() (interface{}, error)) (int return g.orig.Do(key, fn) } +func (g *orderedFlightGroup) Lock(fn func()) { + fn() +} + // TestNoDedup tests invariants on the cache size when singleflight is // unable to dedup calls. func TestNoDedup(t *testing.T) { diff --git a/http.go b/http.go index f99658a..568da49 100644 --- a/http.go +++ b/http.go @@ -126,6 +126,20 @@ func (p *HTTPPool) Set(peers ...string) { } } +// GetAll returns all the peers in the pool +func (p *HTTPPool) GetAll() []ProtoGetter { + p.mu.Lock() + defer p.mu.Unlock() + + var i int + res := make([]ProtoGetter, len(p.httpGetters)) + for _, v := range p.httpGetters { + res[i] = v + i++ + } + return res +} + func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) { p.mu.Lock() defer p.mu.Unlock() @@ -163,6 +177,13 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { } group.Stats.ServerRequests.Add(1) + + // Delete the key and return 200 + if r.Method == http.MethodDelete { + group.localRemove(key) + return + } + var b []byte value := AllocatingByteSliceSink(&b) @@ -201,14 +222,14 @@ var bufferPool = sync.Pool{ New: func() interface{} { return new(bytes.Buffer) }, } -func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error { +func (h *httpGetter) makeRequest(context Context, method string, in *pb.GetRequest, out *http.Response) error { u := fmt.Sprintf( "%v%v/%v", h.baseURL, url.QueryEscape(in.GetGroup()), url.QueryEscape(in.GetKey()), ) - req, err := http.NewRequest("GET", u, nil) + req, err := http.NewRequest(method, u, nil) if err != nil { return err } @@ -220,6 +241,15 @@ func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse if err != nil { return err } + *out = *res + return nil +} + +func (h *httpGetter) Get(ctx Context, in *pb.GetRequest, out *pb.GetResponse) error { + var res http.Response + if err := h.makeRequest(ctx, http.MethodGet, in, &res); err != nil { + return err + } defer res.Body.Close() if res.StatusCode != http.StatusOK { return fmt.Errorf("server returned: %v", res.Status) @@ -227,7 +257,7 @@ func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse b := bufferPool.Get().(*bytes.Buffer) b.Reset() defer bufferPool.Put(b) - _, err = io.Copy(b, res.Body) + _, err := io.Copy(b, res.Body) if err != nil { return fmt.Errorf("reading response body: %v", err) } @@ -237,3 +267,15 @@ func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse } return nil } + +func (h *httpGetter) Remove(ctx Context, in *pb.GetRequest) error { + var res http.Response + if err := h.makeRequest(ctx, http.MethodDelete, in, &res); err != nil { + return err + } + res.Body.Close() + if res.StatusCode != http.StatusOK { + return fmt.Errorf("server returned: %v", res.Status) + } + return nil +} diff --git a/http_test.go b/http_test.go index e782541..0453999 100644 --- a/http_test.go +++ b/http_test.go @@ -19,9 +19,11 @@ package groupcache import ( "errors" "flag" + "fmt" "log" "net" "net/http" + "net/http/httptest" "os" "os/exec" "strconv" @@ -32,14 +34,15 @@ import ( ) var ( - peerAddrs = flag.String("test_peer_addrs", "", "Comma-separated list of peer addresses; used by TestHTTPPool") - peerIndex = flag.Int("test_peer_index", -1, "Index of which peer this child is; used by TestHTTPPool") - peerChild = flag.Bool("test_peer_child", false, "True if running as a child process; used by TestHTTPPool") + peerAddrs = flag.String("test_peer_addrs", "", "Comma-separated list of peer addresses; used by TestHTTPPool") + peerIndex = flag.Int("test_peer_index", -1, "Index of which peer this child is; used by TestHTTPPool") + peerChild = flag.Bool("test_peer_child", false, "True if running as a child process; used by TestHTTPPool") + serverAddr = flag.String("test_server_addr", "", "Address of the server Child Getters will hit ; used by TestHTTPPool") ) func TestHTTPPool(t *testing.T) { if *peerChild { - beChildForTestHTTPPool() + beChildForTestHTTPPool(t) os.Exit(0) } @@ -48,6 +51,13 @@ func TestHTTPPool(t *testing.T) { nGets = 100 ) + var serverHits int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello") + serverHits++ + })) + defer ts.Close() + var childAddr []string for i := 0; i < nChild; i++ { childAddr = append(childAddr, pickFreeAddr(t)) @@ -61,6 +71,7 @@ func TestHTTPPool(t *testing.T) { "--test_peer_child", "--test_peer_addrs="+strings.Join(childAddr, ","), "--test_peer_index="+strconv.Itoa(i), + "--test_server_addr="+ts.URL, ) cmds = append(cmds, cmd) wg.Add(1) @@ -100,6 +111,41 @@ func TestHTTPPool(t *testing.T) { } t.Logf("Get key=%q, value=%q (peer:key)", key, value) } + + if serverHits != nGets { + t.Error("expected serverHits to equal nGets") + } + serverHits = 0 + + var value string + var key = "removeTestKey" + + // Multiple gets on the same key + for i := 0; i < 2; i++ { + if err := g.Get(nil, key, StringSink(&value)); err != nil { + t.Fatal(err) + } + } + + // Should result in only 1 server get + if serverHits != 1 { + t.Error("expected serverHits to be '1'") + } + + // Remove the key from the cache and we should see another server hit + if err := g.Remove(nil, key); err != nil { + t.Fatal(err) + } + + // Get the key again + if err := g.Get(nil, key, StringSink(&value)); err != nil { + t.Fatal(err) + } + + // Should register another server get + if serverHits != 2 { + t.Error("expected serverHits to be '2'") + } } func testKeys(n int) (keys []string) { @@ -110,13 +156,17 @@ func testKeys(n int) (keys []string) { return } -func beChildForTestHTTPPool() { +func beChildForTestHTTPPool(t *testing.T) { addrs := strings.Split(*peerAddrs, ",") p := NewHTTPPool("http://" + addrs[*peerIndex]) p.Set(addrToURL(addrs)...) getter := GetterFunc(func(ctx Context, key string, dest Sink) error { + if _, err := http.Get(*serverAddr); err != nil { + t.Logf("HTTP request from getter failed with '%s'", err) + } + dest.SetString(strconv.Itoa(*peerIndex)+":"+key, time.Time{}) return nil }) diff --git a/peers.go b/peers.go index aff08d3..4bf9e62 100644 --- a/peers.go +++ b/peers.go @@ -30,6 +30,7 @@ type Context interface{} // ProtoGetter is the interface that must be implemented by a peer. type ProtoGetter interface { Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error + Remove(context Context, in *pb.GetRequest) error } // PeerPicker is the interface that must be implemented to locate @@ -39,12 +40,14 @@ type PeerPicker interface { // and true to indicate that a remote peer was nominated. // It returns nil, false if the key owner is the current peer. PickPeer(key string) (peer ProtoGetter, ok bool) + GetAll() []ProtoGetter } // NoPeers is an implementation of PeerPicker that never finds a peer. type NoPeers struct{} func (NoPeers) PickPeer(key string) (peer ProtoGetter, ok bool) { return } +func (NoPeers) GetAll() []ProtoGetter { return []ProtoGetter{} } var ( portPicker func(groupName string) PeerPicker diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index ff2c2ee..a2a6326 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -62,3 +62,12 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, err return c.val, c.err } + +// Lock prevents single flights from occurring for the duration +// of the provided function. This allows users to clear caches +// or preform some operation in between running flights. +func (g *Group) Lock(fn func()) { + g.mu.Lock() + defer g.mu.Unlock() + fn() +}