diff --git a/build.pl b/build.pl index c54e52171..177a32529 100755 --- a/build.pl +++ b/build.pl @@ -307,6 +307,7 @@ TARGET: lib/go/blobserver/handlers - server/go/httputil - lib/go/blobserver - lib/go/httprange + - lib/go/testing TARGET: lib/go/blobserver/localdisk - lib/go/blobref diff --git a/doc/protocol/blob-enumerate-protocol.txt b/doc/protocol/blob-enumerate-protocol.txt index 7245fe016..b4d50e471 100644 --- a/doc/protocol/blob-enumerate-protocol.txt +++ b/doc/protocol/blob-enumerate-protocol.txt @@ -26,7 +26,8 @@ URL GET parameters: feature), then the server will return immediately if any blobs or available, else it will wait for this number of seconds. - This option isn't supported with 'after'. + It is an error to send this option with a non- + zero value along with the 'after' option. The server's reply must include "canLongPoll" set to true if the server supports this feature. Even if the server @@ -35,7 +36,6 @@ URL GET parameters: requested by the client. - Response: HTTP/1.1 200 OK diff --git a/lib/go/blobserver/handlers/enumerate.go b/lib/go/blobserver/handlers/enumerate.go index 032776ee1..1924aa1e7 100644 --- a/lib/go/blobserver/handlers/enumerate.go +++ b/lib/go/blobserver/handlers/enumerate.go @@ -41,15 +41,33 @@ func CreateEnumerateHandler(storage blobserver.Storage, partition blobserver.Par } } +const errMsgMaxWaitSecWithAfter = "Can't use 'maxwaitsec' with 'after'.\n" + func handleEnumerateBlobs(conn http.ResponseWriter, req *http.Request, storage blobserver.BlobEnumerator, partition blobserver.Partition) { - limit, err := strconv.Atoui(req.FormValue("limit")) - if err != nil || limit > maxEnumerate { - limit = maxEnumerate + // Potential input parameters + formValueLimit := req.FormValue("limit") + formValueMaxWaitSec := req.FormValue("maxwaitsec") + formValueAfter := req.FormValue("after") + + var ( + err os.Error + limit uint + ) + if formValueLimit != "" { + limit, err = strconv.Atoui(formValueLimit) + if err != nil || limit > maxEnumerate { + limit = maxEnumerate + } } waitSeconds := 0 - if waitStr := req.FormValue("maxwaitsec"); waitStr != "" { - waitSeconds, _ = strconv.Atoi(waitStr) + if formValueMaxWaitSec != "" { + waitSeconds, _ = strconv.Atoi(formValueMaxWaitSec) + if waitSeconds != 0 && formValueAfter != "" { + conn.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(conn, errMsgMaxWaitSecWithAfter) + return + } switch { case waitSeconds < 0: waitSeconds = 0 @@ -67,7 +85,7 @@ func handleEnumerateBlobs(conn http.ResponseWriter, req *http.Request, storage b blobch := make(chan *blobref.SizedBlobRef, 100) resultch := make(chan os.Error, 1) go func() { - resultch <- storage.EnumerateBlobs(blobch, partition, req.FormValue("after"), limit+1, waitSeconds) + resultch <- storage.EnumerateBlobs(blobch, partition, formValueAfter, limit+1, waitSeconds) }() after := "" diff --git a/lib/go/blobserver/handlers/enumerate_test.go b/lib/go/blobserver/handlers/enumerate_test.go new file mode 100644 index 000000000..0872a3285 --- /dev/null +++ b/lib/go/blobserver/handlers/enumerate_test.go @@ -0,0 +1,136 @@ +/* +Copyright 2011 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "bufio" + "bytes" + "camli/blobref" + "camli/blobserver" + . "camli/testing" + "http" + "io" + "os" + "testing" +) + +type responseWriterMethodCall struct { + method string + headerKey, headerValue string // if method == "SetHeader" + bytesWritten []byte // if method == "Write" + responseCode int // if method == "WriteHeader" +} + +type recordingResponseWriter struct { + log []*responseWriterMethodCall + status int + output *bytes.Buffer +} + +func (rw *recordingResponseWriter) RemoteAddr() string { + return "1.2.3.4" +} + +func (rw *recordingResponseWriter) UsingTLS() bool { + return false +} + +func (rw *recordingResponseWriter) SetHeader(k, v string) { + rw.log = append(rw.log, &responseWriterMethodCall{method: "SetHeader", headerKey: k, headerValue: v}) +} + +func (rw *recordingResponseWriter) Write(buf []byte) (int, os.Error) { + rw.log = append(rw.log, &responseWriterMethodCall{method: "Write", bytesWritten: buf}) + rw.output.Write(buf) + if rw.status == 0 { + rw.status = 200 + } + return len(buf), nil +} + +func (rw *recordingResponseWriter) WriteHeader(code int) { + rw.log = append(rw.log, &responseWriterMethodCall{method: "WriteHeader", responseCode: code}) + rw.status = code +} + +func (rw *recordingResponseWriter) Flush() { + rw.log = append(rw.log, &responseWriterMethodCall{method: "Flush"}) +} + +func (rw *recordingResponseWriter) Hijack() (io.ReadWriteCloser, *bufio.ReadWriter, os.Error) { + panic("Not supported") +} + +func NewRecordingWriter() *recordingResponseWriter { + return &recordingResponseWriter{ + output: &bytes.Buffer{}, + } +} + +func makeGetRequest(url string) *http.Request { + req := &http.Request{ + Method: "GET", + RawURL: url, + } + var err os.Error + req.URL, err = http.ParseURL(url) + if err != nil { + panic("Error parsing url: " + url) + } + return req +} + +type emptyEnumerator struct { +} + +func (ee *emptyEnumerator) EnumerateBlobs(dest chan *blobref.SizedBlobRef, + partition blobserver.Partition, + after string, + limit uint, + waitSeconds int) os.Error { + dest <- nil + return nil +} + +type enumerateInputTest struct { + name string + url string + expectedCode int + expectedBody string +} + +func TestEnumerateInput(t *testing.T) { + enumerator := &emptyEnumerator{} + + emptyOutput := "{\n \"blobs\": [\n\n ],\n \"canLongPoll\": true\n}\n" + + tests := []enumerateInputTest{ + {"no 'after' with 'maxwaitsec'", + "http://example.com/camli/enumerate-blobs?after=foo&maxwaitsec=1", 400, + errMsgMaxWaitSecWithAfter}, + {"'maxwaitsec' of 0 is okay with 'after'", + "http://example.com/camli/enumerate-blobs?after=foo&maxwaitsec=0", 200, + emptyOutput}, + } + for _, test := range tests { + wr := NewRecordingWriter() + req := makeGetRequest(test.url) + handleEnumerateBlobs(wr, req, enumerator, blobserver.DefaultPartition) + ExpectInt(t, test.expectedCode, wr.status, "response code for " + test.name) + ExpectString(t, test.expectedBody, wr.output.String(), "output for " + test.name) + } +} diff --git a/lib/go/testing/testing.go b/lib/go/testing/testing.go index 1607c9702..ac35b4a0e 100644 --- a/lib/go/testing/testing.go +++ b/lib/go/testing/testing.go @@ -34,13 +34,13 @@ func Assert(t *testing.T, got bool, what string) { func ExpectString(t *testing.T, expect, got string, what string) { if expect != got { - t.Errorf("%s: got %v; expected %v", what, got, expect) + t.Errorf("%s: got %q; expected %q", what, got, expect) } } func AssertString(t *testing.T, expect, got string, what string) { if expect != got { - t.Fatalf("%s: got %v; expected %v", what, got, expect) + t.Fatalf("%s: got %q; expected %q", what, got, expect) } }