diff --git a/http.go b/http.go new file mode 100644 index 0000000..db3ac59 --- /dev/null +++ b/http.go @@ -0,0 +1,187 @@ +/* +Copyright 2013 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package groupcache + +import ( + "fmt" + "hash/crc32" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + + "code.google.com/p/goprotobuf/proto" + + pb "github.com/golang/groupcache/groupcachepb" +) + +// TODO: make this configurable? +const defaultBasePath = "/_groupcache/" + +// HTTPPool implements PeerPicker for a pool of HTTP peers. +type HTTPPool struct { + // Context optionally specifies a context for the server to use when it + // receives a request. + // If nil, the server uses a nil Context. + Context func(*http.Request) Context + + // Transport optionally specifies an http.RoundTripper for the client + // to use when it makes a request. + // If nil, the client uses http.DefaultTransport. + Transport func(Context) http.RoundTripper + + // base path including leading and trailing slash, e.g. "/_groupcache/" + basePath string + + // this peer's base URL, e.g. "https://example.net:8000" + self string + + mu sync.Mutex + peers []string +} + +var httpPoolMade bool + +// NewHTTPPool initializes an HTTP pool of peers. +// It registers itself as a PeerPicker and as an HTTP handler with the +// http.DefaultServeMux. +// The self argument be a valid base URL that points to the current server, +// for example "http://example.net:8000". +func NewHTTPPool(self string) *HTTPPool { + if httpPoolMade { + panic("groupcache: NewHTTPPool must be called only once") + } + httpPoolMade = true + p := &HTTPPool{basePath: defaultBasePath, self: self} + RegisterPeerPicker(func() PeerPicker { return p }) + http.Handle(defaultBasePath, p) + return p +} + +// Set updates the pool's list of peers. +// Each peer value should be a valid base URL, +// for example "http://example.net:8000". +func (p *HTTPPool) Set(peers ...string) { + p.mu.Lock() + defer p.mu.Unlock() + p.peers = append([]string{}, peers...) +} + +func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) { + // TODO: make checksum implementation pluggable + h := crc32.Checksum([]byte(key), crc32.IEEETable) + p.mu.Lock() + defer p.mu.Unlock() + if len(p.peers) == 0 { + return nil, false + } + if peer := p.peers[int(h)%len(p.peers)]; peer != p.self { + // TODO: pre-build a slice of *httpGetter when Set() + // is called to avoid these two allocations. + return &httpGetter{p.Transport, peer + p.basePath}, true + } + return nil, false +} + +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Parse request. + if !strings.HasPrefix(r.URL.Path, p.basePath) { + panic("HTTPPool serving unexpected path: " + r.URL.Path) + } + parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) + if len(parts) != 2 { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + groupName, err := url.QueryUnescape(parts[0]) + if err != nil { + http.Error(w, "decoding group: "+err.Error(), http.StatusBadRequest) + return + } + key, err := url.QueryUnescape(parts[1]) + if err != nil { + http.Error(w, "decoding key: "+err.Error(), http.StatusBadRequest) + return + } + + // Fetch the value for this group/key. + group := GetGroup(groupName) + if group == nil { + http.Error(w, "no such group: "+groupName, http.StatusNotFound) + return + } + var ctx Context + if p.Context != nil { + ctx = p.Context(r) + } + var value []byte + err = group.Get(ctx, key, AllocatingByteSliceSink(&value)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Write the value to the response body as a proto message. + body, err := proto.Marshal(&pb.GetResponse{Value: value}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/x-protobuf") + w.Write(body) +} + +type httpGetter struct { + transport func(Context) http.RoundTripper + baseURL string +} + +func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error { + u := fmt.Sprintf( + "%v%v/%v", + h.baseURL, + url.QueryEscape(in.GetGroup()), + url.QueryEscape(in.GetKey()), + ) + req, err := http.NewRequest("GET", u, nil) + if err != nil { + return err + } + tr := http.DefaultTransport + if h.transport != nil { + tr = h.transport(context) + } + res, err := tr.RoundTrip(req) + if err != nil { + return err + } + if res.StatusCode != http.StatusOK { + return fmt.Errorf("server returned: %v", res.Status) + } + defer res.Body.Close() + // TODO: avoid this garbage. + b, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("reading response body: %v", err) + } + err = proto.Unmarshal(b, out) + if err != nil { + return fmt.Errorf("decoding response body: %v", err) + } + return nil +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..279bcbf --- /dev/null +++ b/http_test.go @@ -0,0 +1,166 @@ +/* +Copyright 2013 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package groupcache + +import ( + "errors" + "flag" + "log" + "net" + "net/http" + "os" + "os/exec" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +var ( + peerAddrs = flag.String("test_peer_addrs", "", "Comma-separated list of peer addresses; used by TestHTTPPool") + peerIndex = flag.Int("test_peer_index", -1, "Index of which peer this child is; used by TestHTTPPool") + peerChild = flag.Bool("test_peer_child", false, "True if running as a child process; used by TestHTTPPool") +) + +func TestHTTPPool(t *testing.T) { + if *peerChild { + beChildForTestHTTPPool() + os.Exit(0) + } + + const ( + nChild = 4 + nGets = 100 + ) + + var childAddr []string + for i := 0; i < nChild; i++ { + childAddr = append(childAddr, pickFreeAddr(t)) + } + + var cmds []*exec.Cmd + var wg sync.WaitGroup + for i := 0; i < nChild; i++ { + cmd := exec.Command(os.Args[0], + "--test.run=TestHTTPPool", + "--test_peer_child", + "--test_peer_addrs="+strings.Join(childAddr, ","), + "--test_peer_index="+strconv.Itoa(i), + ) + cmds = append(cmds, cmd) + wg.Add(1) + if err := cmd.Start(); err != nil { + t.Fatal("failed to start child process: ", err) + } + go awaitAddrReady(t, childAddr[i], &wg) + } + defer func() { + for i := 0; i < nChild; i++ { + if cmds[i].Process != nil { + cmds[i].Process.Kill() + } + } + }() + wg.Wait() + + // Use a dummy self address so that we don't handle gets in-process. + p := NewHTTPPool("should-be-ignored") + p.Set(addrToURL(childAddr)...) + + // Dummy getter function. Gets should go to children only. + // The only time this process will handle a get is when the + // children can't be contacted for seome reason. + getter := GetterFunc(func(ctx Context, key string, dest Sink) error { + return errors.New("parent getter called; something's wrong") + }) + g := NewGroup("httpPoolTest", 1<<20, getter) + + for _, key := range testKeys(nGets) { + var value string + if err := g.Get(nil, key, StringSink(&value)); err != nil { + t.Fatal(err) + } + if suffix := ":" + key; !strings.HasSuffix(value, suffix) { + t.Errorf("Get(%q) = %q, want value ending in %q", key, value, suffix) + } + t.Logf("Get key=%q, value=%q (peer:key)", key, value) + } +} + +func testKeys(n int) (keys []string) { + keys = make([]string, n) + for i := range keys { + keys[i] = strconv.Itoa(i) + } + return +} + +func beChildForTestHTTPPool() { + addrs := strings.Split(*peerAddrs, ",") + + p := NewHTTPPool("http://" + addrs[*peerIndex]) + p.Set(addrToURL(addrs)...) + + getter := GetterFunc(func(ctx Context, key string, dest Sink) error { + dest.SetString(strconv.Itoa(*peerIndex) + ":" + key) + return nil + }) + NewGroup("httpPoolTest", 1<<20, getter) + + log.Fatal(http.ListenAndServe(addrs[*peerIndex], p)) +} + +// This is racy. Another process could swoop in and steal the port between the +// call to this function and the next listen call. Should be okay though. +// The proper way would be to pass the l.File() as ExtraFiles to the child +// process, and then close your copy once the child starts. +func pickFreeAddr(t *testing.T) string { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + return l.Addr().String() +} + +func addrToURL(addr []string) []string { + url := make([]string, len(addr)) + for i := range addr { + url[i] = "http://" + addr[i] + } + return url +} + +func awaitAddrReady(t *testing.T, addr string, wg *sync.WaitGroup) { + defer wg.Done() + const max = 1 * time.Second + tries := 0 + for { + tries++ + c, err := net.Dial("tcp", addr) + if err == nil { + c.Close() + return + } + delay := time.Duration(tries) * 25 * time.Millisecond + if delay > max { + delay = max + } + time.Sleep(delay) + } +}