diff --git a/http.go b/http.go index abc25a3..8f637d7 100644 --- a/http.go +++ b/http.go @@ -17,8 +17,9 @@ limitations under the License. package groupcache import ( + "bytes" "fmt" - "io/ioutil" + "io" "net/http" "net/url" "strings" @@ -51,8 +52,9 @@ type HTTPPool struct { // this peer's base URL, e.g. "https://example.net:8000" self string - mu sync.Mutex - peers *consistenthash.Map + mu sync.Mutex // guards peers and httpGetters + peers *consistenthash.Map + httpGetters map[string]*httpGetter // keyed by e.g. "http://10.0.0.2:8008" } // HTTPPoolOptions are the configurations of a HTTPPool. @@ -103,9 +105,10 @@ func NewHTTPPoolOpts(self string, o *HTTPPoolOptions) *HTTPPool { } p := &HTTPPool{ - basePath: opts.BasePath, - self: self, - peers: consistenthash.New(opts.Replicas, opts.HashFn), + basePath: opts.BasePath, + self: self, + peers: consistenthash.New(opts.Replicas, opts.HashFn), + httpGetters: make(map[string]*httpGetter), } RegisterPeerPicker(func() PeerPicker { return p }) return p @@ -119,6 +122,10 @@ func (p *HTTPPool) Set(peers ...string) { defer p.mu.Unlock() p.peers = consistenthash.New(defaultReplicas, nil) p.peers.Add(peers...) + p.httpGetters = make(map[string]*httpGetter, len(peers)) + for _, peer := range peers { + p.httpGetters[peer] = &httpGetter{transport: p.Transport, baseURL: peer + p.basePath} + } } func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) { @@ -128,9 +135,7 @@ func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) { return nil, false } if peer := p.peers.Get(key); peer != p.self { - // TODO: pre-build a slice of *httpGetter when Set() - // is called to avoid these two allocations. - return &httpGetter{p.Transport, peer + p.basePath}, true + return p.httpGetters[peer], true } return nil, false } @@ -182,6 +187,10 @@ type httpGetter struct { baseURL string } +var bufferPool = sync.Pool{ + New: func() interface{} { return new(bytes.Buffer) }, +} + func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error { u := fmt.Sprintf( "%v%v/%v", @@ -205,12 +214,14 @@ func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse if res.StatusCode != http.StatusOK { return fmt.Errorf("server returned: %v", res.Status) } - // TODO: avoid this garbage. - b, err := ioutil.ReadAll(res.Body) + b := bufferPool.Get().(*bytes.Buffer) + b.Reset() + defer bufferPool.Put(b) + _, err = io.Copy(b, res.Body) if err != nil { return fmt.Errorf("reading response body: %v", err) } - err = proto.Unmarshal(b, out) + err = proto.Unmarshal(b.Bytes(), out) if err != nil { return fmt.Errorf("decoding response body: %v", err) }