From d875fcad3042d5f8a3cf1cd26f66174ba087da6e Mon Sep 17 00:00:00 2001 From: Aaron Boodman Date: Tue, 24 Sep 2013 01:38:38 -0700 Subject: [PATCH] Clean up testing in share_test.go Change-Id: I481d440636590086f11d8e97d0eb6fcf31f4a097 --- pkg/server/share.go | 128 ++++++++++++++++++++++++++------------- pkg/server/share_test.go | 108 ++++++++++++++++++--------------- 2 files changed, 145 insertions(+), 91 deletions(-) diff --git a/pkg/server/share.go b/pkg/server/share.go index 5162ec28c..8edaae388 100644 --- a/pkg/server/share.go +++ b/pkg/server/share.go @@ -37,6 +37,48 @@ import ( "camlistore.org/pkg/schema" ) +type responseType int + +const ( + badRequest responseType = iota + unauthorizedRequest +) + +type errorCode int + +const ( + assembleNonTransitive errorCode = iota + invalidMethod + invalidURL + invalidVia + shareBlobInvalid + shareBlobTooLarge + shareExpired + shareFetchFailed + shareReadFailed + shareTargetInvalid + shareNotTransitive + viaChainFetchFailed + viaChainInvalidLink + viaChainReadFailed +) + +type shareError struct { + code errorCode + response responseType + message string +} + +func (e *shareError) Error() string { + return e.message +} + +func unauthorized(code errorCode, format string, args ...interface{}) *shareError { + return &shareError{ + code: code, response: unauthorizedRequest, message: fmt.Sprintf(format, args...), + } +} + const fetchFailureDelay = 200 * time.Millisecond // ShareHandler handles the requests for "share" (and shared) blobs. @@ -72,10 +114,9 @@ func newShareFromConfig(ld blobserver.Loader, conf jsonconfig.Obj) (h http.Handl // Unauthenticated user. Be paranoid. func handleGetViaSharing(conn http.ResponseWriter, req *http.Request, - blobRef blob.Ref, fetcher blob.StreamingFetcher) { + blobRef blob.Ref, fetcher blob.StreamingFetcher) error { if !httputil.IsGet(req) { - httputil.BadRequestError(conn, "Invalid method") - return + return &shareError{code: invalidMethod, response: badRequest, message: "Invalid method"} } viaPathOkay := false @@ -94,8 +135,7 @@ func handleGetViaSharing(conn http.ResponseWriter, req *http.Request, if br, ok := blob.Parse(vs); ok { viaBlobs = append(viaBlobs, br) } else { - httputil.BadRequestError(conn, "Malformed blobref in via param") - return + return &shareError{code: invalidVia, response: badRequest, message: "Malformed blobref in via param"} } } } @@ -109,44 +149,31 @@ func handleGetViaSharing(conn http.ResponseWriter, req *http.Request, case 0: file, size, err := fetcher.FetchStreaming(br) if err != nil { - log.Printf("Fetch chain 0 of %s failed: %v", br.String(), err) - auth.SendUnauthorized(conn, req) - return + return unauthorized(shareFetchFailed, "Fetch chain 0 of %s failed: %v", br, err) } defer file.Close() if size > schema.MaxSchemaBlobSize { - log.Printf("Fetch chain 0 of %s too large", br.String()) - auth.SendUnauthorized(conn, req) - return + return unauthorized(shareBlobTooLarge, "Fetch chain 0 of %s too large", br) } blob, err := schema.BlobFromReader(br, file) if err != nil { - log.Printf("Can't create a blob from %v: %v", br.String(), err) - auth.SendUnauthorized(conn, req) - return + return unauthorized(shareReadFailed, "Can't create a blob from %v: %v", br, err) } share, ok := blob.AsShare() if !ok { - log.Printf("Fetch chain 0 of %s wasn't a valid Share", br.String()) - auth.SendUnauthorized(conn, req) - return + return unauthorized(shareBlobInvalid, "Fetch chain 0 of %s wasn't a valid Share", br) } if share.IsExpired() { - log.Print("Share is expired") - auth.SendUnauthorized(conn, req) - return + return unauthorized(shareExpired, "Share is expired") } if len(fetchChain) > 1 && fetchChain[1].String() != share.Target().String() { - log.Printf("Fetch chain 0->1 (%s -> %q) unauthorized, expected hop to %q", - br.String(), fetchChain[1].String(), share.Target().String()) - auth.SendUnauthorized(conn, req) - return + return unauthorized(shareTargetInvalid, + "Fetch chain 0->1 (%s -> %q) unauthorized, expected hop to %q", + br, fetchChain[1], share.Target()) } isTransitive = share.IsTransitive() if len(fetchChain) > 2 && !isTransitive { - log.Print("Share is not transitive") - auth.SendUnauthorized(conn, req) - return + return unauthorized(shareNotTransitive, "Share is not transitive") } case len(fetchChain) - 1: // Last one is fine (as long as its path up to here has been proven, and it's @@ -155,32 +182,26 @@ func handleGetViaSharing(conn http.ResponseWriter, req *http.Request, default: file, _, err := fetcher.FetchStreaming(br) if err != nil { - log.Printf("Fetch chain %d of %s failed: %v", i, br.String(), err) - auth.SendUnauthorized(conn, req) - return + return unauthorized(viaChainFetchFailed, "Fetch chain %d of %s failed: %v", i, br, err) } defer file.Close() lr := io.LimitReader(file, schema.MaxSchemaBlobSize) slurpBytes, err := ioutil.ReadAll(lr) if err != nil { - log.Printf("Fetch chain %d of %s failed in slurp: %v", i, br.String(), err) - auth.SendUnauthorized(conn, req) - return + return unauthorized(viaChainReadFailed, + "Fetch chain %d of %s failed in slurp: %v", i, br, err) } saught := fetchChain[i+1].String() if bytes.Index(slurpBytes, []byte(saught)) == -1 { - log.Printf("Fetch chain %d of %s failed; no reference to %s", - i, br.String(), saught) - auth.SendUnauthorized(conn, req) - return + return unauthorized(viaChainInvalidLink, + "Fetch chain %d of %s failed; no reference to %s", i, br, saught) } } } if assemble, _ := strconv.ParseBool(req.FormValue("assemble")); assemble { if !isTransitive { - auth.SendUnauthorized(conn, req) - return + return unauthorized(assembleNonTransitive, "Cannot assemble non-transitive share") } dh := &DownloadHandler{ Fetcher: fetcher, @@ -191,15 +212,36 @@ func handleGetViaSharing(conn http.ResponseWriter, req *http.Request, gethandler.ServeBlobRef(conn, req, blobRef, fetcher) } viaPathOkay = true + return nil } -func (h *shareHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (h *shareHandler) serveHTTP(rw http.ResponseWriter, req *http.Request) error { + var err error pathSuffix := httputil.PathSuffix(req) + if len(pathSuffix) == 0 { + // This happens during testing because we don't go through PrefixHandler + pathSuffix = strings.TrimLeft(req.URL.Path, "/") + } pathParts := strings.SplitN(pathSuffix, "/", 2) blobRef, ok := blob.Parse(pathParts[0]) if !ok { - http.Error(rw, fmt.Sprintf("Malformed share pathSuffix: %s", pathSuffix), 400) - return + err = &shareError{code: invalidURL, response: badRequest, + message: fmt.Sprintf("Malformed share pathSuffix: %s", pathSuffix)} + } else { + err = handleGetViaSharing(rw, req, blobRef, h.fetcher) } - handleGetViaSharing(rw, req, blobRef, h.fetcher) + if se, ok := err.(*shareError); ok { + switch se.response { + case badRequest: + httputil.BadRequestError(rw, err.Error()) + case unauthorizedRequest: + log.Print(err) + auth.SendUnauthorized(rw, req) + } + } + return err +} + +func (h *shareHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + h.serveHTTP(rw, req) } diff --git a/pkg/server/share_test.go b/pkg/server/share_test.go index 8cbb0a3ef..73fefa4cc 100644 --- a/pkg/server/share_test.go +++ b/pkg/server/share_test.go @@ -1,5 +1,5 @@ /* -Copyright 2013 Google Inc. +Copyright 2013 The Camlistore Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ limitations under the License. package server import ( - "log" + "fmt" "net/http" "net/http/httptest" "strings" @@ -25,25 +25,34 @@ import ( "time" "camlistore.org/pkg/blob" - "camlistore.org/pkg/httputil" + "camlistore.org/pkg/blobserver" "camlistore.org/pkg/schema" "camlistore.org/pkg/test" - . "camlistore.org/pkg/test/asserts" ) func TestHandleGetViaSharing(t *testing.T) { - // TODO(aa): It would be good if we could test that we are failing for - // the right reason for all of these (some kind of internal error code). - sto := &test.Fetcher{} - handler := &httputil.PrefixHandler{"/", &shareHandler{sto}} - wr := &httptest.ResponseRecorder{} + handler := &shareHandler{fetcher: sto} + var wr *httptest.ResponseRecorder - get := func(path string) *httptest.ResponseRecorder { + putRaw := func(ref blob.Ref, data string) { + if _, err := blobserver.Receive(sto, ref, strings.NewReader(data)); err != nil { + t.Fatal(err) + } + } + + put := func(blob *schema.Blob) { + putRaw(blob.BlobRef(), blob.JSON()) + } + + get := func(path string) *shareError { wr = httptest.NewRecorder() req, _ := http.NewRequest("GET", "http://unused/"+path, nil) - handler.ServeHTTP(wr, req) - return wr + err := handler.serveHTTP(wr, req) + if err != nil { + return err.(*shareError) + } + return nil } content := "monkey" @@ -59,53 +68,56 @@ func TestHandleGetViaSharing(t *testing.T) { SetSigner(blob.SHA1FromString("irrelevant")). SetRawStringField("camliSig", "alsounused") - log.Print("Should fail because first link does not exist") - get(share.Blob().BlobRef().String()) - ExpectInt(t, 401, wr.Code, "") + var err *shareError - log.Print("Should fail because share target does not match next link") - sto.ReceiveBlob(share.Blob().BlobRef(), strings.NewReader(share.Blob().JSON())) - get(contentRef.String() + "?via=" + share.Blob().BlobRef().String()) - ExpectInt(t, 401, wr.Code, "") + if err = get(share.Blob().BlobRef().String()); err == nil || err.code != shareFetchFailed { + t.Error("Expected missing blob error") + } - log.Print("Should fail because first link is not a share") - sto.ReceiveBlob(linkRef, strings.NewReader(link)) - get(linkRef.String()) - ExpectInt(t, 401, wr.Code, "") - log.Print("Should successfully fetch share") - get(share.Blob().BlobRef().String()) - ExpectInt(t, 200, wr.Code, "") + put(share.Blob()) + if err = get(fmt.Sprintf("%s?via=%s", contentRef, share.Blob().BlobRef())); err == nil || err.code != shareTargetInvalid { + t.Error("Expected invalid target error") + } - log.Print("Should successfully fetch link via share") - get(linkRef.String() + "?via=" + share.Blob().BlobRef().String()) - ExpectInt(t, 200, wr.Code, "") + putRaw(linkRef, link) + if err = get(linkRef.String()); err == nil || err.code != shareReadFailed { + t.Error("Expected invalid share blob error") + } - log.Print("Should fail because share is not transitive") - get(contentRef.String() + "?via=" + share.Blob().BlobRef().String() + "," + linkRef.String()) - ExpectInt(t, 401, wr.Code, "") + if err = get(share.Blob().BlobRef().String()); err != nil { + t.Error("Expected to successfully fetch share, but got: %s", err) + } + + if err = get(fmt.Sprintf("%s?via=%s", linkRef, share.Blob().BlobRef())); err != nil { + t.Error("Expected to successfully fetch link via share, but got: %s", err) + } + + if err = get(fmt.Sprintf("%s?via=%s,%s", contentRef, share.Blob().BlobRef(), linkRef)); err == nil || err.code != shareNotTransitive { + t.Error("Expected share not transitive error") + } - log.Print("Should fail because link content does not contain target") share.SetShareIsTransitive(true) - sto.ReceiveBlob(share.Blob().BlobRef(), strings.NewReader(share.Blob().JSON())) - get(linkRef.String() + "?via=" + share.Blob().BlobRef().String() + "," + linkRef.String()) - ExpectInt(t, 401, wr.Code, "") + put(share.Blob()) + if err = get(fmt.Sprintf("%s?via=%s,%s", linkRef, share.Blob().BlobRef(), linkRef)); err == nil || err.code != viaChainInvalidLink { + t.Error("Expected via chain invalid link err") + } - log.Print("Should successfully fetch content via link via share") - sto.ReceiveBlob(contentRef, strings.NewReader(content)) - get(contentRef.String() + "?via=" + share.Blob().BlobRef().String() + "," + linkRef.String()) - ExpectInt(t, 200, wr.Code, "") + putRaw(contentRef, content) + if err = get(fmt.Sprintf("%s?via=%s,%s", contentRef, share.Blob().BlobRef(), linkRef)); err != nil { + t.Error("Expected to succesfully fetch via link via share, but got: %s", err) + } - log.Print("Should fail because share is expired") share.SetShareExpiration(time.Now().Add(-time.Duration(10) * time.Minute)) - sto.ReceiveBlob(share.Blob().BlobRef(), strings.NewReader(share.Blob().JSON())) - get(contentRef.String() + "?via=" + share.Blob().BlobRef().String() + "," + linkRef.String()) - ExpectInt(t, 401, wr.Code, "") + put(share.Blob()) + if err = get(fmt.Sprintf("%s?via=%s,%s", contentRef, share.Blob().BlobRef(), linkRef)); err == nil || err.code != shareExpired { + t.Error("Expected share expired error") + } - log.Print("Should succeed because share has not expired") share.SetShareExpiration(time.Now().Add(time.Duration(10) * time.Minute)) - sto.ReceiveBlob(share.Blob().BlobRef(), strings.NewReader(share.Blob().JSON())) - get(contentRef.String() + "?via=" + share.Blob().BlobRef().String() + "," + linkRef.String()) - ExpectInt(t, 200, wr.Code, "") + put(share.Blob()) + if err = get(fmt.Sprintf("%s?via=%s,%s", contentRef, share.Blob().BlobRef(), linkRef)); err != nil { + t.Error("Expected to successfully fetch unexpired share, but got: %s", err) + } // TODO(aa): assemble }