diff --git a/cmd/camget/camget.go b/cmd/camget/camget.go index 9d7f1ac43..88650d6f5 100644 --- a/cmd/camget/camget.go +++ b/cmd/camget/camget.go @@ -125,6 +125,19 @@ func main() { if *flagVerbose { log.Printf("Using temp blob cache directory %s", diskCacheFetcher.Root) } + if *flagShared != "" { + diskCacheFetcher.SetCacheHitHook(func(br blob.Ref, rc io.ReadCloser) (io.ReadCloser, error) { + var buf bytes.Buffer + if err := cl.UpdateShareChain(br, io.TeeReader(rc, &buf)); err != nil { + rc.Close() + return nil, err + } + return struct { + io.Reader + io.Closer + }{io.MultiReader(&buf, rc), rc}, nil + }) + } for _, br := range items { if *flagGraph { diff --git a/pkg/cacher/cacher.go b/pkg/cacher/cacher.go index 5ae03b215..6bcab97fb 100644 --- a/pkg/cacher/cacher.go +++ b/pkg/cacher/cacher.go @@ -42,13 +42,35 @@ func NewCachingFetcher(cache blobserver.Cache, fetcher blob.Fetcher) *CachingFet type CachingFetcher struct { c blobserver.Cache sf blob.Fetcher + // cacheHitHook, if set, is called right after a cache hit. It is meant to add + // potential side-effects from calling the Fetcher that would have happened + // if we had had a cache miss. It is the responsibility of cacheHitHook to return + // a ReadCloser equivalent to the state that rc was given in. + cacheHitHook func(br blob.Ref, rc io.ReadCloser) (io.ReadCloser, error) g singleflight.Group } -func (cf *CachingFetcher) Fetch(br blob.Ref) (file io.ReadCloser, size uint32, err error) { - file, size, err = cf.c.Fetch(br) +// SetCacheHitHook sets a function that will modify the return values from Fetch +// in the case of a cache hit. +// Its purpose is to add potential side-effects from calling the Fetcher that would +// have happened if we had had a cache miss. It is the responsibility of fn to +// return a ReadCloser equivalent to the state that rc was given in. +func (cf *CachingFetcher) SetCacheHitHook(fn func(br blob.Ref, rc io.ReadCloser) (io.ReadCloser, error)) { + cf.cacheHitHook = fn +} + +func (cf *CachingFetcher) Fetch(br blob.Ref) (content io.ReadCloser, size uint32, err error) { + content, size, err = cf.c.Fetch(br) if err == nil { + if cf.cacheHitHook != nil { + rc, err := cf.cacheHitHook(br, content) + if err != nil { + content.Close() + return nil, 0, err + } + content = rc + } return } if err = cf.faultIn(br); err != nil { diff --git a/pkg/client/client.go b/pkg/client/client.go index 5b33ec6ab..860563942 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -126,7 +126,8 @@ type Client struct { // via maps the access path from a share root to a desired target. // It is non-nil when in "sharing" mode, where the Client is fetching // a share. - via map[string]string // target => via (target is referenced from via) + viaMu sync.RWMutex + via map[blob.Ref]blob.Ref // target => via (target is referenced from via) log *log.Logger // not nil httpGate *syncutil.Gate @@ -319,7 +320,7 @@ func NewFromShareRoot(shareBlobURL string, opts ...ClientOption) (c *Client, tar c.prefixv = m[1] c.isSharePrefix = true c.authMode = auth.None{} - c.via = make(map[string]string) + c.via = make(map[blob.Ref]blob.Ref) root = m[2] req := c.newRequest("GET", shareBlobURL, nil) @@ -329,7 +330,11 @@ func NewFromShareRoot(shareBlobURL string, opts ...ClientOption) (c *Client, tar } defer res.Body.Close() var buf bytes.Buffer - b, err := schema.BlobFromReader(blob.ParseOrZero(root), io.TeeReader(res.Body, &buf)) + rootbr, ok := blob.Parse(root) + if !ok { + return nil, blob.Ref{}, fmt.Errorf("invalid root blob ref for sharing: %q", root) + } + b, err := schema.BlobFromReader(rootbr, io.TeeReader(res.Body, &buf)) if err != nil { return nil, blob.Ref{}, fmt.Errorf("error parsing JSON from %s: %v , with response: %q", shareBlobURL, err, buf.Bytes()) } @@ -340,7 +345,7 @@ func NewFromShareRoot(shareBlobURL string, opts ...ClientOption) (c *Client, tar if !target.Valid() { return nil, blob.Ref{}, fmt.Errorf("no target.") } - c.via[target.String()] = root + c.via[target] = rootbr return c, target, nil } diff --git a/pkg/client/get.go b/pkg/client/get.go index c3beccbc8..212fbd985 100644 --- a/pkg/client/get.go +++ b/pkg/client/get.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "log" "math" "net/http" "os" @@ -44,22 +45,21 @@ func (c *Client) FetchSchemaBlob(b blob.Ref) (*schema.Blob, error) { } func (c *Client) Fetch(b blob.Ref) (io.ReadCloser, uint32, error) { - return c.FetchVia(b, c.viaPathTo(b)) + return c.fetchVia(b, c.viaPathTo(b)) } func (c *Client) viaPathTo(b blob.Ref) (path []blob.Ref) { - if c.via == nil { - return nil - } - it := b.String() + c.viaMu.RLock() + defer c.viaMu.RUnlock() // Append path backwards first, + key := b for { - v := c.via[it] - if v == "" { + v, ok := c.via[key] + if !ok { break } - path = append(path, blob.MustParse(v)) - it = v + key = v + path = append(path, key) } // Then reverse it for i := 0; i < len(path)/2; i++ { @@ -70,7 +70,7 @@ func (c *Client) viaPathTo(b blob.Ref) (path []blob.Ref) { var blobsRx = regexp.MustCompile(blob.Pattern) -func (c *Client) FetchVia(b blob.Ref, v []blob.Ref) (body io.ReadCloser, size uint32, err error) { +func (c *Client) fetchVia(b blob.Ref, v []blob.Ref) (body io.ReadCloser, size uint32, err error) { if c.sto != nil { if len(v) > 0 { return nil, 0, errors.New("FetchVia not supported in non-HTTP mode") @@ -113,8 +113,7 @@ func (c *Client) FetchVia(b blob.Ref, v []blob.Ref) (body io.ReadCloser, size ui return nil, 0, fmt.Errorf("Got status code %d from blobserver for %s", resp.StatusCode, b) } - var buf bytes.Buffer - var reader io.Reader = io.MultiReader(&buf, resp.Body) + var reader io.Reader = resp.Body var closer io.Closer = resp.Body if resp.ContentLength > 0 { if resp.ContentLength > math.MaxUint32 { @@ -122,6 +121,7 @@ func (c *Client) FetchVia(b blob.Ref, v []blob.Ref) (body io.ReadCloser, size ui } size = uint32(resp.ContentLength) } else { + var buf bytes.Buffer size = 0 // Might be compressed. Slurp it to memory. n, err := io.CopyN(&buf, resp.Body, constants.MaxBlobSize+1) @@ -138,31 +138,54 @@ func (c *Client) FetchVia(b blob.Ref, v []blob.Ref) (body io.ReadCloser, size ui } } + var buf bytes.Buffer + if err := c.UpdateShareChain(b, io.TeeReader(reader, &buf)); err != nil { + if err != ErrNotSharing { + return nil, 0, err + } + } + mr := io.MultiReader(&buf, reader) var rc io.ReadCloser = struct { io.Reader io.Closer - }{reader, closer} + }{mr, closer} + return rc, size, nil +} + +// ErrNotSharing is returned when a client that was not created with +// NewFromShareRoot tries to access shared blobs. +var ErrNotSharing = errors.New("Client can not deal with shared blobs. Create it with NewFromShareRoot.") + +// UpdateShareChain reads the schema of b from r, and instructs the client that +// all blob refs found in this schema should use b as a preceding chain link, in +// all subsequent shared blobs fetches. If the client was not created with +// NewFromShareRoot, ErrNotSharing is returned. +func (c *Client) UpdateShareChain(b blob.Ref, r io.Reader) error { + c.viaMu.Lock() + defer c.viaMu.Unlock() if c.via == nil { // Not in sharing mode, so return immediately. - return rc, size, nil + return ErrNotSharing } - // Slurp 1 MB to find references to other blobrefs for the via path. - if buf.Len() == 0 { - const maxSlurp = 1 << 20 - _, err = io.Copy(&buf, io.LimitReader(resp.Body, maxSlurp)) - if err != nil { - return nil, 0, err - } + var buf bytes.Buffer + const maxSlurp = 1 << 20 + if _, err := io.Copy(&buf, io.LimitReader(r, maxSlurp)); err != nil { + return err } // If it looks like a JSON schema blob (starts with '{') if schema.LikelySchemaBlob(buf.Bytes()) { for _, blobstr := range blobsRx.FindAllString(buf.String(), -1) { - c.via[blobstr] = b.String() + br, ok := blob.Parse(blobstr) + if !ok { + log.Printf("Invalid blob ref %q noticed in schema of %v", blobstr, b) + continue + } + c.via[br] = b } } - return rc, size, nil + return nil } func (c *Client) ReceiveBlob(br blob.Ref, source io.Reader) (blob.SizedRef, error) {