Update goauth2 to f06a85362aa5

Change-Id: I581d449099b6201dc78593c3394fa3ae0954e0c3
changeset:   75:f06a85362aa5
tag:         tip
user:        Brad Fitzpatrick <bradfitz@golang.org>
date:        Tue Aug 12 13:58:32 2014 -0700
summary:     oauth: clean up docs, code, fix data race, don't send client_secret in two places
This commit is contained in:
Brad Fitzpatrick 2014-08-12 14:03:22 -07:00
parent a1fc7e5aea
commit 1b22acca30
2 changed files with 95 additions and 44 deletions

View File

@ -2,8 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// The oauth package provides support for making // Package oauth supports making OAuth2-authenticated HTTP requests.
// OAuth2-authenticated HTTP requests.
// //
// Example usage: // Example usage:
// //
@ -39,15 +38,23 @@ package oauth
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt"
"io"
"io/ioutil" "io/ioutil"
"mime" "mime"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strconv"
"strings" "strings"
"sync"
"time" "time"
) )
// OAuthError is the error type returned by many operations.
//
// In retrospect it should not exist. Don't depend on it.
type OAuthError struct { type OAuthError struct {
prefix string prefix string
msg string msg string
@ -123,7 +130,16 @@ type Config struct {
// TokenCache allows tokens to be cached for subsequent requests. // TokenCache allows tokens to be cached for subsequent requests.
TokenCache Cache TokenCache Cache
AccessType string // Optional, "online" (default) or "offline", no refresh token if "online" // AccessType is an OAuth extension that gets sent as the
// "access_type" field in the URL from AuthCodeURL.
// See https://developers.google.com/accounts/docs/OAuth2WebServer.
// It may be "online" (the default) or "offline".
// If your application needs to refresh access tokens when the
// user is not present at the browser, then use offline. This
// will result in your application obtaining a refresh token
// the first time your application exchanges an authorization
// code for a user.
AccessType string
// ApprovalPrompt indicates whether the user should be // ApprovalPrompt indicates whether the user should be
// re-prompted for consent. If set to "auto" (default) the // re-prompted for consent. If set to "auto" (default) the
@ -141,11 +157,17 @@ type Token struct {
AccessToken string AccessToken string
RefreshToken string RefreshToken string
Expiry time.Time // If zero the token has no (known) expiry time. Expiry time.Time // If zero the token has no (known) expiry time.
Extra map[string]string // May be nil.
// Extra optionally contains extra metadata from the server
// when updating a token. The only current key that may be
// populated is "id_token". It may be nil and will be
// initialized as needed.
Extra map[string]string
} }
// Expired reports whether the token has expired or is invalid.
func (t *Token) Expired() bool { func (t *Token) Expired() bool {
if t.Expiry.IsZero() { if t.Expiry.IsZero() || t.AccessToken == "" {
return false return false
} }
return t.Expiry.Before(time.Now()) return t.Expiry.Before(time.Now())
@ -165,6 +187,9 @@ type Transport struct {
*Config *Config
*Token *Token
// mu guards modifying the token.
mu sync.Mutex
// Transport is the HTTP transport to use when making requests. // Transport is the HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil. // It will default to http.DefaultTransport if nil.
// (It should never be an oauth.Transport.) // (It should never be an oauth.Transport.)
@ -247,35 +272,48 @@ func (t *Transport) Exchange(code string) (*Token, error) {
// If the Token is invalid callers should expect HTTP-level errors, // If the Token is invalid callers should expect HTTP-level errors,
// as indicated by the Response's StatusCode. // as indicated by the Response's StatusCode.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
accessToken, err := t.getAccessToken()
if err != nil {
return nil, err
}
// To set the Authorization header, we must make a copy of the Request
// so that we don't modify the Request we were given.
// This is required by the specification of http.RoundTripper.
req = cloneRequest(req)
req.Header.Set("Authorization", "Bearer "+accessToken)
// Make the HTTP request.
return t.transport().RoundTrip(req)
}
func (t *Transport) getAccessToken() (string, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.Token == nil { if t.Token == nil {
if t.Config == nil { if t.Config == nil {
return nil, OAuthError{"RoundTrip", "no Config supplied"} return "", OAuthError{"RoundTrip", "no Config supplied"}
} }
if t.TokenCache == nil { if t.TokenCache == nil {
return nil, OAuthError{"RoundTrip", "no Token supplied"} return "", OAuthError{"RoundTrip", "no Token supplied"}
} }
var err error var err error
t.Token, err = t.TokenCache.Token() t.Token, err = t.TokenCache.Token()
if err != nil { if err != nil {
return nil, err return "", err
} }
} }
// Refresh the Token if it has expired. // Refresh the Token if it has expired.
if t.Expired() { if t.Expired() {
if err := t.Refresh(); err != nil { if err := t.Refresh(); err != nil {
return nil, err return "", err
} }
} }
if t.AccessToken == "" {
// To set the Authorization header, we must make a copy of the Request return "", errors.New("no access token obtained from refresh")
// so that we don't modify the Request we were given. }
// This is required by the specification of http.RoundTripper. return t.AccessToken, nil
req = cloneRequest(req)
req.Header.Set("Authorization", "Bearer "+t.AccessToken)
// Make the HTTP request.
return t.transport().RoundTrip(req)
} }
// cloneRequest returns a clone of the provided *http.Request. // cloneRequest returns a clone of the provided *http.Request.
@ -329,9 +367,16 @@ func (t *Transport) AuthenticateClient() error {
return t.updateToken(t.Token, url.Values{"grant_type": {"client_credentials"}}) return t.updateToken(t.Token, url.Values{"grant_type": {"client_credentials"}})
} }
// updateToken mutates both tok and v.
func (t *Transport) updateToken(tok *Token, v url.Values) error { func (t *Transport) updateToken(tok *Token, v url.Values) error {
v.Set("client_id", t.ClientId) v.Set("client_id", t.ClientId)
v.Set("client_secret", t.ClientSecret) // Note that we're not setting v's client_secret to t.ClientSecret, due
// to https://code.google.com/p/goauth2/issues/detail?id=31
// Reddit only accepts client_secret in Authorization header.
// Dropbox accepts either, but not both.
// The spec requires servers to always support the Authorization header,
// so that's all we use.
client := &http.Client{Transport: t.transport()} client := &http.Client{Transport: t.transport()}
req, err := http.NewRequest("POST", t.TokenURL, strings.NewReader(v.Encode())) req, err := http.NewRequest("POST", t.TokenURL, strings.NewReader(v.Encode()))
if err != nil { if err != nil {
@ -345,22 +390,23 @@ func (t *Transport) updateToken(tok *Token, v url.Values) error {
} }
defer r.Body.Close() defer r.Body.Close()
if r.StatusCode != 200 { if r.StatusCode != 200 {
return OAuthError{"updateToken", r.Status} return OAuthError{"updateToken", "Unexpected HTTP status " + r.Status}
} }
var b struct { var b struct {
Access string `json:"access_token"` Access string `json:"access_token"`
Refresh string `json:"refresh_token"` Refresh string `json:"refresh_token"`
ExpiresIn time.Duration `json:"expires_in"` ExpiresIn int64 `json:"expires_in"` // seconds
Id string `json:"id_token"` Id string `json:"id_token"`
} }
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return err
}
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content { switch content {
case "application/x-www-form-urlencoded", "text/plain": case "application/x-www-form-urlencoded", "text/plain":
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return err
}
vals, err := url.ParseQuery(string(body)) vals, err := url.ParseQuery(string(body))
if err != nil { if err != nil {
return err return err
@ -368,25 +414,25 @@ func (t *Transport) updateToken(tok *Token, v url.Values) error {
b.Access = vals.Get("access_token") b.Access = vals.Get("access_token")
b.Refresh = vals.Get("refresh_token") b.Refresh = vals.Get("refresh_token")
b.ExpiresIn, _ = time.ParseDuration(vals.Get("expires_in") + "s") b.ExpiresIn, _ = strconv.ParseInt(vals.Get("expires_in"), 10, 64)
b.Id = vals.Get("id_token") b.Id = vals.Get("id_token")
default: default:
if err = json.NewDecoder(r.Body).Decode(&b); err != nil { if err = json.Unmarshal(body, &b); err != nil {
return err return fmt.Errorf("got bad response from server: %q", body)
} }
// The JSON parser treats the unitless ExpiresIn like 'ns' instead of 's' as above, }
// so compensate here. if b.Access == "" {
b.ExpiresIn *= time.Second return errors.New("received empty access token from authorization server")
} }
tok.AccessToken = b.Access tok.AccessToken = b.Access
// Don't overwrite `RefreshToken` with an empty value // Don't overwrite `RefreshToken` with an empty value
if len(b.Refresh) > 0 { if b.Refresh != "" {
tok.RefreshToken = b.Refresh tok.RefreshToken = b.Refresh
} }
if b.ExpiresIn == 0 { if b.ExpiresIn == 0 {
tok.Expiry = time.Time{} tok.Expiry = time.Time{}
} else { } else {
tok.Expiry = time.Now().Add(b.ExpiresIn) tok.Expiry = time.Now().Add(time.Duration(b.ExpiresIn) * time.Second)
} }
if b.Id != "" { if b.Id != "" {
if tok.Extra == nil { if tok.Extra == nil {

View File

@ -23,7 +23,7 @@ var requests = []struct {
}{ }{
{ {
path: "/token", path: "/token",
query: "grant_type=authorization_code&code=c0d3&client_id=cl13nt1d&client_secret=s3cr3t", query: "grant_type=authorization_code&code=c0d3&client_id=cl13nt1d",
contenttype: "application/json", contenttype: "application/json",
auth: "Basic Y2wxM250MWQ6czNjcjN0", auth: "Basic Y2wxM250MWQ6czNjcjN0",
body: ` body: `
@ -38,7 +38,7 @@ var requests = []struct {
{path: "/secure", auth: "Bearer token1", body: "first payload"}, {path: "/secure", auth: "Bearer token1", body: "first payload"},
{ {
path: "/token", path: "/token",
query: "grant_type=refresh_token&refresh_token=refreshtoken1&client_id=cl13nt1d&client_secret=s3cr3t", query: "grant_type=refresh_token&refresh_token=refreshtoken1&client_id=cl13nt1d",
contenttype: "application/json", contenttype: "application/json",
auth: "Basic Y2wxM250MWQ6czNjcjN0", auth: "Basic Y2wxM250MWQ6czNjcjN0",
body: ` body: `
@ -53,7 +53,7 @@ var requests = []struct {
{path: "/secure", auth: "Bearer token2", body: "second payload"}, {path: "/secure", auth: "Bearer token2", body: "second payload"},
{ {
path: "/token", path: "/token",
query: "grant_type=refresh_token&refresh_token=refreshtoken2&client_id=cl13nt1d&client_secret=s3cr3t", query: "grant_type=refresh_token&refresh_token=refreshtoken2&client_id=cl13nt1d",
contenttype: "application/x-www-form-urlencoded", contenttype: "application/x-www-form-urlencoded",
body: "access_token=token3&refresh_token=refreshtoken3&id_token=idtoken3&expires_in=3600", body: "access_token=token3&refresh_token=refreshtoken3&id_token=idtoken3&expires_in=3600",
auth: "Basic Y2wxM250MWQ6czNjcjN0", auth: "Basic Y2wxM250MWQ6czNjcjN0",
@ -61,7 +61,7 @@ var requests = []struct {
{path: "/secure", auth: "Bearer token3", body: "third payload"}, {path: "/secure", auth: "Bearer token3", body: "third payload"},
{ {
path: "/token", path: "/token",
query: "grant_type=client_credentials&client_id=cl13nt1d&client_secret=s3cr3t", query: "grant_type=client_credentials&client_id=cl13nt1d",
contenttype: "application/json", contenttype: "application/json",
auth: "Basic Y2wxM250MWQ6czNjcjN0", auth: "Basic Y2wxM250MWQ6czNjcjN0",
body: ` body: `
@ -171,10 +171,15 @@ func checkToken(t *testing.T, tok *Token, access, refresh, id string) {
if g, w := tok.Extra["id_token"], id; g != w { if g, w := tok.Extra["id_token"], id; g != w {
t.Errorf("Extra['id_token'] = %q, want %q", g, w) t.Errorf("Extra['id_token'] = %q, want %q", g, w)
} }
if tok.Expiry.IsZero() {
t.Errorf("Expiry is zero; want ~1 hour")
} else {
exp := tok.Expiry.Sub(time.Now()) exp := tok.Expiry.Sub(time.Now())
if (time.Hour-time.Second) > exp || exp > time.Hour { const slop = 3 * time.Second // time moving during test
if (time.Hour-slop) > exp || exp > time.Hour {
t.Errorf("Expiry = %v, want ~1 hour", exp) t.Errorf("Expiry = %v, want ~1 hour", exp)
} }
}
} }
func checkBody(t *testing.T, r *http.Response, body string) { func checkBody(t *testing.T, r *http.Response, body string) {