diff --git a/go.mod b/go.mod index 35a1857..1c6c15b 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,4 @@ require ( github.com/sirupsen/logrus v1.6.0 ) -go 1.13 +go 1.15 diff --git a/groupcache.go b/groupcache.go index 88c4b02..915e265 100644 --- a/groupcache.go +++ b/groupcache.go @@ -116,7 +116,7 @@ func newGroup(name string, cacheBytes int64, getter Getter, peers PeerPicker) *G peers: peers, cacheBytes: cacheBytes, loadGroup: &singleflight.Group{}, - setGroup: &singleflight.Group{}, + setGroup: &singleflight.Group{}, removeGroup: &singleflight.Group{}, } if fn := newGroupHook; fn != nil { @@ -258,49 +258,30 @@ func (g *Group) Get(ctx context.Context, key string, dest Sink) error { return setSinkView(dest, value) } -func (g *Group) Set(ctx context.Context, key string, value []byte) error { +func (g *Group) Set(ctx context.Context, key string, value []byte, expire time.Time, hotCache bool) error { g.peersOnce.Do(g.initPeers) - _, err := g.setGroup.Do(key, func() (interface{}, error) { + if key == "" { + return errors.New("empty Set() key not allowed") + } - // Set to key owner first + _, err := g.setGroup.Do(key, func() (interface{}, error) { + // If remote peer owns this key owner, ok := g.peers.PickPeer(key) if ok { - if err := g.setFromPeer(ctx, owner, key, value); err != nil { + if err := g.setFromPeer(ctx, owner, key, value, expire); err != nil { return nil, err } - } - // Set to our cache next - g.localSet(key, value) - wg := sync.WaitGroup{} - errs := make(chan error) - - // Asynchronously add the key and value to all hot and main caches of peers - for _, peer := range g.peers.GetAll() { - // avoid adding to owner a second time - if peer == owner { - continue + // TODO(thrawn01): Not sure if this is useful outside of tests... + // maybe we should ALWAYS update the local cache? + if hotCache { + g.localSet(key, value, expire, &g.hotCache) } - - wg.Add(1) - go func(peer ProtoGetter) { - errs <- g.setFromPeer(ctx, peer, key, value) - wg.Done() - }(peer) + return nil, nil } - go func() { - wg.Wait() - close(errs) - }() - - // TODO(thrawn01): Should we report all errors? Reporting context - // cancelled error for each peer doesn't make much sense. - var err error - for e := range errs { - err = e - } - - return nil, err + // We own this key + g.localSet(key, value, expire, &g.mainCache) + return nil, nil }) return err } @@ -477,11 +458,16 @@ func (g *Group) getFromPeer(ctx context.Context, peer ProtoGetter, key string) ( return value, nil } -func (g *Group) setFromPeer(ctx context.Context, peer ProtoGetter, key string, value []byte) error { - req := &pb.GetRequest{ - Group: &g.name, - Key: &key, - Value: value, +func (g *Group) setFromPeer(ctx context.Context, peer ProtoGetter, k string, v []byte, e time.Time) error { + var expire int64 + if !e.IsZero() { + expire = e.UnixNano() + } + req := &pb.SetRequest{ + Expire: &expire, + Group: &g.name, + Key: &k, + Value: v, } return peer.Set(ctx, req) } @@ -506,19 +492,19 @@ func (g *Group) lookupCache(key string) (value ByteView, ok bool) { return } -func (g *Group) localSet(key string, value []byte) { +func (g *Group) localSet(key string, value []byte, expire time.Time, cache *cache) { if g.cacheBytes <= 0 { return } bv := ByteView{ b: value, - e: time.Time{}, + e: expire, } + // Ensure no requests are in flight g.loadGroup.Lock(func() { - g.hotCache.set(key, bv) - g.mainCache.set(key, bv) + g.populateCache(key, bv, cache) }) } @@ -640,15 +626,6 @@ func (c *cache) get(key string) (value ByteView, ok bool) { return vi.(ByteView), true } -func (c *cache) set(key string, value ByteView) { - c.mu.Lock() - defer c.mu.Unlock() - if c.lru == nil { - return - } - c.lru.Add(key, value, time.Now().Add(60*time.Minute)) // TODO: parameterize this -} - func (c *cache) remove(key string) { c.mu.Lock() defer c.mu.Unlock() diff --git a/groupcache_test.go b/groupcache_test.go index 0ed0a23..e87c62a 100644 --- a/groupcache_test.go +++ b/groupcache_test.go @@ -263,7 +263,7 @@ func (p *fakePeer) Get(_ context.Context, in *pb.GetRequest, out *pb.GetResponse return nil } -func (p *fakePeer) Set(_ context.Context, in *pb.GetRequest) error { +func (p *fakePeer) Set(_ context.Context, in *pb.SetRequest) error { p.hits++ if p.fail { return errors.New("simulated error from peer") diff --git a/groupcachepb/groupcache.pb.go b/groupcachepb/groupcache.pb.go index f6c1313..d6abd47 100644 --- a/groupcachepb/groupcache.pb.go +++ b/groupcachepb/groupcache.pb.go @@ -10,6 +10,7 @@ It is generated from these files: It has these top-level messages: GetRequest GetResponse + SetRequest */ package groupcachepb @@ -31,7 +32,6 @@ const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package type GetRequest struct { Group *string `protobuf:"bytes,1,req,name=group" json:"group,omitempty"` Key *string `protobuf:"bytes,2,req,name=key" json:"key,omitempty"` - Value []byte XXX_unrecognized []byte `json:"-"` } @@ -87,26 +87,69 @@ func (m *GetResponse) GetExpire() int64 { return 0 } +type SetRequest struct { + Group *string `protobuf:"bytes,1,req,name=group" json:"group,omitempty"` + Key *string `protobuf:"bytes,2,req,name=key" json:"key,omitempty"` + Value []byte `protobuf:"bytes,3,opt,name=value" json:"value,omitempty"` + Expire *int64 `protobuf:"varint,4,opt,name=expire" json:"expire,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SetRequest) Reset() { *m = SetRequest{} } +func (m *SetRequest) String() string { return proto.CompactTextString(m) } +func (*SetRequest) ProtoMessage() {} +func (*SetRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +func (m *SetRequest) GetGroup() string { + if m != nil && m.Group != nil { + return *m.Group + } + return "" +} + +func (m *SetRequest) GetKey() string { + if m != nil && m.Key != nil { + return *m.Key + } + return "" +} + +func (m *SetRequest) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +func (m *SetRequest) GetExpire() int64 { + if m != nil && m.Expire != nil { + return *m.Expire + } + return 0 +} + func init() { proto.RegisterType((*GetRequest)(nil), "groupcachepb.GetRequest") proto.RegisterType((*GetResponse)(nil), "groupcachepb.GetResponse") + proto.RegisterType((*SetRequest)(nil), "groupcachepb.SetRequest") } func init() { proto.RegisterFile("groupcache.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 197 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x48, 0x2f, 0xca, 0x2f, - 0x2d, 0x48, 0x4e, 0x4c, 0xce, 0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x41, 0x88, - 0x14, 0x24, 0x29, 0x99, 0x70, 0x71, 0xb9, 0xa7, 0x96, 0x04, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, - 0x08, 0x89, 0x70, 0xb1, 0x82, 0x65, 0x25, 0x18, 0x15, 0x98, 0x34, 0x38, 0x83, 0x20, 0x1c, 0x21, - 0x01, 0x2e, 0xe6, 0xec, 0xd4, 0x4a, 0x09, 0x26, 0xb0, 0x18, 0x88, 0xa9, 0x14, 0xc5, 0xc5, 0x0d, - 0xd6, 0x55, 0x5c, 0x90, 0x9f, 0x57, 0x9c, 0x0a, 0xd2, 0x56, 0x96, 0x98, 0x53, 0x9a, 0x2a, 0xc1, - 0xa8, 0xc0, 0xa8, 0xc1, 0x13, 0x04, 0xe1, 0x08, 0xc9, 0x72, 0x71, 0xe5, 0x66, 0xe6, 0x95, 0x96, - 0xa4, 0xc6, 0x17, 0x16, 0x14, 0x4b, 0x30, 0x29, 0x30, 0x6a, 0x30, 0x06, 0x71, 0x42, 0x44, 0x02, - 0x0b, 0x8a, 0x85, 0xc4, 0xb8, 0xd8, 0x52, 0x2b, 0x0a, 0x32, 0x8b, 0x52, 0x25, 0x98, 0x15, 0x18, - 0x35, 0x98, 0x83, 0xa0, 0x3c, 0x23, 0x2f, 0x2e, 0x2e, 0x77, 0x90, 0xb5, 0xce, 0x20, 0x17, 0x0a, - 0xd9, 0x70, 0x31, 0xbb, 0xa7, 0x96, 0x08, 0x49, 0xe8, 0x21, 0xbb, 0x5a, 0x0f, 0xe1, 0x64, 0x29, - 0x49, 0x2c, 0x32, 0x10, 0x67, 0x29, 0x31, 0x00, 0x02, 0x00, 0x00, 0xff, 0xff, 0xd4, 0x73, 0xe1, - 0xb8, 0xfe, 0x00, 0x00, 0x00, + // 215 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x50, 0x31, 0x4b, 0xc5, 0x30, + 0x18, 0x34, 0x8d, 0x0a, 0xfd, 0xec, 0x50, 0x82, 0x48, 0x14, 0x84, 0x90, 0x29, 0x53, 0x07, 0x71, + 0x74, 0x73, 0x28, 0xb8, 0x19, 0x37, 0x17, 0x69, 0xcb, 0x87, 0x16, 0xb5, 0x49, 0x9b, 0x44, 0x7c, + 0xff, 0xfe, 0x91, 0xe6, 0x41, 0x3a, 0xbc, 0xe5, 0x6d, 0xb9, 0x3b, 0x2e, 0x77, 0xdf, 0x41, 0xfd, + 0xb9, 0x98, 0x60, 0x87, 0x6e, 0xf8, 0xc2, 0xc6, 0x2e, 0xc6, 0x1b, 0x56, 0x65, 0xc6, 0xf6, 0xf2, + 0x11, 0xa0, 0x45, 0xaf, 0x71, 0x0e, 0xe8, 0x3c, 0xbb, 0x86, 0x8b, 0x55, 0xe5, 0x44, 0x14, 0xaa, + 0xd4, 0x09, 0xb0, 0x1a, 0xe8, 0x37, 0xee, 0x78, 0xb1, 0x72, 0xf1, 0x29, 0xdf, 0xe1, 0x6a, 0x75, + 0x39, 0x6b, 0x26, 0x87, 0xd1, 0xf6, 0xd7, 0xfd, 0x04, 0xe4, 0x44, 0x10, 0x55, 0xe9, 0x04, 0xd8, + 0x3d, 0xc0, 0xef, 0x38, 0x05, 0x8f, 0x1f, 0xb3, 0x75, 0xbc, 0x10, 0x44, 0x11, 0x5d, 0x26, 0xe6, + 0xd5, 0x3a, 0x76, 0x03, 0x97, 0xf8, 0x6f, 0xc7, 0x05, 0x39, 0x15, 0x44, 0x51, 0x7d, 0x40, 0xb2, + 0x07, 0x78, 0x3b, 0xb9, 0x51, 0xae, 0x40, 0xb7, 0x15, 0x72, 0xc6, 0xf9, 0x36, 0xe3, 0xe1, 0x05, + 0xa0, 0x8d, 0x1f, 0x3d, 0xc7, 0x15, 0xd8, 0x13, 0xd0, 0x16, 0x3d, 0xe3, 0xcd, 0x76, 0x99, 0x26, + 0xcf, 0x72, 0x77, 0x7b, 0x44, 0x49, 0xa7, 0xcb, 0xb3, 0x7d, 0x00, 0x00, 0x00, 0xff, 0xff, 0x02, + 0x10, 0x64, 0xec, 0x62, 0x01, 0x00, 0x00, } diff --git a/groupcachepb/groupcache.proto b/groupcachepb/groupcache.proto index c14ab2c..a24b410 100644 --- a/groupcachepb/groupcache.proto +++ b/groupcachepb/groupcache.proto @@ -29,6 +29,13 @@ message GetResponse { optional int64 expire = 3; } +message SetRequest { + required string group = 1; + required string key = 2; + optional bytes value = 3; + optional int64 expire = 4; +} + service GroupCache { rpc Get(GetRequest) returns (GetResponse) { }; diff --git a/http.go b/http.go index 9e9d4ee..5a9d88b 100644 --- a/http.go +++ b/http.go @@ -26,6 +26,7 @@ import ( "net/url" "strings" "sync" + "time" "github.com/golang/protobuf/proto" "github.com/mailgun/groupcache/v2/consistenthash" @@ -191,6 +192,34 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // The read the body and set the key value + if r.Method == http.MethodPut { + defer r.Body.Close() + b := bufferPool.Get().(*bytes.Buffer) + b.Reset() + defer bufferPool.Put(b) + _, err := io.Copy(b, r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var out pb.SetRequest + err = proto.Unmarshal(b.Bytes(), &out) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var expire time.Time + if out.Expire != nil && *out.Expire != 0 { + expire = time.Unix(*out.Expire/int64(time.Second), *out.Expire%int64(time.Second)) + } + + group.localSet(*out.Key, out.Value, expire, &group.mainCache) + return + } + var b []byte value := AllocatingByteSliceSink(&b) @@ -225,7 +254,6 @@ type httpGetter struct { baseURL string } -// GetURL func (p *httpGetter) GetURL() string { return p.baseURL } @@ -234,14 +262,20 @@ var bufferPool = sync.Pool{ New: func() interface{} { return new(bytes.Buffer) }, } -func (h *httpGetter) makeRequest(ctx context.Context, method string, in *pb.GetRequest, out *http.Response) error { +type request interface { + GetGroup() string + GetKey() string +} + +func (h *httpGetter) makeRequest(ctx context.Context, m string, in request, b io.Reader, out *http.Response) error { u := fmt.Sprintf( "%v%v/%v", h.baseURL, url.QueryEscape(in.GetGroup()), url.QueryEscape(in.GetKey()), ) - req, err := http.NewRequest(method, u, nil) + + req, err := http.NewRequest(m, u, b) if err != nil { return err } @@ -264,7 +298,7 @@ func (h *httpGetter) makeRequest(ctx context.Context, method string, in *pb.GetR func (h *httpGetter) Get(ctx context.Context, in *pb.GetRequest, out *pb.GetResponse) error { var res http.Response - if err := h.makeRequest(ctx, http.MethodGet, in, &res); err != nil { + if err := h.makeRequest(ctx, http.MethodGet, in, nil, &res); err != nil { return err } defer res.Body.Close() @@ -285,9 +319,13 @@ func (h *httpGetter) Get(ctx context.Context, in *pb.GetRequest, out *pb.GetResp return nil } -func (h *httpGetter) Set(ctx context.Context, in *pb.GetRequest) error { +func (h *httpGetter) Set(ctx context.Context, in *pb.SetRequest) error { + body, err := proto.Marshal(in) + if err != nil { + return fmt.Errorf("while marshaling SetRequest body: %w", err) + } var res http.Response - if err := h.makeRequest(ctx, http.MethodPut, in, &res); err != nil { + if err := h.makeRequest(ctx, http.MethodPut, in, bytes.NewReader(body), &res); err != nil { return err } defer res.Body.Close() @@ -304,7 +342,7 @@ func (h *httpGetter) Set(ctx context.Context, in *pb.GetRequest) error { func (h *httpGetter) Remove(ctx context.Context, in *pb.GetRequest) error { var res http.Response - if err := h.makeRequest(ctx, http.MethodDelete, in, &res); err != nil { + if err := h.makeRequest(ctx, http.MethodDelete, in, nil, &res); err != nil { return err } defer res.Body.Close() diff --git a/http_test.go b/http_test.go index 55e6aef..86b6d34 100644 --- a/http_test.go +++ b/http_test.go @@ -17,6 +17,7 @@ limitations under the License. package groupcache import ( + "bytes" "context" "errors" "flag" @@ -75,6 +76,7 @@ func TestHTTPPool(t *testing.T) { "--test_server_addr="+ts.URL, ) cmds = append(cmds, cmd) + cmd.Stdout = os.Stdout wg.Add(1) if err := cmd.Start(); err != nil { t.Fatal("failed to start child process: ", err) @@ -152,20 +154,24 @@ func TestHTTPPool(t *testing.T) { t.Error("expected serverHits to be '2'") } - key = "setTestKey" - setValue := "test set" - var getValue string - // Add the key to the cache - if err := g.Set(ctx, key, []byte(setValue)); err != nil { + key = "setMyTestKey" + setValue := []byte("test set") + // Add the key to the cache, optionally updating our local hot cache + if err := g.Set(ctx, key, setValue, time.Time{}, false); err != nil { t.Fatal(err) } // Get the key - if err := g.Get(ctx, key, StringSink(&getValue)); err != nil { + var getValue ByteView + if err := g.Get(ctx, key, ByteViewSink(&getValue)); err != nil { t.Fatal(err) } - if setValue != getValue { + if serverHits != 2 { + t.Errorf("expected serverHits to be '3' got '%d'", serverHits) + } + + if !bytes.Equal(setValue, getValue.ByteSlice()) { t.Fatal(errors.New(fmt.Sprintf("incorrect value retrieved after set: %s", getValue))) } } diff --git a/peers.go b/peers.go index 299d060..39fd76a 100644 --- a/peers.go +++ b/peers.go @@ -28,7 +28,7 @@ import ( type ProtoGetter interface { Get(context context.Context, in *pb.GetRequest, out *pb.GetResponse) error Remove(context context.Context, in *pb.GetRequest) error - Set(context context.Context, in *pb.GetRequest) error + Set(context context.Context, in *pb.SetRequest) error // GetURL returns the peer URL GetURL() string } diff --git a/sinks.go b/sinks.go index 1c04ee8..894fecf 100644 --- a/sinks.go +++ b/sinks.go @@ -170,7 +170,6 @@ func ProtoSink(m proto.Message) Sink { type protoSink struct { dst proto.Message // authoritative value typ string - ttl time.Duration v ByteView // encoded }