diff --git a/go.mod b/go.mod index 35a1857..1c6c15b 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,4 @@ require ( github.com/sirupsen/logrus v1.6.0 ) -go 1.13 +go 1.15 diff --git a/groupcache.go b/groupcache.go index 5eda3cf..915e265 100644 --- a/groupcache.go +++ b/groupcache.go @@ -116,6 +116,7 @@ func newGroup(name string, cacheBytes int64, getter Getter, peers PeerPicker) *G peers: peers, cacheBytes: cacheBytes, loadGroup: &singleflight.Group{}, + setGroup: &singleflight.Group{}, removeGroup: &singleflight.Group{}, } if fn := newGroupHook; fn != nil { @@ -182,6 +183,10 @@ type Group struct { // concurrent callers. loadGroup flightGroup + // setGroup ensures that each added key is only added + // remotely once regardless of the number of concurrent callers. + setGroup flightGroup + // removeGroup ensures that each removed key is only removed // remotely once regardless of the number of concurrent callers. removeGroup flightGroup @@ -253,6 +258,34 @@ func (g *Group) Get(ctx context.Context, key string, dest Sink) error { return setSinkView(dest, value) } +func (g *Group) Set(ctx context.Context, key string, value []byte, expire time.Time, hotCache bool) error { + g.peersOnce.Do(g.initPeers) + + if key == "" { + return errors.New("empty Set() key not allowed") + } + + _, err := g.setGroup.Do(key, func() (interface{}, error) { + // If remote peer owns this key + owner, ok := g.peers.PickPeer(key) + if ok { + if err := g.setFromPeer(ctx, owner, key, value, expire); err != nil { + return nil, err + } + // TODO(thrawn01): Not sure if this is useful outside of tests... + // maybe we should ALWAYS update the local cache? + if hotCache { + g.localSet(key, value, expire, &g.hotCache) + } + return nil, nil + } + // We own this key + g.localSet(key, value, expire, &g.mainCache) + return nil, nil + }) + return err +} + // Remove clears the key from our cache then forwards the remove // request to all peers. func (g *Group) Remove(ctx context.Context, key string) error { @@ -425,6 +458,20 @@ func (g *Group) getFromPeer(ctx context.Context, peer ProtoGetter, key string) ( return value, nil } +func (g *Group) setFromPeer(ctx context.Context, peer ProtoGetter, k string, v []byte, e time.Time) error { + var expire int64 + if !e.IsZero() { + expire = e.UnixNano() + } + req := &pb.SetRequest{ + Expire: &expire, + Group: &g.name, + Key: &k, + Value: v, + } + return peer.Set(ctx, req) +} + func (g *Group) removeFromPeer(ctx context.Context, peer ProtoGetter, key string) error { req := &pb.GetRequest{ Group: &g.name, @@ -445,6 +492,22 @@ func (g *Group) lookupCache(key string) (value ByteView, ok bool) { return } +func (g *Group) localSet(key string, value []byte, expire time.Time, cache *cache) { + if g.cacheBytes <= 0 { + return + } + + bv := ByteView{ + b: value, + e: expire, + } + + // Ensure no requests are in flight + g.loadGroup.Lock(func() { + g.populateCache(key, bv, cache) + }) +} + func (g *Group) localRemove(key string) { // Clear key from our local cache if g.cacheBytes <= 0 { diff --git a/groupcache_test.go b/groupcache_test.go index a99a60e..e87c62a 100644 --- a/groupcache_test.go +++ b/groupcache_test.go @@ -263,6 +263,14 @@ func (p *fakePeer) Get(_ context.Context, in *pb.GetRequest, out *pb.GetResponse return nil } +func (p *fakePeer) Set(_ context.Context, in *pb.SetRequest) error { + p.hits++ + if p.fail { + return errors.New("simulated error from peer") + } + return nil +} + func (p *fakePeer) Remove(_ context.Context, in *pb.GetRequest) error { p.hits++ if p.fail { diff --git a/groupcachepb/groupcache.pb.go b/groupcachepb/groupcache.pb.go index 85b7c69..d6abd47 100644 --- a/groupcachepb/groupcache.pb.go +++ b/groupcachepb/groupcache.pb.go @@ -10,6 +10,7 @@ It is generated from these files: It has these top-level messages: GetRequest GetResponse + SetRequest */ package groupcachepb @@ -86,26 +87,69 @@ func (m *GetResponse) GetExpire() int64 { return 0 } +type SetRequest struct { + Group *string `protobuf:"bytes,1,req,name=group" json:"group,omitempty"` + Key *string `protobuf:"bytes,2,req,name=key" json:"key,omitempty"` + Value []byte `protobuf:"bytes,3,opt,name=value" json:"value,omitempty"` + Expire *int64 `protobuf:"varint,4,opt,name=expire" json:"expire,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SetRequest) Reset() { *m = SetRequest{} } +func (m *SetRequest) String() string { return proto.CompactTextString(m) } +func (*SetRequest) ProtoMessage() {} +func (*SetRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +func (m *SetRequest) GetGroup() string { + if m != nil && m.Group != nil { + return *m.Group + } + return "" +} + +func (m *SetRequest) GetKey() string { + if m != nil && m.Key != nil { + return *m.Key + } + return "" +} + +func (m *SetRequest) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +func (m *SetRequest) GetExpire() int64 { + if m != nil && m.Expire != nil { + return *m.Expire + } + return 0 +} + func init() { proto.RegisterType((*GetRequest)(nil), "groupcachepb.GetRequest") proto.RegisterType((*GetResponse)(nil), "groupcachepb.GetResponse") + proto.RegisterType((*SetRequest)(nil), "groupcachepb.SetRequest") } func init() { proto.RegisterFile("groupcache.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 197 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x48, 0x2f, 0xca, 0x2f, - 0x2d, 0x48, 0x4e, 0x4c, 0xce, 0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x41, 0x88, - 0x14, 0x24, 0x29, 0x99, 0x70, 0x71, 0xb9, 0xa7, 0x96, 0x04, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, - 0x08, 0x89, 0x70, 0xb1, 0x82, 0x65, 0x25, 0x18, 0x15, 0x98, 0x34, 0x38, 0x83, 0x20, 0x1c, 0x21, - 0x01, 0x2e, 0xe6, 0xec, 0xd4, 0x4a, 0x09, 0x26, 0xb0, 0x18, 0x88, 0xa9, 0x14, 0xc5, 0xc5, 0x0d, - 0xd6, 0x55, 0x5c, 0x90, 0x9f, 0x57, 0x9c, 0x0a, 0xd2, 0x56, 0x96, 0x98, 0x53, 0x9a, 0x2a, 0xc1, - 0xa8, 0xc0, 0xa8, 0xc1, 0x13, 0x04, 0xe1, 0x08, 0xc9, 0x72, 0x71, 0xe5, 0x66, 0xe6, 0x95, 0x96, - 0xa4, 0xc6, 0x17, 0x16, 0x14, 0x4b, 0x30, 0x29, 0x30, 0x6a, 0x30, 0x06, 0x71, 0x42, 0x44, 0x02, - 0x0b, 0x8a, 0x85, 0xc4, 0xb8, 0xd8, 0x52, 0x2b, 0x0a, 0x32, 0x8b, 0x52, 0x25, 0x98, 0x15, 0x18, - 0x35, 0x98, 0x83, 0xa0, 0x3c, 0x23, 0x2f, 0x2e, 0x2e, 0x77, 0x90, 0xb5, 0xce, 0x20, 0x17, 0x0a, - 0xd9, 0x70, 0x31, 0xbb, 0xa7, 0x96, 0x08, 0x49, 0xe8, 0x21, 0xbb, 0x5a, 0x0f, 0xe1, 0x64, 0x29, - 0x49, 0x2c, 0x32, 0x10, 0x67, 0x29, 0x31, 0x00, 0x02, 0x00, 0x00, 0xff, 0xff, 0xd4, 0x73, 0xe1, - 0xb8, 0xfe, 0x00, 0x00, 0x00, + // 215 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x50, 0x31, 0x4b, 0xc5, 0x30, + 0x18, 0x34, 0x8d, 0x0a, 0xfd, 0xec, 0x50, 0x82, 0x48, 0x14, 0x84, 0x90, 0x29, 0x53, 0x07, 0x71, + 0x74, 0x73, 0x28, 0xb8, 0x19, 0x37, 0x17, 0x69, 0xcb, 0x87, 0x16, 0xb5, 0x49, 0x9b, 0x44, 0x7c, + 0xff, 0xfe, 0x91, 0xe6, 0x41, 0x3a, 0xbc, 0xe5, 0x6d, 0xb9, 0x3b, 0x2e, 0x77, 0xdf, 0x41, 0xfd, + 0xb9, 0x98, 0x60, 0x87, 0x6e, 0xf8, 0xc2, 0xc6, 0x2e, 0xc6, 0x1b, 0x56, 0x65, 0xc6, 0xf6, 0xf2, + 0x11, 0xa0, 0x45, 0xaf, 0x71, 0x0e, 0xe8, 0x3c, 0xbb, 0x86, 0x8b, 0x55, 0xe5, 0x44, 0x14, 0xaa, + 0xd4, 0x09, 0xb0, 0x1a, 0xe8, 0x37, 0xee, 0x78, 0xb1, 0x72, 0xf1, 0x29, 0xdf, 0xe1, 0x6a, 0x75, + 0x39, 0x6b, 0x26, 0x87, 0xd1, 0xf6, 0xd7, 0xfd, 0x04, 0xe4, 0x44, 0x10, 0x55, 0xe9, 0x04, 0xd8, + 0x3d, 0xc0, 0xef, 0x38, 0x05, 0x8f, 0x1f, 0xb3, 0x75, 0xbc, 0x10, 0x44, 0x11, 0x5d, 0x26, 0xe6, + 0xd5, 0x3a, 0x76, 0x03, 0x97, 0xf8, 0x6f, 0xc7, 0x05, 0x39, 0x15, 0x44, 0x51, 0x7d, 0x40, 0xb2, + 0x07, 0x78, 0x3b, 0xb9, 0x51, 0xae, 0x40, 0xb7, 0x15, 0x72, 0xc6, 0xf9, 0x36, 0xe3, 0xe1, 0x05, + 0xa0, 0x8d, 0x1f, 0x3d, 0xc7, 0x15, 0xd8, 0x13, 0xd0, 0x16, 0x3d, 0xe3, 0xcd, 0x76, 0x99, 0x26, + 0xcf, 0x72, 0x77, 0x7b, 0x44, 0x49, 0xa7, 0xcb, 0xb3, 0x7d, 0x00, 0x00, 0x00, 0xff, 0xff, 0x02, + 0x10, 0x64, 0xec, 0x62, 0x01, 0x00, 0x00, } diff --git a/groupcachepb/groupcache.proto b/groupcachepb/groupcache.proto index c14ab2c..a24b410 100644 --- a/groupcachepb/groupcache.proto +++ b/groupcachepb/groupcache.proto @@ -29,6 +29,13 @@ message GetResponse { optional int64 expire = 3; } +message SetRequest { + required string group = 1; + required string key = 2; + optional bytes value = 3; + optional int64 expire = 4; +} + service GroupCache { rpc Get(GetRequest) returns (GetResponse) { }; diff --git a/http.go b/http.go index c7fc689..b1dc10a 100644 --- a/http.go +++ b/http.go @@ -26,6 +26,7 @@ import ( "net/url" "strings" "sync" + "time" "github.com/golang/protobuf/proto" "github.com/mailgun/groupcache/v2/consistenthash" @@ -191,6 +192,34 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // The read the body and set the key value + if r.Method == http.MethodPut { + defer r.Body.Close() + b := bufferPool.Get().(*bytes.Buffer) + b.Reset() + defer bufferPool.Put(b) + _, err := io.Copy(b, r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var out pb.SetRequest + err = proto.Unmarshal(b.Bytes(), &out) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var expire time.Time + if out.Expire != nil && *out.Expire != 0 { + expire = time.Unix(*out.Expire/int64(time.Second), *out.Expire%int64(time.Second)) + } + + group.localSet(*out.Key, out.Value, expire, &group.mainCache) + return + } + var b []byte value := AllocatingByteSliceSink(&b) @@ -225,7 +254,6 @@ type httpGetter struct { baseURL string } -// GetURL func (p *httpGetter) GetURL() string { return p.baseURL } @@ -234,14 +262,19 @@ var bufferPool = sync.Pool{ New: func() interface{} { return new(bytes.Buffer) }, } -func (h *httpGetter) makeRequest(ctx context.Context, method string, in *pb.GetRequest, out *http.Response) error { +type request interface { + GetGroup() string + GetKey() string +} + +func (h *httpGetter) makeRequest(ctx context.Context, m string, in request, b io.Reader, out *http.Response) error { u := fmt.Sprintf( "%v%v/%v", h.baseURL, url.QueryEscape(in.GetGroup()), url.QueryEscape(in.GetKey()), ) - req, err := http.NewRequestWithContext(ctx, method, u, nil) + req, err := http.NewRequestWithContext(ctx, m, u, b) if err != nil { return err } @@ -261,7 +294,7 @@ func (h *httpGetter) makeRequest(ctx context.Context, method string, in *pb.GetR func (h *httpGetter) Get(ctx context.Context, in *pb.GetRequest, out *pb.GetResponse) error { var res http.Response - if err := h.makeRequest(ctx, http.MethodGet, in, &res); err != nil { + if err := h.makeRequest(ctx, http.MethodGet, in, nil, &res); err != nil { return err } defer res.Body.Close() @@ -282,9 +315,30 @@ func (h *httpGetter) Get(ctx context.Context, in *pb.GetRequest, out *pb.GetResp return nil } -func (h *httpGetter) Remove(ctx context.Context, in *pb.GetRequest) error { +func (h *httpGetter) Set(ctx context.Context, in *pb.SetRequest) error { + body, err := proto.Marshal(in) + if err != nil { + return fmt.Errorf("while marshaling SetRequest body: %w", err) + } var res http.Response - if err := h.makeRequest(ctx, http.MethodDelete, in, &res); err != nil { + if err := h.makeRequest(ctx, http.MethodPut, in, bytes.NewReader(body), &res); err != nil { + return err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("while reading body response: %v", res.Status) + } + return fmt.Errorf("server returned status %d: %s", res.StatusCode, body) + } + return nil +} + +func (h *httpGetter) Remove(ctx context.Context, in *pb.GetRequest) error { + var res http.Response + if err := h.makeRequest(ctx, http.MethodDelete, in, nil, &res); err != nil { return err } defer res.Body.Close() diff --git a/http_test.go b/http_test.go index d3b40d4..86b6d34 100644 --- a/http_test.go +++ b/http_test.go @@ -17,6 +17,7 @@ limitations under the License. package groupcache import ( + "bytes" "context" "errors" "flag" @@ -75,6 +76,7 @@ func TestHTTPPool(t *testing.T) { "--test_server_addr="+ts.URL, ) cmds = append(cmds, cmd) + cmd.Stdout = os.Stdout wg.Add(1) if err := cmd.Start(); err != nil { t.Fatal("failed to start child process: ", err) @@ -151,6 +153,27 @@ func TestHTTPPool(t *testing.T) { if serverHits != 2 { t.Error("expected serverHits to be '2'") } + + key = "setMyTestKey" + setValue := []byte("test set") + // Add the key to the cache, optionally updating our local hot cache + if err := g.Set(ctx, key, setValue, time.Time{}, false); err != nil { + t.Fatal(err) + } + + // Get the key + var getValue ByteView + if err := g.Get(ctx, key, ByteViewSink(&getValue)); err != nil { + t.Fatal(err) + } + + if serverHits != 2 { + t.Errorf("expected serverHits to be '3' got '%d'", serverHits) + } + + if !bytes.Equal(setValue, getValue.ByteSlice()) { + t.Fatal(errors.New(fmt.Sprintf("incorrect value retrieved after set: %s", getValue))) + } } func testKeys(n int) (keys []string) { diff --git a/peers.go b/peers.go index 18d8659..39fd76a 100644 --- a/peers.go +++ b/peers.go @@ -28,6 +28,7 @@ import ( type ProtoGetter interface { Get(context context.Context, in *pb.GetRequest, out *pb.GetResponse) error Remove(context context.Context, in *pb.GetRequest) error + Set(context context.Context, in *pb.SetRequest) error // GetURL returns the peer URL GetURL() string } diff --git a/sinks.go b/sinks.go index 1c04ee8..894fecf 100644 --- a/sinks.go +++ b/sinks.go @@ -170,7 +170,6 @@ func ProtoSink(m proto.Message) Sink { type protoSink struct { dst proto.Message // authoritative value typ string - ttl time.Duration v ByteView // encoded }