importer/twitter: refactor credentials isolation

-store in the importer account the temporary credentials needed
during setup.
-store in the importer account the access token, and the access token
secret, and guard them. They are needed at the end of the setup (to get
the user's info) and all along a run.
-not share the same oauth.Client with the whole package. Create a new
one everytime it's needed in the setup. Create a new one at the beginning
of a run too, and it's ok to reuse it all along, since it's used
read-only.

Change-Id: I4823b35abf8a5170f63ca502026f6c8f98e2c1e7
This commit is contained in:
mpl 2014-04-25 19:47:35 +02:00
parent bd52105da6
commit 991b540c8b
2 changed files with 100 additions and 68 deletions

View File

@ -25,7 +25,6 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"camlistore.org/pkg/context" "camlistore.org/pkg/context"
@ -43,11 +42,15 @@ const (
userInfoAPIPath = "account/verify_credentials.json" userInfoAPIPath = "account/verify_credentials.json"
// Permanode attributes on account node: // Permanode attributes on account node:
acctAttrUserID = "twitterUserID" acctAttrUserID = "twitterUserID"
acctAttrScreenName = "twitterScreenName" acctAttrScreenName = "twitterScreenName"
acctAttrUserFirst = "twitterFirstName" acctAttrUserFirst = "twitterFirstName"
acctAttrUserLast = "twitterLastName" acctAttrUserLast = "twitterLastName"
acctAttrAccessToken = "oauthAccessToken" // TODO(mpl): refactor these 4 below into an oauth package when doing flickr.
acctAttrTempToken = "oauthTempToken"
acctAttrTempSecret = "oauthTempSecret"
acctAttrAccessToken = "oauthAccessToken"
acctAttrAccessTokenSecret = "oauthAccessTokenSecret"
tweetRequestLimit = 200 // max number of tweets we can get in a user_timeline request tweetRequestLimit = 200 // max number of tweets we can get in a user_timeline request
) )
@ -58,32 +61,7 @@ func init() {
var _ importer.ImporterSetupHTMLer = (*imp)(nil) var _ importer.ImporterSetupHTMLer = (*imp)(nil)
var ( type imp struct{}
oauthClient = &oauth.Client{
TemporaryCredentialRequestURI: temporaryCredentialRequestURL,
ResourceOwnerAuthorizationURI: resourceOwnerAuthorizationURL,
TokenRequestURI: tokenRequestURL,
}
)
type imp struct {
// cred are the various credentials passed around during OAuth. First the temporary
// ones, then the access token and secret.
mu sync.Mutex // guards credsVal
credsVal *oauth.Credentials
}
func (im *imp) creds() *oauth.Credentials {
im.mu.Lock()
defer im.mu.Unlock()
return im.credsVal
}
func (im *imp) setCreds(v *oauth.Credentials) {
im.mu.Lock()
defer im.mu.Unlock()
im.credsVal = v
}
func (im *imp) NeedsAPIKey() bool { return true } func (im *imp) NeedsAPIKey() bool { return true }
@ -129,13 +107,38 @@ func (im *imp) AccountSetupHTML(host *importer.Host) string {
// A run is our state for a given run of the importer. // A run is our state for a given run of the importer.
type run struct { type run struct {
*importer.RunContext *importer.RunContext
im *imp im *imp
oauthClient *oauth.Client // No need to guard, used read-only.
accessCreds *oauth.Credentials // No need to guard, used read-only.
} }
func (im *imp) Run(ctx *importer.RunContext) error { func (im *imp) Run(ctx *importer.RunContext) error {
clientId, secret, err := ctx.Credentials()
if err != nil {
return fmt.Errorf("no API credentials: %v", err)
}
accountNode := ctx.AccountNode()
accessToken := accountNode.Attr(acctAttrAccessToken)
accessSecret := accountNode.Attr(acctAttrAccessTokenSecret)
if accessToken == "" || accessSecret == "" {
return errors.New("access credentials not found")
}
r := &run{ r := &run{
RunContext: ctx, RunContext: ctx,
im: im, im: im,
oauthClient: &oauth.Client{
TemporaryCredentialRequestURI: temporaryCredentialRequestURL,
ResourceOwnerAuthorizationURI: resourceOwnerAuthorizationURL,
TokenRequestURI: tokenRequestURL,
Credentials: oauth.Credentials{
Token: clientId,
Secret: secret,
},
},
accessCreds: &oauth.Credentials{
Token: accessToken,
Secret: accessSecret,
},
} }
userID := ctx.AccountNode().Attr(acctAttrUserID) userID := ctx.AccountNode().Attr(acctAttrUserID)
if userID == "" { if userID == "" {
@ -165,7 +168,7 @@ func (r *run) importTweets(userID string) error {
} }
var resp []*tweetItem var resp []*tweetItem
if err := r.im.doAPI(r.Context, &resp, "statuses/user_timeline.json", if err := doAPI(r.Context, r.oauthClient, r.accessCreds, &resp, "statuses/user_timeline.json",
"user_id", userID, "user_id", userID,
"count", strconv.Itoa(tweetRequestLimit), "count", strconv.Itoa(tweetRequestLimit),
"max_id", maxId); err != nil { "max_id", maxId); err != nil {
@ -243,9 +246,9 @@ type userInfo struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
} }
func (im *imp) getUserInfo(ctx *context.Context) (userInfo, error) { func getUserInfo(ctx *context.Context, oauthClient *oauth.Client, creds *oauth.Credentials) (userInfo, error) {
var ui userInfo var ui userInfo
if err := im.doAPI(ctx, &ui, userInfoAPIPath); err != nil { if err := doAPI(ctx, oauthClient, creds, &ui, userInfoAPIPath); err != nil {
return ui, err return ui, err
} }
if ui.ID == "" { if ui.ID == "" {
@ -254,13 +257,16 @@ func (im *imp) getUserInfo(ctx *context.Context) (userInfo, error) {
return ui, nil return ui, nil
} }
func (im *imp) doAPI(ctx *context.Context, result interface{}, apiPath string, keyval ...string) error { func doAPI(ctx *context.Context, oauthClient *oauth.Client, creds *oauth.Credentials, result interface{}, apiPath string, keyval ...string) error {
if len(keyval)%2 == 1 { if len(keyval)%2 == 1 {
panic("Incorrect number of keyval arguments. must be even.") panic("Incorrect number of keyval arguments. must be even.")
} }
if im.creds() == nil { if creds == nil {
return fmt.Errorf("No authentication creds") return errors.New("No authentication creds")
}
if oauthClient == nil {
return errors.New("No authentication client")
} }
form := url.Values{} form := url.Values{}
@ -271,7 +277,7 @@ func (im *imp) doAPI(ctx *context.Context, result interface{}, apiPath string, k
} }
fullURL := apiURL + apiPath fullURL := apiURL + apiPath
res, err := im.doGet(ctx, fullURL, form) res, err := doGet(ctx, oauthClient, creds, fullURL, form)
if err != nil { if err != nil {
return err return err
} }
@ -282,11 +288,13 @@ func (im *imp) doAPI(ctx *context.Context, result interface{}, apiPath string, k
return nil return nil
} }
func (im *imp) doGet(ctx *context.Context, url string, form url.Values) (*http.Response, error) { func doGet(ctx *context.Context, oauthClient *oauth.Client, creds *oauth.Credentials, url string, form url.Values) (*http.Response, error) {
creds := im.creds()
if creds == nil { if creds == nil {
return nil, errors.New("No OAuth credentials. Not logged in?") return nil, errors.New("No OAuth credentials. Not logged in?")
} }
if creds == nil {
return nil, errors.New("No OAuth client.")
}
res, err := oauthClient.Get(ctx.HTTPClient(), creds, url, form) res, err := oauthClient.Get(ctx.HTTPClient(), creds, url, form)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error fetching %s: %v", url, err) return nil, fmt.Errorf("Error fetching %s: %v", url, err)
@ -297,31 +305,43 @@ func (im *imp) doGet(ctx *context.Context, url string, form url.Values) (*http.R
return res, nil return res, nil
} }
func auth(ctx *importer.SetupContext) (*oauth.Credentials, error) { func newOauthClient(ctx *importer.SetupContext) (*oauth.Client, error) {
clientId, secret, err := ctx.Credentials() clientId, secret, err := ctx.Credentials()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &oauth.Credentials{ return &oauth.Client{
Token: clientId, TemporaryCredentialRequestURI: temporaryCredentialRequestURL,
Secret: secret, ResourceOwnerAuthorizationURI: resourceOwnerAuthorizationURL,
TokenRequestURI: tokenRequestURL,
Credentials: oauth.Credentials{
Token: clientId,
Secret: secret,
},
}, nil }, nil
} }
func (im *imp) ServeSetup(w http.ResponseWriter, r *http.Request, ctx *importer.SetupContext) error { func (im *imp) ServeSetup(w http.ResponseWriter, r *http.Request, ctx *importer.SetupContext) error {
cred, err := auth(ctx) oauthClient, err := newOauthClient(ctx)
if err != nil { if err != nil {
err = fmt.Errorf("Error getting API credentials: %v", err) err = fmt.Errorf("error getting OAuth client: %v", err)
httputil.ServeError(w, r, err) httputil.ServeError(w, r, err)
return err return err
} }
oauthClient.Credentials = *cred
tempCred, err := oauthClient.RequestTemporaryCredentials(ctx.HTTPClient(), ctx.CallbackURL(), nil) tempCred, err := oauthClient.RequestTemporaryCredentials(ctx.HTTPClient(), ctx.CallbackURL(), nil)
if err != nil { if err != nil {
err = fmt.Errorf("Error getting temp cred: %v", err) err = fmt.Errorf("Error getting temp cred: %v", err)
httputil.ServeError(w, r, err) httputil.ServeError(w, r, err)
return err
}
if err := ctx.AccountNode.SetAttrs(
acctAttrTempToken, tempCred.Token,
acctAttrTempSecret, tempCred.Secret,
); err != nil {
err = fmt.Errorf("Error saving temp creds: %v", err)
httputil.ServeError(w, r, err)
return err
} }
im.setCreds(tempCred)
authURL := oauthClient.AuthorizationURL(tempCred, nil) authURL := oauthClient.AuthorizationURL(tempCred, nil)
http.Redirect(w, r, authURL, 302) http.Redirect(w, r, authURL, 302)
@ -329,19 +349,32 @@ func (im *imp) ServeSetup(w http.ResponseWriter, r *http.Request, ctx *importer.
} }
func (im *imp) ServeCallback(w http.ResponseWriter, r *http.Request, ctx *importer.SetupContext) { func (im *imp) ServeCallback(w http.ResponseWriter, r *http.Request, ctx *importer.SetupContext) {
creds := im.creds() tempToken := ctx.AccountNode.Attr(acctAttrTempToken)
if creds == nil { tempSecret := ctx.AccountNode.Attr(acctAttrTempSecret)
log.Printf("twitter: nil creds in callback") if tempToken == "" || tempSecret == "" {
httputil.BadRequestError(w, "nil creds in callback") log.Printf("twitter: no temp creds in callback")
httputil.BadRequestError(w, "no temp creds in callback")
return return
} }
if creds.Token != r.FormValue("oauth_token") { if tempToken != r.FormValue("oauth_token") {
log.Printf("unexpected oauth_token: got %v, want %v", r.FormValue("oauth_token"), creds.Token) log.Printf("unexpected oauth_token: got %v, want %v", r.FormValue("oauth_token"), tempToken)
httputil.BadRequestError(w, "unexpected oauth_token") httputil.BadRequestError(w, "unexpected oauth_token")
return return
} }
oauthClient, err := newOauthClient(ctx)
tokenCred, vals, err := oauthClient.RequestToken(ctx.Context.HTTPClient(), creds, r.FormValue("oauth_verifier")) if err != nil {
err = fmt.Errorf("error getting OAuth client: %v", err)
httputil.ServeError(w, r, err)
return
}
tokenCred, vals, err := oauthClient.RequestToken(
ctx.Context.HTTPClient(),
&oauth.Credentials{
Token: tempToken,
Secret: tempSecret,
},
r.FormValue("oauth_verifier"),
)
if err != nil { if err != nil {
httputil.ServeError(w, r, fmt.Errorf("Error getting request token: %v ", err)) httputil.ServeError(w, r, fmt.Errorf("Error getting request token: %v ", err))
return return
@ -351,9 +384,15 @@ func (im *imp) ServeCallback(w http.ResponseWriter, r *http.Request, ctx *import
httputil.ServeError(w, r, fmt.Errorf("Couldn't get user id: %v", err)) httputil.ServeError(w, r, fmt.Errorf("Couldn't get user id: %v", err))
return return
} }
im.setCreds(tokenCred) if err := ctx.AccountNode.SetAttrs(
acctAttrAccessToken, tokenCred.Token,
acctAttrAccessTokenSecret, tokenCred.Secret,
); err != nil {
httputil.ServeError(w, r, fmt.Errorf("Error setting token attributes: %v", err))
return
}
u, err := im.getUserInfo(ctx.Context) u, err := getUserInfo(ctx.Context, oauthClient, tokenCred)
if err != nil { if err != nil {
httputil.ServeError(w, r, fmt.Errorf("Couldn't get user info: %v", err)) httputil.ServeError(w, r, fmt.Errorf("Couldn't get user info: %v", err))
return return
@ -370,7 +409,6 @@ func (im *imp) ServeCallback(w http.ResponseWriter, r *http.Request, ctx *import
acctAttrUserFirst, firstName, acctAttrUserFirst, firstName,
acctAttrUserLast, lastName, acctAttrUserLast, lastName,
acctAttrScreenName, u.ScreenName, acctAttrScreenName, u.ScreenName,
acctAttrAccessToken, tokenCred.Token,
); err != nil { ); err != nil {
httputil.ServeError(w, r, fmt.Errorf("Error setting attribute: %v", err)) httputil.ServeError(w, r, fmt.Errorf("Error setting attribute: %v", err))
return return

View File

@ -30,19 +30,13 @@ import (
) )
func TestGetUserID(t *testing.T) { func TestGetUserID(t *testing.T) {
im := &imp{
credsVal: &oauth.Credentials{
Token: "foo",
Secret: "bar",
},
}
ctx := context.New() ctx := context.New()
ctx.SetHTTPClient(&http.Client{ ctx.SetHTTPClient(&http.Client{
Transport: newFakeTransport(map[string]func() *http.Response{ Transport: newFakeTransport(map[string]func() *http.Response{
apiURL + userInfoAPIPath: fileResponder(filepath.FromSlash("testdata/verify_credentials-res.json")), apiURL + userInfoAPIPath: fileResponder(filepath.FromSlash("testdata/verify_credentials-res.json")),
}), }),
}) })
inf, err := im.getUserInfo(ctx) inf, err := getUserInfo(ctx, &oauth.Client{}, &oauth.Credentials{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }